From cfaef8574ec795079872b24fce2d60b0fa5926d7 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Thu, 15 Dec 2022 08:19:04 -0800 Subject: [PATCH 001/381] Enabled {Conv2D, DepthwiseConv2D}+BiasAdd+_FusedHardSwish fusion for FP32/BF16 --- .../core/common_runtime/mkl_layout_pass.cc | 4 +- .../grappler/optimizers/mkl_remapper_test.cc | 136 +++++++++ .../core/grappler/optimizers/remapper.cc | 259 ++++++++++++++++-- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 6 + 4 files changed, 379 insertions(+), 26 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index d9d5540e1cf17d..5cd2968fb253bb 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -1780,6 +1780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { fused_ops == std::vector{"BiasAdd", "Relu"} || fused_ops == std::vector{"BiasAdd", "Relu6"} || fused_ops == std::vector{"BiasAdd", "Elu"} || + fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"} || fused_ops == std::vector{"BiasAdd", "Add"} || fused_ops == std::vector{"BiasAdd", "Add", "Relu"} || fused_ops == std::vector{"BiasAdd", "Add", "Relu6"} || @@ -1812,7 +1813,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return (fused_ops == std::vector{"BiasAdd"} || fused_ops == std::vector{"BiasAdd", "Relu"} || fused_ops == std::vector{"BiasAdd", "Relu6"} || - fused_ops == std::vector{"BiasAdd", "Elu"}); + fused_ops == std::vector{"BiasAdd", "Elu"} || + fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}); } // Rewrites input node to a new node specified by its matching rewrite info. diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index eba284ec86792b..91dabbca011efb 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -925,6 +925,142 @@ class MklRemapperSwishTest : public GrapplerTest { TEST_F(MklRemapperSwishTest, F32) { RunTest(); } TEST_F(MklRemapperSwishTest, BF16) { RunTest(); } +class FusedConvBiasAddAndHardSwishTest : public GrapplerTest { + public: + const string kAddOp = "Add"; + const string kAddV2Op = "AddV2"; + + template + void RunTest(const string& add_op, const bool is_depthwise) { + using ::tensorflow::ops::Placeholder; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = ops::Placeholder::Shape({is_depthwise ? 384 : 128}); + + auto input = Placeholder(s.WithOpName("input"), DType, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DType, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DType, bias_shape); + const DataType const_dt = with_cast_op ? DT_FLOAT : DType; + typedef typename EnumToDataType::Type DT; + Tensor three(const_dt, TensorShape({})); + Tensor one_sixth(const_dt, TensorShape({})); + three.scalar
()() = static_cast
(3.0f); + one_sixth.scalar
()() = static_cast
(1.0f / 6.0f); + auto three_op = + with_cast_op + ? ops::Cast(s.WithOpName("three"), Input::Initializer(three), + DT_BFLOAT16) + : ops::Const(s.WithOpName("three"), Input::Initializer(three)); + auto one_sixth_op = + with_cast_op ? ops::Cast(s.WithOpName("one_sixth"), + Input::Initializer(one_sixth), DT_BFLOAT16) + : ops::Const(s.WithOpName("one_sixth"), + Input::Initializer(one_sixth)); + + std::vector strides = {1, 1, 1, 1}; + Output conv; + if (is_depthwise) { + conv = ops::DepthwiseConv2dNative( + s.WithOpName("conv"), input, filter, strides, "SAME", + ops::DepthwiseConv2dNative::Attrs().DataFormat("NHWC")); + } else { + conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME", + ops::Conv2D::Attrs().DataFormat("NHWC")); + } + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias, + ops::BiasAdd::Attrs().DataFormat("NHWC")); + + Output add; + if (add_op == kAddV2Op) { + add = ops::AddV2(s.WithOpName(add_op), three_op, bias_add); + } else { + add = ops::Add(s.WithOpName(add_op), three_op, bias_add); + } + + auto relu6 = ops::Relu6(s.WithOpName("relu_6"), add); + auto mul_one_sixth = + ops::Mul(s.WithOpName("mul_one_sixth"), one_sixth_op, bias_add); + auto mul_output = ops::Mul(s.WithOpName("output"), mul_one_sixth, relu6); + + auto fetch = ops::Identity(s.WithOpName("fetch"), mul_output); + + auto input_tensor = GenerateTensorWithSetRandom( + TensorShape(input_shape.shape_.dim_sizes())); + auto filter_tensor = GenerateTensorWithSetRandom( + TensorShape(filter_shape.shape_.dim_sizes())); + auto bias_tensor = GenerateTensorWithSetRandom( + TensorShape(bias_shape.shape_.dim_sizes())); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_tensor}, + {"filter", filter_tensor}, + {"bias", bias_tensor}}; + + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "output") { + if (is_depthwise) { + EXPECT_EQ("_FusedDepthwiseConv2dNative", node.op()); + } else { + EXPECT_EQ("_FusedConv2D", node.op()); + } + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + EXPECT_EQ("bias", node.input(2)); + EXPECT_EQ(1, node.attr().at("num_args").i()); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + EXPECT_EQ(2, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("_FusedHardSwish", fused_ops[1]); + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectClose(tensors_expected[0], tensors[0], 1e-6); + } +}; + +TEST_F(FusedConvBiasAddAndHardSwishTest, Float32Conv2DBiasHardSwish) { + RunTest("AddV2", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Float32DWConv2DBiasHardSwish) { + RunTest("AddV2", true); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16Conv2DBiasHardSwish) { + RunTest("Add", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16DWConv2DBiasHardSwish) { + RunTest("Add", true); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16Conv2DBiasHardSwishWithCast) { + RunTest("Add", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, + Bfloat16DWConv2DBiasHardSwishWithCast) { + RunTest("Add", true); +} + } // namespace grappler } // namespace tensorflow #endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 082f15f243be85..8b252650a02ea8 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -951,7 +951,8 @@ bool FindConv2DWithBatchNormAndActivation( bool FindContractionWithBiasInPort(const RemapperContext& ctx, const utils::MutableNodeView& add_node_view, const NodeDef& add_node_def, int port_id, - ContractionWithBiasAdd* base) { + ContractionWithBiasAdd* base, + const int allowed_fanouts = 1) { // Input to AddN must match ContractionWithBiasAdd pattern. if (add_node_view.NumRegularFanins() < port_id + 1) return false; const auto& bias_add_node_view = @@ -962,7 +963,7 @@ bool FindContractionWithBiasInPort(const RemapperContext& ctx, if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base, /*check_device_compatible=*/false)) return false; - if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) || + if (bias_add_node_view->GetRegularFanout(0).size() > allowed_fanouts || !HaveSameDataType(&add_node_def, bias_add_node_def) || IsInPreserveSet(ctx, bias_add_node_def)) return false; @@ -1121,6 +1122,37 @@ bool FindContractionWithBiasAndAddActivation( return true; } +inline bool VerifySingleConstant(utils::MutableNodeView* node_view, + const float value) { + NodeDef* node_def = node_view->node(); + Tensor const_tensor; + if (node_def != nullptr && node_def->op() == "Const" && + const_tensor.FromProto(node_def->attr().at("value").tensor())) { + if (const_tensor.NumElements() == 1) { + DataType dtype = const_tensor.dtype(); + float const_value; + if (dtype == DT_FLOAT) { + const_value = const_tensor.flat()(0); + } else if (dtype == DT_BFLOAT16) { + const_value = const_tensor.flat()(0); + } else if (dtype == DT_HALF) { + const_value = const_tensor.flat()(0); + } else { + return false; + } + if (std::abs(const_value - value) > 1e-2) { + return false; + } else { + return true; + } + } else { + return false; + } + } else { + return false; + } +} + inline bool VerifyConstants(RemapperContext* ctx, std::map* nodes_map, std::map* values_map) { @@ -1128,29 +1160,7 @@ inline bool VerifyConstants(RemapperContext* ctx, for (auto it = values_map->begin(); it != values_map->end(); ++it) { int node_idx = nodes_map->at(it->first); MutableNodeView* node_view = ctx->graph_view.GetNode(node_idx); - NodeDef* node_def = node_view->node(); - Tensor const_tensor; - if (node_def != nullptr && node_def->op() == "Const" && - const_tensor.FromProto(node_def->attr().at("value").tensor())) { - if (const_tensor.NumElements() == 1) { - DataType dtype = const_tensor.dtype(); - float const_value; - if (dtype == DT_FLOAT) { - const_value = const_tensor.flat()(0); - } else if (dtype == DT_BFLOAT16) { - const_value = const_tensor.flat()(0); - } else if (dtype == DT_HALF) { - const_value = const_tensor.flat()(0); - } else { - return false; - } - if (std::abs(const_value - it->second) > 1e-2) return false; - } else { - return false; - } - } else { - return false; - } + if (!VerifySingleConstant(node_view, it->second)) return false; } return true; } @@ -2067,6 +2077,155 @@ bool FindTensorToHashBucket(const RemapperContext& ctx, int node_index, return true; } +// clang-format off +// HardSwish pattern +// input Const (value: 3) +// | \ / +// | Add or AddV2 +// | | +// | Relu6 +// | / +// | / +// Const (value: 0.1666) | / +// \ | / +// Mul / +// \ / +// Mul +// clang-format on +bool FindHardSwish(RemapperContext& ctx, int node_index, + std::map* matched_nodes_map, + std::set* remove_node_indices) { + if (!IsMKLEnabled()) return false; + + using utils::MatchingDirection; + using utils::NodeStatus; + // clang-format off + utils::OpTypePattern pattern {"Mul", "output", NodeStatus::kReplace, + { + {"Mul", "mul_one_sixth", NodeStatus::kRemove, + { + {"Const|Cast", "one_sixth", NodeStatus::kRemain}, + {"*", "input", NodeStatus::kRemain} + } + }, + {"Relu6", "relu6", NodeStatus::kRemove, + { + {"Add|AddV2", "add", NodeStatus::kRemove, + { + {"*", "input", NodeStatus::kRemain}, + {"Const|Cast", "three", NodeStatus::kRemain} + + } + } + } + }, + } + }; + // clang-format on + bool found_match = false; + utils::SubGraphMatcher graph_matcher( + &(ctx.graph_view)); + + matched_nodes_map->clear(); + remove_node_indices->clear(); + + found_match = graph_matcher.GetMatchedNodes( + pattern, ctx.nodes_to_preserve, ctx.graph_view.GetNode(node_index), + matched_nodes_map, remove_node_indices); + + if (found_match) { + // Check if the values of Const nodes are as expected + std::map values_map = {{"three", 3.0}, + {"one_sixth", 0.16666}}; + for (auto it = values_map.begin(); it != values_map.end(); ++it) { + auto* const_or_cast_view = + ctx.graph_view.GetNode(matched_nodes_map->at(it->first)); + if (const_or_cast_view->node()->op() == "Const") { + if (!VerifySingleConstant(const_or_cast_view, it->second)) return false; + } else { + // There is Cast op after Const + auto* const_node_view = + const_or_cast_view->GetRegularFanin(0).node_view(); + if (const_node_view->node()->op() != "Const") return false; + if (!VerifySingleConstant(const_node_view, it->second)) return false; + } + } + } + + return found_match; +} + +// clang-format off +// Contraction + BiasAdd + _FusedHardSwish activation +// input filter +// \ / +// Contraction bias +// | / +// BiasAdd +// | +// _FusedHardSwish +// clang-format on +bool FindContractionWithBiasAddAndHardSwish( + RemapperContext& ctx, int node_index, + std::map* matched_nodes_map, + std::set* remove_node_indices) { + if (!IsMKLEnabled()) return false; + + const auto* node_view = ctx.graph_view.GetNode(node_index); + if (HasControlFaninOrFanout(*node_view)) return false; + + const auto* node_def = node_view->node(); + + // Check if HardSwish pattern is available + if (!FindHardSwish(ctx, node_index, matched_nodes_map, remove_node_indices)) + return false; + // Get handle of Add|AddV2 op that is the root of HardSwish pattern. + const auto* add_node_view = + ctx.graph_view.GetNode(matched_nodes_map->at("add")); + const auto* add_node_def = add_node_view->node(); + + // Check if ContractionWithBias pattern is feeding HardSwish + ContractionWithBiasAdd base; + int port_id = 0; + // BiasAdd node is expected to have 2 fanouts feeding the HardSwish pattern. + if (!FindContractionWithBiasInPort(ctx, *add_node_view, *add_node_def, + port_id, &base, /*allowed_fanouts*/ 2)) { + port_id = 1; + if (!FindContractionWithBiasInPort(ctx, *add_node_view, *add_node_def, + port_id, &base, /*allowed_fanouts*/ 2)) { + VLOG(2) << "Contraction + BiasAdd pattern was not found although" + << " HardSwish pattern was found, so fusion failed."; + return false; + } + } + + // Get the BiasAdd node + const auto* bias_node_def = ctx.graph_view.GetNode(base.bias_add)->node(); + if (!HaveSameDataType(add_node_def, bias_node_def)) return false; + + // Get the contraction node + const auto* contraction_node_view = ctx.graph_view.GetNode(base.contraction); + const auto* contraction_node_def = contraction_node_view->node(); + + // Currently only Conv2D and DepthwiseConv2D contraction ops are supported + if (!IsConv2D(*contraction_node_def) && + !IsDepthwiseConv2dNative(*contraction_node_def)) + return false; + + // Check if contraction is compatible with CPU + if (!IsCpuCompatibleConv2D(ctx, contraction_node_def) && + !IsCpuCompatibleDepthwiseConv2dNative(contraction_node_def)) + return false; + + // We found a {Conv2D, DepthwiseConv2D}+BiasAdd+_FusedHardSwish pattern. + matched_nodes_map->insert({"contraction", base.contraction}); + matched_nodes_map->insert({"bias_add", base.bias_add}); + + remove_node_indices->insert(base.contraction); + remove_node_indices->insert(base.bias_add); + return true; +} + bool FindFusedBatchMatMul(RemapperContext* ctx, int node_index, std::map* matched_nodes_map, std::set* remove_node_indices) { @@ -2695,6 +2854,47 @@ Status AddFusedContractionNode( return OkStatus(); } +Status FuseContractionWithBiasAddAndHardSwish( + RemapperContext* ctx, std::map* matched_nodes_map, + std::set* remove_node_indices, std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { + auto* output_node = + ctx->graph_view.GetNode(matched_nodes_map->at("output"))->node(); + auto* contraction_node = + ctx->graph_view.GetNode(matched_nodes_map->at("contraction"))->node(); + auto* bias_add_node = + ctx->graph_view.GetNode(matched_nodes_map->at("bias_add"))->node(); + + bool is_conv2d = IsConv2D(*contraction_node); + + NodeDef fused_node; + fused_node.set_name(output_node->name()); + fused_node.set_op(is_conv2d ? kFusedConv2D : kFusedDepthwiseConv2dNative); + fused_node.set_device(contraction_node->device()); + fused_node.add_input(contraction_node->input(0)); + fused_node.add_input(contraction_node->input(1)); + fused_node.add_input(bias_add_node->input(1)); + + if (is_conv2d) { + CopyConv2DAttributes(*contraction_node, &fused_node); + } else { + CopyDepthwiseConv2dNativeAttributes(*contraction_node, &fused_node); + } + SetFusedOpAttributes(&fused_node, {"BiasAdd", "_FusedHardSwish"}); + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_node), &status); + TF_RETURN_IF_ERROR(status); + TF_RETURN_IF_ERROR(mutation->Apply()); + (*invalidated_nodes)[matched_nodes_map->at("output")] = true; + + for (const auto& node_idx : *remove_node_indices) { + (*nodes_to_delete)[node_idx] = true; + } + return OkStatus(); +} + Status AddFusedMatMulBiasAddAndGelu( RemapperContext* ctx, const std::map& matched_nodes_map, const std::set& remove_node_indices, @@ -3669,6 +3869,15 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, std::map matched_nodes_map; std::set remove_node_indices; + // Remap {Conv2D|DepthwiseConv2D} + BiasAdd + HardSwish subgraph + if (FindContractionWithBiasAddAndHardSwish(ctx, i, &matched_nodes_map, + &remove_node_indices)) { + TF_RETURN_IF_ERROR(FuseContractionWithBiasAddAndHardSwish( + &ctx, &matched_nodes_map, &remove_node_indices, &invalidated_nodes, + &nodes_to_delete)); + continue; + } + // Softplus + Tanh + Mul to Mish conversion matched_nodes_map.clear(); remove_node_indices.clear(); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index e7a14ee9706c5b..d543ab4f827846 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -1440,6 +1440,9 @@ class MklFusedConvOp OP_REQUIRES(context, num_args == 1, errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); + } else if (fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, dnnl::algorithm::eltwise_hardswish); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); @@ -1591,6 +1594,9 @@ class MklFusedDepthwiseConvOp } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); + } else if (fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, dnnl::algorithm::eltwise_hardswish); } else { OP_REQUIRES(context, false, errors::Unimplemented("Fusion is not implemented: [", From f964d4e025d352ccdad10714f9b1b58a3af292a7 Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Tue, 31 Oct 2023 13:02:12 -0700 Subject: [PATCH 002/381] Disable oneDNN rewrite for small Matmul. --- .../xla/xla/service/cpu/cpu_compiler.cc | 3 +-- .../xla/xla/service/cpu/onednn_rewriter.cc | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 0c8dc69ddd8562..4a7d95b3cba7b1 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -699,8 +699,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) // AOT compiled code runs in single thread. if (!is_aot_compile) { - // Temporarily disabling oneDNN rewriter because it causes JAX regression. - // pipeline.AddPass(); + pipeline.AddPass(); } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/onednn_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_rewriter.cc index 45b38c59d76a61..8f6e092e8cf067 100644 --- a/third_party/xla/xla/service/cpu/onednn_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_rewriter.cc @@ -17,13 +17,13 @@ limitations under the License. #include "xla/service/cpu/onednn_rewriter.h" +#include "tsl/platform/cpu_info.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/pattern_matcher.h" #include "xla/status_macros.h" -#include "tsl/platform/cpu_info.h" namespace xla { namespace cpu { @@ -46,8 +46,8 @@ Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { } bool IsSupportedType(xla::PrimitiveType dtype) { - using tsl::port::TestCPUFeature; using tsl::port::CPUFeature; + using tsl::port::TestCPUFeature; switch (dtype) { case F32: return true; @@ -118,6 +118,21 @@ class OneDnnRewriterVisitor : public DfsHloRewriteVisitor { (dot_dim_numbers.rhs_contracting_dimensions(0) == rhs_shape.rank() - 2); if (!should_rewrite) return OkStatus(); + // OneDNN matmul has scratch allocation and copy overheads. The overheads + // can be amortized if there is sufficient MAC (multiply-accumulate) + // operations. We don't rewrite for small cases (determined empirically). + // TODO(intel-tf): Relax the condition when more optimizations in oneDNN + // matmul is achieved. + auto rank = lhs_shape.rank(); + auto rhs_dims = rhs_shape.dimensions(); + int64_t num_mac_ops = ShapeUtil::ElementsIn(lhs_shape) * rhs_dims.back(); + if (rank == 2) { + should_rewrite &= num_mac_ops >= (1 << 23); + } else { + should_rewrite &= num_mac_ops >= (1 << 18); + } + if (!should_rewrite) return OkStatus(); + HloInstruction* matmul_call = dot_instr->AddInstruction(HloInstruction::CreateCustomCall( output_shape, From 1a8f12ee0387557bfbcd2a521af45003e403c56c Mon Sep 17 00:00:00 2001 From: sushreebarsa <84765720+sushreebarsa@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:54:52 +0530 Subject: [PATCH 003/381] Updated the description of max_delta parameter The documentation currently states that the max_delta parameter controls the maximum absolute change in brightness. However, the max_delta parameter actually controls the maximum relative change in brightness. This means that the actual change in brightness will depend on the range of values in the input image. So, I have updated the description for the max_delta parameter. Could you please have a look and do the needful. Thank you! --- tensorflow/python/ops/image_ops_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index fd8995f6250a81..acc16d5bf2b69c 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2027,10 +2027,11 @@ def random_brightness(image, max_delta, seed=None): with `tf.image.random_*` ops, `tf.image.stateless_random_*` ops guarantee the same results given the same seed independent of how many times the function is called, and independent of global seed settings (e.g. tf.random.set_seed). + Args: image: An image or images to adjust. - max_delta: float, must be non-negative. + max_delta: float, must be non-negative. The max_delta parameter controls the maximum relative change in brightness. This means that the actual change in brightness will depend on the range of values in the input image. seed: A Python integer. Used to create a random seed. See `tf.compat.v1.set_random_seed` for behavior. From 0116f163307a8f17b61eb0c39277488c10351772 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Wed, 15 Nov 2023 10:11:22 +0000 Subject: [PATCH 004/381] Prevent OSV scanner running in forks The OSV scanner just fails anyway, so stop it running in forks to stop the failures being logged and emailed. --- .github/workflows/osv-scanner-scheduled.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index bb39d60168e08d..fb7366768436c5 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -27,6 +27,7 @@ permissions: jobs: scan-scheduled: + if: github.repository == 'tensorflow/tensorflow' uses: "google/osv-scanner/.github/workflows/osv-scanner-reusable.yml@main" with: scan-args: |- @@ -36,4 +37,4 @@ jobs: --lockfile=requirements.txt:./requirements_lock_3_12.txt --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.requirements.txt --lockfile=requirements.txt:./ci/official/containers/linux_arm64/jax.requirements.txt - --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.usertools/test.requirements.txt \ No newline at end of file + --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.usertools/test.requirements.txt From 7507a9d9228624ba93f9b3a20ee51dce282679e0 Mon Sep 17 00:00:00 2001 From: Raunak Date: Sat, 18 Nov 2023 22:37:28 -0800 Subject: [PATCH 005/381] Implement portserver on Windows --- .../ci_build/windows/bazel/cpu_win_test.sh | 256 ++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh diff --git a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh new file mode 100644 index 00000000000000..544b1dc2133717 --- /dev/null +++ b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh @@ -0,0 +1,256 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is a CI script for invoking 'bazel test ... ...' +# It assumes the standard setup on tensorflow Jenkins Windows machines. +# Update the flags/variables below to make it work on your local system. + +# REQUIREMENTS: +# * All installed in standard locations: +# - JDK8, and JAVA_HOME set. +# - Microsoft Visual Studio 2015 Community Edition +# - Msys2 +# - Python 3.x (with pip, setuptools, venv) +# * Bazel Windows executable copied as "bazel.exe" and included in PATH. + + +# All commands should be visible (-x). +set -x + +POSITIONAL_ARGS=() +XBF_ARGS="" +XTF_ARGS="" +while [[ $# -gt 0 ]]; do + case "$1" in + --extra_build_flags) + XBF_ARGS="$2" + shift # past argument + shift # past value + ;; + --extra_test_flags) + XTF_ARGS="$2" + shift # past argument + shift # past value + ;; + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + +# Bazelisk (renamed as bazel) is kept in C:\Tools +export PATH=/c/ProgramData/chocolatey/bin:/c/Tools/bazel:/c/Program\ Files/Git:/c/Program\ \ +Files/Git/cmd:/c/msys64:/c/msys64/usr/bin:/c/Windows/system32:/c/Windows:/c/Windows/System32/Wbem + +# Environment variables to be set by Jenkins before calling this script + +export PYTHON_VERSION=${PYTHON_VERSION:-"310"} +export TF_PYTHON_VERSION=${PYTHON_VERSION:0:1}.${PYTHON_VERSION:1} +# keep the tensorflow git repo clone under here as tensorflow subdir +MYTFWS_ROOT=${WORKSPACE:-"C:/Users/mlp_admin"} +MYTFWS_ROOT=`cygpath -m $MYTFWS_ROOT` +export MYTFWS_ROOT="$MYTFWS_ROOT" +export MYTFWS_NAME="tensorflow" +export MYTFWS="${MYTFWS_ROOT}/${MYTFWS_NAME}" +export MYTFWS_ARTIFACT="${MYTFWS_ROOT}/artifact" + + +# Import General Test Target +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Environment variables specific to the system where this job is running, are to +# be set by a script for the specific system. This needs to be set here by sourcing a file. + +export TMP=${TMP:-"${MYTFWS_ROOT}/tmp"} +export TEMP="$TMP" +export TMPDIR=${TMPDIR:-"${MYTFWS}-build"} # used internally by TF build +export TEST_TARGET=${TEST_TARGET:-"${DEFAULT_BAZEL_TARGETS}"} +export MSYS_LOCATION='C:/msys64' +export GIT_LOCATION='C:/Program Files/Git' +export JAVA_LOCATION='C:/Program Files/Eclipse Adoptium/jdk-11.0.14.101-hotspot' +export VS_LOCATION='C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools' +export NATIVE_PYTHON_LOCATION="C:/Python${PYTHON_VERSION}" +export PORTSERVER_LOCATION='C:/Program Files/python_portpicker/src/portserver.py' + + +echo "*** *** hostname is $(hostname) *** ***" +which bazel +which git +[[ -e "$NATIVE_PYTHON_LOCATION/python.exe" ]] || \ +{ echo "Specified Python path is incorrect: $NATIVE_PYTHON_LOCATION"; exit 1;} +[[ -e "$NATIVE_PYTHON_LOCATION/Scripts/pip.exe" ]] || \ +{ echo "Specified Python path has no pip: $NATIVE_PYTHON_LOCATION"; exit 1;} +[[ -e "$NATIVE_PYTHON_LOCATION/Lib/venv" ]] || \ +{ echo "Specified Python path has no venv: $NATIVE_PYTHON_LOCATION"; exit 1;} + +$NATIVE_PYTHON_LOCATION/python.exe -m pip list + +# =========================== Start of actual script ========================= +# This script sets necessary environment variables and runs TF-Windows build & unit tests +# We also assume a few Software components are also installed in the machine: MS VC++, +# MINGW SYS64, Python 3.x, JAVA, Git, Bazelisk etc. + +# Asuumptions +# 1) TF repo cloned into to %WORKSPACE%\tensorflow (aka %TF_LOCATION%) +# 2) Bazelisk is installed in "C:\Tools\Bazel" +# 3) The following jobs-specific env vars will be exported by the caller +# WORKSPACE (ex. C:\Jenkins\workspace\tensorflow-eigen-test-win) +# PYTHON_VERSION (ex. 38) +# PIP_MODULES (if set will contain any additional pip packages) +# 4) System-specific env variables for the location of different software +# components needed for building. + +# Create Python virtual env +cd ${MYTFWS_ROOT} +export PYTHON_DIRECTORY="${MYTFWS_ROOT}"/venv_py${PYTHON_VERSION} +"${NATIVE_PYTHON_LOCATION}"/python.exe -mvenv --clear "${PYTHON_DIRECTORY}" + +#activate virtual env +source "${PYTHON_DIRECTORY}"/Scripts/activate + +which python +python --version + +# Install pip modules specs from tensorflow/tools/ci_build/release/requirements_common.txt +python -m pip install -r $MYTFWS/tensorflow/tools/ci_build/release/requirements_common.txt + +# set up other Variables required by Bazel. +export PYTHON_BIN_PATH="${PYTHON_DIRECTORY}"/Scripts/python.exe +export PYTHON_LIB_PATH="${PYTHON_DIRECTORY}"/Lib/site-packages +export BAZEL_VS=${VS_LOCATION} +export BAZEL_VC=${VS_LOCATION}/VC +export JAVA_HOME=${JAVA_LOCATION} +export BAZEL_SH="${MSYS_LOCATION}"/usr/bin/bash.exe + +cd ${MYTFWS_ROOT} +mkdir -p "$TMP" +mv summary.log summary.log.bak +mv test_failures.log test_failures.log.bak +mv test_run.log test_run.log.bak +rm -rf ${MYTFWS_ARTIFACT} +mkdir -p ${MYTFWS_ARTIFACT} + +cd $MYTFWS + +# All commands shall pass +set -e + +# Setting up the environment variables Bazel and ./configure needs +source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \ + || { echo "Failed to source common_env.sh" >&2; exit 1; } + +# load bazel_test_lib.sh +source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ + || { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; } + +# Recreate an empty bazelrc file under source root +export TMP_BAZELRC=.tmp.bazelrc +rm -f "${TMP_BAZELRC}" +touch "${TMP_BAZELRC}" + +function cleanup { + # Remove all options in .tmp.bazelrc + echo "" > "${TMP_BAZELRC}" +} +trap cleanup EXIT + +# Enable short object file path to avoid long path issues on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi + +run_configure_for_cpu_build + +# Unset so the script continues even if commands fail, needed to correctly process the logs +set +e + +# start the port server before testing so that each invocation of +# portpicker will defer to the single instance of portserver +# Define the batch script content +BATCH_SCRIPT_START=" +@echo off +set SCRIPT_PATH="${PORTSERVER_LOCATION}" +echo Starting the server... +start \"PORTSERVER\" \"%PYTHON_BIN_PATH%\" \"%SCRIPT_PATH%\" +echo Server started. +" +# Save the batch script content to a temporary batch file +BATCH_SCRIPT_FILE="temp_script.bat" +echo "$BATCH_SCRIPT_START" > "$BATCH_SCRIPT_FILE" + +# Run the batch script +cmd.exe /C "$BATCH_SCRIPT_FILE" + +# NUMBER_OF_PROCESSORS is predefined on Windows +N_JOBS="${NUMBER_OF_PROCESSORS}" +bazel --windows_enable_symlinks test \ + --action_env=TEMP=${TMP} --action_env=TMP=${TMP} ${XTF_ARGS} \ + --experimental_cc_shared_library --enable_runfiles --nodistinct_host_configuration \ + --build_tag_filters=-no_pip,-no_windows,-no_oss,-gpu,-tpu \ + --test_tag_filters=-no_windows,-no_oss,-gpu,-tpu \ + --build_tests_only --config=monolithic \ + --dynamic_mode=off --config=xla --config=opt \ + --build_tests_only -k \ + --test_env=PORTSERVER_ADDRESS=@unittest-portserver \ + --repo_env=TF_PYTHON_VERSION=${TF_PYTHON_VERSION} \ + --test_size_filters=small,medium --jobs="${N_JOBS}" --test_timeout=300,450,1200,3600 \ + --flaky_test_attempts=3 --verbose_failures \ + ${POSITIONAL_ARGS[@]} \ + -- ${TEST_TARGET} \ + > run.log 2>&1 + +build_ret_val=$? # Store the ret value + +BATCH_SCRIPT_STOP=" +echo Killing the server... +taskkill /FI \"WindowTitle eq PORTSERVER*\" /F /T +echo Server killed. +" +BATCH_SCRIPT_FILEl="temp_script.bat" +echo "$BATCH_SCRIPT_STOP" > "$BATCH_SCRIPT_FILEl" +cmd.exe /C "$BATCH_SCRIPT_FILEl" + +# Removing the temporary batch script +rm -f "$BATCH_SCRIPT_FILE" +rm -f "$BATCH_SCRIPT_FILEl" + +# process results +cd $MYTFWS_ROOT + +# Check to make sure the log was created +[ ! -f "${MYTFWS}"/run.log ] && exit 1 + +# handle logs for unit test +cd ${MYTFWS_ARTIFACT} +cp "${MYTFWS}"/run.log ./test_run.log + +fgrep "FAILED: Build did NOT complete" test_run.log > summary.log +fgrep "Executed" test_run.log >> summary.log + +[ $build_ret_val -eq 0 ] && exit 0 + +echo "FAILED TESTS:" > test_failures.log +fgrep "FAILED" test_run.log | grep " ms)" | sed -e 's/^.*\] //' -e 's/ .*$//' | sort | \ +uniq >> test_failures.log +echo >> test_failures.log +echo "SKIPPED TESTS:" >> test_failures.log +fgrep "SKIPPED" test_run.log | grep -v "listed below:" | sed -e 's/^.*\] //' | sort | \ +uniq >> test_failures.log + +exit 1 \ No newline at end of file From 6ad02267b9d0534f9979b345f9887a11986a5730 Mon Sep 17 00:00:00 2001 From: mraunak <83710963+mraunak@users.noreply.github.com> Date: Sat, 18 Nov 2023 22:43:53 -0800 Subject: [PATCH 006/381] Update cpu_win_test.sh --- tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh index 544b1dc2133717..484680aadb81b0 100644 --- a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh +++ b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh @@ -253,4 +253,4 @@ echo "SKIPPED TESTS:" >> test_failures.log fgrep "SKIPPED" test_run.log | grep -v "listed below:" | sed -e 's/^.*\] //' | sort | \ uniq >> test_failures.log -exit 1 \ No newline at end of file +exit 1 From 116f328474860517b9856cdb920c05f711221cfc Mon Sep 17 00:00:00 2001 From: Malik Shahzad Muzaffar Date: Wed, 22 Nov 2023 11:03:44 +0100 Subject: [PATCH 007/381] add hh_vsx deps for highwayhash_dynamic --- third_party/highwayhash/highwayhash.BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/highwayhash/highwayhash.BUILD b/third_party/highwayhash/highwayhash.BUILD index 76f0c962ef8b8a..c24c987a276acd 100644 --- a/third_party/highwayhash/highwayhash.BUILD +++ b/third_party/highwayhash/highwayhash.BUILD @@ -286,6 +286,7 @@ cc_library( ":hh_portable", ":hh_types", ] + select({ + ":cpu_ppc": [":hh_vsx"], ":cpu_aarch64": [":hh_neon"], "//conditions:default": [ ":hh_avx2", From ee2a0ce3396b13424b18759f98a00fdafa980016 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Wed, 22 Nov 2023 10:04:46 -0800 Subject: [PATCH 008/381] Migrate references and remove legacy target tpu:tpu_estimator. PiperOrigin-RevId: 584658931 --- tensorflow/python/BUILD | 1 - tensorflow/python/tpu/BUILD | 37 ------------------- .../python/tpu/async_checkpoint_test.py | 4 +- tensorflow/python/tpu/error_handling.py | 19 ---------- tensorflow/python/tpu/tpu_config.py | 19 ---------- tensorflow/python/tpu/tpu_context.py | 19 ---------- tensorflow/python/tpu/tpu_estimator.py | 29 --------------- tensorflow/python/tpu/util.py | 19 ---------- 8 files changed, 2 insertions(+), 145 deletions(-) delete mode 100644 tensorflow/python/tpu/error_handling.py delete mode 100644 tensorflow/python/tpu/tpu_config.py delete mode 100644 tensorflow/python/tpu/tpu_context.py delete mode 100644 tensorflow/python/tpu/tpu_estimator.py delete mode 100644 tensorflow/python/tpu/util.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cd256c59759520..b9b176ebb997f6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -83,7 +83,6 @@ py_strict_library( "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:stateful_random_ops", "//tensorflow/python/ops/structured:structured_ops", - "//tensorflow/python/tpu:tpu_estimator", "//tensorflow/python/tpu:tpu_noestimator", ], ) diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index d3d06e05502fdc..751078f59d56b9 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -99,7 +99,6 @@ tpu_py_strict_test( disable_mlir_bridge = False, deps = [ ":async_checkpoint", - ":tpu_estimator", ":tpu_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/compat:v2_compat", @@ -169,42 +168,6 @@ py_strict_library( ], ) -py_strict_library( - name = "tpu_estimator", - srcs = [ - "error_handling.py", - "tpu_config.py", - "tpu_context.py", - "tpu_estimator.py", - "util.py", - ], - srcs_version = "PY3", - deps = [ - ":async_checkpoint", - ":feature_column", - ":feature_column_v2", - ":functional", - ":preempted_hook_py", - ":tpu_embedding", - ":tpu_lib", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:util", - "//tensorflow/python/framework:for_generated_wrappers", - "//tensorflow/python/framework:function", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:control_flow_ops", - "//tensorflow/python/ops:init_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:state_ops", - "//tensorflow/python/ops:summary_ops_v2", - "//tensorflow/python/ops:variable_scope", - "//tensorflow/python/ops:variables", - "//tensorflow/python/training", - ], -) - py_strict_library( name = "functional", srcs = ["functional.py"], diff --git a/tensorflow/python/tpu/async_checkpoint_test.py b/tensorflow/python/tpu/async_checkpoint_test.py index 070eff0e20c60e..3601c5fad6cc0d 100644 --- a/tensorflow/python/tpu/async_checkpoint_test.py +++ b/tensorflow/python/tpu/async_checkpoint_test.py @@ -33,13 +33,13 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.tpu import async_checkpoint -from tensorflow.python.tpu import tpu_config -from tensorflow.python.tpu import tpu_estimator from tensorflow.python.tpu import tpu_optimizer from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import training from tensorflow_estimator.python.estimator import estimator as estimator_lib from tensorflow_estimator.python.estimator import model_fn as model_fn_lib +from tensorflow_estimator.python.estimator.tpu import tpu_config +from tensorflow_estimator.python.estimator.tpu import tpu_estimator FLAGS = flags.FLAGS flags.DEFINE_string('tpu', '', 'TPU to use in this test.') diff --git a/tensorflow/python/tpu/error_handling.py b/tensorflow/python/tpu/error_handling.py deleted file mode 100644 index 1e6660af511bc1..00000000000000 --- a/tensorflow/python/tpu/error_handling.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.error_handling import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tpu_config.py b/tensorflow/python/tpu/tpu_config.py deleted file mode 100644 index eda3717520f7a8..00000000000000 --- a/tensorflow/python/tpu/tpu_config.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.tpu_config import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tpu_context.py b/tensorflow/python/tpu/tpu_context.py deleted file mode 100644 index d1f3ee55723df3..00000000000000 --- a/tensorflow/python/tpu/tpu_context.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.tpu_context import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py deleted file mode 100644 index f28db848e56252..00000000000000 --- a/tensorflow/python/tpu/tpu_estimator.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import,redefined-builtin -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import * -# used by tests -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _clone_export_output_with_tensors -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _create_global_step -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _export_output_to_tensors -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _get_scaffold -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _Inputs -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ENQUEUE_OPS -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ESTIMATOR -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_TRAIN_OP -# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/python/tpu/util.py b/tensorflow/python/tpu/util.py deleted file mode 100644 index c5b8964b20a6e2..00000000000000 --- a/tensorflow/python/tpu/util.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.util import * -# pylint: enable=wildcard-import,unused-import From 11a8573128bea46f58d6ff48fdb665387069fd23 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Wed, 22 Nov 2023 11:00:30 -0800 Subject: [PATCH 009/381] Internal change - attach buildcleaner tags. PiperOrigin-RevId: 584673172 --- tensorflow/compiler/mlir/lite/BUILD | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8117705b0fac2b..ce124ea8e56f38 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1314,18 +1314,18 @@ cc_library( ":common", ":fake_quant_utils", ":tensorflow_lite_d2s", - ":tensorflow_lite_legalize_tf", - ":tensorflow_lite_optimize", - ":tensorflow_lite_optimize_batch_matmul", - ":tensorflow_lite_quantize", + ":tensorflow_lite_legalize_tf", # buildcleaner: keep + ":tensorflow_lite_optimize", # buildcleaner: keep + ":tensorflow_lite_optimize_batch_matmul", # buildcleaner: keep + ":tensorflow_lite_quantize", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", - "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", + "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep + "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", "//tensorflow/compiler/mlir/tensorflow", From 9f32f00a944ccd9ac56d36ef91ec82adb6685b09 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Wed, 22 Nov 2023 11:19:38 -0800 Subject: [PATCH 010/381] Add DCN collective stat information as stat metadata to XPlane PiperOrigin-RevId: 584678221 --- tensorflow/core/profiler/protobuf/BUILD | 7 +++ .../protobuf/dcn_collective_info.proto | 55 +++++++++++++++++++ .../tsl/tsl/profiler/utils/xplane_schema.cc | 1 + .../tsl/tsl/profiler/utils/xplane_schema.h | 1 + 4 files changed, 64 insertions(+) create mode 100644 tensorflow/core/profiler/protobuf/dcn_collective_info.proto diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 0e0ffa72b6ce71..c521c9ad89d392 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -270,3 +270,10 @@ tf_proto_library( cc_api_version = 2, visibility = [":friends"], ) + +tf_proto_library( + name = "dcn_collective_info_proto", + srcs = ["dcn_collective_info.proto"], + cc_api_version = 2, + visibility = [":friends"], +) diff --git a/tensorflow/core/profiler/protobuf/dcn_collective_info.proto b/tensorflow/core/profiler/protobuf/dcn_collective_info.proto new file mode 100644 index 00000000000000..5359a3dd54c1c6 --- /dev/null +++ b/tensorflow/core/profiler/protobuf/dcn_collective_info.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package tensorflow.profiler; + +// This proto is based on MegaScaleInfoProto and should be consistent with it. +message DcnCollectiveInfoProto { + enum TransferType { + UNKNOWN_TRANSFER_TYPE = 0; + + // XLA AllToAll transfer. + // Needs `endpoint_groups`. + ALL_TO_ALL = 1; + + // Peer-To-Peer DCN transfer from source to one destination. + // Needs one_to_one_groups. + ONE_TO_ONE = 2; + + // XLA reduce-scatter transfer. + // Needs `endpoint_groups`. + REDUCE_SCATTER = 3; + + // XLA AllGather transfer. + // Needs `endpoint_groups`. + ALL_GATHER = 4; + + // XLA all-reduce transfer. + // Needs `endpoint_groups`. + ALL_REDUCE = 5; + } + + message Endpoint { + int32 slice_id = 1; + int32 device_id = 2; + } + + message EndpointGroup { + repeated Endpoint endpoints = 1; + } + + message OneToOneGroup { + Endpoint source = 1; + Endpoint destination = 2; + } + + // The type of DCN transfer. + TransferType transfer_type = 1; + + // Groups of endpoints (in the form of slice id and device id) involved in + // `ALL_TO_ALL`, `REDUCE_SCATTER`, `ALL_REDUCE` and `ALL_GATHER` transfer. + repeated EndpointGroup endpoint_groups = 2; + + // Groups of endpoints (in the form of slice id and device id) involved in + // `ONE_TO_ONE` transfer. + repeated OneToOneGroup one_to_one_groups = 3; +} diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 2f7eb630aa324a..62b69f2910b334 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -272,6 +272,7 @@ const StatTypeMap& GetStatTypeMap() { {"model_version", kModelVersion}, {"bytes_transferred", kBytesTransferred}, {"queue", kDmaQueue}, + {"dcn_collective_info", kDcnCollectiveInfo}, // Performance counter related. {"Raw Value", kRawValue}, {"Scaled Value", kScaledValue}, diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 8fa320791f0ee5..7bbd052f815eb9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -259,6 +259,7 @@ enum StatType { kModelVersion, kBytesTransferred, kDmaQueue, + kDcnCollectiveInfo, // Performance counter related. kRawValue, kScaledValue, From a5614421de9f91c4b82dbcfb79c98e668af52407 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 22 Nov 2023 11:25:09 -0800 Subject: [PATCH 011/381] Disable copy/move constructors. PiperOrigin-RevId: 584679465 --- tensorflow/lite/simple_memory_arena.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index bf3efd03d518ed..c72b3595919a88 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -86,6 +86,11 @@ class ResizableAlignedBuffer { size_t GetAlignment() const { return alignment_; } private: + ResizableAlignedBuffer(const ResizableAlignedBuffer&) = delete; + ResizableAlignedBuffer& operator=(const ResizableAlignedBuffer&) = delete; + ResizableAlignedBuffer(ResizableAlignedBuffer&&) = delete; + ResizableAlignedBuffer& operator=(ResizableAlignedBuffer&&) = delete; + size_t RequiredAllocationSize(size_t data_array_size) const { return data_array_size + alignment_ - 1; } From 2d963d224d460759c411d3b5abaff061d4f97c6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2023 11:28:31 -0800 Subject: [PATCH 012/381] Disable some layering checks to fix broken TAP test. PiperOrigin-RevId: 584680176 --- tensorflow/compiler/tf2tensorrt/BUILD | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 0571309466823e..b91fb494667c5f 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -3,9 +3,9 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -# Placeholder: load py_proto_library load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# Placeholder: load py_proto_library load( "//tensorflow:tensorflow.bzl", "VERSION", @@ -21,17 +21,18 @@ load( "tf_additional_all_protos", "tf_proto_library", ) -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "cuda_rpath_flags", -) -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") # Platform specific build config load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -421,6 +422,7 @@ tf_cuda_library( "utils/trt_execution_context.h", "utils/trt_shape_optimization_profiles.h", ], + features = ["-layering_check"], deps = [ ":common_utils", ":trt_allocator", @@ -441,6 +443,7 @@ tf_cuda_library( name = "trt_logging", srcs = ["utils/trt_logger.cc"], hdrs = ["utils/trt_logger.h"], + features = ["-layering_check"], visibility = ["//visibility:public"], deps = [ ":common_utils", @@ -525,6 +528,7 @@ tf_cuda_library( name = "trt_allocator", srcs = ["utils/trt_allocator.cc"], hdrs = ["utils/trt_allocator.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", @@ -588,6 +592,7 @@ tf_cuda_library( "convert/logger_registry.h", ], copts = tf_copts(), + features = ["-layering_check"], deps = [ "//tensorflow/core:lib", "@com_google_absl//absl/strings", From 193accc13e133fc50a5a558a8f0bdbe98c7d1884 Mon Sep 17 00:00:00 2001 From: mraunak <83710963+mraunak@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:42:38 -0800 Subject: [PATCH 013/381] Added comments --- tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh index 484680aadb81b0..25a2f1d4cb44f7 100644 --- a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh +++ b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh @@ -14,7 +14,8 @@ # limitations under the License. # ============================================================================== -# This script is a CI script for invoking 'bazel test ... ...' +# This script is a CI script maintained by Intel and is used to launch the nightly CI test +# build on the Windows platform. # It assumes the standard setup on tensorflow Jenkins Windows machines. # Update the flags/variables below to make it work on your local system. From ca87bdeaf2e732500276b0a5a1bec1c5db0073fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2023 12:35:43 -0800 Subject: [PATCH 014/381] Only do a roll if the shift is nonzero. PiperOrigin-RevId: 584695897 --- tensorflow/python/tpu/tpu_embedding_v3_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils.py b/tensorflow/python/tpu/tpu_embedding_v3_utils.py index ed30d9947c8842..08ed796c54b492 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_utils.py @@ -73,7 +73,8 @@ def unshuffle_from_sc_to_cpu( shards = shards_t[:, offset_in_shard : offset_in_shard + size_in_shard, :] # This table's shards were rotated by `shard_rotation`, so we need to rotate # the same amount in opposite direction - shards = manip_ops.roll(shards, -shard_rotation, axis=0) + if shard_rotation: + shards = manip_ops.roll(shards, -shard_rotation, axis=0) # Re-arrange (transpose and reshape) the shards to get the queried embedding # table. intermediate_tensor = array_ops.transpose(shards, (1, 0, 2)) From 0e447935cfc9025338bff139b36fef62a9c93048 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 22 Nov 2023 12:59:55 -0800 Subject: [PATCH 015/381] Use malloc/free instead of new/delete. PiperOrigin-RevId: 584701062 --- tensorflow/lite/simple_memory_arena.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index eb97c777931439..9a849e11a3d08a 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -54,7 +54,7 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_allocation_size); #endif - char* new_buffer = new char[new_allocation_size]; + char* new_buffer = reinterpret_cast(std::malloc(new_allocation_size)); #if defined(__clang__) #if __has_feature(memory_sanitizer) memset(new_buffer, 0, new_allocation_size); @@ -67,7 +67,7 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { const size_t copy_amount = std::min(new_size, data_size_); std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); } - delete[] buffer_; + std::free(buffer_); buffer_ = new_buffer; aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER @@ -92,7 +92,7 @@ void ResizableAlignedBuffer::Release() { OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), RequiredAllocationSize(data_size_)); #endif - delete[] buffer_; + std::free(buffer_); buffer_ = nullptr; data_size_ = 0; aligned_ptr_ = nullptr; From 245b0f6bdba2dc3eeba534ec8a17e2bb9611e855 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 22 Nov 2023 13:01:21 -0800 Subject: [PATCH 016/381] Disable `//tests:lax_test_gpu` in JAX build script because it is currently broken PiperOrigin-RevId: 584701413 --- third_party/xla/.kokoro/jax/build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index 18a4c388ac4b37..ed95d1413c7f0b 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -89,11 +89,11 @@ build_and_test_on_rbe_gpu() { --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow|LaxTest.testBitcastConvertType1" \ + --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \ --test_tag_filters=-multiaccelerator \ --remote_instance_name=projects/tensorflow-testing/instances/default_instance \ --bes_instance_name="tensorflow-testing" \ - -- //tests:gpu_tests //tests:backend_independent_tests + -- //tests:gpu_tests //tests:backend_independent_tests -//tests:lax_test_gpu } # Generate a templated results file to make output accessible to everyone From 4b4e10413f1f0caecf05ae8ac622a3ead4d73406 Mon Sep 17 00:00:00 2001 From: Robert David Date: Wed, 22 Nov 2023 13:30:12 -0800 Subject: [PATCH 017/381] Checking if offsets are under the allocation size allows for a buffer overflow, as "zero" offset is already an up-aligned index in the buffer. Use the data size instead. PiperOrigin-RevId: 584708037 --- tensorflow/lite/simple_memory_arena.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 9a849e11a3d08a..a61f68c7cdb97e 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -209,8 +209,8 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - TF_LITE_ENSURE(context, underlying_buffer_.GetAllocationSize() >= - (alloc.offset + alloc.size)); + TF_LITE_ENSURE( + context, underlying_buffer_.GetDataSize() >= (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { From 8f5ad8699457782c9e0a2295df9e6dbdd756e587 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 22 Nov 2023 13:32:14 -0800 Subject: [PATCH 018/381] [xla:gpu] Track command buffer state explicitly for each executor PiperOrigin-RevId: 584708400 --- .../xla/xla/service/gpu/runtime3/BUILD | 3 +- .../gpu/runtime3/command_buffer_cmd.cc | 30 +--------- .../service/gpu/runtime3/command_buffer_cmd.h | 21 +++---- .../gpu/runtime3/command_buffer_thunk.cc | 55 +++++++++++++++---- .../gpu/runtime3/command_buffer_thunk.h | 45 +++++++++++---- 5 files changed, 92 insertions(+), 62 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index c5d7e682ba52ff..28cadbf3d4f503 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -34,6 +34,7 @@ cc_library( "//xla/service/gpu:thunk", "//xla/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -106,7 +107,7 @@ cc_library( "//xla/service/gpu:thunk", "//xla/stream_executor", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 4e1dbfbcc24402..55649a4aa9a6fd 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include #include #include #include @@ -46,6 +45,9 @@ namespace xla::gpu { //===----------------------------------------------------------------------===// void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { + for (BufferAllocation::Slice& slice : cmd->slices()) { + allocs_indices_.insert(slice.index()); + } commands_.push_back(std::move(cmd)); } @@ -63,38 +65,12 @@ Status CommandBufferCmdSequence::Record( if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { TF_RETURN_IF_ERROR(command_buffer->Update()); } - // Returns if no cmd requires update. - if (!ShouldUpdateCmd(params)) { - return OkStatus(); - } for (auto& cmd : commands_) { TF_RETURN_IF_ERROR(cmd->Record(params, command_buffer)); } return command_buffer->Finalize(); } -bool CommandBufferCmdSequence::ShouldUpdateCmd( - const CommandBufferCmd::RecordParams& params) { - bool should_update = false; - const BufferAllocations* allocs = params.buffer_allocations; - size_t size = allocs->size(); - if (prev_allocs_.size() < size) { - prev_allocs_.resize(size); - should_update = true; - } - // Traversing all allocations from `params` using the index alone (no need for - // offset) is enough because every time `BufferAllocation` remapped to a new - // physical memory location all commands reading from any slice from that - // allocation must be invalidated. - for (unsigned i = 0; i < size; ++i) { - se::DeviceMemoryBase new_alloc = allocs->GetDeviceAddress(i); - se::DeviceMemoryBase& prev_alloc = prev_allocs_[i]; - should_update |= !new_alloc.IsSameAs(prev_alloc); - prev_alloc = new_alloc; - } - return should_update; -} - //===----------------------------------------------------------------------===// // LaunchCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index a80150bfd31589..d1e06d054e5af9 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -30,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/thunk.h" #include "xla/status.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_pimpl.h" @@ -100,18 +103,16 @@ class CommandBufferCmdSequence { Status Record(const CommandBufferCmd::RecordParams& params, se::CommandBuffer* command_buffer); - private: - // Traverse the list of commands and figures out if any of them requires an - // update. Also updates `prev_allocs_` with new allocations from `params`. - bool ShouldUpdateCmd(const CommandBufferCmd::RecordParams& params); + // Returns buffer allocations indices referenced by commands in this sequence. + const absl::flat_hash_set& allocs_indices() const { + return allocs_indices_; + } + private: std::vector> commands_; - // Mapping from buffer slice index to device memory passed at that index via - // the `CommandBufferCmd::RecordParams` in previous invocation of `Record`. - // We can just use a vector instead of map because `BufferAllocation` has a - // unique identifier assigned contiguously and thus can be used as array - // index. - std::vector prev_allocs_; + + // Buffer allocations indices referenced by commands in this sequence. + absl::flat_hash_set allocs_indices_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index a845dce067a5ef..bc7a25dd3824bd 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -15,22 +15,25 @@ limitations under the License. #include "xla/service/gpu/runtime3/command_buffer_thunk.h" -#include #include #include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor_pimpl.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { -CommandBufferThunk::State::State(se::CommandBuffer command_buffer) +CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( + se::CommandBuffer command_buffer) : command_buffer(std::move(command_buffer)) {} CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, @@ -43,33 +46,61 @@ Status CommandBufferThunk::Initialize(se::StreamExecutor* executor, return commands_.Initialize(executor, executable_source); } +bool CommandBufferThunk::ExecutorCommandBuffer::ShouldUpdateCommandBuffer( + const CommandBufferCmdSequence& commands, + const CommandBufferCmd::RecordParams& params) { + bool should_update = false; + const BufferAllocations* allocs = params.buffer_allocations; + + // We check only allocations referenced by commands in a cmd sequence, and + // leave every other entry default initialized (nullptr device memory). + for (BufferAllocation::Index index : commands.allocs_indices()) { + se::DeviceMemoryBase alloc = allocs->GetDeviceAddress(index); + + if (recorded_allocs.size() <= index) { + recorded_allocs.resize(index + 1); + } + + if (!recorded_allocs[index].IsSameAs(alloc)) { + recorded_allocs[index] = alloc; + should_update = true; + } + } + + return should_update; +} + Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { se::StreamExecutor* executor = params.stream->parent(); - TF_ASSIGN_OR_RETURN(State * state, GetOrCreateCommandBuffer(executor)); - - absl::MutexLock lock(&state->mutex); + TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, + GetOrCreateCommandBuffer(executor)); CommandBufferCmd::RecordParams record_params = {params.buffer_allocations}; - TF_RETURN_IF_ERROR(commands_.Record(record_params, &state->command_buffer)); - return executor->Submit(params.stream, state->command_buffer); + absl::MutexLock lock(&cmd_buffer->mutex); + + if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { + TF_RETURN_IF_ERROR( + commands_.Record(record_params, &cmd_buffer->command_buffer)); + } + + return executor->Submit(params.stream, cmd_buffer->command_buffer); } -StatusOr +StatusOr CommandBufferThunk::GetOrCreateCommandBuffer(se::StreamExecutor* executor) { absl::MutexLock lock(&mutex_); // Check if command buffer already exists if (auto it = command_buffers_.find(executor); it != command_buffers_.end()) { - return it->second.get(); + return &it->second; } // Create a new empty command buffer. TF_ASSIGN_OR_RETURN(auto command_buffer, se::CommandBuffer::Create(executor)); - auto emplaced = command_buffers_.emplace( - executor, std::make_unique(std::move(command_buffer))); + auto emplaced = command_buffers_.emplace(executor, std::move(command_buffer)); - return emplaced.first->second.get(); + return &emplaced.first->second; } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h index a66f47f1c65863..dada9c27fbb2ec 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h @@ -16,17 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ #define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ -#include +#include #include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "absl/synchronization/mutex.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" namespace xla::gpu { @@ -39,26 +40,46 @@ class CommandBufferThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - // se::CommandBuffer is not thread safe, and we guard it with a mutex to - // guarantee that we do not mutate it concurrently. - struct State { - explicit State(se::CommandBuffer command_buffer); - + // Command buffer instantiated on a `se::StreamExecutor` instance, and + // auxiliary state required for efficient command buffer updates. + struct ExecutorCommandBuffer { + explicit ExecutorCommandBuffer(se::CommandBuffer command_buffer); + + // Returns true if `commands` cmd sequence has to be recorded into + // `command_buffer` to update it (see `recorded_allocs` below). + bool ShouldUpdateCommandBuffer(const CommandBufferCmdSequence& commands, + const CommandBufferCmd::RecordParams& params) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex); + + // se::CommandBuffer is not thread safe, and we guard it with a mutex to + // guarantee that we do not mutate it concurrently. absl::Mutex mutex; se::CommandBuffer command_buffer ABSL_GUARDED_BY(mutex); + + // Mapping from buffer allocation index to the device memory passed at that + // index to the last call of `commands_.Record(...)` for `command_buffer`. + // We can just use a vector instead of map because `BufferAllocation::Index` + // is a unique identifier assigned contiguously and thus can be used as + // array index. + // + // If no device memory addresses changed from a previous call to `Record`, + // we can skip command buffer update and simply submit it for execution on a + // stream. All other pieces of information (like thread and block sizes) + // captured by commands at construction time and do not change. + std::vector recorded_allocs ABSL_GUARDED_BY(mutex); }; - using OwnedCommandBuffer = std::unique_ptr; // Returns a command buffer instantiated for `executor` or creates new one. - StatusOr GetOrCreateCommandBuffer(se::StreamExecutor* executor); + StatusOr GetOrCreateCommandBuffer( + se::StreamExecutor* executor); // Command sequence that initializes command buffers on each executor. CommandBufferCmdSequence commands_; // Command buffer sequence instantiates command buffers on all executors. absl::Mutex mutex_; - absl::flat_hash_map command_buffers_ - ABSL_GUARDED_BY(mutex_); + absl::node_hash_map + command_buffers_ ABSL_GUARDED_BY(mutex_); }; } // namespace xla::gpu From 1f40e95809bbc3d01cfb06431ac96b819df4ee12 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 22 Nov 2023 14:22:23 -0800 Subject: [PATCH 019/381] [xla:gpu] Do not keep StreamExecutor pointer in CommandBuffer Nothing guarantees that se::StreamExecutor* will point to alive memory address as StreamExecutor is movable type. Always explicitly pass stream executor in arguments. PiperOrigin-RevId: 584719063 --- .../xla/service/gpu/runtime3/command_buffer_cmd.cc | 13 +++++++------ .../xla/service/gpu/runtime3/command_buffer_cmd.h | 1 + .../service/gpu/runtime3/command_buffer_cmd_test.cc | 4 ++-- .../service/gpu/runtime3/command_buffer_thunk.cc | 3 ++- .../xla/xla/stream_executor/command_buffer.cc | 13 ++++++------- .../xla/xla/stream_executor/command_buffer.h | 10 +++------- .../cuda/cuda_command_buffer_test.cc | 4 ++-- .../xla/stream_executor/gpu/gpu_command_buffer.cc | 6 +++--- 8 files changed, 26 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 55649a4aa9a6fd..bfd1450f0d5dc5 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -100,7 +100,7 @@ Status LaunchCmd::Record(const RecordParams& params, VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ << ", shmem_bytes=" << shmem_bytes_; - se::Kernel* kernel = kernels_[command_buffer->executor()].get(); + se::Kernel* kernel = kernels_[params.executor].get(); if (kernel == nullptr) { return absl::InternalError( "Kernel not loaded on a command buffer executor"); @@ -187,13 +187,14 @@ Status GemmCmd::Record(const RecordParams& params, params.buffer_allocations->GetDeviceAddress(rhs_buffer_); se::DeviceMemoryBase out = params.buffer_allocations->GetDeviceAddress(output_buffer_); + TF_ASSIGN_OR_RETURN( auto nested_buffer, - stream_executor::CommandBuffer::Trace( - command_buffer->executor(), [&](stream_executor::Stream* stream) { - return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, - stream); - })); + se::CommandBuffer::Trace(params.executor, [&](se::Stream* stream) { + return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, + stream); + })); + return command_buffer->AddNestedCommandBuffer(nested_buffer); } diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index d1e06d054e5af9..82658a38a4a296 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -55,6 +55,7 @@ class CommandBufferCmd { // module, we only know the buffer slices required for HLO operations, but the // concrete device pointers become available only at run time. struct RecordParams { + se::StreamExecutor* executor; const BufferAllocations* buffer_allocations; }; diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc index 0d2ebb60fcb40b..61474a5ba05e62 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_test.cc @@ -68,7 +68,7 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({&allocations}, &command_buffer)); + TF_ASSERT_OK(commands.Record({executor, &allocations}, &command_buffer)); // Execute command buffer and verify that it copied the memory. TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); @@ -119,7 +119,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({&allocations}, &command_buffer)); + TF_ASSERT_OK(commands.Record({executor, &allocations}, &command_buffer)); // Execute command buffer and verify that it copied the memory. TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index bc7a25dd3824bd..08d5cdb9ca74d4 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -75,7 +75,8 @@ Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, GetOrCreateCommandBuffer(executor)); - CommandBufferCmd::RecordParams record_params = {params.buffer_allocations}; + CommandBufferCmd::RecordParams record_params = {executor, + params.buffer_allocations}; absl::MutexLock lock(&cmd_buffer->mutex); diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index f04889419bd350..82861a15c5fd0b 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -42,7 +42,7 @@ CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; std::unique_ptr command_buffer, executor->implementation()->GetCommandBufferImplementation(mode)); - CommandBuffer cmd(executor, std::move(command_buffer)); + CommandBuffer cmd(std::move(command_buffer)); return cmd; } @@ -80,15 +80,13 @@ internal::CommandBufferInterface* CommandBuffer::implementation() { } /*static*/ CommandBuffer CommandBuffer::Wrap( - StreamExecutor* executor, std::unique_ptr implementation) { - return CommandBuffer(executor, std::move(implementation)); + return CommandBuffer(std::move(implementation)); } CommandBuffer::CommandBuffer( - StreamExecutor* executor, std::unique_ptr implementation) - : executor_(executor), implementation_(std::move(implementation)) {} + : implementation_(std::move(implementation)) {} tsl::Status CommandBuffer::Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, @@ -106,8 +104,9 @@ tsl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return implementation_->MemcpyDeviceToDevice(dst, src, size); } -tsl::Status CommandBuffer::If(DeviceMemory pred, Builder then_builder) { - return implementation_->If(executor_, pred, std::move(then_builder)); +tsl::Status CommandBuffer::If(StreamExecutor* executor, DeviceMemory pred, + Builder then_builder) { + return implementation_->If(executor, pred, std::move(then_builder)); } CommandBuffer::Mode CommandBuffer::mode() const { diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 5a28b409646395..ecbed0e872ba0f 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -117,7 +117,8 @@ class CommandBuffer { // by `then_builder` if predicate is true. Builder should not call `Update` or // `Finalize` on command buffer argument, parent command buffer is responsible // for updating and finalizing conditional command buffers. - tsl::Status If(DeviceMemory pred, Builder then_builder); + tsl::Status If(StreamExecutor* executor, DeviceMemory pred, + Builder then_builder); // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. @@ -140,8 +141,6 @@ class CommandBuffer { // Returns command buffer state. State state() const; - StreamExecutor* executor() const { return executor_; } - //===--------------------------------------------------------------------===// // Semi-internal APIs //===--------------------------------------------------------------------===// @@ -154,15 +153,12 @@ class CommandBuffer { // Wraps platform-specific command buffer implementation into a top-level // StreamExecutor command buffer. static CommandBuffer Wrap( - StreamExecutor* executor, std::unique_ptr implementation); private: - CommandBuffer( - StreamExecutor* executor, + explicit CommandBuffer( std::unique_ptr implementation); - StreamExecutor* executor_; std::unique_ptr implementation_; CommandBuffer(const CommandBuffer&) = delete; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 585df1fc8f6b1b..5c4faa4824201b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -263,7 +263,7 @@ TEST(CudaCommandBufferTest, ConditionalIf) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.If(pred, then_builder)); + TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); TF_ASSERT_OK(cmd_buffer.Finalize()); TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); @@ -302,7 +302,7 @@ TEST(CudaCommandBufferTest, ConditionalIf) { // Update command buffer with a conditional to use new builder. TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.If(pred, then_builder)); + TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); TF_ASSERT_OK(cmd_buffer.Finalize()); TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index bba03c350e783d..a9d0e7fc907b17 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -313,9 +313,9 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, GpuGraphHandle then_graph = std::get(result).graph; // Wrap conditional graph into command buffer and pass it to the builder. - auto then_cmd_buffer = CommandBuffer::Wrap( - executor, parent_->GetCommandBufferImplementation( - nested, then_graph, /*is_owned_graph=*/false)); + auto then_cmd_buffer = + CommandBuffer::Wrap(parent_->GetCommandBufferImplementation( + nested, then_graph, /*is_owned_graph=*/false)); TF_RETURN_IF_ERROR(then_builder(&then_cmd_buffer)); TF_RETURN_IF_ERROR(then_cmd_buffer.Finalize()); From 74bb48736e9a55225d1f7f2a2a2189c09b4369b6 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 22 Nov 2023 14:22:28 -0800 Subject: [PATCH 020/381] Integrate StableHLO at openxla/stablehlo@e3276cd PiperOrigin-RevId: 584719081 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 857daf51d1ef98..80ab0e479b0ca8 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "2e78612f34a7e43cb94169e446204913bf353457" - STABLEHLO_SHA256 = "3ffe41a3c8683878bf088ef1caf36cd36e8ff8d6b714cf0e2e08b2e9a2a89471" + STABLEHLO_COMMIT = "e3276cd896751bfbebd7b18112f81547fbc2bc9c" + STABLEHLO_SHA256 = "948d0265a9ea4214ecfa854564793b08161d693e9cfbd73686f2df9e38034ada" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 857daf51d1ef98..80ab0e479b0ca8 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "2e78612f34a7e43cb94169e446204913bf353457" - STABLEHLO_SHA256 = "3ffe41a3c8683878bf088ef1caf36cd36e8ff8d6b714cf0e2e08b2e9a2a89471" + STABLEHLO_COMMIT = "e3276cd896751bfbebd7b18112f81547fbc2bc9c" + STABLEHLO_SHA256 = "948d0265a9ea4214ecfa854564793b08161d693e9cfbd73686f2df9e38034ada" # LINT.ThenChange(Google-internal path) tf_http_archive( From 81b244f2d54494cfff14f28f2efe95d9b613fb4d Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 22 Nov 2023 14:22:37 -0800 Subject: [PATCH 021/381] Use appropriate RBE config in JAX GPU build script PiperOrigin-RevId: 584719118 --- third_party/xla/.kokoro/jax/build.sh | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index ed95d1413c7f0b..e6a00626fcc43e 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -74,9 +74,6 @@ build_and_test_on_rbe_gpu() { # Runs non-multiaccelerator tests with one GPU apiece. # It appears --run_under needs an absolute path. - # we need to add `--remote_instance_name` and `--bes_instance_name`. Why this - # is only needed for gpu is still a mystery. - # TODO(ddunleavy): reenable `LaxTest.testBitcastConvertType1` bazel \ test \ @@ -85,14 +82,13 @@ build_and_test_on_rbe_gpu() { --config=avx_posix \ --config=mkl_open_source_only \ --config="rbe_linux_cuda12.2_nvcc_py3.9" \ + --config=tensorflow_testing_rbe_linux \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \ --test_tag_filters=-multiaccelerator \ - --remote_instance_name=projects/tensorflow-testing/instances/default_instance \ - --bes_instance_name="tensorflow-testing" \ -- //tests:gpu_tests //tests:backend_independent_tests -//tests:lax_test_gpu } From 0ea2ee3ebe16f938ea2ee43d1da3504f2194ec39 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 22 Nov 2023 15:48:59 -0800 Subject: [PATCH 022/381] Reenable `//tests:lax_test_gpu` in JAX build script, disable specific failing test via `JAX_EXCLUDE_TEST_TARGETS` PiperOrigin-RevId: 584736891 --- third_party/xla/.kokoro/jax/build.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index e6a00626fcc43e..5affd26e0f74bd 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -74,7 +74,7 @@ build_and_test_on_rbe_gpu() { # Runs non-multiaccelerator tests with one GPU apiece. # It appears --run_under needs an absolute path. - # TODO(ddunleavy): reenable `LaxTest.testBitcastConvertType1` + # TODO(ddunleavy): reenable `LaxTest.testBitcastConvertType` bazel \ test \ --verbose_failures=true \ @@ -87,9 +87,9 @@ build_and_test_on_rbe_gpu() { --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \ + --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow|LaxTest.testBitcastConvertType" \ --test_tag_filters=-multiaccelerator \ - -- //tests:gpu_tests //tests:backend_independent_tests -//tests:lax_test_gpu + -- //tests:gpu_tests //tests:backend_independent_tests } # Generate a templated results file to make output accessible to everyone From 57e873eb0d70e8edcda8f95d99d416ae1ce4efb7 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 22 Nov 2023 16:29:22 -0800 Subject: [PATCH 023/381] [xla:gpu] Add AotCompilationResult for thunk runtime PiperOrigin-RevId: 584744534 --- third_party/xla/xla/service/compiler.h | 2 +- .../xla/xla/service/cpu/cpu_compiler.cc | 2 +- .../xla/xla/service/cpu/cpu_compiler.h | 2 +- third_party/xla/xla/service/gpu/BUILD | 5 + .../service/gpu/compile_module_to_llvm_ir.cc | 4 +- .../service/gpu/compile_module_to_llvm_ir.h | 5 + .../xla/xla/service/gpu/gpu_compiler.cc | 114 +++++++++++++++++- 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index 036afa1e045d10..f6cc2d222e9567 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -68,7 +68,7 @@ class AotCompilationResult { } virtual StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) const { + Compiler* compiler, se::StreamExecutor* executor) { return Unimplemented("LoadExecutable unimplemented."); } diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 7b51d877fc4b2f..383437acf4eeef 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -435,7 +435,7 @@ runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( StatusOr> CpuXlaRuntimeAotCompilationResult::LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) const { + Compiler* compiler, se::StreamExecutor* executor) { XlaRuntimeExecutableProto xla_runtime_executable = xla_runtime_cpu_executable_.xla_runtime_executable(); TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h index a6bce23d574b2f..0d7ce6eb53862c 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.h +++ b/third_party/xla/xla/service/cpu/cpu_compiler.h @@ -122,7 +122,7 @@ class CpuXlaRuntimeAotCompilationResult : public AotCompilationResult { } StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) const override; + Compiler* compiler, se::StreamExecutor* executor) override; private: XlaRuntimeCpuExecutableProto xla_runtime_cpu_executable_; diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 73f77151d1e6f7..c6615c857e1dca 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2622,6 +2622,7 @@ cc_library( ":ir_emitter_unnested", ":metrics", ":runtime_intrinsics", + ":thunk", "//xla:shape_util", "//xla:status", "//xla:status_macros", @@ -3007,7 +3008,11 @@ cc_library( ]) + xla_export_hlo_deps() + [ ":command_buffer_scheduling", ":fusion_pipeline", + ":ir_emitter_context", + ":ir_emitter_unnested", ":prepare_hlo_for_ir_emitting_pipeline", + ":thunk", + "@llvm-project//mlir:FuncDialect", "@local_tsl//tsl/lib/monitoring:counter", ], ) diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index f911846b85c324..2a64ef43f41d57 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -177,6 +177,8 @@ static Status LowerToXlaGpuRuntime( return OkStatus(); } +} // namespace + void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence) { for (std::unique_ptr& thunk : *thunk_sequence) { @@ -203,8 +205,6 @@ void ForAllThunks(const std::function& fn, } } -} // namespace - static void ForwardCollectiveAttrs(mlir::ModuleOp module, llvm::StringRef entry_function_name, const HloModuleConfig& config) { diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index 1205968c1719db..03f93a3482975b 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_ #define XLA_SERVICE_GPU_COMPILE_MODULE_TO_LLVM_IR_H_ +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/thunk.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/statusor.h" @@ -54,6 +56,9 @@ struct CompileModuleResults { bool use_original_allocations; }; +void ForAllThunks(const std::function& fn, + ThunkSequence* thunk_sequence); + // Removes all globals from the given module that are both uninitialized and // have no uses within that module. void RemoveUnusedAndUninitializedGlobals( diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0d24a3e0a71ef7..4eb572aa76a4e9 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -54,7 +54,13 @@ limitations under the License. #include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/SplitModule.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -141,6 +147,8 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_stats.h" #include "xla/service/gpu/horizontal_loop_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/loop_double_buffer_transformer.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/metrics.h" @@ -157,6 +165,7 @@ limitations under the License. #include "xla/service/gpu/runtime_intrinsics.h" #include "xla/service/gpu/scatter_slice_simplifier.h" #include "xla/service/gpu/softmax_rewriter_triton.h" +#include "xla/service/gpu/thunk.h" #include "xla/service/gpu/topk_specializer.h" #include "xla/service/gpu/topk_splitter.h" #include "xla/service/gpu/tree_reduction_rewriter.h" @@ -226,6 +235,7 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -314,16 +324,33 @@ class GpuAotCompilationResult : public AotCompilationResult { } StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) const override; + Compiler* compiler, se::StreamExecutor* executor) override; private: XlaRuntimeGpuExecutableProto xla_runtime_gpu_executable_; }; +class GpuThunkAotCompilationResult : public AotCompilationResult { + public: + // TODO(anlunx): Add SerializeAsString(). + StatusOr> LoadExecutable( + Compiler* compiler, se::StreamExecutor* stream_exec) override; + + private: + std::unique_ptr hlo_module_; + std::unique_ptr buffer_assignment_; + std::string asm_text_; + std::vector binary_; + + // We can call LoadExecutable only once because buffer_assignment_ is + // moved to GpuExecutable when LoadExecutable is called. + bool loadable_ = true; +}; + } // end anonymous namespace StatusOr> GpuAotCompilationResult::LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) const { + Compiler* compiler, se::StreamExecutor* executor) { XlaRuntimeExecutableProto xla_runtime_executable = xla_runtime_gpu_executable_.xla_runtime_executable(); TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, @@ -352,6 +379,89 @@ StatusOr> GpuAotCompilationResult::LoadExecutable( GetGpuVersion(executor), executor); } +StatusOr> +GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, + se::StreamExecutor* stream_exec) { + if (!loadable_) { + return InternalError("The AOT compilation result is not loadable."); + } + loadable_ = false; + + // Build the executable, which should be a thunk sequence. + TF_ASSIGN_OR_RETURN( + se::Platform * platform, + se::MultiPlatformManager::PlatformWithId(compiler->PlatformId())); + std::string platform_name = platform->Name(); + se::DeviceDescription gpu_device_info = stream_exec->GetDeviceDescription(); + mlir::DialectRegistry registry; + IrEmitterUnnested::GetDependentDialects(registry); + auto mlir_context = std::make_unique(registry); + llvm::LLVMContext llvm_context; + auto llvm_module = std::make_unique("", llvm_context); + IrEmitterContext ir_emitter_context( + hlo_module_.get(), buffer_assignment_.get(), platform_name, + gpu_device_info, mlir_context.get(), llvm_module.get(), + /*emit_ir_from_hlo=*/true); + mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( + mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module_->name()); + std::vector ordered_allocations; + absl::flat_hash_map + operation_map; + TF_RETURN_IF_ERROR(HloToLhloModule(*buffer_assignment_, *hlo_module_, + *mlir_module, &ordered_allocations, + &operation_map)); + ir_emitter_context.set_allocations(ordered_allocations); + auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); + auto entry_function = mlir::cast( + mlir_module->lookupSymbol(hlo_module_->entry_computation()->name())); + // TODO(anlunx): EmitLmhloRegion emits fusion kernels. We need to make sure + // ptx and cubin already contain emission results and disable kernel emission + // here. + TF_RETURN_IF_ERROR( + ir_emitter->EmitLmhloRegion(&entry_function.getBody(), operation_map)); + std::unique_ptr thunk_sequence = + ir_emitter->ConsumeThunkSequence(); + ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, + thunk_sequence.get()); + + // Get all other fields required by GpuExecutable. + std::vector constants = + std::move(ir_emitter_context.constants()); + TF_ASSIGN_OR_RETURN(auto output_info, + GetOutputInfo(*hlo_module_, *buffer_assignment_)); + const Shape& output_shape = hlo_module_->result_shape(); + std::function buffer_assignment_dumper = [] { + return std::string(); + }; + bool enable_persistent_temp_buffers = + hlo_module_->config() + .debug_options() + .xla_gpu_enable_persistent_temp_buffers(); + int64_t debug_buffer_assignment_show_max = + hlo_module_->config() + .debug_options() + .xla_debug_buffer_assignment_show_max(); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + GpuExecutable::Create(GpuExecutable::Params{ + /*asm_text=*/asm_text_, + /*binary=*/binary_, + /*gpu_version=*/gpu_device_info.gpu_compute_capability(), + /*executable=*/std::move(thunk_sequence), + /*constants=*/std::move(constants), + /*output_info=*/std::move(output_info), + /*module_name=*/std::move(hlo_module_->name()), + /*output_shape=*/std::move(output_shape), + /*mlir_allocations=*/std::nullopt, + /*buffer_assignment=*/std::move(buffer_assignment_), + /*enable_persistent_temp_buffers=*/enable_persistent_temp_buffers, + /*debug_buffer_assignment_show_max=*/debug_buffer_assignment_show_max, + /*debug_module=*/std::unique_ptr(), + /*enable_debug_info_manager=*/true})); + return executable; +} + GpuCompiler::GpuCompiler(se::Platform::Id platform_id, const char* target_triple, const char* data_layout) : platform_id_(platform_id), From 1b5255293e1f0e706bcb47059627b5b0c4a22889 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 22 Nov 2023 16:33:00 -0800 Subject: [PATCH 024/381] Remove the deprecated PJRT GPU client creation. PiperOrigin-RevId: 584745165 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 22 ------------------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 12 ---------- .../pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc | 15 +++++-------- .../xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 20 +++++++---------- 4 files changed, 14 insertions(+), 55 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 84844600eee99f..ad450b5ee94990 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -868,28 +868,6 @@ absl::StatusOr StreamExecutorGpuDevice::GetAllocatorStats() return stats.value(); } -StatusOr> GetStreamExecutorGpuClient( - bool asynchronous, const GpuAllocatorConfig& allocator_config, int node_id, - int num_nodes, const std::optional>& allowed_devices, - std::optional platform_name, - bool should_stage_host_to_device_transfers, - PjRtClient::KeyValueGetCallback kv_get, - PjRtClient::KeyValuePutCallback kv_put, bool enable_mock_nccl) { - GpuClientOptions options; - options.allocator_config = allocator_config; - options.node_id = node_id; - options.num_nodes = num_nodes; - options.allowed_devices = allowed_devices; - options.platform_name = platform_name; - options.should_stage_host_to_device_transfers = - should_stage_host_to_device_transfers; - options.kv_get = kv_get; - options.kv_put = kv_put; - options.enable_mock_nccl = enable_mock_nccl; - - return GetStreamExecutorGpuClient(options); -} - StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options) { #if TENSORFLOW_USE_ROCM diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index c66888b2bc09e5..b4fcb20f5278be 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -246,18 +246,6 @@ struct GpuClientOptions { StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options); -// TODO(b/311119497): Remove this function after all callsites are updated. -ABSL_DEPRECATED("Use the the above function that takes GpuClientOptions.") -StatusOr> GetStreamExecutorGpuClient( - bool asynchronous, const GpuAllocatorConfig& allocator_config, int node_id, - int num_nodes = 1, - const std::optional>& allowed_devices = std::nullopt, - std::optional platform_name = std::nullopt, - bool should_stage_host_to_device_transfers = true, - PjRtClient::KeyValueGetCallback kv_get = nullptr, - PjRtClient::KeyValuePutCallback kv_put = nullptr, - bool enable_mock_nccl = false); - } // namespace xla #endif // XLA_PJRT_GPU_SE_GPU_PJRT_CLIENT_H_ diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 9b94c2f67134bf..7daa467f8e6d41 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -83,9 +83,8 @@ void ValidateResult( } TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast(client.release())); Compiler::TargetConfig gpu_target_config = xla::Compiler::TargetConfig( @@ -113,9 +112,8 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { } TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast(client.release())); auto gpu_compiler = gpu::NVPTXCompiler(); @@ -143,9 +141,8 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { } TEST(StreamExecutorGpuCompilerTest, SuccessLoadFromSerializedExecutable) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast(client.release())); StreamExecutorGpuCompiler compiler; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index d395a46f93b52e..69013b441b9f7d 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -76,9 +76,8 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), "Fake_device", {0, 1}); - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology, client.get()), @@ -88,9 +87,8 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { TEST(StreamExecutorGpuCompilerTest, SuccessXla) { StreamExecutorGpuCompiler compiler; - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); TF_ASSERT_OK_AND_ASSIGN(auto topology, client->GetTopologyDescription()); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -142,9 +140,8 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { StreamExecutorGpuTopologyDescription topology(CudaId(), CudaName(), "Fake_device", {0, 1}); - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, client.get()), StatusIs(absl::StatusCode::kUnimplemented)); @@ -159,9 +156,8 @@ TEST(StreamExecutorGpuCompilerTest, SuccessMlir) { auto mlir_module = mlir::parseSourceString(mlir_str, &context); - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto topology, client->GetTopologyDescription()); TF_ASSERT_OK_AND_ASSIGN( auto executable, From ad7cc3358b1aeecb165f0976a93f8ec81150fa9e Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Wed, 22 Nov 2023 17:14:21 -0800 Subject: [PATCH 025/381] Eliminate TensorBoard's re-export of tf.summary at TF import time. Explicitly export tf.summary.audio, tf.summary.histogram, tf.summary.image, tf.summary.scalar, and tf.summary.text as wrappers to their TensorBoard implementations. This has a small effect on error messages - previously if TensorBoard was not installed and one of the five APIs above were called, there would be an error stating that tf.summary had no such attribute. Now, these APIs always exist and are always callable, but if TensorBoard is not installed there will be an error that their implementations are missing. PiperOrigin-RevId: 584752641 --- tensorflow/api_template.__init__.py | 13 - tensorflow/compat_template.__init__.py | 10 - tensorflow/python/BUILD | 1 + tensorflow/python/modules_with_exports.py | 1 + tensorflow/python/summary/BUILD | 8 + tensorflow/python/summary/summary.py | 17 +- tensorflow/python/summary/summary_v2_test.py | 56 ++- tensorflow/python/summary/tb_summary.py | 374 +++++++++++++++++++ tensorflow/tools/docs/generate2.py | 2 +- 9 files changed, 438 insertions(+), 44 deletions(-) create mode 100644 tensorflow/python/summary/tb_summary.py diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 321738084016a7..1ccf2fe07f0af9 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -30,7 +30,6 @@ import distutils as _distutils import importlib import inspect as _inspect -import logging as _logging import os as _os import site as _site import sys as _sys @@ -62,18 +61,6 @@ __path__.append(_tf_api_dir) # Hook external TensorFlow modules. -# Import compat before trying to import summary from tensorboard, so that -# reexport_tf_summary can get compat from sys.modules. Only needed if using -# lazy loading. -_current_module.compat.v2 # pylint: disable=pointless-statement -try: - from tensorboard.summary._tf import summary - _current_module.__path__ = ( - [_module_util.get_parent_dir(summary)] + _current_module.__path__) - setattr(_current_module, "summary", summary) -except ImportError: - _logging.warning( - "Limited tf.summary API due to missing TensorBoard installation.") # Load tensorflow-io-gcs-filesystem if enabled if (_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "true" or diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 9d2f954293eddc..701623c328081e 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -16,7 +16,6 @@ # pylint: disable=g-bad-import-order,g-import-not-at-top,protected-access -import logging as _logging import os as _os import sys as _sys import typing as _typing @@ -31,15 +30,6 @@ # Hook external TensorFlow modules. _current_module = _sys.modules[__name__] -try: - from tensorboard.summary._tf import summary - _current_module.__path__ = ( - [_module_util.get_parent_dir(summary)] + _current_module.__path__) - setattr(_current_module, "summary", summary) -except ImportError: - _logging.warning( - "Limited tf.compat.v2.summary API due to missing TensorBoard " - "installation.") # Lazy-load estimator. _estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator" diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b9b176ebb997f6..775233f6ef3c5a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -447,6 +447,7 @@ py_strict_library( "//tensorflow/python/profiler:trace", "//tensorflow/python/saved_model", "//tensorflow/python/summary:summary_py", + "//tensorflow/python/summary:tb_summary", "//tensorflow/python/tpu:tpu_noestimator", "//tensorflow/python/training", "//tensorflow/python/training:quantize_training", diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 5f86568227670c..627f5548988d3f 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -170,6 +170,7 @@ # Summary from tensorflow.python.summary import summary +from tensorflow.python.summary import tb_summary # TPU from tensorflow.python.tpu import api diff --git a/tensorflow/python/summary/BUILD b/tensorflow/python/summary/BUILD index 5ed5b0f74dc2df..7af5c5cb277ae8 100644 --- a/tensorflow/python/summary/BUILD +++ b/tensorflow/python/summary/BUILD @@ -37,6 +37,7 @@ py_strict_library( srcs = ["summary.py"], visibility = ["//visibility:public"], deps = [ + ":tb_summary", "//tensorflow/core:protos_all_py", "//tensorflow/python/distribute:summary_op_util", "//tensorflow/python/eager:context", @@ -124,3 +125,10 @@ tf_py_strict_test( "@pypi_tb_nightly//:pkg", ], ) + +py_strict_library( + name = "tb_summary", + srcs = ["tb_summary.py"], + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python/util:tf_export"], +) diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 161456a7aecae0..b6112b1d7db1d9 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -46,7 +46,7 @@ from tensorflow.python.ops import gen_summary_ops as _gen_summary_ops # pylint: disable=unused-import from tensorflow.python.ops import summary_op_util as _summary_op_util from tensorflow.python.ops import summary_ops_v2 as _summary_ops_v2 - +from tensorflow.python.summary import tb_summary # exports FileWriter, FileWriterCache # pylint: disable=unused-import from tensorflow.python.summary.writer.writer import FileWriter @@ -124,9 +124,8 @@ def scalar(name, tensor, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import scalar as scalar_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - scalar_v2(name=tag, data=tensor, step=_get_step_for_v2()) + tb_summary.scalar(name=tag, data=tensor, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') @@ -235,9 +234,8 @@ def image(name, tensor, max_outputs=3, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import image as image_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - image_v2( + tb_summary.image( name=tag, data=tensor, step=_get_step_for_v2(), @@ -330,9 +328,8 @@ def histogram(name, values, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import histogram as histogram_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - histogram_v2(name=tag, data=values, step=_get_step_for_v2()) + tb_summary.histogram(name=tag, data=values, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') @@ -440,12 +437,11 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None, if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import audio as audio_v2 # pylint: disable=g-import-not-at-top if tensor.shape.rank == 2: # TF2 op requires 3-D tensor, add the `channels` dimension. tensor = _array_ops.expand_dims_v2(tensor, axis=2) with _compat_summary_scope(name, family) as tag: - audio_v2( + tb_summary.audio( name=tag, data=tensor, sample_rate=sample_rate, @@ -540,8 +536,7 @@ def text(name, tensor, collections=None): return _constant_op.constant('') # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import text as text_v2 # pylint: disable=g-import-not-at-top - text_v2(name=name, data=tensor, step=_get_step_for_v2()) + tb_summary.text(name=name, data=tensor, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') diff --git a/tensorflow/python/summary/summary_v2_test.py b/tensorflow/python/summary/summary_v2_test.py index d6454b46893f05..6e3721b311f209 100644 --- a/tensorflow/python/summary/summary_v2_test.py +++ b/tensorflow/python/summary/summary_v2_test.py @@ -43,7 +43,9 @@ def test_scalar_summary_v2__w_writer(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_scalar_v2.assert_called_once_with('float', data=i, step=1) + mock_scalar_v2.assert_called_once_with( + name='float', data=i, step=1, description=test.mock.ANY + ) @test_util.run_v2_only def test_scalar_summary_v2__wo_writer(self): @@ -79,7 +81,11 @@ def test_scalar_summary_v2__family(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_scalar_v2.assert_called_once_with( - 'otter/otter/float', data=constant_op.constant(2.5), step=1) + name='otter/otter/float', + data=constant_op.constant(2.5), + step=1, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_scalar_summary_v2__family_w_outer_scope(self): @@ -95,7 +101,11 @@ def test_scalar_summary_v2__family_w_outer_scope(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_scalar_v2.assert_called_once_with( - 'crabnet/sea/crabnet/float', data=constant_op.constant(3.5), step=1) + name='crabnet/sea/crabnet/float', + data=constant_op.constant(3.5), + step=1, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_scalar_summary_v2__v1_set_step(self): @@ -111,7 +121,9 @@ def test_scalar_summary_v2__v1_set_step(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_scalar_v2.assert_called_once_with('float', data=i, step=1024) + mock_scalar_v2.assert_called_once_with( + name='float', data=i, step=1024, description=test.mock.ANY + ) @test_util.run_v2_only def test_image_summary_v2(self): @@ -127,7 +139,12 @@ def test_image_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_image_v2.assert_called_once_with( - 'family/outer/family/image', data=i, step=2, max_outputs=3) + name='family/outer/family/image', + data=i, + step=2, + max_outputs=3, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_histogram_summary_v2(self): @@ -142,7 +159,12 @@ def test_histogram_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_histogram_v2.assert_called_once_with( - 'family/family/histogram', data=i, step=3) + name='family/family/histogram', + data=i, + step=3, + buckets=test.mock.ANY, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_audio_summary_v2(self): @@ -158,7 +180,14 @@ def test_audio_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_audio_v2.assert_called_once_with( - 'dolphin/wave', data=i, sample_rate=0.2, step=10, max_outputs=3) + name='dolphin/wave', + data=i, + sample_rate=0.2, + step=10, + max_outputs=3, + encoding=test.mock.ANY, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_audio_summary_v2__2d_tensor(self): @@ -175,7 +204,14 @@ def test_audio_summary_v2__2d_tensor(self): self.assertEqual(tensor.dtype, dtypes.string) mock_audio_v2.assert_called_once_with( - 'wave', data=test.mock.ANY, sample_rate=0.2, step=11, max_outputs=3) + name='wave', + data=test.mock.ANY, + sample_rate=0.2, + step=11, + max_outputs=3, + encoding=test.mock.ANY, + description=test.mock.ANY, + ) input_3d = array_ops.ones((5, 3, 1)) # 3-D input tensor self.assertAllEqual(mock_audio_v2.call_args[1]['data'], input_3d) @@ -191,7 +227,9 @@ def test_text_summary_v2(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_text_v2.assert_called_once_with('text', data=i, step=22) + mock_text_v2.assert_called_once_with( + name='text', data=i, step=22, description=test.mock.ANY + ) if __name__ == '__main__': diff --git a/tensorflow/python/summary/tb_summary.py b/tensorflow/python/summary/tb_summary.py new file mode 100644 index 00000000000000..682ca5a2b7e1dd --- /dev/null +++ b/tensorflow/python/summary/tb_summary.py @@ -0,0 +1,374 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Re-exports the APIs of TF2 summary that live in TensorBoard.""" + +from tensorflow.python.util.tf_export import tf_export + +_TENSORBOARD_NOT_INSTALLED_ERROR = ( + "TensorBoard is not installed, missing implementation for" +) + + +class TBNotInstalledError(Exception): + + def __init__(self, summary_api): + self.error_message = f"{_TENSORBOARD_NOT_INSTALLED_ERROR} {summary_api}" + super().__init__(self.error_message) + + +@tf_export("summary.audio", v1=[]) +def audio( + name, + data, + sample_rate, + step=None, + max_outputs=3, + encoding=None, + description=None, +): + """Write an audio summary. + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` representing audio data with shape `[k, t, c]`, where `k` + is the number of audio clips, `t` is the number of frames, and `c` is the + number of channels. Elements should be floating-point values in `[-1.0, + 1.0]`. Any of the dimensions may be statically unknown (i.e., `None`). + sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the sample + rate, in Hz. Must be positive. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this many + audio clips will be emitted at each step. When more than `max_outputs` + many clips are provided, the first `max_outputs` many clips will be used + and the rest silently discarded. + encoding: Optional constant `str` for the desired encoding. Only "wav" is + currently supported, but this is not guaranteed to remain the default, so + if you want "wav" in particular, set this explicitly. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import audio as audio_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.audio") from exc + return audio_v2( + name=name, + data=data, + sample_rate=sample_rate, + step=step, + max_outputs=max_outputs, + encoding=encoding, + description=description, + ) + + +@tf_export("summary.histogram", v1=[]) +def histogram(name, data, step=None, buckets=None, description=None): + """Write a histogram summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. + + Writes a histogram to the current default summary writer, for later analysis + in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written + using this API will appear in both places). Like `tf.summary.scalar` points, + each histogram is associated with a `step` and a `name`. All the histograms + with the same `name` constitute a time series of histograms. + + The histogram is calculated over all the elements of the given `Tensor` + without regard to its shape or rank. + + This example writes 2 histograms: + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0) + tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0) + ``` + + A common use case is to examine the changing activation patterns (or lack + thereof) at specific layers in a neural network, over time. + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + for step in range(100): + # Generate fake "activations". + activations = [ + tf.random.normal([1000], mean=step, stddev=1), + tf.random.normal([1000], mean=step, stddev=10), + tf.random.normal([1000], mean=step, stddev=100), + ] + + tf.summary.histogram("layer1/activate", activations[0], step=step) + tf.summary.histogram("layer2/activate", activations[1], step=step) + tf.summary.histogram("layer3/activate", activations[2], step=step) + ``` + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` of any shape. The histogram is computed over its elements, + which must be castable to `float64`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + buckets: Optional positive `int`. The output will have this many buckets, + except in two edge cases. If there is no data, then there are no buckets. + If there is data but all points have the same value, then all buckets' + left and right endpoints are the same and only the last bucket has nonzero + count. Defaults to 30 if not specified. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import histogram as histogram_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.histogram") from exc + return histogram_v2( + name=name, data=data, step=step, buckets=buckets, description=description + ) + + +@tf_export("summary.image", v1=[]) +def image(name, data, step=None, max_outputs=3, description=None): + """Write an image summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. + + Writes a collection of images to the current default summary writer. Data + appears in TensorBoard's 'Images' dashboard. Like `tf.summary.scalar` points, + each collection of images is associated with a `step` and a `name`. All the + image collections with the same `name` constitute a time series of image + collections. + + This example writes 2 random grayscale images: + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + image1 = tf.random.uniform(shape=[8, 8, 1]) + image2 = tf.random.uniform(shape=[8, 8, 1]) + tf.summary.image("grayscale_noise", [image1, image2], step=0) + ``` + + To avoid clipping, data should be converted to one of the following: + + - floating point values in the range [0,1], or + - uint8 values in the range [0,255] + + ```python + # Convert the original dtype=int32 `Tensor` into `dtype=float64`. + rgb_image_float = tf.constant([ + [[1000, 0, 0], [0, 500, 1000]], + ]) / 1000 + tf.summary.image("picture", [rgb_image_float], step=0) + + # Convert original dtype=uint8 `Tensor` into proper range. + rgb_image_uint8 = tf.constant([ + [[1, 1, 0], [0, 0, 1]], + ], dtype=tf.uint8) * 255 + tf.summary.image("picture", [rgb_image_uint8], step=1) + ``` + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` representing pixel data with shape `[k, h, w, c]`, where + `k` is the number of images, `h` and `w` are the height and width of the + images, and `c` is the number of channels, which should be 1, 2, 3, or 4 + (grayscale, grayscale with alpha, RGB, RGBA). Any of the dimensions may be + statically unknown (i.e., `None`). Floating point data will be clipped to + the range [0,1]. Other data types will be clipped into an allowed range + for safe casting to uint8, using `tf.image.convert_image_dtype`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this many + images will be emitted at each step. When more than `max_outputs` many + images are provided, the first `max_outputs` many images will be used and + the rest silently discarded. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import image as image_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.image") from exc + return image_v2( + name=name, + data=data, + step=step, + max_outputs=max_outputs, + description=description, + ) + + +@tf_export("summary.scalar", v1=[]) +def scalar(name, data, step=None, description=None): + """Write a scalar summary. + + See also `tf.summary.image`, `tf.summary.histogram`, + `tf.summary.SummaryWriter`. + + Writes simple numeric values for later analysis in TensorBoard. Writes go to + the current default summary writer. Each summary point is associated with an + integral `step` value. This enables the incremental logging of time series + data. A common usage of this API is to log loss during training to produce + a loss curve. + + For example: + + ```python + test_summary_writer = tf.summary.create_file_writer('test/logdir') + with test_summary_writer.as_default(): + tf.summary.scalar('loss', 0.345, step=1) + tf.summary.scalar('loss', 0.234, step=2) + tf.summary.scalar('loss', 0.123, step=3) + ``` + + Multiple independent time series may be logged by giving each series a unique + `name` value. + + See [Get started with + TensorBoard](https://www.tensorflow.org/tensorboard/get_started) + for more examples of effective usage of `tf.summary.scalar`. + + In general, this API expects that data points are logged with a monotonically + increasing step value. Duplicate points for a single step or points logged out + of order by step are not guaranteed to display as desired in TensorBoard. + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A real numeric scalar value, convertible to a `float32` Tensor. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was written because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import scalar as scalar_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.scalar") from exc + return scalar_v2(name=name, data=data, step=step, description=description) + + +@tf_export("summary.text", v1=[]) +def text(name, data, step=None, description=None): + r"""Write a text summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`, `tf.summary.image`. + + Writes text Tensor values for later visualization and analysis in TensorBoard. + Writes go to the current default summary writer. Like `tf.summary.scalar` + points, text points are each associated with a `step` and a `name`. + All the points with the same `name` constitute a time series of text values. + + For Example: + ```python + test_summary_writer = tf.summary.create_file_writer('test/logdir') + with test_summary_writer.as_default(): + tf.summary.text('first_text', 'hello world!', step=0) + tf.summary.text('first_text', 'nice to meet you!', step=1) + ``` + + The text summary can also contain Markdown, and TensorBoard will render the + text + as such. + + ```python + with test_summary_writer.as_default(): + text_data = ''' + | *hello* | *there* | + |---------|---------| + | this | is | + | a | table | + ''' + text_data = '\n'.join(l.strip() for l in text_data.splitlines()) + tf.summary.text('markdown_text', text_data, step=0) + ``` + + Since text is Tensor valued, each text point may be a Tensor of string values. + rank-1 and rank-2 Tensors are rendered as tables in TensorBoard. For higher + ranked + Tensors, you'll see just a 2D slice of the data. To avoid this, reshape the + Tensor + to at most rank-2 prior to passing it to this function. + + Demo notebook at + ["Displaying text data in + TensorBoard"](https://www.tensorflow.org/tensorboard/text_summaries). + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A UTF-8 string Tensor value. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import text as text_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.text") from exc + return text_v2(name=name, data=data, step=step, description=description) diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py index d56e508f8bf2dc..ed8ffc015fb7df 100644 --- a/tensorflow/tools/docs/generate2.py +++ b/tensorflow/tools/docs/generate2.py @@ -324,7 +324,7 @@ def edit_yaml_file(path): expected_path_contents = { "tf/summary/audio.md": - "tensorboard/plugins/audio/summary_v2.py", + "python/summary/tb_summary.py", "tf/estimator/DNNClassifier.md": "tensorflow_estimator/python/estimator/canned/dnn.py", "tf/nn/sigmoid_cross_entropy_with_logits.md": From 2cf8048c322f1471eea3d02f0773816f2a6ca2d1 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:43:18 -0800 Subject: [PATCH 026/381] PR #7222: [ROCM] fixes for hipblas-lt 6.0 enums and datatypes Imported from GitHub PR https://github.com/openxla/xla/pull/7222 These are local ROCM-side fixes to be compatible with hipblas-lt 6.0 release @xla-rotation: would you have a look, please ? Copybara import of the project: -- ee67b12bae4a146d2d5e28ab69cb68308e357696 by Pavel Emeliyanenko : fixes for hipblas-lt 6.0 enums and datatypes Merging this change closes #7222 PiperOrigin-RevId: 584758032 --- .../xla/stream_executor/rocm/hip_blas_lt.cc | 38 ++++++++----------- .../xla/stream_executor/rocm/hip_blas_lt.h | 19 +++++----- .../stream_executor/rocm/hip_blas_utils.cc | 22 +++++------ .../xla/stream_executor/rocm/hip_blas_utils.h | 27 +++++++------ 4 files changed, 51 insertions(+), 55 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 6a20995b438a3f..e25a0a946c7912 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -421,31 +421,31 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( namespace { -template +template struct HipToNativeT; template <> -struct HipToNativeT { +struct HipToNativeT { using type = Eigen::bfloat16; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = Eigen::half; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = float; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = double; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = complex64; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = complex128; }; @@ -476,22 +476,14 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( } // Other data types: - TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_16B, - HIPBLASLT_R_16B) - TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_16F, - HIPBLASLT_R_16F) - TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(float, HIPBLASLT_R_32F, HIPBLASLT_R_32F, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(double, HIPBLASLT_R_64F, HIPBLASLT_R_64F, HIPBLASLT_R_64F, - HIPBLASLT_R_64F) - TYPED_MATMUL(complex64, HIPBLASLT_C_32F, HIPBLASLT_C_32F, HIPBLASLT_C_32F, - HIPBLASLT_C_32F) - TYPED_MATMUL(complex128, HIPBLASLT_C_64F, HIPBLASLT_C_64F, HIPBLASLT_C_64F, - HIPBLASLT_C_64F) + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(float, HIP_R_32F, HIP_R_32F, HIP_R_32F, HIP_R_32F) + TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F, HIP_R_64F) + TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F, HIP_C_32F) + TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F, HIP_C_64F) #undef TYPED_MATMUL diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 678608e1c57ed3..0ab58918a66f43 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -42,16 +42,16 @@ class BlasLt : public gpu::BlasLt { struct MatrixLayout { static tsl::StatusOr Create(const gpu::MatrixLayout& m); - hipblasltDatatype_t type() const { return datatype_; } + hipDataType type() const { return datatype_; } hipblasLtMatrixLayout_t get() const { return handle_.get(); } private: - MatrixLayout(hipblasLtMatrixLayout_t handle, hipblasltDatatype_t datatype) + MatrixLayout(hipblasLtMatrixLayout_t handle, hipDataType datatype) : handle_(handle, wrap::hipblasLtMatrixLayoutDestroy), datatype_(datatype) {} Owned handle_; - hipblasltDatatype_t datatype_; + hipDataType datatype_; }; class MatmulDesc { @@ -63,24 +63,23 @@ class BlasLt : public gpu::BlasLt { Epilogue epilogue = Epilogue::kDefault, PointerMode pointer_mode = PointerMode::kHost); - hipblasLtComputeType_t compute_type() const { return compute_type_; } - hipblasltDatatype_t scale_type() const { return datatype_; } + hipblasComputeType_t compute_type() const { return compute_type_; } + hipDataType scale_type() const { return datatype_; } hipblasPointerMode_t pointer_mode() const { return HIPBLAS_POINTER_MODE_HOST; } hipblasLtMatmulDesc_t get() const { return handle_.get(); } private: - MatmulDesc(hipblasLtMatmulDesc_t handle, - hipblasLtComputeType_t compute_type, - hipblasltDatatype_t datatype) + MatmulDesc(hipblasLtMatmulDesc_t handle, hipblasComputeType_t compute_type, + hipDataType datatype) : handle_(handle, wrap::hipblasLtMatmulDescDestroy), compute_type_(compute_type), datatype_(datatype) {} Owned handle_; - hipblasLtComputeType_t compute_type_; - hipblasltDatatype_t datatype_; + hipblasComputeType_t compute_type_; + hipDataType datatype_; }; struct MatmulPlan : public gpu::BlasLt::MatmulPlan { diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc index 69d2a48bfaf4be..8bd0be07c53464 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc @@ -32,36 +32,36 @@ tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) { return tsl::OkStatus(); } -hipblasltDatatype_t AsHipblasDataType(blas::DataType type) { +hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3FN: LOG(FATAL) << "hipblaslt does not support F8 yet"; case blas::DataType::kHalf: - return HIPBLASLT_R_16F; + return HIP_R_16F; case blas::DataType::kBF16: - return HIPBLASLT_R_16B; + return HIP_R_16BF; case blas::DataType::kFloat: - return HIPBLASLT_R_32F; + return HIP_R_32F; case blas::DataType::kDouble: - return HIPBLASLT_R_64F; + return HIP_R_64F; case blas::DataType::kInt8: - return HIPBLASLT_R_8I; + return HIP_R_8I; case blas::DataType::kInt32: - return HIPBLASLT_R_32I; + return HIP_R_32I; case blas::DataType::kComplexFloat: - return HIPBLASLT_C_32F; + return HIP_C_32F; case blas::DataType::kComplexDouble: - return HIPBLASLT_C_64F; + return HIP_C_64F; default: LOG(FATAL) << "unknown data type"; } } -hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type) { +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type) { if (type == blas::ComputationType::kF32 || type == blas::ComputationType::kTF32AsF32) - return HIPBLASLT_COMPUTE_F32; + return HIPBLAS_COMPUTE_32F; else LOG(FATAL) << "unsupported hipblaslt computation type"; } diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h index c4f76767c02dc1..726386a1bb6f2b 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h @@ -26,15 +26,20 @@ limitations under the License. #if TF_HIPBLASLT #if TF_ROCM_VERSION < 60000 -#define hipblasltDatatype_t hipblasDatatype_t -#define HIPBLASLT_R_16F HIPBLAS_R_16F -#define HIPBLASLT_R_16B HIPBLAS_R_16B -#define HIPBLASLT_R_32F HIPBLAS_R_32F -#define HIPBLASLT_R_64F HIPBLAS_R_64F -#define HIPBLASLT_R_8I HIPBLAS_R_8I -#define HIPBLASLT_R_32I HIPBLAS_R_32I -#define HIPBLASLT_C_32F HIPBLAS_C_32F -#define HIPBLASLT_C_64F HIPBLAS_C_64F +#define hipDataType hipblasDatatype_t +#define HIP_R_16F HIPBLAS_R_16F +#define HIP_R_16BF HIPBLAS_R_16B +#define HIP_R_32F HIPBLAS_R_32F +#define HIP_R_64F HIPBLAS_R_64F +#define HIP_R_8I HIPBLAS_R_8I +#define HIP_R_32I HIPBLAS_R_32I +#define HIP_C_32F HIPBLAS_C_32F +#define HIP_C_64F HIPBLAS_C_64F + +#define hipblasComputeType_t hipblasLtComputeType_t +#define HIPBLAS_COMPUTE_32F HIPBLASLT_COMPUTE_F32 +#define HIPBLAS_COMPUTE_64F HIPBLASLT_COMPUTE_F64 +#define HIPBLAS_COMPUTE_32I HIPBLASLT_COMPUTE_I32 #endif namespace stream_executor { @@ -44,8 +49,8 @@ namespace rocm { TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr)) tsl::Status ToStatus(hipblasStatus_t status, const char* prefix); -hipblasltDatatype_t AsHipblasDataType(blas::DataType type); -hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type); +hipDataType AsHipblasDataType(blas::DataType type); +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type); hipblasOperation_t AsHipblasOperation(blas::Transpose trans); } // namespace rocm From 01799dd97e878b6648f9ee57d107ffdf1674fd9c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 22 Nov 2023 17:44:28 -0800 Subject: [PATCH 027/381] [xla:gpu] Add support for conditional commands PiperOrigin-RevId: 584758194 --- .../gpu/runtime3/command_buffer_cmd.cc | 63 ++++++++++++-- .../service/gpu/runtime3/command_buffer_cmd.h | 49 ++++++++++- .../gpu/runtime3/command_buffer_thunk_test.cc | 82 +++++++++++++++++++ 3 files changed, 185 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index bfd1450f0d5dc5..26039ed5fba41f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "absl/container/inlined_vector.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" @@ -34,6 +34,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -46,6 +47,7 @@ namespace xla::gpu { void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { for (BufferAllocation::Slice& slice : cmd->slices()) { + slices_.insert(slice); allocs_indices_.insert(slice.index()); } commands_.push_back(std::move(cmd)); @@ -61,14 +63,34 @@ Status CommandBufferCmdSequence::Initialize( Status CommandBufferCmdSequence::Record( const CommandBufferCmd::RecordParams& params, - se::CommandBuffer* command_buffer) { - if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { - TF_RETURN_IF_ERROR(command_buffer->Update()); + se::CommandBuffer* command_buffer, RecordMode mode) { + if (mode == RecordMode::kExclusive) { + if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { + TF_RETURN_IF_ERROR(command_buffer->Update()); + } } + for (auto& cmd : commands_) { TF_RETURN_IF_ERROR(cmd->Record(params, command_buffer)); } - return command_buffer->Finalize(); + + if (mode == RecordMode::kExclusive) { + TF_RETURN_IF_ERROR(command_buffer->Finalize()); + } + + return OkStatus(); +} + +// Returns buffer allocation slices referenced by commands in this sequence. +const absl::flat_hash_set& +CommandBufferCmdSequence::slices() const { + return slices_; +} + +// Returns buffer allocations indices referenced by commands in this sequence. +const absl::flat_hash_set& +CommandBufferCmdSequence::allocs_indices() const { + return allocs_indices_; } //===----------------------------------------------------------------------===// @@ -151,6 +173,37 @@ CommandBufferCmd::Slices MemcpyDeviceToDeviceCmd::slices() { return {dst_, src_}; } +//===----------------------------------------------------------------------===// +// IfCmd +//===----------------------------------------------------------------------===// + +IfCmd::IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds) + : pred_(pred), then_cmds_(std::move(then_cmds)) {} + +Status IfCmd::Initialize(se::StreamExecutor* executor, + ExecutableSource source) { + return then_cmds_.Initialize(executor, source); +} + +Status IfCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + params.buffer_allocations->GetDeviceAddress(pred_); + + return command_buffer->If( + params.executor, se::DeviceMemory(pred), + [&](se::CommandBuffer* then_cmd_buffer) { + return then_cmds_.Record( + params, then_cmd_buffer, + CommandBufferCmdSequence::RecordMode::kConditional); + }); +} + +CommandBufferCmd::Slices IfCmd::slices() { + auto& slices = then_cmds_.slices(); + return {slices.begin(), slices.end()}; +} + //===----------------------------------------------------------------------===// // GemmCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 82658a38a4a296..7d9dccf861af76 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -89,6 +89,21 @@ class CommandBufferCmdSequence { public: CommandBufferCmdSequence() = default; + enum class RecordMode { + // In exclusive mode no one else is recording commands into the command + // buffer argument, and cmd sequence is responsible for updating command + // buffer state: finalizing after all commands recorded, and + // switching to update state before recording updates. + kExclusive, + + // In conditional mode multiple cmd sequences can be recorded into the + // command buffer argument, and with command buffer state managed externally + // cmd sequence should not finalize or update it. This mode is used when + // command buffer cmd sequence is recorded into conditional command buffers + // owned by the parent command buffer. + kConditional + }; + void Append(std::unique_ptr cmd); template @@ -102,16 +117,21 @@ class CommandBufferCmdSequence { // Records all commands added to a sequence into the given command buffer. Status Record(const CommandBufferCmd::RecordParams& params, - se::CommandBuffer* command_buffer); + se::CommandBuffer* command_buffer, + RecordMode mode = RecordMode::kExclusive); + + // Returns buffer allocation slices referenced by commands in this sequence. + const absl::flat_hash_set& slices() const; // Returns buffer allocations indices referenced by commands in this sequence. - const absl::flat_hash_set& allocs_indices() const { - return allocs_indices_; - } + const absl::flat_hash_set& allocs_indices() const; private: std::vector> commands_; + // Buffer allocation slices referenced by commands in this sequence. + absl::flat_hash_set slices_; + // Buffer allocations indices referenced by commands in this sequence. absl::flat_hash_set allocs_indices_; }; @@ -165,6 +185,27 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { int64_t num_bytes_; }; +//===----------------------------------------------------------------------===// +// IfCmd +//===----------------------------------------------------------------------===// + +class IfCmd : public CommandBufferCmd { + public: + IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds); + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource source) override; + + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence then_cmds_; +}; + //===----------------------------------------------------------------------===// // GemmCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 7712ce1c8c7098..2c18b511c092d3 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_test_kernels.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" @@ -388,4 +389,85 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { ASSERT_EQ(dst, std::vector(4, 21 + 21)); } +TEST(CommandBufferThunkTest, IfCmd) { +#if !defined(XLA_GPU_USE_CUDA_GRAPH_CONDITIONAL) + GTEST_SKIP() << "CUDA graph conditionals not enabled"; +#endif + + se::StreamExecutor* executor = CudaExecutor(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: pred=true, a=42, b=0 + se::DeviceMemory pred = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + stream.ThenMemcpy(&pred, &kTrue, 1); + stream.ThenMemset32(&a, 42, byte_length); + stream.ThenMemZero(&b, byte_length); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_p(&alloc_p, 0, 1); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + + // Prepare commands sequence for `then` branch. + CommandBufferCmdSequence then_commands; + then_commands.Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(slice_p, std::move(then_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + CommandBufferCmd::ExecutableSource source = { + /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize(executor, source)); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Prepare buffer allocation for updating command buffer: c=0 + se::DeviceMemory c = executor->AllocateArray(length, 0); + stream.ThenMemZero(&c, byte_length); + + // Update buffer allocation #2 to buffer `c`. + allocations = BufferAllocations({pred, a, c}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), c, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); +} + } // namespace xla::gpu From c6287d55a0b4d110dbf1c645a68334a8b442e870 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Wed, 22 Nov 2023 17:47:45 -0800 Subject: [PATCH 028/381] Create a callback registry for tf.data v2 behavior functions to remove v2_compat.py's dependency on tf.data. PiperOrigin-RevId: 584758671 --- tensorflow/python/compat/BUILD | 7 +-- tensorflow/python/compat/v2_compat.py | 46 ++++++------------- tensorflow/python/data/experimental/ops/BUILD | 4 ++ .../python/data/experimental/ops/counter.py | 12 +++++ .../data/experimental/ops/interleave_ops.py | 14 ++++++ .../data/experimental/ops/random_ops.py | 12 +++++ .../python/data/experimental/ops/readers.py | 18 ++++++++ tensorflow/python/data/ops/BUILD | 2 + tensorflow/python/data/ops/dataset_ops.py | 12 +++++ tensorflow/python/data/ops/readers.py | 16 +++++++ tensorflow/python/ops/BUILD | 1 + tensorflow/python/ops/gradients_impl.py | 1 + 12 files changed, 106 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 68bab012e8bf28..8765961c533f7c 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -13,14 +13,9 @@ py_strict_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:tf2", - "//tensorflow/python/data/experimental/ops:counter", - "//tensorflow/python/data/experimental/ops:interleave_ops", - "//tensorflow/python/data/experimental/ops:random_ops", - "//tensorflow/python/data/experimental/ops:readers", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:registry", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:control_flow_v2_toggles", diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index cef625b1dc355a..5820e477eb2e5f 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -15,19 +15,13 @@ """Switching v2 features on and off.""" from tensorflow.python import tf2 -from tensorflow.python.data.experimental.ops import counter -from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.experimental.ops import random_ops -from tensorflow.python.data.experimental.ops import readers as exp_readers -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers from tensorflow.python.eager import monitoring from tensorflow.python.framework import ops +from tensorflow.python.framework import registry from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import resource_variables_toggle - from tensorflow.python.util.tf_export import tf_export # Metrics to track the status of v2_behavior @@ -35,6 +29,12 @@ "/tensorflow/version/v2_behavior", "whether v2_behavior is enabled or disabled", "status") +_DATA_V2_CALLBACKS = registry.Registry("data_v2_callbacks") + + +def register_data_v2_callback(data_v2_func): + _DATA_V2_CALLBACKS.register(data_v2_func, data_v2_func.__module__) + @tf_export(v1=["enable_v2_behavior"]) def enable_v2_behavior(): @@ -65,19 +65,9 @@ def enable_v2_behavior(): # Enables TensorArrayV2 and control flow V2. control_flow_v2_toggles.enable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V2 versions. - dataset_ops.Dataset = dataset_ops.DatasetV2 - readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV2 - readers.TFRecordDataset = readers.TFRecordDatasetV2 - readers.TextLineDataset = readers.TextLineDatasetV2 - counter.Counter = counter.CounterV2 - interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v2 - interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v2 - random_ops.RandomDataset = random_ops.RandomDatasetV2 - exp_readers.CsvDataset = exp_readers.CsvDatasetV2 - exp_readers.SqlDataset = exp_readers.SqlDatasetV2 - exp_readers.make_batched_features_dataset = ( - exp_readers.make_batched_features_dataset_v2) - exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v2 + for v2_enabler_name in _DATA_V2_CALLBACKS.list(): + v2_enabler = _DATA_V2_CALLBACKS.lookup(v2_enabler_name) + v2_enabler() @tf_export(v1=["disable_v2_behavior"]) @@ -110,16 +100,6 @@ def disable_v2_behavior(): # Disables TensorArrayV2 and control flow V2. control_flow_v2_toggles.disable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V1 versions. - dataset_ops.Dataset = dataset_ops.DatasetV1 - readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV1 - readers.TFRecordDataset = readers.TFRecordDatasetV1 - readers.TextLineDataset = readers.TextLineDatasetV1 - counter.Counter = counter.CounterV1 - interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v1 - interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v1 - random_ops.RandomDataset = random_ops.RandomDatasetV1 - exp_readers.CsvDataset = exp_readers.CsvDatasetV1 - exp_readers.SqlDataset = exp_readers.SqlDatasetV1 - exp_readers.make_batched_features_dataset = ( - exp_readers.make_batched_features_dataset_v1) - exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v1 + for v2_disabler_name in _DATA_V2_CALLBACKS.list(): + v2_disabler = _DATA_V2_CALLBACKS.lookup(v2_disabler_name) + v2_disabler() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index c3e153a4e775e6..cc604a2afebe22 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -56,6 +56,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/util:deprecation", @@ -194,6 +195,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/util:deprecation", @@ -349,6 +351,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -365,6 +368,7 @@ py_strict_library( ":error_ops", ":parsing_ops", "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", "//tensorflow/python/data/ops:readers", diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py index 2a8eaaae76afaa..e9dc2b49a0ea0d 100644 --- a/tensorflow/python/data/experimental/ops/counter.py +++ b/tensorflow/python/data/experimental/ops/counter.py @@ -14,6 +14,7 @@ # ============================================================================== """The Counter Dataset.""" from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.util import deprecation @@ -70,3 +71,14 @@ def CounterV1(start=0, step=1, dtype=dtypes.int64): Counter = CounterV2 else: Counter = CounterV1 + + +def _tf2_callback(): # pylint: disable=invalid-name + global Counter + if tf2.enabled(): + Counter = CounterV2 + else: + Counter = CounterV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py index 4cf61f9d5c7f9b..7f1d97d6a0e90e 100644 --- a/tensorflow/python/data/experimental/ops/interleave_ops.py +++ b/tensorflow/python/data/experimental/ops/interleave_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Non-deterministic dataset transformations.""" from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.util import deprecation @@ -245,3 +246,16 @@ def choose_from_datasets_v1(datasets, else: choose_from_datasets = choose_from_datasets_v1 sample_from_datasets = sample_from_datasets_v1 + + +def _tf2_callback(): + global choose_from_datasets, sample_from_datasets + if tf2.enabled(): + choose_from_datasets = choose_from_datasets_v2 + sample_from_datasets = sample_from_datasets_v2 + else: + choose_from_datasets = choose_from_datasets_v1 + sample_from_datasets = sample_from_datasets_v1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py index 8e951ea962c3d9..a88f14a8063b42 100644 --- a/tensorflow/python/data/experimental/ops/random_ops.py +++ b/tensorflow/python/data/experimental/ops/random_ops.py @@ -16,6 +16,7 @@ import functools from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import random_op from tensorflow.python.util import deprecation @@ -44,3 +45,14 @@ def __init__(self, seed=None): RandomDataset = RandomDatasetV2 else: RandomDataset = RandomDatasetV1 + + +def _tf2_callback(): + global RandomDataset + if tf2.enabled(): + RandomDataset = RandomDatasetV2 + else: + RandomDataset = RandomDatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py index 1ae47f4c9c8e70..75a4a9c39ffa50 100644 --- a/tensorflow/python/data/experimental/ops/readers.py +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import parsing_ops from tensorflow.python.data.ops import dataset_ops @@ -1220,3 +1221,20 @@ def __init__(self, driver_name, data_source_name, query, output_types): SqlDataset = SqlDatasetV1 make_batched_features_dataset = make_batched_features_dataset_v1 make_csv_dataset = make_csv_dataset_v1 + + +def _tf2_callback(): + global CsvDataset, SqlDataset, make_batched_features_dataset, make_csv_dataset + if tf2.enabled(): + CsvDataset = CsvDatasetV2 + SqlDataset = SqlDatasetV2 + make_batched_features_dataset = make_batched_features_dataset_v2 + make_csv_dataset = make_csv_dataset_v2 + else: + CsvDataset = CsvDatasetV1 + SqlDataset = SqlDatasetV1 + make_batched_features_dataset = make_batched_features_dataset_v1 + make_csv_dataset = make_csv_dataset_v1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index b7706ef00699c5..6b2b93bc412504 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -98,6 +98,7 @@ py_strict_library( "//tensorflow/python/autograph/operators:py_builtins", "//tensorflow/python/checkpoint", "//tensorflow/python/checkpoint:checkpoint_management", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/experimental/ops:take_while_ops", "//tensorflow/python/data/experimental/service:_pywrap_snapshot_utils", "//tensorflow/python/data/util:convert", @@ -268,6 +269,7 @@ py_strict_library( ":dataset_ops", ":structured_function", "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/util:convert", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 08ea8693d1cbbc..358b316ea0bd3b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -28,6 +28,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import struct_pb2 from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_autograph from tensorflow.python.data.ops import debug_mode from tensorflow.python.data.ops import iterator_ops @@ -4212,6 +4213,17 @@ def with_options(self, options, name=None) -> "DatasetV1Adapter": Dataset = DatasetV1 +def _tf2_callback(): + global Dataset + if tf2.enabled(): + Dataset = DatasetV2 + else: + Dataset = DatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) + + class DatasetV1Adapter(DatasetV1): """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 347b7a5c272973..566abb7b66eceb 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -16,6 +16,7 @@ import os from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import from_tensor_slices_op from tensorflow.python.data.ops import structured_function @@ -705,3 +706,18 @@ def _filenames(self, value): FixedLengthRecordDataset = FixedLengthRecordDatasetV1 TFRecordDataset = TFRecordDatasetV1 TextLineDataset = TextLineDatasetV1 + + +def _tf2_callback(): + global FixedLengthRecordDataset, TFRecordDataset, TextLineDataset + if tf2.enabled(): + FixedLengthRecordDataset = FixedLengthRecordDatasetV2 + TFRecordDataset = TFRecordDatasetV2 + TextLineDataset = TextLineDatasetV2 + else: + FixedLengthRecordDataset = FixedLengthRecordDatasetV1 + TFRecordDataset = TFRecordDatasetV1 + TextLineDataset = TextLineDatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index f380992cc583b9..04b18ba19277ae 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1632,6 +1632,7 @@ py_strict_library( ":cudnn_rnn_grad", ":gradients_util", ":image_grad", + ":io_ops", ":linalg_grad", ":linalg_ops", ":logging_ops", diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index ae88a6d6306831..a45b9965078898 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import cudnn_rnn_grad # pylint: disable=unused-import from tensorflow.python.ops import gradients_util from tensorflow.python.ops import image_grad # pylint: disable=unused-import +from tensorflow.python.ops import io_ops # pylint: disable=unused-import from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import from tensorflow.python.ops import logging_ops # pylint: disable=unused-import From b52f5bb4b297314bdb2bc0fdad873e33298b9a28 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2023 17:50:37 -0800 Subject: [PATCH 029/381] Add `WriteToString` API for saving `SavedModel` protos. - Add it APIs `third_party/tensorflow/cc/saved_model/image_format/internal_api.h`, which serialize a SavedModel to `std::string`. - Add `WriteToString` method to `ComposableSplitter`. No functionality or behavior changes to the `Write` method. - Add test cases for saving to `std::string`. PiperOrigin-RevId: 584759186 --- tensorflow/cc/saved_model/image_format/BUILD | 2 + .../saved_model/image_format/internal_api.cc | 24 ++- .../saved_model/image_format/internal_api.h | 20 ++- tensorflow/tools/proto_splitter/cc/BUILD | 7 + .../cc/composable_splitter_base.cc | 165 +++++++++++++----- .../cc/composable_splitter_base.h | 9 + .../cc/composable_splitter_test.cc | 77 ++++++-- 7 files changed, 239 insertions(+), 65 deletions(-) diff --git a/tensorflow/cc/saved_model/image_format/BUILD b/tensorflow/cc/saved_model/image_format/BUILD index 10a35871a708be..7fd743cf9c8356 100644 --- a/tensorflow/cc/saved_model/image_format/BUILD +++ b/tensorflow/cc/saved_model/image_format/BUILD @@ -32,7 +32,9 @@ cc_library( "//tensorflow/tools/proto_splitter/cc:max_size", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ] + if_not_windows_or_mac([ "//tensorflow/tools/proto_splitter:merge", "//tensorflow/tools/proto_splitter/cc:saved_model_splitter", diff --git a/tensorflow/cc/saved_model/image_format/internal_api.cc b/tensorflow/cc/saved_model/image_format/internal_api.cc index b959602ba445c9..f9eea13682765b 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.cc +++ b/tensorflow/cc/saved_model/image_format/internal_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/cc/saved_model/util.h" @@ -31,7 +32,7 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/saved_model_splitter.h" #include "tensorflow/tools/proto_splitter/merge.h" #endif - +#define IS_OSS false namespace tensorflow { namespace image_format { @@ -104,6 +105,27 @@ absl::Status WriteSavedModel(SavedModel* saved_model_proto, #endif } +absl::StatusOr WriteSavedModelToString( + SavedModel* saved_model_proto) { +#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) + tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); + return splitter.WriteToString(); +#else + return absl::UnimplementedError( + "WriteSavedModelToString not implemented for Windows or MacOS."); +#endif +} + +#if !IS_OSS +// TODO(b/311769337): Define the function unconditionally after tf oss +// dependency is updated to protobuf v22.x. +absl::StatusOr WriteSavedModelToCord( + SavedModel* saved_model_proto) { + tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); + return splitter.WriteToCord(); +} +#endif + absl::Status WriteSavedModel(SavedModel* saved_model_proto, const std::string& file_prefix, int debug_max_size) { diff --git a/tensorflow/cc/saved_model/image_format/internal_api.h b/tensorflow/cc/saved_model/image_format/internal_api.h index 465b00a74bfada..7a14b4d031972f 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.h +++ b/tensorflow/cc/saved_model/image_format/internal_api.h @@ -19,8 +19,12 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "tensorflow/core/protobuf/saved_model.pb.h" +#define IS_OSS false + namespace tensorflow { namespace image_format { @@ -29,13 +33,21 @@ namespace image_format { absl::Status ReadSavedModel(const std::string& file_prefix, SavedModel* saved_model_proto); -// Writes the SavedModel proto to {file_prefix}{.pb|.cpb}. -// If the proto is < the protobuf maximum size, then it will be serialized -// as a `.pb` proto binary. When larger than the maximum size, the SavedModel -// proto is destructively separated into chunks and written to +// Writes the SavedModel proto to a file or to string. If the proto is < the +// protobuf maximum size, then it will be serialized as a `.pb` proto binary. +// When larger than the maximum size, the SavedModel proto is destructively +// separated into chunks and written to // `.cpb` (chunked proto). +// +// Write SavedModel to {file_prefix}{.pb|.cpb}. absl::Status WriteSavedModel(SavedModel* saved_model_proto, const std::string& file_prefix); +// Writes the SavedModel proto to std::string +absl::StatusOr WriteSavedModelToString( + SavedModel* saved_model_proto); +#if !IS_OSS +absl::StatusOr WriteSavedModelToCord(SavedModel* saved_model_proto); +#endif // See above. The `debug_max_size` argument can be used to the maximum size to // less than 2GB for testing purposes. diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index 716b7fa317fafe..e7c60cb3050b6b 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -48,8 +48,11 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:protobuf", + "@riegeli//riegeli/bytes:cord_writer", "@riegeli//riegeli/bytes:fd_writer", + "@riegeli//riegeli/bytes:string_writer", "@riegeli//riegeli/records:record_writer", ] + if_oss([ "//tensorflow/tools/proto_splitter:protos_impl", @@ -87,11 +90,15 @@ tf_cc_test( "//tensorflow/tools/proto_splitter:chunk_proto_cc", "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", + "@riegeli//riegeli/bytes:cord_reader", "@riegeli//riegeli/bytes:fd_reader", + "@riegeli//riegeli/bytes:string_reader", "@riegeli//riegeli/records:record_reader", ] + if_oss([ "//tensorflow/tools/proto_splitter:protos_impl", diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index 4fdaa2777d9f8f..76f44d9bb8ed21 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -1,5 +1,7 @@ #include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" +#include + /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +16,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -25,16 +29,25 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "riegeli/bytes/cord_writer.h" // from @riegeli #include "riegeli/bytes/fd_writer.h" // from @riegeli +#include "riegeli/bytes/string_writer.h" // from @riegeli #include "riegeli/records/record_writer.h" // from @riegeli #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" +#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#define IS_OSS true + namespace tensorflow { namespace tools::proto_splitter { @@ -87,27 +100,67 @@ ComposableSplitterBase::Split() { return std::make_pair(&chunks_, &chunked_message_); } -absl::Status ComposableSplitterBase::Write(std::string file_prefix) { +template +static absl::Status WriteToRecordWriter( + riegeli::RecordWriter& writer, const std::vector& chunks, + ChunkedMessage& chunked_message, + const ::proto_splitter::VersionDef& version) { + // Export Riegeli / chunked file. + ChunkMetadata metadata; + *metadata.mutable_message() = chunked_message; + *metadata.mutable_version() = version; + auto* metadata_chunks = metadata.mutable_chunks(); + + for (const auto& chunk : chunks) { + auto* chunk_metadata = metadata_chunks->Add(); + if (std::holds_alternative>( + chunk)) { + const auto& msg_chunk = + std::get>(chunk); + LOG(INFO) << "Writing chunk of size " << msg_chunk->ByteSizeLong(); + writer.WriteRecord(*msg_chunk); + chunk_metadata->set_size(msg_chunk->ByteSizeLong()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + } else if (std::holds_alternative(chunk)) { + auto* msg_chunk = std::get(chunk); + writer.WriteRecord(*msg_chunk); + chunk_metadata->set_size(msg_chunk->ByteSizeLong()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + } else { + const auto& str_chunk = std::get(chunk); + writer.WriteRecord(str_chunk); + chunk_metadata->set_size(str_chunk.size()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::BYTES); + } + chunk_metadata->set_offset(writer.LastPos().get().numeric()); + } + writer.WriteRecord(metadata); + return absl::OkStatus(); +} + +absl::Status ComposableSplitterBase::CheckIfWriteImplemented() { if (parent_splitter_ != nullptr) { return absl::UnimplementedError( "The `Write` function behavior for children ComposableSplitter has not " - "been defined. Please call the parent ComposableSplitter's `Write` " - "instead."); - } - auto split_status = Split(); - if (!split_status.ok()) { - return split_status.status(); + "been defined. Please call `parent_splitter.Write()` instead."); } + return absl::OkStatus(); +} - auto chunks = split_status.value().first; - auto chunked_message = split_status.value().second; +absl::Status ComposableSplitterBase::Write(std::string file_prefix) { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; tsl::Env* env = tsl::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir( std::string{tensorflow::io::Dirname(file_prefix)})); std::string output_path; - if (chunked_message->chunked_fields().empty()) { + if (chunked_message.chunked_fields().empty()) { // Export regular pb. output_path = absl::StrCat(file_prefix, ".pb"); TF_RETURN_IF_ERROR( @@ -115,44 +168,72 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { } else { // Export Riegeli / chunked file. output_path = absl::StrCat(file_prefix, ".cpb"); - riegeli::RecordWriter writer((riegeli::FdWriter(output_path))); - - ChunkMetadata metadata; - metadata.mutable_message()->MergeFrom(*chunked_message); - metadata.mutable_version()->MergeFrom(Version()); - auto metadata_chunks = metadata.mutable_chunks(); - - for (auto chunk : *chunks) { - auto chunk_metadata = metadata_chunks->Add(); - if (std::holds_alternative>( - chunk)) { - auto msg_chunk = - std::get>(chunk); - LOG(INFO) << "Writing chunk of size " << msg_chunk->ByteSizeLong(); - writer.WriteRecord(*msg_chunk); - chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); - } else if (std::holds_alternative(chunk)) { - auto msg_chunk = std::get(chunk); - writer.WriteRecord(*msg_chunk); - chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); - } else { - auto str_chunk = std::get(chunk); - writer.WriteRecord(str_chunk); - chunk_metadata->set_size(str_chunk.size()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::BYTES); - } - chunk_metadata->set_offset(writer.LastPos().get().numeric()); - } - - writer.WriteRecord(metadata); + using WriterType = riegeli::FdWriter<>; + riegeli::RecordWriter writer((WriterType(output_path))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); if (!writer.Close()) return writer.status(); } LOG(INFO) << "Splitter output written to " << output_path; return absl::OkStatus(); } +absl::StatusOr ComposableSplitterBase::WriteToString() { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; + + std::string output; + if (chunked_message.chunked_fields().empty()) { + // Export regular pb. + if (!message_->SerializeToString(&output)) + return absl::InvalidArgumentError("Serialization to string failed"); + } else { + // Export Riegeli / chunked file. + using WriterType = riegeli::StringWriter<>; + riegeli::RecordWriter writer((WriterType(&output))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); + if (!writer.Close()) return writer.status(); + } + LOG(INFO) << "Splitter output written to string"; + return output; +} + +#if !IS_OSS +absl::StatusOr ComposableSplitterBase::WriteToCord() { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; + + absl::Cord output; + if (chunked_message.chunked_fields().empty()) { + // Export regular pb. + if (!message_->SerializeToCord(&output)) + return absl::InvalidArgumentError("Serialization to absl::Cord failed"); + } else { + // Export Riegeli / chunked file. + using WriterType = riegeli::CordWriter<>; + riegeli::RecordWriter writer((WriterType(&output))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); + if (!writer.Close()) return writer.status(); + } + LOG(INFO) << "Splitter output written to absl::Cord"; + + return output; +} +#endif + absl::Status ComposableSplitterBase::SetMessageAsBaseChunk() { if (!chunks_.empty()) { return absl::FailedPreconditionError( diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h index 478638b43fb989..55ced8eb992ef3 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_COMPOSABLE_SPLITTER_BASE_H_ #define TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_COMPOSABLE_SPLITTER_BASE_H_ +#include #include #include #include @@ -22,11 +23,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tsl/platform/protobuf.h" +#define IS_OSS true + namespace tensorflow { namespace tools::proto_splitter { @@ -62,6 +66,10 @@ class ComposableSplitterBase : public Splitter { // attach a `.pb` or `.cpb` (chunked pb) suffix depending on whether the // proto is split. absl::Status Write(std::string file_prefix) override; + absl::StatusOr WriteToString(); +#if !IS_OSS + absl::StatusOr WriteToCord(); +#endif VersionDef Version() override; @@ -93,6 +101,7 @@ class ComposableSplitterBase : public Splitter { // the chunks were always added to the end of the list. However, this is not // always the case the indices must be updated. absl::Status FixChunks(); + absl::Status CheckIfWriteImplemented(); bool built_; tsl::protobuf::Message* message_; diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc index 85eeab4f5a2dad..39930cc04a7ff0 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc @@ -16,12 +16,18 @@ limitations under the License. #include #include +#include +#include +#include #include #include #include #include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "riegeli/bytes/cord_reader.h" // from @riegeli #include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/bytes/string_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/env.h" @@ -33,10 +39,11 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#define IS_OSS true + namespace tensorflow { namespace tools::proto_splitter { namespace { @@ -120,23 +127,9 @@ TEST(RepeatedStringSplitterTest, TestSplitChunks) { EXPECT_EQ(chunked_message2, chunked_message); } -TEST(RepeatedStringSplitterTest, TestWrite) { - std::vector strings = {"piece-1", "piece-2", "piece-3"}; - auto message = SetUpRepeatedString(strings); - RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); - - std::string output_prefix = tensorflow::io::GetTempFilename(""); - TF_ASSERT_OK(splitter.Write(output_prefix)); - std::string expected_file = absl::StrCat(output_prefix, ".cpb"); - - TF_ASSERT_OK_AND_ASSIGN(auto exists, - internal::FileExists(Env::Default(), expected_file)); - EXPECT_TRUE(exists); - - // Look for the last chunk, which should contain a ChunkMetadata proto. - riegeli::RecordReader> reader( - (riegeli::FdReader(expected_file))); - +template +static void CheckChunks(riegeli::RecordReader& reader, + std::vector& strings) { ChunkMetadata chunk_metadata; reader.Seek(reader.Size().value()); reader.SeekBack(); @@ -169,6 +162,54 @@ TEST(RepeatedStringSplitterTest, TestWrite) { })pb")); } +TEST(RepeatedStringSplitterTest, TestWrite) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + + std::string output_prefix = tensorflow::io::GetTempFilename(""); + TF_ASSERT_OK(splitter.Write(output_prefix)); + std::string expected_file = absl::StrCat(output_prefix, ".cpb"); + + TF_ASSERT_OK_AND_ASSIGN(auto exists, + internal::FileExists(Env::Default(), expected_file)); + EXPECT_TRUE(exists); + + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> file_reader( + (riegeli::FdReader(expected_file))); + + CheckChunks(file_reader, strings); +} + +TEST(RepeatedStringSplitterTest, TestWriteToString) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + + TF_ASSERT_OK_AND_ASSIGN(std::string string_output, splitter.WriteToString()); + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> string_reader( + std::forward_as_tuple(string_output)); + + CheckChunks(string_reader, strings); +} + +#if !IS_OSS +TEST(RepeatedStringSplitterTest, TestWriteToCord) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + + TF_ASSERT_OK_AND_ASSIGN(absl::Cord cord_output, splitter.WriteToCord()); + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> cord_reader( + std::forward_as_tuple(&cord_output)); + + CheckChunks(cord_reader, strings); +} +#endif + TEST(RepeatedStringSplitterTest, TestNoSplit) { RepeatedString message; // No strings RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); From 8936d190c386a54c39af55e5916344f40a685604 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 22 Nov 2023 18:27:39 -0800 Subject: [PATCH 030/381] Open source most of dtensor tests. Also open sourced the d_random random variable creation helpers. Although no public API was added for them. PiperOrigin-RevId: 584764935 --- tensorflow/dtensor/python/BUILD | 21 +- tensorflow/dtensor/python/d_random.py | 331 +++++++ tensorflow/dtensor/python/tests/BUILD | 229 +++++ tensorflow/dtensor/python/tests/api_test.py | 305 ++++++ .../python/tests/batchparallel_spmd_test.py | 660 +++++++++++++ tensorflow/dtensor/python/tests/conv_test.py | 350 +++++++ tensorflow/dtensor/python/tests/mnist_test.py | 197 ++++ .../tests/multi_client_input_util_test.py | 548 +++++++++++ .../python/tests/multi_client_test_util.py | 13 +- .../dtensor/python/tests/numerics_test.py | 125 +++ .../dtensor/python/tests/sparse_test.py | 141 +++ .../tests/tpu_device_assignment_test.py | 889 ++++++++++++++++++ tensorflow/python/ops/BUILD | 2 + 13 files changed, 3805 insertions(+), 6 deletions(-) create mode 100644 tensorflow/dtensor/python/d_random.py create mode 100644 tensorflow/dtensor/python/tests/api_test.py create mode 100644 tensorflow/dtensor/python/tests/batchparallel_spmd_test.py create mode 100644 tensorflow/dtensor/python/tests/conv_test.py create mode 100644 tensorflow/dtensor/python/tests/mnist_test.py create mode 100644 tensorflow/dtensor/python/tests/multi_client_input_util_test.py create mode 100644 tensorflow/dtensor/python/tests/numerics_test.py create mode 100644 tensorflow/dtensor/python/tests/sparse_test.py create mode 100644 tensorflow/dtensor/python/tests/tpu_device_assignment_test.py diff --git a/tensorflow/dtensor/python/BUILD b/tensorflow/dtensor/python/BUILD index 4090792e1e97a8..bf0c9564a97bb8 100644 --- a/tensorflow/dtensor/python/BUILD +++ b/tensorflow/dtensor/python/BUILD @@ -1,7 +1,7 @@ # DTensor Python API and libraries. -load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") default_visibility = [ @@ -100,6 +100,25 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "d_random", + srcs = ["d_random.py"], + srcs_version = "PY3", + deps = [ + ":api", + ":layout", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:shape_util", + "//tensorflow/python/ops:stateless_random_ops_gen", + ], +) + pytype_strict_library( name = "d_variable", srcs = ["d_variable.py"], diff --git a/tensorflow/dtensor/python/d_random.py b/tensorflow/dtensor/python/d_random.py new file mode 100644 index 00000000000000..1697f3598151c7 --- /dev/null +++ b/tensorflow/dtensor/python/d_random.py @@ -0,0 +1,331 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""DTensor helpers for random generators.""" + +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_stateless_random_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import shape_util + +# ------------------------------------------------------------------------------ +# stateless rngs +# ------------------------------------------------------------------------------ + + +# TODO(b/171746536): switch all rng ops to official versions once supported. +def _old_tf_random_stateless_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless normal implementation that takes an layout.""" + with ops.name_scope( + name, "stateless_random_normal", [shape, seed, mean, stddev] + ) as name: + seed = ops.convert_to_tensor(seed, dtype=dtypes.int32, name="seed") + shape = shape_util.shape_tensor(shape) + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_random_normal, + layout, + shape, + seed, + dtype, + ) + result = math_ops.add(rnd * stddev, mean, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def _old_tf_random_stateless_uniform( + shape, + seed, + minval=0, + maxval=None, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless uniform implementation that takes an layout.""" + dtype = dtypes.as_dtype(dtype) + accepted_dtypes = ( + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.int64, + dtypes.uint32, + dtypes.uint64, + ) + if dtype not in accepted_dtypes: + raise ValueError( + f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are " + f"{accepted_dtypes}." + ) + if dtype.is_integer: + if (minval is None) != (maxval is None): + raise ValueError( + f"For integer `dtype` argument {dtype}, argument `minval` and " + f"`maxval` must be both None or not None. Got `minval`={minval} and " + f"`maxval`={maxval}." + ) + if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): + raise ValueError( + f"Argument `dtype` got invalid value {dtype} when argument `minval` " + "is not None. Please don't use unsigned integers in this case." + ) + + shape = shape_util.shape_tensor(shape) + with ops.name_scope( + name, "stateless_random_uniform", [shape, seed, minval, maxval] + ) as name: + seed = ops.convert_to_tensor(seed, dtype_hint=dtypes.int32, name="seed") + + if dtype.is_integer and minval is None and maxval is None: + result = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform_full_int, + layout, + shape, + seed=seed, + dtype=dtype, + name=name, + ) + else: + if not dtype.is_integer and maxval is None: + maxval = 1 + val_range = ops.convert_to_tensor( + maxval - minval, dtype=dtype, name="range" + ) + minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") + if dtype.is_integer: + result = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform_int, + layout, + shape, + seed=seed, + minval=minval, + maxval=maxval, + ) + else: + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform, + layout, + shape, + seed=seed, + dtype=dtype, + ) + result = math_ops.add(rnd * val_range, minval, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def _old_tf_stateless_truncated_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless truncated normal implementation that takes an layout.""" + with ops.name_scope( + name, "stateless_truncated_normal", [shape, seed, mean, stddev] + ) as name: + seed = ops.convert_to_tensor(seed, dtype=dtypes.int32, name="seed") + shape = shape_util.shape_tensor(shape) + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_truncated_normal, + layout, + shape, + seed, + dtype, + ) + result = math_ops.add(rnd * stddev, mean, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def stateless_random_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless RNG.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_random_stateless_normal( + shape, + seed=seed, + mean=mean, + stddev=stddev, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_random_uniform( + shape, + seed, + minval=0, + maxval=None, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless random uniform.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_random_stateless_uniform( + shape, + seed=seed, + minval=minval, + maxval=maxval, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_truncated_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless RNG.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_stateless_truncated_normal( + shape, + seed=seed, + mean=mean, + stddev=stddev, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_split(seed, num=2, mesh=None): + seed = ops.convert_to_tensor(seed) + layout = None + if mesh: + layout = layout_lib.Layout.replicated(mesh, rank=2) + return stateless_random_uniform( + shape=[num, 2], + seed=seed, + dtype=seed.dtype, + minval=None, + maxval=None, + layout=layout, + ) + + +# ------------------------------------------------------------------------------ +# stateless dropout. +# ------------------------------------------------------------------------------ + + +def _get_noise_shape(x, noise_shape): + """Noisve shape util copied from tf nn_ops.""" + # If noise_shape is none return immediately. + if noise_shape is None: + return array_ops.shape(x) + + try: + # Best effort to figure out the intended shape. + # If not possible, let the op to handle it. + # In eager mode exception will show up. + noise_shape_ = tensor_shape.as_shape(noise_shape) + except (TypeError, ValueError): + return noise_shape + + if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims): + new_dims = [] + for i, dim in enumerate(x.shape.dims): + if noise_shape_.dims[i].value is None and dim.value is not None: + new_dims.append(dim.value) + else: + new_dims.append(noise_shape_.dims[i].value) + return tensor_shape.TensorShape(new_dims) + + return noise_shape + + +# TODO(b/171213877, b/169909066): Fix layout prop in function case for the rng +# Op used. The layout prop should be able to propagate the layout from input +# tensor `x` to the tf.mul and then back propagate the layout to the +# `random_tensor`. +def dropout(x, rate, noise_shape=None, seed=None, name=None): + """DTensor replacement for dropout.""" + if not isinstance(rate, float): + raise ValueError("rate should be float for dropout.") + if seed is None: + raise ValueError("seed must be specified for DTensor dropout. Got: None") + + with ops.name_scope(name, "dropout", [x]): + x_dtype = x.dtype + keep_prob = 1 - rate + scale = 1 / keep_prob + scale = ops.convert_to_tensor(scale, dtype=x_dtype) + ret = gen_math_ops.mul(x, scale) + + noise_shape = _get_noise_shape(x, noise_shape) + # stateless_random_uniform requires a shape [2] seed. + seed = [seed, 0] + + if context.executing_eagerly(): + layout = api.fetch_layout(x) + else: + layout = None + random_tensor = _old_tf_random_stateless_uniform( + noise_shape, seed=seed, minval=0, maxval=1, dtype=x_dtype, layout=layout + ) + keep_mask = random_tensor >= rate + ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype)) + if not context.executing_eagerly(): + ret.set_shape(x.get_shape()) + return ret + + +# TODO(b/195413777): error out for stateful dropout. diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index b26e2260f3cb3a..0a7d9f90345e30 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -70,6 +70,32 @@ pytype_strict_library( ], ) +py_strict_test( + name = "api_test", + srcs = [ + "api_test.py", + ], + python_version = "PY3", + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_random", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + # TODO(b/301286466): Investigate why python annotation type mismatch is not catptured by the type # strict BUILD rules. @@ -89,6 +115,40 @@ dtensor_test( ], ) +dtensor_test( + name = "batchparallel_spmd_test", + srcs = ["batchparallel_spmd_test.py"], + additional_backends = [TPU_V4_DONUT_BACKEND], + main = "batchparallel_spmd_test.py", + shard_count = { + "cpu": 4, + "gpu": 4, + "tpu": 4, + TPU_V4_DONUT_BACKEND: 8, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/ops:image_ops_gen", + "//tensorflow/python/ops:linalg_ops_gen", + "//tensorflow/python/ops:nn_impl", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "cache_test", srcs = ["cache_test.py"], @@ -187,6 +247,35 @@ dtensor_test( ], ) +dtensor_test( + name = "conv_test", + srcs = [ + "conv_test.py", + ], + additional_backends = [TPU_V3_DONUT_BACKEND], + # All tests require 8 TPUs. + disable = ["tpu"], + shard_count = { + "cpu": 4, + "gpu": 4, + TPU_V3_DONUT_BACKEND: 4, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:special_math_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "device_test", srcs = ["device_test.py"], @@ -249,6 +338,52 @@ py_strict_test( ], ) +py_strict_test( + name = "multi_client_input_util_test", + timeout = "long", + srcs = ["multi_client_input_util_test.py"], + env = { + "TF2_BEHAVIOR": "1", + }, + shard_count = 8, + tags = [ + # ThreadSanitizer does not support starting new threads after multi-threaded fork. + "notsan", + "no_oss", # Fails on OSS. + "nosan", # b/195537906 + ], + deps = [ + ":multi_client_test_util", + ":test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/dtensor/python:accelerator_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:config", + "//tensorflow/dtensor/python:input_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:mesh_util", + "//tensorflow/python/data/experimental/service:server_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:device_spec", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/lib/io:tf_record", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:check_ops", + "//tensorflow/python/ops:io_ops", + "//tensorflow/python/ops:parsing_config", + "//tensorflow/python/ops:parsing_ops", + "//tensorflow/python/ops:parsing_ops_gen", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/logging", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "layout_test", srcs = ["layout_test.py"], @@ -807,3 +942,97 @@ dtensor_test( "//third_party/py/numpy", ], ) + +dtensor_test( + name = "mnist_test", + size = "large", + srcs = ["mnist_test.py"], + shard_count = { + "tpu": 2, + }, + tags = ["nosan"], # Non-opt builds has slow XLA compilation. + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:input_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "numerics_test", + srcs = ["numerics_test.py"], + additional_backends = [TPU_V3_DONUT_BACKEND], + disable = ALL_BACKENDS, + enable = [ + "tpu", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:accelerator_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "sparse_test", + srcs = ["sparse_test.py"], + main = "sparse_test.py", + shard_count = { + "cpu": 4, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "tpu_device_assignment_test", + srcs = ["tpu_device_assignment_test.py"], + disable = ALL_BACKENDS, + enable = [ + "tpu", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/dtensor/python:tpu_util", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + ], +) diff --git a/tensorflow/dtensor/python/tests/api_test.py b/tensorflow/dtensor/python/tests/api_test.py new file mode 100644 index 00000000000000..7231086651439f --- /dev/null +++ b/tensorflow/dtensor/python/tests/api_test.py @@ -0,0 +1,305 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for the internal DTensor Python API.""" + +from absl.testing import parameterized +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_random +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' + + +class APITest(test_util.DTensorBaseTest): + + def setUp(self): + super(APITest, self).setUp() + global_ids = test_util.create_device_ids_array((2, 2)) + local_device_ids = np.ravel(global_ids).tolist() + mesh_dict = { + 'CPU': Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + global_ids, + local_device_ids, + test_util.create_device_list((2, 2), 'CPU'), + ) + } + self.mesh = self.configTestMesh(mesh_dict) + self.layouts_1d = [ + Layout.replicated(self.mesh, rank=1), + Layout.batch_sharded(self.mesh, _MESH_DIM_X, rank=1), + Layout.batch_sharded(self.mesh, _MESH_DIM_Y, rank=1), + ] + self.layouts_2d = [ + Layout.replicated(self.mesh, rank=2), + Layout.batch_sharded(self.mesh, _MESH_DIM_X, rank=2), + Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=2), + Layout([_MESH_DIM_X, _MESH_DIM_Y], self.mesh), + ] + + def testV2API(self): + layout = Layout.replicated(self.mesh, rank=1) + zero_tensor = array_ops.zeros([10], layout=layout) + zero_like_tensor = array_ops.zeros_like_v2(zero_tensor, layout=layout) + self.assertAllEqual(zero_like_tensor.numpy(), zero_tensor.numpy()) + + ones_tensor = array_ops.ones([10], layout=layout) + ones_like_tensor = array_ops.ones_like_v2(zero_tensor, layout=layout) + self.assertAllEqual(ones_like_tensor.numpy(), ones_tensor.numpy()) + + def testStatelessRandom(self): + # test dtype default float32 random + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([0, 0], dtype=dtypes.int64), + minval=0.0, + maxval=10.0, + ) + self.assertEqual([10], result.shape) + + # test dtype default int32 minval maxval are both None + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int64), + dtype=dtypes.int32, + minval=None, + maxval=None, + ) + self.assertEqual([10], result.shape) + + # test maxval is None or not given + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int64), + maxval=12, + dtype=dtypes.int32, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 12) + + def testStatelessRandomNormal(self): + # test dtype default float32 random + result = stateless_random_ops.stateless_random_normal( + [10], seed=constant_op.constant([0, 0], dtype=dtypes.int32) + ) + self.assertEqual([10], result.shape) + + # test dtype double + result = stateless_random_ops.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + dtype=dtypes.double, + ) + self.assertEqual([10], result.shape) + + # test mean and stddev + result = stateless_random_ops.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + mean=0, + stddev=0, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 0) + + # test dtensor version of each, check layouts + layout = Layout.replicated(self.mesh, rank=1) + + # test dtype default float 32 random + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([0, 0], dtype=dtypes.int32), + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertEqual(layout, api.fetch_layout(result)) + + # test dtype double + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + dtype=dtypes.double, + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertEqual(layout, api.fetch_layout(result)) + + # test mean and stddev + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + mean=0, + stddev=0, + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 0) + self.assertEqual(layout, api.fetch_layout(result)) + + @parameterized.named_parameters(*set( + test_util.product((('_labels_unsharded', 0), ('_labels_batch', 1), + ('_labels_inner', 2), ('_labels_both', 3)), + (('_logits_unsharded', 0), ('_logits_batch', 1), + ('_logits_inner', 2), ('_logits_both', 3))))) + def testSoftmaxCrossentropyWithLogits(self, labels_layout, logits_layout): + expected_layout = Layout.replicated(self.mesh, rank=1) + if (labels_layout == 1 or labels_layout == 3 or logits_layout == 1 or + logits_layout == 3): + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=1) + + labels_layout = self.layouts_2d[labels_layout] + logits_layout = self.layouts_2d[logits_layout] + labels_numpy = np.random.uniform(size=[6, 4]) + logits_numpy = np.random.uniform(size=[6, 4]) + labels = constant_op.constant(labels_numpy, dtype=dtypes.float32) + logits = constant_op.constant(logits_numpy, dtype=dtypes.float32) + + # Should we test against the built in version or the patched version? + expected = nn_ops.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + + labels = numpy_util.pack_numpy(labels, labels_layout) + logits = numpy_util.pack_numpy(logits, logits_layout) + dtensor_result = nn_ops.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + self.assertDTensorEqual(expected, expected_layout, dtensor_result) + + @parameterized.named_parameters(*set( + test_util.product((('_labels_unsharded', 0), ('_labels_batch_x', 1), + ('_labels_batch_y', 2)), + (('_logits_unsharded', 0), ('_logits_batch', 1), + ('_logits_inner', 2), ('_logits_both', 3))))) + def testSparseSoftmaxCrossentropyWithLogits(self, labels_layout, + logits_layout): + expected_layout = Layout.replicated(self.mesh, rank=1) + if labels_layout == 1 or logits_layout == 1 or logits_layout == 3: + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=1) + elif labels_layout == 2: + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_Y, rank=1) + + labels_layout = self.layouts_1d[labels_layout] + logits_layout = self.layouts_2d[logits_layout] + labels_numpy = np.random.randint(size=[6], low=0, high=4) + logits_numpy = np.random.uniform(size=[6, 4]) + labels = constant_op.constant(labels_numpy, dtype=dtypes.int64) + logits = constant_op.constant(logits_numpy, dtype=dtypes.float32) + + # Should we test against the built in version or the patched version? + expected = nn_ops.sparse_softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + + labels = numpy_util.pack_numpy(labels, labels_layout) + logits = numpy_util.pack_numpy(logits, logits_layout) + dtensor_result = nn_ops.sparse_softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + self.assertDTensorEqual(expected, expected_layout, dtensor_result) + + def test_dropout_raises_on_none_seed(self): + with api.default_mesh(self.mesh): + with self.assertRaisesRegex(ValueError, 'seed must be specified'): + _ = d_random.dropout( + array_ops.ones([2, 2], dtype=dtypes.float32), rate=0.5, seed=None + ) + + def test_default_mesh(self): + + @polymorphic_function.function + def func(a): + return a + 3.0 + + with api.default_mesh(self.mesh): + a = array_ops.zeros(shape=()) + result = func(a) + + self.assertEqual(result, 3.0) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + # Also make sure it works as wrapper + @api.default_mesh(self.mesh) + def func2(): + b = array_ops.ones(shape=()) + return func(b) + + result = func2() + self.assertEqual(result, 4.0) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + with self.assertRaisesRegex(ValueError, 'Expect `mesh` to be `Mesh`'): + with api.default_mesh(None): + pass + + def test_default_mesh_with_constant(self): + + @polymorphic_function.function + def func(): + return constant_op.constant([3, 4]) + + with api.default_mesh(self.mesh): + result = func() + + self.assertAllEqual(result, [3, 4]) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + def test_error_no_default_mesh(self): + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'No default mesh has been registered to DTensor', + ): + with ops.device_v2(api.device_name()): + _ = constant_op.constant(3.0) + + def test_get_default_mesh(self): + self.assertIsNone(api.get_default_mesh()) + with api.default_mesh(self.mesh): + self.assertEqual(api.get_default_mesh(), self.mesh) + + with api.default_mesh(self.mesh.host_mesh()): + self.assertEqual(api.get_default_mesh(), self.mesh.host_mesh()) + + self.assertEqual(api.get_default_mesh(), self.mesh) + + self.assertIsNone(api.get_default_mesh()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py b/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py new file mode 100644 index 00000000000000..b6cbbd0459b8e6 --- /dev/null +++ b/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py @@ -0,0 +1,660 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for batchparallel_spmd.""" + +import itertools +from absl.testing import parameterized +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.dtensor.python.tests import test_util_ops +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +# pylint: enable=g-direct-tensorflow-import + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DTensorBatchParallelSPMDTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorBatchParallelSPMDTest, self).setUp() + + self.skipForDeviceType(['TPU'], + 'all tests require 8 TPU cores.', + unless_device_count_equals_to=8) + # Builds a 8x2 mesh. + self._mesh_dim_b = 'b' + self._mesh_dim_x = 'x' + self._dims = [self._mesh_dim_b, self._mesh_dim_x] + + global_ids = test_util.create_device_ids_array((4, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: Mesh( + self._dims, + global_ids, + local_ids, + test_util.create_device_list((4, 2), device), + ) + for device in ('CPU', 'GPU', 'TPU') + } + self.mesh = self.configTestMesh(mesh_dict) + context.ensure_initialized() + + # Creates a bunch of common layouts used by tests later. + # 4-d + self.replicated_layout_4d = Layout.replicated(self.mesh, rank=4) + self.batch_layout_4d = Layout.batch_sharded( + self.mesh, self._mesh_dim_b, rank=4) + + # 5-d + self.replicated_layout_5d = Layout.replicated(self.mesh, rank=5) + self.batch_layout_5d = Layout.batch_sharded( + self.mesh, self._mesh_dim_b, rank=5) + + @parameterized.named_parameters(('NoBatchDim', 0), ('SingleBatchDim', 1), + ('TwoBatchDim', 2)) + def testCholesky(self, num_batch_dim): + # Input needs to be symmetric and positive definite. + x = constant_op.constant( + [[1, 1, 1, 1], [1, 5, 5, 5], [1, 5, 14, 14], [1, 5, 14, 17]], + dtype=dtypes.float32, + ) + for _ in range(num_batch_dim): + x = array_ops.expand_dims_v2(x, 0) + s = [4] + [1 for _ in range(array_ops.rank(x) - 1)] + x = gen_array_ops.tile(x, s) + + expected_result = gen_linalg_ops.cholesky(x) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * 2, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = gen_linalg_ops.cholesky(input=x) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('NoBatchDim', 0), ('SingleBatchDim', 1), ('TwoBatchDim', 2)], + test_util_ops.FFT_OPS, + ) + ) + def testFFT(self, num_batch_dim, fft_op, num_nonbatch_dim): + shape = [4 for i in range(num_batch_dim + num_nonbatch_dim)] + np.random.seed(123) + x = constant_op.constant( + np.random.normal(0.0, 1.0, np.prod(shape)).reshape(shape), + dtype=dtypes.complex64, + ) + expected_result = fft_op(input=x) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * num_nonbatch_dim, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = fft_op(input=x) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('NoBatchDim', 0), ('SingleBatchDim', 1), ('TwoBatchDim', 2)], + test_util_ops.RFFT_OPS, + ) + ) + def testRFFT(self, num_batch_dim, rfft_op, num_nonbatch_dim, dtype): + self.skipForDeviceType(['GPU'], 'RFFT has numerical issues on GPU') + shape = [4 for i in range(num_batch_dim + num_nonbatch_dim)] + np.random.seed(123) + x = constant_op.constant( + np.random.normal(0.0, 1.0, np.prod(shape)).reshape(shape), dtype=dtype + ) + expected_result = rfft_op(input=x, fft_length=[2] * num_nonbatch_dim) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * num_nonbatch_dim, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = rfft_op(input=x, fft_length=[2] * num_nonbatch_dim) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('Replicated', 'replicated'), ('Sharded', 'batch')], + [ + ( + 'SamePadding', + 'SAME', + ), + ( + 'ValidPadding', + 'VALID', + ), + ], + test_util_ops.BATCH_PARALLEL_2D_WINDOW_OPS, + ) + ) + def test2DWindowOp(self, layout_spec, padding, op): + np.random.seed(123) + row_window_size = 3 + col_window_size = 4 + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_rows * num_cols * 3).reshape( + [8, num_rows, num_cols, 3]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + expected_result = op(inputs, window_size, stride_size, padding) + + if layout_spec == 'replicated': + layout = self.replicated_layout_4d + else: + layout = self.batch_layout_4d + + x = numpy_util.pack_numpy(inputs, layout) + got = op(x, window_size, stride_size, padding) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('Replicated', 'replicated'), ('BatchSharded', 'batch')], + [ + ( + 'SamePadding', + 'SAME', + ), + ( + 'ValidPadding', + 'VALID', + ), + ], + test_util_ops.BATCH_PARALLEL_3D_WINDOW_OPS, + ) + ) + def test3DWindowOp(self, layout_spec, padding, op): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + expected_result = op(inputs, window_size, stride_size, padding) + + if layout_spec == 'replicated': + layout = self.replicated_layout_5d + else: + layout = self.batch_layout_5d + + x = numpy_util.pack_numpy(inputs, layout) + + got = op(x, window_size, stride_size, padding) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(test_util_ops.PADDINGS) + def testDepthwiseConv2dNative(self, padding): + np.random.seed(123) + x_in = np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]) + + kernel_in = np.array([ + [[[2, 0.1]], [[3, 0.2]]], + [[[0, 0.3]], [[1, 0.4]]], + ]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + kernel = constant_op.constant(kernel_in, dtype=dtypes.float32) + expected_result = nn_impl.depthwise_conv2d_v2( + inputs, kernel, strides=[1, 1, 1, 1], padding=padding + ) + + layout = self.batch_layout_4d + + x = numpy_util.pack_numpy(inputs, layout) + kernel = numpy_util.pack_numpy(kernel, self.replicated_layout_4d) + got = nn_impl.depthwise_conv2d_v2( + x, kernel, strides=[1, 1, 1, 1], padding=padding + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeBilinear(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.resize_bilinear( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.resize_bilinear( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeNearestNeighbor(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.resize_nearest_neighbor( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.resize_nearest_neighbor( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testAdjustContrastv2(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9 * 3).reshape([8, 9, 9, 3]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.adjust_contrastv2( + images=images, contrast_factor=0.5 + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.adjust_contrastv2(images=images, contrast_factor=0.5) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testAdjustSaturation(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9 * 3).reshape([8, 9, 9, 3]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.adjust_saturation(images=images, scale=0.5) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.adjust_saturation(images=images, scale=0.5) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.parameters( + itertools.permutations(['sharded', 'replicated'], 2)) + def testResizeBilinearGradBatchSharded(self, spec1, spec2): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + specs = [spec1, spec2] + layouts = [ + self.batch_layout_4d if spec == 'sharded' else self.replicated_layout_4d + for spec in specs + ] + + # Test images is replicated, grads is batch sharded + images = numpy_util.pack_numpy(images, layouts[0]) + grads = numpy_util.pack_numpy(grads, layouts[1]) + + got = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + self.assertDTensorEqual(expected_result, self.batch_layout_4d, got) + + def testResizeBilinearGradReplicated(self): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + images = numpy_util.pack_numpy(images, self.replicated_layout_4d) + grads = numpy_util.pack_numpy(grads, self.replicated_layout_4d) + + got = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + self.assertDTensorEqual(expected_result, self.replicated_layout_4d, got) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPool3DGrad(self, shard_spec, padding): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as tape: + tape.watch([inputs]) + expected_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + expected_grad = tape.gradient(expected_result, [inputs]) + layout = ( + self.batch_layout_5d + if shard_spec == 'sharded' + else self.replicated_layout_5d + ) + + inputs = numpy_util.pack_numpy(inputs, layout) + + with ops.device_v2(api.device_name()): + with backprop.GradientTape() as tape: + tape.watch([inputs]) + dtensor_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + dtensor_grad = tape.gradient(dtensor_result, [inputs]) + + self.assertDTensorEqual(expected_grad[0], layout, dtensor_grad[0]) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPool3DGradGrad(self, shard_spec, padding): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + expected_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + expected_first_grad = inner_tape.gradient(expected_result, [inputs]) + expected_second_grad = outer_tape.gradient(expected_first_grad, [inputs]) + + if shard_spec == 'sharded': + layout = self.batch_layout_5d + else: + layout = self.replicated_layout_5d + + inputs = numpy_util.pack_numpy(inputs, layout) + + @polymorphic_function.function() + def compute_gradients(inputs): + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + dtensor_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + dtensor_first_grad = inner_tape.gradient(dtensor_result, [inputs]) + dtensor_second_grad = outer_tape.gradient(dtensor_first_grad[0], [inputs]) + return dtensor_first_grad, dtensor_second_grad + + dtensor_first_grad, dtensor_second_grad = compute_gradients(inputs) + + self.assertDTensorEqual(expected_first_grad[0], layout, + dtensor_first_grad[0]) + self.assertDTensorEqual(expected_second_grad[0], layout, + dtensor_second_grad[0]) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPoolGradGrad(self, shard_spec, padding): + np.random.seed(123) + row_window_size = 3 + col_window_size = 4 + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_rows * num_cols * 3).reshape( + [8, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + expected_result = nn_ops.max_pool_v2( + inputs, window_size, stride_size, padding + ) + expected_first_grad = inner_tape.gradient(expected_result, [inputs]) + expected_second_grad = outer_tape.gradient(expected_first_grad, [inputs]) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + inputs = numpy_util.pack_numpy(inputs, layout) + + @polymorphic_function.function() + def compute_gradients(inputs): + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + dtensor_result = nn_ops.max_pool_v2( + inputs, window_size, stride_size, padding + ) + dtensor_first_grad = inner_tape.gradient(dtensor_result, [inputs]) + dtensor_second_grad = outer_tape.gradient(dtensor_first_grad[0], [inputs]) + return dtensor_first_grad, dtensor_second_grad + + dtensor_first_grad, dtensor_second_grad = compute_gradients(inputs) + + self.assertDTensorEqual(expected_first_grad[0], layout, + dtensor_first_grad[0]) + self.assertDTensorEqual(expected_second_grad[0], layout, + dtensor_second_grad[0]) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeNearestNeighborGrad(self, shard_spec): + np.random.seed(123) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_nearest_neighbor_grad( + grads=grads, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + + grads = numpy_util.pack_numpy(grads, layout) + + got = gen_image_ops.resize_nearest_neighbor_grad( + grads=grads, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/conv_test.py b/tensorflow/dtensor/python/tests/conv_test.py new file mode 100644 index 00000000000000..25cab09e1096ac --- /dev/null +++ b/tensorflow/dtensor/python/tests/conv_test.py @@ -0,0 +1,350 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for executing ops needed to implement image model.""" + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager import backprop +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.platform import test + + +UNSHARDED = layout_lib.UNSHARDED +Mesh = layout_lib.Mesh +Layout = layout_lib.Layout + +BATCH_DIM = 'batch' +DEPTH_DIM = 'depth' +HEIGHT_DIM = 'height' +WIDTH_DIM = 'width' +BATCH_SIZE = 4 +DEPTH = 8 +HEIGHT = 12 +WIDTH = 12 +CHANNEL_IN = 1 +CHANNEL_OUT = 3 + + +class ConvOpTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + global_ids = test_util.create_device_ids_array((2, 2, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = {} + for device in ('CPU', 'GPU', 'TPU'): + mesh_dict[device] = Mesh( + [BATCH_DIM, HEIGHT_DIM, WIDTH_DIM], + global_ids, + local_ids, + test_util.create_device_list((2, 2, 2), device), + ) + + self.mesh = self.configTestMesh(mesh_dict) + + self.replicated_2d = Layout.replicated(self.mesh, 2) + self.batch_sharded_2d = Layout.batch_sharded(self.mesh, BATCH_DIM, 2) + + @parameterized.named_parameters( + test_util.product( + *[ + [ + ( + 'Conv2D', + nn_ops.conv2d_v2, + (BATCH_SIZE, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bhwc,xy->by', + [1, 2, 1, 1], + ), + ( + 'Conv3D', + nn_ops.conv3d_v2, + (BATCH_SIZE, DEPTH, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bdhwc,xy->by', + [1, 1, 2, 1, 1], + ), + ], + [ + ('Eager', True), + ('Graph', False), + ], + [ + ('ReplicatedInput', 'replicated'), + ('BatchShardedInput', 'batch_sharded'), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + ] + ) + ) + def testConvFollowedByEinsum(self, conv_op, input_size, kernel_size, + einsum_eq, strides, eager_mode, input_sharding, + padding): + x_in = constant_op.constant( + np.random.random(size=input_size), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_size), dtype=dtypes.float32 + ) + weight = constant_op.constant( + np.random.random(size=(2, 2)), dtype=dtypes.float32 + ) + + def conv_fn(inputs, img_kernel, layer_weights): + output = conv_op(inputs, img_kernel, strides=strides, padding=padding) + output = special_math_ops.einsum(einsum_eq, output, layer_weights) + return output + + if not eager_mode: + conv_fn = polymorphic_function.function(conv_fn) + + golden_result = conv_fn(x_in, kernel_in, weight) + + if input_sharding == 'replicated': + input_layout = Layout.replicated(self.mesh, len(input_size)) + output_layout = self.replicated_2d + elif input_sharding == 'batch_sharded': + input_layout = Layout.batch_sharded(self.mesh, BATCH_DIM, len(input_size)) + output_layout = self.batch_sharded_2d + + kernel_layout = Layout.replicated(self.mesh, len(kernel_size)) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_weight = numpy_util.pack_numpy(weight, self.replicated_2d) + d_result = conv_fn(d_x_in, d_kernel_in, d_weight) + + self.assertDTensorEqual(golden_result, output_layout, d_result) + + @parameterized.named_parameters( + test_util.product( + *[ + [ + ( + 'Conv2D', + nn_ops.conv2d_v2, + (BATCH_SIZE, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bhwc,xy->by', + [1, 1, 1, 1], + ), + ( + 'Conv3D', + nn_ops.conv3d_v2, + (BATCH_SIZE, DEPTH, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bdhwc,xy->by', + [1, 1, 1, 1, 1], + ), + ], + [ + ('ReplicatedInput', 'replicated'), + ('BatchShardedInput', 'batch_sharded'), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + ] + ) + ) + def testConvFollowedByEinsumWithGradient(self, conv_op, input_size, + kernel_size, einsum_eq, strides, + input_sharding, padding): + x_in = constant_op.constant( + np.random.random(size=input_size), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_size), dtype=dtypes.float32 + ) + weight = constant_op.constant( + np.random.random(size=(2, 2)), dtype=dtypes.float32 + ) + + @polymorphic_function.function + def conv_fn(inputs, img_kernel, layer_weights): + with backprop.GradientTape() as tape: + tape.watch([inputs, img_kernel, layer_weights]) + output = conv_op(inputs, img_kernel, strides=strides, padding=padding) + output = special_math_ops.einsum(einsum_eq, output, layer_weights) + + inputs_grad, kernel_grad, weight_grad = tape.gradient( + output, [inputs, img_kernel, layer_weights]) + return output, inputs_grad, kernel_grad, weight_grad + + result, inputs_grad, kernel_grad, weight_grad = conv_fn( + x_in, kernel_in, weight) + + if input_sharding == 'replicated': + input_layout = Layout.replicated(self.mesh, len(input_size)) + output_layout = self.replicated_2d + elif input_sharding == 'batch_sharded': + input_layout = Layout.batch_sharded(self.mesh, BATCH_DIM, len(input_size)) + output_layout = self.batch_sharded_2d + + kernel_layout = Layout.replicated(self.mesh, len(kernel_size)) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_weight = numpy_util.pack_numpy(weight, self.replicated_2d) + d_result, d_inputs_grad, d_kernel_grad, d_weight_grad = conv_fn( + d_x_in, d_kernel_in, d_weight) + + self.assertDTensorEqual(result, output_layout, d_result) + # TODO(b/208700444): layout of input grads should match layout of input. + self.assertDTensorEqual( + inputs_grad, + Layout.replicated(self.mesh, len(input_size)), + d_inputs_grad, + ) + self.assertDTensorEqual(kernel_grad, kernel_layout, d_kernel_grad) + self.assertDTensorEqual(weight_grad, self.replicated_2d, d_weight_grad) + + +SPATIALLY_PARTITIONED_CONV_TEST_CASES = [ + [ + ('Case1', (BATCH_SIZE, 8, 16, CHANNEL_IN), (3, 5, CHANNEL_IN, + CHANNEL_OUT)), + ('Case2', (BATCH_SIZE, 8, 128, CHANNEL_IN), (3, 9, CHANNEL_IN, + CHANNEL_OUT)), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + [ + ('Batch_1d_2x4', [BATCH_DIM, UNSHARDED, WIDTH_DIM, UNSHARDED], (2, 4)), + ('2d_2x4', [UNSHARDED, HEIGHT_DIM, WIDTH_DIM, UNSHARDED], (2, 4)), + ('Batch_2d_2x2x2', [BATCH_DIM, HEIGHT_DIM, WIDTH_DIM, + UNSHARDED], (2, 2, 2)), + ], +] + + +class SpatiallyPartitionedConvOpTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + # TODO(b/261485237): Enable CPU testing once CollectivePermute is supported + # on CPU's. + if not test_util.is_tpu_present(): + self.skipTest('This test only runs on TPUs.') + + def _create_mesh(self, mesh_dims, topology): + global_ids = test_util.create_device_ids_array(topology) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = {} + for device in ('CPU', 'GPU', 'TPU'): + mesh_dict[device] = Mesh( + mesh_dims, + global_ids, + local_ids, + test_util.create_device_list(topology, device), + ) + + return self.configTestMesh(mesh_dict) + + @parameterized.named_parameters( + test_util.product(*SPATIALLY_PARTITIONED_CONV_TEST_CASES)) + def testConv(self, input_shape, kernel_shape, padding, sharding_specs, + topology): + mesh_dims = [spec for spec in sharding_specs if spec != UNSHARDED] + mesh = self._create_mesh(mesh_dims, topology) + + x_in = constant_op.constant( + np.random.random(size=input_shape), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_shape), dtype=dtypes.float32 + ) + + expected_output = nn_ops.conv2d_v2( + x_in, kernel_in, strides=[1, 1, 1, 1], padding=padding + ) + + input_layout = Layout(sharding_specs, mesh) + kernel_layout = Layout.replicated(mesh, 4) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_output = nn_ops.conv2d_v2( + d_x_in, d_kernel_in, strides=[1, 1, 1, 1], padding=padding + ) + + self.assertDTensorEqual(expected_output, input_layout, d_output) + + @parameterized.named_parameters( + test_util.product(*SPATIALLY_PARTITIONED_CONV_TEST_CASES)) + def testConvWithGradient(self, input_shape, kernel_shape, padding, + sharding_specs, topology): + # TODO(b/208700444): add support for SPMD expansion of spatially partitioned + # conv backprop. + self.skipTest( + 'b/208700444: Spatially partitioned conv backprop not implemented.') + + mesh_dims = [spec for spec in sharding_specs if spec != UNSHARDED] + mesh = self._create_mesh(mesh_dims, topology) + + x_in = constant_op.constant( + np.random.random(size=input_shape), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_shape), dtype=dtypes.float32 + ) + + @polymorphic_function.function + def conv_fn(inputs, img_kernel, padding): + with backprop.GradientTape() as tape: + tape.watch([inputs, img_kernel]) + output = nn_ops.conv2d_v2( + inputs, img_kernel, strides=[1, 1, 1, 1], padding=padding + ) + inputs_grad, kernel_grad = tape.gradient(output, [inputs, img_kernel]) + return output, inputs_grad, kernel_grad + + expected_output, expected_inputs_grad, expected_kernel_grad = conv_fn( + x_in, kernel_in, padding) + + input_layout = Layout(sharding_specs, mesh) + kernel_layout = Layout.replicated(mesh, 4) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + + d_output, d_inputs_grad, d_kernel_grad = conv_fn(d_x_in, d_kernel_in, + padding) + + self.assertDTensorEqual(expected_output, input_layout, d_output) + self.assertDTensorEqual(expected_inputs_grad, input_layout, d_inputs_grad) + self.assertDTensorEqual(expected_kernel_grad, kernel_layout, d_kernel_grad) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/mnist_test.py b/tensorflow/dtensor/python/tests/mnist_test.py new file mode 100644 index 00000000000000..5fd08f18414ef0 --- /dev/null +++ b/tensorflow/dtensor/python/tests/mnist_test.py @@ -0,0 +1,197 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DTensor MNIST test.""" + +from absl.testing import parameterized + +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import input_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +_BATCH_DIM = 'batch' +_DEVICE_IDS = test_util.create_device_ids_array((2,)) +_ONE_D_MESH = layout_lib.Mesh( + [_BATCH_DIM], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'CPU'), +) +_ONE_D_TPU_MESH = layout_lib.Mesh( + [_BATCH_DIM], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'TPU'), +) +_BATCH_SIZE = 1024 +_STEPS = 5 +_LR = 1e-3 +_ATOL = 1 # absolute error becomes large as gradients approach zero. +_RTOL = 1e-3 +Layout = layout_lib.Layout + + +def mnist_fake_dataset(): + imgs = [] + labels = [] + for i in range(_STEPS * _BATCH_SIZE): + img = stateless_random_ops.stateless_random_uniform( + shape=(28, 28, 1), + seed=[1, i], + minval=0, + maxval=256, + dtype=dtypes.float32, + ) + imgs.append(img) + label = stateless_random_ops.stateless_random_uniform( + shape=(1,), seed=[2, i], minval=0, maxval=10, dtype=dtypes.int64 + ) + labels.append(label) + + return dataset_ops.DatasetV2.from_tensor_slices( + (array_ops_stack.stack(imgs), array_ops_stack.stack(labels)) + ) + + +def _run_step(inputs, w, b, k): + with backprop.GradientTape() as g: + g.watch([w, b]) + logits = nn_ops.conv2d_v2(inputs, k, strides=[1, 1, 1, 1], padding='SAME') + logits = array_ops.reshape(logits, [logits.shape[0], -1]) + logits = math_ops.matmul(logits, w) + logits = logits + b + loss = math_ops.reduce_sum(logits, axis=[0, 1]) + gw, gb = g.gradient(loss, [w, b]) + for v, v_grad in zip([w, b], [gw, gb]): + v.assign_sub(_LR * v_grad) + return gw, gb, loss + + +class DTensorMNISTTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorMNISTTest, self).setUp() + + global_ids = test_util.create_device_ids_array((2,)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: layout_lib.Mesh( + [_BATCH_DIM], + global_ids, + local_ids, + test_util.create_device_list((2,), device), + ) + for device in ['TPU', 'GPU', 'CPU'] + } + self.mesh = self.configTestMesh(mesh_dict) + + def init_var(self, mesh): + # Initialize TF randon normal variables(without using DTensor). + w_initializer = stateless_random_ops.stateless_random_normal( + shape=[28 * 28, 10], seed=[0, 1] + ) + b_initializer = stateless_random_ops.stateless_random_normal( + shape=[10], seed=[1, 2] + ) + # A filter with 3x3 shape, 1 input channel and 1 output channel. + k_initializer = stateless_random_ops.stateless_random_normal( + [3, 3, 1, 1], seed=[2, 3] + ) + + n_w = variables.Variable(w_initializer) + n_b = variables.Variable(b_initializer) + n_k = variables.Variable(k_initializer) + + # Initialize DTensor variables. + w_initializer_on_mesh = api.copy_to_mesh( + w_initializer, Layout.replicated(mesh, 2) + ) + b_initializer_on_mesh = api.copy_to_mesh( + b_initializer, Layout.replicated(mesh, rank=1) + ) + k_initializer_on_mesh = api.copy_to_mesh( + k_initializer, Layout.replicated(mesh, rank=4) + ) + + w = d_variable.DVariable(w_initializer_on_mesh) + b = d_variable.DVariable(b_initializer_on_mesh) + k = d_variable.DVariable(k_initializer_on_mesh) + + return (n_w, n_b, n_k), (w, b, k) + + @parameterized.named_parameters(('Eager', False), ('Function', True)) + def testMnist(self, on_function): + mnist_dataset = mnist_fake_dataset() + + (n_w, n_b, n_k), (w, b, k) = self.init_var(self.mesh) + + n_dataset = mnist_dataset.batch(_BATCH_SIZE, drop_remainder=True) + n_iter = iter(n_dataset) + + input_layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=4) + label_layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + dtensor_dataset = input_util.DTensorDataset( + dataset=mnist_dataset, + global_batch_size=_BATCH_SIZE, + mesh=self.mesh, + layouts=(input_layout, label_layout), + batch_dim=_BATCH_DIM, + ) + dtensor_iter = iter(dtensor_dataset) + + step_fn = ( + polymorphic_function.function(_run_step) if on_function else _run_step + ) + + # Training loop. + for _ in range(_STEPS): + # Normal run without DTensor. + n_input, _ = next(n_iter) + g_nw, g_nb, n_loss = step_fn(n_input, n_w, n_b, n_k) + + # DTensor Run + dtensor_input, _ = next(dtensor_iter) + with ops.device_v2(api.device_name()): + gw, gb, loss = step_fn(dtensor_input, w, b, k) + + loss_unpack = api.unpack(loss) + self.assertAllEqual(loss_unpack[0], loss_unpack[1]) + + self.assertAllClose(n_loss, loss, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(g_nw, gw, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(g_nb, gb, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(n_w, w, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(n_b, b, atol=_ATOL, rtol=_RTOL) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/multi_client_input_util_test.py b/tensorflow/dtensor/python/tests/multi_client_input_util_test.py new file mode 100644 index 00000000000000..3538cb4e09b4e6 --- /dev/null +++ b/tensorflow/dtensor/python/tests/multi_client_input_util_test.py @@ -0,0 +1,548 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Multi-client tests for input_util.""" + +import os +from typing import Any, List, Mapping, Optional, Tuple + +from absl import logging +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import config +from tensorflow.dtensor.python import input_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import mesh_util +from tensorflow.dtensor.python.tests import multi_client_test_util +from tensorflow.dtensor.python.tests import test_backend_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.data.experimental.service import server_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context +from tensorflow.python.framework import config as tf_config +from tensorflow.python.framework import device_spec +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_config +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + + +mp_context = test_backend_util.get_mp_context() + +# Multi-client test constants. +JOB_NAME = 'worker' +TF_DATA_SERVICE_JOB_NAME = 'dtensor_tf_data' +NUM_CLIENTS = 4 +NUM_DEVICES_PER_CLIENT = 4 + +# Mesh constants. +MESH_DIM_BATCH = 'batch' +MESH_DIM_HEIGHT = 'height' +MESH_DIM_WIDTH = 'width' + +# Data constants. +IMG_HEIGHT = 8 +IMG_WIDTH = 8 +IMG_CHANNELS = 3 + +UNSHARDED = layout_lib.UNSHARDED +Mesh = layout_lib.Mesh +Layout = layout_lib.Layout + + +def redirect_output(file_name): + # Redirect stderr/stdout to undeclared outputs on sponge. + artifact_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', '') + if artifact_dir: + with open(os.path.join(artifact_dir, file_name), 'wb') as fp: + os.dup2(fp.fileno(), 1) + os.dup2(fp.fileno(), 2) + + +def create_dispatcher(test_name, worker_addresses, port, pipe=None): + dispatcher = server_lib.DispatchServer( + config=server_lib.DispatcherConfig( + port=port, protocol='grpc', worker_addresses=worker_addresses + ) + ) + dispatcher.start() + if pipe is None: + # Dispatcher is not within subprocess, so do not block. + return dispatcher, dispatcher._address + else: + redirect_output(f'test-{test_name}-dispatcher.log') + pipe.send(dispatcher._address) + signal = pipe.recv() # blocks until a 'stop' signal is received + if signal == 'stop': + dispatcher._stop() + pipe.send('stopped') + else: + raise ValueError('Got unknown signal %s' % signal) + + +def create_worker(test_name, dispatcher_address, port=None, pipe=None): + worker = server_lib.WorkerServer( + config=server_lib.WorkerConfig( + port=port, dispatcher_address=dispatcher_address, protocol='grpc' + ) + ) + worker.start() + if pipe is None: + # Worker is not within subprocess, so do not block. + return worker, worker._address + else: + redirect_output(f'test-{test_name}-worker.log') + pipe.send(worker._address) + signal = pipe.recv() # blocks until a 'stop' signal is received + if signal == 'stop': + worker._stop() + pipe.send('stopped') + else: + raise ValueError('Got unknown signal %s' % signal) + + +class TFDataServiceCluster: + """tf.data service cluster with dispatcher and workers as subprocesses. + + To run the cluster in co-located mode, set `num_workers` to 0 and create the + tf.data service workers manually in each client process. + """ + + def __init__(self, + test_name, + num_workers, + worker_ports=None, + worker_addresses=None): + self._test_name = test_name + self._num_workers = num_workers + self._start_dispatcher(worker_addresses) + self._start_workers(worker_ports) + + def _start_dispatcher(self, worker_addresses, port=0): + self._pipe_to_dispatcher, dispatcher_pipe = mp_context.Pipe(True) + logging.info( + 'Starting remote dispatcher on port %d with worker addresses: %s', port, + worker_addresses) + self._dispatcher_process = mp_context.Process( + target=create_dispatcher, + args=(self._test_name, worker_addresses, port, dispatcher_pipe), + ) + self._dispatcher_process.start() + self._dispatcher_address = self._pipe_to_dispatcher.recv() + + def dispatcher_address(self): + return self._dispatcher_address + + def _start_workers(self, worker_ports=None): + self._workers = [] + self._worker_addresses = [] + self._worker_pipes = [] + for idx in range(self._num_workers): + port = worker_ports[idx] if worker_ports else None + self._start_worker(port) + + def _start_worker(self, port=None): + pipe_to_worker, worker_pipe = mp_context.Pipe(True) + logging.info( + 'Starting remote worker on port %d with dispatcher address: %s', port, + self._dispatcher_address) + worker_process = mp_context.Process( + target=create_worker, + args=(self._test_name, self._dispatcher_address, port, worker_pipe), + ) + worker_process.start() + worker_address = self._pipe_to_worker.recv() + self._workers.append(worker_process) + self._worker_addresses.append(worker_address) + self._worker_pipes.append(pipe_to_worker) + + def worker_addresses(self): + return self._worker_addresses + + def stop(self): + # Segfault logs may still be printed because clean exit of child processes + # is not always possible. This will not affect the outcome of the test. + logging.info('Will try to stop TFDataServiceCluster!') + + for idx in range(self._num_workers): + address = self._worker_addresses[idx] + pipe_to_worker = self._worker_pipes[idx] + logging.info('Stopping worker %s...', address) + pipe_to_worker.send('stop') + if pipe_to_worker.poll(2): + if pipe_to_worker.recv() == 'stopped': + logging.info('Successfully stopped worker %s', address) + self._workers[idx].terminate() + + logging.info('Stopping dispatcher...') + self._pipe_to_dispatcher.send('stop') + if self._pipe_to_dispatcher.poll(2): + if self._pipe_to_dispatcher.recv() == 'stopped': + logging.info('Successfully stopped dispatcher') + self._dispatcher_process.terminate() + + +def setup_local_devices(num_devices): + physical_cpus = tf_config.list_physical_devices('CPU') + tf_config.set_logical_device_configuration( + physical_cpus[0], + [context.LogicalDeviceConfiguration() for _ in range(num_devices)], + ) + + +def setup_client(client_id: int, test_name: str, env: Mapping[str, str], + num_local_devices: int): + """Set up a DTensor client for use in multi-client tests. + + Args: + client_id: the index of the client. + test_name: the name of the test under which this client is running, used To + identify the log file artifact containing the test output. + env: a dictionary of environment variables to update. + num_local_devices: number of local devices to set up. + """ + # Redirect client's stderr/stdout to undeclared outputs on sponge. + redirect_output(f'test-{test_name}-process-{client_id}.log') + + # Update any specified environment variables. + for var, val in env.items(): + os.environ[var] = val + + # Set up local devices. + setup_local_devices(num_local_devices) + + # Set up DTensor cluster and enable collectives. + accelerator_util.initialize_accelerator_system() + + +def run_client( + client_id: int, + test_name: str, + env: Mapping[str, str], + num_local_devices: int, + dispatcher_address: str, + worker_port: int, + batch_size: int, + dataset_paths: List[str], + mesh: Mesh, + batch_dim: Optional[str], + layouts: Tuple[Layout, Layout], +) -> List[Tuple[Any, Any]]: + # Co-located tf.data service mode. It is important to hold the worker object + # until the end otherwise it will get garbage collected. + worker, worker_address = create_worker( # pylint: disable=unused-variable + test_name, dispatcher_address, port=worker_port) + logging.info( + 'tf.data service worker running at %s', + worker_address, + ) + + setup_client(client_id, test_name, env, num_local_devices) + + def decode_fn(record_bytes): + decoded = parsing_ops.parse_single_example_v2( + serialized=record_bytes, + features={ + 'idx': parsing_config.FixedLenFeature([], dtype=dtypes.int64), + 'elem': parsing_config.FixedLenFeature([], dtype=dtypes.string), + }, + ) + parsed_elem = gen_parsing_ops.parse_tensor(decoded['elem'], dtypes.int32) + elem = check_ops.ensure_shape( + parsed_elem, [IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS] + ) + return decoded['idx'], elem + + dataset = dataset_ops.DatasetV2.from_tensor_slices(dataset_paths) + dataset = dataset.interleave(readers.TFRecordDatasetV2) + dataset = dataset.map(decode_fn) + + tf_data_service_config = input_util.TFDataServiceConfig( + dispatcher_address=dispatcher_address, job_name=TF_DATA_SERVICE_JOB_NAME + ) + d_dataset = input_util.DTensorDataset( + dataset=dataset, + global_batch_size=batch_size, + mesh=mesh, + layouts=layouts, + batch_dim=batch_dim, + tf_data_service_config=tf_data_service_config, + ) + + # Subprocesses cannot return a sharded DTensor as it triggers a copy and + # copying non-replicated DTensors is not supported. So instead we unpack it + # and return the component tensors. + ret = [] + for batch_idx, elem in d_dataset: + n_batch_idx = api.unpack(batch_idx) + n_elem = api.unpack(elem) + ret.append((n_batch_idx, n_elem)) + return ret + + +class MultiClientDTensorDatasetTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + logging.info('Check per client log in Test artifacts.') + + self.server_ports = [ + multi_client_test_util.pick_unused_port() for _ in range(NUM_CLIENTS) + ] + + self.worker_ports = [ + multi_client_test_util.pick_unused_port() for _ in range(NUM_CLIENTS) + ] + worker_addresses = [f'localhost:{port}' for port in self.worker_ports] + self.cluster = TFDataServiceCluster( + test_name=self._testMethodName, + num_workers=0, # Co-located mode. + worker_addresses=worker_addresses) + + def tearDown(self): + super().tearDown() + self.cluster.stop() + + def write_dataset(self, dataset, num_files, num_elems): + """Writes a dataset_ops.DatasetV2 to multiple files.""" + dataset_paths = [] + dataset_iter = iter(dataset) + + for file_idx in range(num_files): + dataset_path = os.path.join(self.get_temp_dir(), + f'dataset-{file_idx}.tfrecords') + dataset_paths.append(dataset_path) + with tf_record.TFRecordWriter(dataset_path) as writer: + for _ in range(num_elems // num_files): + idx, elem = next(dataset_iter) + elem_bytes = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'idx': feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[idx]) + ), + 'elem': feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[io_ops.serialize_tensor(elem).numpy()] + ) + ), + } + ) + ).SerializeToString() + writer.write(elem_bytes) + + return dataset_paths + + @parameterized.product( + ( + { + # batch=4 x height=2 x width=2 + # 1 replica per client. + 'mesh_dims': [(MESH_DIM_BATCH, 4), + (MESH_DIM_HEIGHT, 2), + (MESH_DIM_WIDTH, 2)], + }, { + # batch=4 x height=2 x width=2 (transposed) + # 1 replica per client with reordered local partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 4), + (MESH_DIM_WIDTH, 2), + (MESH_DIM_HEIGHT, 2)], + }, { + # batch=8 x height=2 x width=1 + # 2 replicas per client. + 'mesh_dims': [(MESH_DIM_BATCH, 8), + (MESH_DIM_HEIGHT, 2), + (MESH_DIM_WIDTH, 1)], + }, { + # batch=8 x height=2 x width=1 (transposed) + # 2 replicas per client with reordered partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 8), + (MESH_DIM_WIDTH, 1), + (MESH_DIM_HEIGHT, 2)], + }, { + # batch=2 x height=4 x width=2 + # 1 replica split over 2 clients. + 'mesh_dims': [(MESH_DIM_BATCH, 2), + (MESH_DIM_HEIGHT, 4), + (MESH_DIM_WIDTH, 2)], + }, { + # batch=2 x height=4 x width=2 (transposed) + # 1 replica split over 2 clients with reordered partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 2), + (MESH_DIM_WIDTH, 2), + (MESH_DIM_HEIGHT, 4)], + }, + ), + ( + { + # Replicated + 'idx_sharding': [UNSHARDED], + 'images_sharding': [UNSHARDED, UNSHARDED, UNSHARDED, UNSHARDED], + }, { + # Batch sharded + 'idx_sharding': [MESH_DIM_BATCH], + 'images_sharding': + [MESH_DIM_BATCH, UNSHARDED, UNSHARDED, UNSHARDED], + }, { + # Spatially sharded + 'idx_sharding': [UNSHARDED], + 'images_sharding': + [UNSHARDED, MESH_DIM_HEIGHT, MESH_DIM_WIDTH, UNSHARDED], + }, { + # Batch and spatially sharded + 'idx_sharding': [MESH_DIM_BATCH], + 'images_sharding': + [MESH_DIM_BATCH, MESH_DIM_HEIGHT, MESH_DIM_WIDTH, UNSHARDED], + } + )) + def testMultiClientIter(self, mesh_dims, idx_sharding, images_sharding): + num_batches = 4 + batch_size = 16 + num_elems = num_batches * batch_size + + images = stateless_random_ops.stateless_random_uniform( + [num_elems, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS], + seed=(1, 2), + minval=0, + maxval=255, + dtype=dtypes.int32, + ) + dataset = dataset_ops.DatasetV2.from_tensor_slices(images) + + # Enumerate the dataset elements to make it easier to identify the batches + # returned by the DTensorDataset. + dataset = dataset.enumerate() + + # Store a mapping of index to dataset elements which can be looked up later + # to identify the batches returned by the DTensorDataset. + all_elems = {idx.numpy(): elem for idx, elem in dataset} + + # Write the dataset and shard it among multiple files. + dataset_paths = self.write_dataset( + dataset, num_files=8, num_elems=num_elems) + + # Construct args for starmap. + args = [] + mesh_dim_names, mesh_dim_sizes = zip(*mesh_dims) + global_device_ids = test_util.create_device_ids_array(mesh_dim_sizes) + device_ids_split = np.split(np.ravel(global_device_ids), NUM_CLIENTS) + dtensor_jobs = [ + f'localhost:{self.server_ports[i]}' for i in range(NUM_CLIENTS) + ] + + for client_id in range(NUM_CLIENTS): + # Manually specify DTensor environment variables since we are in a test + # environment. + env = { + config._DT_CLIENT_ID: str(client_id), + config._DT_JOB_NAME: str(JOB_NAME), + config._DT_JOBS: ','.join(dtensor_jobs) + } + + local_device_ids = device_ids_split[client_id].tolist() + local_devices = [ + device_spec.DeviceSpecV2( # pylint: disable=g-complex-comprehension + job=JOB_NAME, + replica=0, + task=client_id, + device_type='CPU', + device_index=i, + ) + for i in range(len(local_device_ids)) + ] + mesh = Mesh( + dim_names=mesh_dim_names, + global_device_ids=global_device_ids, + local_device_ids=local_device_ids, + local_devices=local_devices, + ) + idx_layout = Layout(idx_sharding, mesh) + images_layout = Layout(images_sharding, mesh) + batch_dim = MESH_DIM_BATCH if MESH_DIM_BATCH in images_sharding else None + + args.append((client_id, self._testMethodName, env, NUM_DEVICES_PER_CLIENT, + self.cluster.dispatcher_address(), + self.worker_ports[client_id], batch_size, dataset_paths, + mesh, batch_dim, (idx_layout, images_layout))) + + def get_results(): + # Run the DTensor client processes and get the DTensor dataset components. + with mp_context.Pool(NUM_CLIENTS) as pool: + results = pool.starmap(run_client, args) + pool.close() + pool.join() + + return results + + # TODO(b/271162918): fix multi-client use case. + with self.assertRaises(NotImplementedError): + results = get_results() + + return + # pylint: disable=unreachable + + # Create a mesh on the main test process. The tensor components returned + # from each DTensor client subprocess will be packed onto this mesh to + # verify correctness. + test_mesh = mesh_util.create_mesh( + mesh_dims=mesh_dims, + devices=[ + 'CPU:%d' % i for i in range(NUM_CLIENTS * NUM_DEVICES_PER_CLIENT) + ]) + test_mesh = self.configTestMesh({'CPU': test_mesh}) + idx_test_layout = Layout(idx_sharding, test_mesh) + images_test_layout = Layout(images_sharding, test_mesh) + + for batch_elems in zip(*results): + # Collect the tensor components returned from each client. + idx_components = [] + images_components = [] + for client_id in range(NUM_CLIENTS): + local_idx, local_images = batch_elems[client_id] + idx_components.extend(local_idx) + images_components.extend(local_images) + + # Pack the dataset elements into a DTensor on the test mesh. + d_idx = api.pack(idx_components, idx_test_layout) + d_images = api.pack(images_components, images_test_layout) + + # Get the batch of elements from the original dataset using the element + # indices. + batch_stack = [] + for elem_idx in d_idx: + batch_stack.append(all_elems.pop(elem_idx.numpy())) + batch = array_ops_stack.stack(batch_stack) + + self.assertDTensorEqual(batch, images_test_layout, d_images) + + self.assertEmpty( + all_elems, 'Not all batches were returned by DTensorDataset.') + + +if __name__ == '__main__': + test_backend_util.handle_test_main(test.main) diff --git a/tensorflow/dtensor/python/tests/multi_client_test_util.py b/tensorflow/dtensor/python/tests/multi_client_test_util.py index 35d3e7aa10ef98..dd4a69f14f77e4 100644 --- a/tensorflow/dtensor/python/tests/multi_client_test_util.py +++ b/tensorflow/dtensor/python/tests/multi_client_test_util.py @@ -31,6 +31,11 @@ 'Number of clients. 0 for local mode. 2 is the only allowed value for TPU.') +def pick_unused_port(): + """Helper function to return an unused port.""" + return portpicker.pick_unused_port() + + def multi_client_main(client_config_function): """Creates a Flock of TensorFlow Processes on localhost.""" flags.FLAGS(sys.argv, known_only=True) @@ -49,12 +54,11 @@ def multi_client_main(client_config_function): # Inverts the order of ports intentionally to rule out ordering bugs. server_ports = sorted( - [portpicker.pick_unused_port() for _ in range(num_process)], reverse=True) - - additional_ports = sorted( - [portpicker.pick_unused_port() for _ in range(num_process)] + [pick_unused_port() for _ in range(num_process)], reverse=True ) + additional_ports = sorted([pick_unused_port() for _ in range(num_process)]) + # Starts processes procs = [] for client_idx in range(num_process): @@ -138,4 +142,3 @@ def run_client(idx, num_clients, server_ports, additional_ports, # The following function call never returns. tf_test.main() - diff --git a/tensorflow/dtensor/python/tests/numerics_test.py b/tensorflow/dtensor/python/tests/numerics_test.py new file mode 100644 index 00000000000000..60bb7995adf6e2 --- /dev/null +++ b/tensorflow/dtensor/python/tests/numerics_test.py @@ -0,0 +1,125 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for numerics in DTensor Ops.""" + +import os + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh +UNSHARDED = layout_lib.UNSHARDED +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' +_MESH_DIMS = [_MESH_DIM_X, _MESH_DIM_Y] + + +class NumericTest(test_util.DTensorBaseTest): + + def setUp(self): + super(NumericTest, self).setUp() + + self.skipForDeviceType(['TPU'], + 'all tests require 8 TPU cores.', + unless_device_count_equals_to=8) + + test_util.reset_logical_devices('CPU', 8) + accelerator_util.initialize_accelerator_system() + + self.stateless_random_seed = [0, 1] + + def _create_mesh(self, topology, device): + device_ids = test_util.create_device_ids_array(topology) + return Mesh( + _MESH_DIMS, + device_ids, + np.ravel(device_ids).tolist(), + test_util.create_device_list(topology, device), + ) + + # Tests AllReduce numerics with and without mixed precision reduce enabled, + # based on go/dtensor-numerics. + @parameterized.named_parameters(('_without_mixed_precision_reduce', False), + ('_with_mixed_precision_reduce', True)) + def test_all_reduce(self, enable_mixed_precision_reduce): + if enable_mixed_precision_reduce: + os.environ['DTENSOR_ENABLE_MIXED_PRECISION_REDUCE'] = '' + # Override group size since we are testing on smaller mesh. + os.environ['DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE'] = '4' + else: + if 'DTENSOR_ENABLE_MIXED_PRECISION_REDUCE' in os.environ: + del os.environ['DTENSOR_ENABLE_MIXED_PRECISION_REDUCE'] + + @polymorphic_function.function + def _compute_reduction(inp): + return math_ops.reduce_sum(inp, axis=[2]) + + input_tensor = stateless_random_ops.stateless_random_uniform( + shape=(8, 8, 8, 64), + seed=self.stateless_random_seed, + minval=-5.0, + maxval=5.0, + dtype=dtypes.bfloat16, + ) + expected = _compute_reduction(input_tensor) + + # Compute reduction on 8x1, since dim 2 is unsharded AllReduce will not be + # needed. + mesh_8x1 = self._create_mesh((8, 1), 'TPU') + input_8x1 = numpy_util.pack_numpy( + input_tensor, + Layout([_MESH_DIM_X, UNSHARDED, UNSHARDED, UNSHARDED], mesh_8x1), + ) + result_8x1 = _compute_reduction(input_8x1) + result_8x1_np = numpy_util.to_numpy(result_8x1) + + # Compute reduction on 1x8, AllReduce will be needed since dim 2 is sharded. + mesh_1x8 = self._create_mesh((1, 8), 'TPU') + input_1x8 = numpy_util.pack_numpy( + input_tensor, + Layout([_MESH_DIM_X, UNSHARDED, _MESH_DIM_Y, UNSHARDED], mesh_1x8), + ) + result_1x8 = _compute_reduction(input_1x8) + result_1x8_np = numpy_util.to_numpy(result_1x8) + + self.assertEqual(result_8x1.dtype, dtypes.bfloat16) + self.assertEqual(result_1x8.dtype, dtypes.bfloat16) + + # Mixed precision does not apply since AllReduce was not used, result will + # always be close to the expected value. + self.assertAllClose(result_8x1_np, expected, atol=1e-5, rtol=1e-5) + + # AllReduce was needed, so result will be more accurate if mixed precision + # is enabled. + if enable_mixed_precision_reduce: + self.assertAllClose(result_1x8_np, expected, atol=1e-5, rtol=1e-5) + else: + self.assertNotAllClose(result_1x8_np, expected, atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/sparse_test.py b/tensorflow/dtensor/python/tests/sparse_test.py new file mode 100644 index 00000000000000..b519da74e4ea57 --- /dev/null +++ b/tensorflow/dtensor/python/tests/sparse_test.py @@ -0,0 +1,141 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +# Convenient constants to use for tests. +_BATCH_DIM = "batch" +_MESH_DIM_X = "x" + +# Shorter notation +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DTensorSPMDTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + self.skipForDeviceType(["GPU", "TPU"], + "SparseTensors only supported on CPU.") + + global_ids = test_util.create_device_ids_array((2, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: Mesh( + [_BATCH_DIM, _MESH_DIM_X], + global_ids, + local_ids, + test_util.create_device_list((2, 2), device), + ) + for device in ("CPU", "GPU", "TPU") + } + self.mesh = self.configTestMesh(mesh_dict) + + @parameterized.parameters( + [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64] + ) + def testIdentityOpWithSparseTensorInputSimple(self, dtype): + inputs = array_ops.ones([6, 4], dtype=dtype) + layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + + @polymorphic_function.function + def f(x): + return array_ops.identity(x) + + self.assertDTensorEqual( + inputs, layout, + f(numpy_util.pack_numpy(inputs, layout, make_sparse=True))) + + @parameterized.product( + dtype=[dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64], + is_sparse_a=[True, False], + is_sparse_b=[True, False], + ) + def testIdentityOpWithSparseTensorInputComplex(self, dtype, is_sparse_a, + is_sparse_b): + inputs_a = array_ops.ones([2, 1], dtype=dtype) + inputs_b = array_ops.ones([32, 16], dtype=dtype) + + layout_a = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + layout_b = Layout.replicated(self.mesh, rank=2) + + @polymorphic_function.function + def f(x, y): + return array_ops.identity(x), array_ops.identity(y) + + got_a, got_b = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=is_sparse_a), + numpy_util.pack_numpy(inputs_b, layout_b, make_sparse=is_sparse_b)) + + self.assertDTensorEqual(inputs_a, layout_a, got_a) + self.assertDTensorEqual(inputs_b, layout_b, got_b) + + def testMultipleIdentityOpFromOneSparseTensor(self): + inputs_a = array_ops.ones([2, 1]) + layout_a = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + + @polymorphic_function.function + def f(x): + return array_ops.identity(x), array_ops.identity(x) + + got_a, got_b = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=True)) + + self.assertDTensorEqual(inputs_a, layout_a, got_a) + self.assertDTensorEqual(inputs_a, layout_a, got_b) + + @parameterized.product( + is_sparse_a=[True, False], + is_sparse_b=[True, False], + shard_type=["Replicated", "Sharded"]) + def testSparseTensorDenseMatMul(self, is_sparse_a, is_sparse_b, shard_type): + inputs_a = array_ops.ones([16, 16]) + inputs_b = array_ops.ones([16, 16]) + + if shard_type == "Replicated": + layout_a = Layout.replicated(self.mesh, rank=2) + layout_b = Layout.replicated(self.mesh, rank=2) + else: + layout_a = Layout([_MESH_DIM_X, _BATCH_DIM], self.mesh) + layout_b = Layout(["unsharded", _MESH_DIM_X], self.mesh) + + expected = math_ops.matmul(inputs_a, inputs_b) + + @polymorphic_function.function + def f(x, y): + return math_ops.matmul(x, y) + + got = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=is_sparse_a), + numpy_util.pack_numpy(inputs_b, layout_b, make_sparse=is_sparse_b)) + + self.assertDTensorEqual(expected, Layout.replicated(self.mesh, rank=2), got) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py b/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py new file mode 100644 index 00000000000000..08ece48382e52b --- /dev/null +++ b/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py @@ -0,0 +1,889 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for TPU device assignment.""" + +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python import tpu_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DeviceAssignmentTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + accelerator_util.initialize_accelerator_system('TPU') + + def tearDown(self): + accelerator_util.shutdown_accelerator_system() + super().tearDown() + + def _build_all_reduce_ring(self, core_locations): + permutation = tpu_util._build_all_reduce_ring(core_locations) + return [core_locations[element] for element in permutation] + + # Picture of chips: + # 0 -- 1 + # | | + # 3 -- 2 + def testBuildAllReduceRing4Replicas(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips with core0/core1 assignments: + # 0/1 -- 2/3 + # | | + # 6/7 -- 4/5 + def testBuildAllReduceRing8ReplicasUsingTwoCores(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 -- 2 -- 3 + # | | + # 15 6 -- 5 -- 4 + # | | + # 14 7 -- 8 -- 9 + # | | + # 13-- 12-- 11-- 10 + def testBuildAllReduceRing32Replicas(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 7 -- 0 6 -- 5 + # | | + # 2 -- 1 3 -- 4 + def testBuildAllReduceRing3D(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + ] + expected = [ + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 31-- 0 -- 1 -- 2 30--29--28--27 + # | | + # 14 5 -- 4 -- 3 15 24--25--26 + # | | | | + # 13 6 -- 7 -- 8 16 23--22--21 + # | | | | + # 12-- 11-- 10-- 9 17--18--19--20 + def testBuildAllReduceRing3DLarge(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + ] + expected = [ + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 4 -- 5 + # | | | | + # 3 -- 2 7 -- 6 + # + # 12-- 13 8 -- 9 + # | | | | + # 15-- 14 11-- 10 + def testBuildOrthogonalAllReduceRings(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + ] + result = tpu_util._build_orthogonal_rings( + core_locations, ring_size=8, rotate_ring_across_rings=False) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 12 -- 13 + # | | | | + # 3 -- 2 15 -- 14 + # + # 4 -- 5 8 -- 9 + # | | | | + # 7 -- 6 11-- 10 + def testBuildOrthogonalRotatedAllReduceRings(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + ] + result = tpu_util._build_orthogonal_rings( + core_locations, ring_size=8, rotate_ring_across_rings=True) + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice, disallowing splitting hosts. + def testCreateDFMeshNoSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [4, 4, 1, 2], ['core', 'y', 'z', 'x'], + can_split_host_across_rings=False, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice with at most 2, 2, 1, 2 devices from + # each dimension, disallowing splitting hosts. + def testCreateDFMeshWithRingBoundsNoSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [2, 2, 1, 2], ['core', 'x', 'y', 'z'], + can_split_host_across_rings=False, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice, allowing splitting hosts. + def testCreateDFMeshSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [4, 4, 1, 2], ['core', 'y', 'z', 'x'], + can_split_host_across_rings=True, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 2x64 mesh on a 4x4x4 PF slice, allowing splitting hosts. + def testCreateMeshPFSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 4, 2], [4, 4, 4, 2], ['core', 'x', 'y', 'z'], + can_split_host_across_rings=True, + ring_size=64) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + tpu_util._CoreLocation(0, 0, 2, 0), + tpu_util._CoreLocation(0, 0, 2, 1), + tpu_util._CoreLocation(1, 0, 2, 0), + tpu_util._CoreLocation(1, 0, 2, 1), + tpu_util._CoreLocation(2, 0, 2, 0), + tpu_util._CoreLocation(2, 0, 2, 1), + tpu_util._CoreLocation(3, 0, 2, 0), + tpu_util._CoreLocation(3, 0, 2, 1), + tpu_util._CoreLocation(0, 1, 2, 0), + tpu_util._CoreLocation(0, 1, 2, 1), + tpu_util._CoreLocation(1, 1, 2, 0), + tpu_util._CoreLocation(1, 1, 2, 1), + tpu_util._CoreLocation(2, 1, 2, 0), + tpu_util._CoreLocation(2, 1, 2, 1), + tpu_util._CoreLocation(3, 1, 2, 0), + tpu_util._CoreLocation(3, 1, 2, 1), + tpu_util._CoreLocation(0, 2, 2, 0), + tpu_util._CoreLocation(0, 2, 2, 1), + tpu_util._CoreLocation(1, 2, 2, 0), + tpu_util._CoreLocation(1, 2, 2, 1), + tpu_util._CoreLocation(2, 2, 2, 0), + tpu_util._CoreLocation(2, 2, 2, 1), + tpu_util._CoreLocation(3, 2, 2, 0), + tpu_util._CoreLocation(3, 2, 2, 1), + tpu_util._CoreLocation(0, 3, 2, 0), + tpu_util._CoreLocation(0, 3, 2, 1), + tpu_util._CoreLocation(1, 3, 2, 0), + tpu_util._CoreLocation(1, 3, 2, 1), + tpu_util._CoreLocation(2, 3, 2, 0), + tpu_util._CoreLocation(2, 3, 2, 1), + tpu_util._CoreLocation(3, 3, 2, 0), + tpu_util._CoreLocation(3, 3, 2, 1), + tpu_util._CoreLocation(0, 0, 3, 0), + tpu_util._CoreLocation(0, 0, 3, 1), + tpu_util._CoreLocation(1, 0, 3, 0), + tpu_util._CoreLocation(1, 0, 3, 1), + tpu_util._CoreLocation(2, 0, 3, 0), + tpu_util._CoreLocation(2, 0, 3, 1), + tpu_util._CoreLocation(3, 0, 3, 0), + tpu_util._CoreLocation(3, 0, 3, 1), + tpu_util._CoreLocation(0, 1, 3, 0), + tpu_util._CoreLocation(0, 1, 3, 1), + tpu_util._CoreLocation(1, 1, 3, 0), + tpu_util._CoreLocation(1, 1, 3, 1), + tpu_util._CoreLocation(2, 1, 3, 0), + tpu_util._CoreLocation(2, 1, 3, 1), + tpu_util._CoreLocation(3, 1, 3, 0), + tpu_util._CoreLocation(3, 1, 3, 1), + tpu_util._CoreLocation(0, 2, 3, 0), + tpu_util._CoreLocation(0, 2, 3, 1), + tpu_util._CoreLocation(1, 2, 3, 0), + tpu_util._CoreLocation(1, 2, 3, 1), + tpu_util._CoreLocation(2, 2, 3, 0), + tpu_util._CoreLocation(2, 2, 3, 1), + tpu_util._CoreLocation(3, 2, 3, 0), + tpu_util._CoreLocation(3, 2, 3, 1), + tpu_util._CoreLocation(0, 3, 3, 0), + tpu_util._CoreLocation(0, 3, 3, 1), + tpu_util._CoreLocation(1, 3, 3, 0), + tpu_util._CoreLocation(1, 3, 3, 1), + tpu_util._CoreLocation(2, 3, 3, 0), + tpu_util._CoreLocation(2, 3, 3, 1), + tpu_util._CoreLocation(3, 3, 3, 0), + tpu_util._CoreLocation(3, 3, 3, 1), + ] + self.assertAllEqual(result, expected) + + def testCreateMeshNoSplittingHostsUnfulfillable(self): + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_unfulfillable_without_splitting_hosts', + can_split_host_across_rings=False) + + def testCreateMeshWithDefaultOptions(self): + mesh = tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_default_options') + self.assertAllEqual(mesh.shape(), [2]) + self.assertEqual(mesh.num_local_devices(), 2) + + def testCreateMeshWithWrongShape(self): + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x'], [1], 'mesh_with_wrong_shape') + + # Build rings for the batch dimension. + def testCreateMeshWithPositiveRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_positive_ring_dims', + ring_dims=1) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Build rings for all non-batch dimensions. + def testCreateMeshWithNegativeRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y', 'z'], [1, 2, 1], + 'mesh_with_negative_ring_dims', + ring_dims=-2) + self.assertAllEqual(mesh.shape(), [1, 2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Build single-core rings. + def testCreateMeshWithZeroRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_zero_ring_dims', + ring_dims=0) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + def testCreateMeshWithCustomAxes(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_custom_axes', + ring_axes=['x', 'z', 'y', 'core']) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # More cores (2 cores) on the first axis (core) than ring size (1). + def testCreateMeshWithDividedAxis(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_divided_axis', + ring_dims=-1, + ring_axes=['core', 'z', 'y', 'x']) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Both meshes should produce the same result despite different `ring_dim`. + def testCreateMultipleMeshes(self): + a = constant_op.constant([[0, 1], [2, 3]], dtype=dtypes.int32) + b_expected = math_ops.reduce_sum(a) + + mesh_1 = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], 'mesh_1', ring_dims=1) + a_1 = numpy_util.pack_numpy(a, Layout(['x', 'y'], mesh_1)) + b_1 = math_ops.reduce_sum(a_1) + self.assertDTensorEqual(b_expected, Layout.replicated(mesh_1, rank=0), b_1) + + mesh_2 = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_2', + ring_dims=-1) + a_2 = numpy_util.pack_numpy(a, Layout(['x', 'y'], mesh_2)) + b_2 = math_ops.reduce_sum(a_2) + self.assertDTensorEqual(b_expected, Layout.replicated(mesh_2, rank=0), b_2) + + def testCreateMeshWithEmptyName(self): + tpu_util.create_tpu_mesh(['x'], [2], '') + + def testCreateMeshWithExistingName(self): + tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_existing_name') + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_existing_name') + + def testGetDeviceIDs(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_to_get_device_ids') + self.assertAllEqual(tpu_util.get_device_ids(mesh), [0, 1]) + + def testGetDeviceLocations(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_to_get_device_locations') + self.assertAllEqual( + tpu_util.get_device_locations(mesh), [{ + 'x': 0, + 'y': 0 + }, { + 'x': 1, + 'y': 0 + }]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 04b18ba19277ae..7fabdd432da41f 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -311,6 +311,7 @@ tf_gen_op_strict_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/compiler/tests:__pkg__", + "//tensorflow/dtensor/python/tests:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/kernel_tests/image_ops:__pkg__", "//tensorflow/python/ops/parallel_for:__pkg__", @@ -486,6 +487,7 @@ tf_gen_op_strict_wrapper_private_py( name = "parsing_ops_gen", visibility = [ "//learning/brain/python/ops:__pkg__", + "//tensorflow/dtensor/python/tests:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/autograph/operators:__pkg__", "//tensorflow/python/data/ops:__pkg__", From 985383ebfd658d2ea67e75d06244d5c67ca41c8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2023 19:25:40 -0800 Subject: [PATCH 031/381] Fix integer representation inconsistency in `tf.roll` kernel's `memcpy` path, causing overflow when the number of elements of the input tensor exceeds INT_MAX for 32-bit integers (2147483647). PiperOrigin-RevId: 584772247 --- tensorflow/core/kernels/roll_op.cc | 24 ++++++++++++------- .../python/kernel_tests/array_ops/BUILD | 6 +++-- .../kernel_tests/array_ops/manip_ops_test.py | 18 ++++++++++++-- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc index ac5a410a261ed9..ae9169df96ac3a 100644 --- a/tensorflow/core/kernels/roll_op.cc +++ b/tensorflow/core/kernels/roll_op.cc @@ -15,14 +15,19 @@ limitations under the License. #include "tensorflow/core/kernels/roll_op.h" +#include +#include + #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/register_types_traits.h" -#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" @@ -191,9 +196,9 @@ void DoRollWithMemcpy(const OpKernelContext* context, int64_t start, int64_t end) { // the number of indices over in the flattened tensor you need to skip in // order to make it over from one side of the isd to the other - const int64_t isd_range = std::max(dim_range[isd], 1); - // the distance along the flattend tensor to the next element in the isd - const int64_t isd_stride = isd_range / std::max(dim_size[isd], 1); + const int64_t isd_range = std::max(dim_range[isd], 1); + // the distance along the flattened tensor to the next element in the isd + const int64_t isd_stride = isd_range / std::max(dim_size[isd], 1); // start and end represent the i-th group currently so we will convert // them into numbers representing the i-th elements. @@ -295,9 +300,10 @@ void DoRollWithMemcpy(const OpKernelContext* context, // Shard auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); const int64_t ave_group_size = dim_range[isd] / 2; - const int total_work = 2 * num_elements / std::max(dim_range[isd], 1); + const int64_t total_work = + 2 * num_elements / std::max(dim_range[isd], 1); // 25000 - experimentally determined with float and bool types - const int cost_per_group = 25000 * sizeof(T) * ave_group_size; + const int64_t cost_per_group = 25000 * sizeof(T) * ave_group_size; Shard(worker_threads->num_threads, worker_threads->workers, total_work, cost_per_group, std::move(work)); } diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 4852a3c1768527..80cb2b53072a28 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -469,9 +469,11 @@ cuda_py_strict_test( cuda_py_strict_test( name = "manip_ops_test", - size = "small", + size = "medium", srcs = ["manip_ops_test.py"], - tags = ["no_windows_gpu"], + tags = [ + "no_windows_gpu", + ], deps = [ "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", diff --git a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py index 65291165da8827..35e2c3c0f86e36 100644 --- a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py @@ -105,11 +105,25 @@ def testEmptyInput(self): self._testAll(np.zeros([0, 1]), 1, 1) self._testAll(np.zeros([1, 0]), 1, 1) + @test_util.run_v2_only + def testLargeInput(self): + with test_util.force_cpu(): + # Num elements just over INT_MAX for int32 to ensure no overflow + np_input = np.arange(0, 128 * 524289 * 33, dtype=np.int8).reshape( + 128, -1, 33 + ) + + for shift in range(-5, 5): + roll = manip_ops.roll(np_input, shift, 0) + self.assertAllEqual(roll[shift], np_input[0], msg=f"shift={shift}") + self.assertAllEqual(roll[0], np_input[-shift], msg=f"shift={shift}") + @test_util.run_deprecated_v1 def testInvalidInputShape(self): # The input should be 1-D or higher, checked in shape function. - with self.assertRaisesRegex(ValueError, - "Shape must be at least rank 1 but is rank 0"): + with self.assertRaisesRegex( + ValueError, "Shape must be at least rank 1 but is rank 0" + ): manip_ops.roll(7, 1, 0) @test_util.run_deprecated_v1 From 0bc64d360438bac13f05e0bd0ed7d79acdf1551a Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Wed, 22 Nov 2023 19:29:19 -0800 Subject: [PATCH 032/381] Implement Component base class for StableHLO Quantizer. Each derived components is expected to derive this interface and implement the `Run` function, which will be run when running the pipeline. This provides the most basic unit of action applied on a StableHLO graph. PiperOrigin-RevId: 584772794 --- .../mlir/quantization/stablehlo/cc/BUILD | 24 +++++++++++ .../quantization/stablehlo/cc/component.h | 40 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD new file mode 100644 index 00000000000000..ec2bfa892e0c66 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -0,0 +1,24 @@ +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_portable", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", + "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "component", + hdrs = ["component.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h new file mode 100644 index 00000000000000..0f2a53d906ea99 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// Component is a public abstraction for StableHLO Quantizer that represents the +// most basic unit of action applied to the StableHLO graph. Derived classes +// should override the `Run` method to implement the action. +class Component { + public: + virtual ~Component() = default; + + // Runs the action to the StableHLO graph, passed by the `module_op`. `config` + // should provide information necessary to configure the action's behavior. + virtual FailureOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) = 0; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ From 2d7328762fa39668b1f802788f6143c4ffa1e1c4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 22 Nov 2023 20:43:49 -0800 Subject: [PATCH 033/381] [stream_executor] NFC: Use build flags to enable/compile CUDA graph conditional nodes support PiperOrigin-RevId: 584784016 --- .../gpu/runtime3/command_buffer_thunk_test.cc | 9 ++++---- third_party/xla/xla/stream_executor/BUILD | 6 +++++ .../xla/xla/stream_executor/command_buffer.cc | 11 +++++++++ .../xla/xla/stream_executor/command_buffer.h | 9 ++++++++ .../xla/xla/stream_executor/cuda/BUILD | 23 ++++++++++++++++--- .../cuda/cuda_command_buffer_test.cc | 12 ++++------ .../cuda/cuda_conditional_kernels.cu.cc | 4 ++-- 7 files changed, 56 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 2c18b511c092d3..7bd5e1b1aa4395 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -30,12 +30,12 @@ limitations under the License. #include "xla/service/service_executable_run_options.h" #include "xla/shape_util.h" #include "xla/stream_executor/blas.h" +#include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_test_kernels.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" @@ -390,11 +390,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { } TEST(CommandBufferThunkTest, IfCmd) { -#if !defined(XLA_GPU_USE_CUDA_GRAPH_CONDITIONAL) - GTEST_SKIP() << "CUDA graph conditionals not enabled"; -#endif - se::StreamExecutor* executor = CudaExecutor(); + if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } se::Stream stream(executor); stream.Init(); diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 67bf9a3d8db120..638fbaca0278ae 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -439,6 +439,12 @@ cc_library( name = "command_buffer", srcs = ["command_buffer.cc"], hdrs = ["command_buffer.h"], + local_defines = select({ + "//xla/stream_executor/cuda:graph_conditional_enabled": [ + "STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":stream_executor_headers", diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 82861a15c5fd0b..6ac06e67f1e9a0 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -71,6 +71,17 @@ CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; return command_buffer; } +/*static*/ bool CommandBuffer::SupportsConditionalCommands( + const Platform* platform) { + // TODO(ezhulenev): We should extend a Platform with a way to query + // implemented StreamExecutor features, for now we know that only CUDA + // platform supports conditional commands in command buffers. +#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) + return platform->Name() == "CUDA"; +#endif + return false; +} + const internal::CommandBufferInterface* CommandBuffer::implementation() const { return implementation_.get(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index ecbed0e872ba0f..0c12cc82f8ffcf 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -98,6 +99,14 @@ class CommandBuffer { absl::AnyInvocable function, Mode mode = Mode::kPrimary); + //===--------------------------------------------------------------------===// + // Command buffer properties + //===--------------------------------------------------------------------===// + + // Returns true if command buffer on a given platform supports conditional + // commands (If, IfThen, While). + static bool SupportsConditionalCommands(const Platform* platform); + //===--------------------------------------------------------------------===// // Command buffer API //===--------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 92fbe7c4616e68..c1261666286d42 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1,7 +1,5 @@ -# Description: -# CUDA-platform specific StreamExecutor support code. - load("//xla/tests:build_defs.bzl", "xla_test") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "//xla:xla.bzl", @@ -40,6 +38,21 @@ package_group( packages = stream_executor_friends(), ) +# Add `--//third_party/tensorflow/compiler/xla/stream_executor/cuda:enable_graph_conditional` to +# build command to enable CUDA graph conditional nodes support. Requires CUDA >=12.3. +# +# See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#conditional-graph-nodes +bool_flag( + name = "enable_graph_conditional", + build_setting_default = False, +) + +config_setting( + name = "graph_conditional_enabled", + flag_values = {":enable_graph_conditional": "True"}, + visibility = ["//visibility:public"], +) + cc_library( name = "cuda_platform_id", srcs = ["cuda_platform_id.cc"], @@ -446,6 +459,10 @@ cuda_library( cuda_library( name = "cuda_conditional_kernels", srcs = if_cuda_is_configured(["cuda_conditional_kernels.cu.cc"]), + local_defines = select({ + ":graph_conditional_enabled": ["STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1"], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = ["@local_config_cuda//cuda:cuda_headers"], ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 5c4faa4824201b..981187b854f637 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -218,15 +218,11 @@ TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { } TEST(CudaCommandBufferTest, ConditionalIf) { -#if CUDA_VERSION < 12030 - GTEST_SKIP() << "CUDA graph conditionals are not supported"; -#endif - -#if !defined(XLA_GPU_USE_CUDA_GRAPH_CONDITIONAL) - GTEST_SKIP() << "CUDA graph conditionals not enabled"; -#endif - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + if (!CommandBuffer::SupportsConditionalCommands(platform)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); Stream stream(executor); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index 4117f74cab3296..3c6f492615e170 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -23,13 +23,13 @@ namespace { __global__ void SetCondition(cudaGraphConditionalHandle handle, bool* predicate) { -#if defined(XLA_GPU_USE_CUDA_GRAPH_CONDITIONAL) +#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) if (*predicate) { cudaGraphSetConditional(handle, 1); } else { cudaGraphSetConditional(handle, 0); } -#endif // defined(XLA_GPU_USE_CUDA_GRAPH_CONDITIONAL) +#endif // defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) } #else From b487f0246c5804978f86f0a6bf2add1b911af23c Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:25:44 -0800 Subject: [PATCH 034/381] PR #7034: [ROCM] enabling gemm_algorithm_picker on ROCM platform Imported from GitHub PR https://github.com/openxla/xla/pull/7034 Here we enable gemm_algorithm_picker as well as buffer_comparator on ROCM. Some comments are due: - the function BlasSupport::DoBlasGemmWithAlgorithm is currently NYI, therefore, I integrated a workaround to skip autotuning on ROCm for Blas (i.e. only Blas-lt autotuning is supported on ROCm). - there is no support for in-process kernels through driver API: so I added a temporary workaround until a required function is added to HIP runtime (basically, the analogue of **cudaGetFuncBySymbol** is needed - bugfixed algorithm selection for hip/cuda-blas-lt gemms: previously blas-lt API did not set algorithm ID in blas::ProfileResult, hence autotuning always returned **kDefaultAlgorithm**: see xla/stream_executor/cuda/cuda_blas_lt.cc and xla/stream_executor/rocm/hip_blas_lt.cc, respectively - extended xla/service/gpu/gemm_algorithm_picker_test.cc to test not only for blas$gemm but also for blas$lt$gemm (added a test boolean parameter) - added vector size check in xla/service/gpu/runtime/gpublas_lt_matmul.cc which otherwise may crash if the number of found algorithms is too small @xla-rotation: would you have a look, please ? Copybara import of the project: -- c8d4371d64f627616b32f1e7c77a1230f3f4c121 by Pavel Emeliyanenko : enabling gemm algorithm picker for amdgpu ongoing experimental updates bugfixes for gemm_alg_picker for blas-lt case hacky workaround to enable device-side buffer_comparator on rocm cleanup after enabling buffer_comparator hacky way better workaround to deal with in-process kernels some cosmetics -- 701c99f7d701f84403911279cec97e33d6b17474 by Pavel Emeliyanenko : addressing reviewer comments Merging this change closes #7034 PiperOrigin-RevId: 584815066 --- third_party/xla/xla/service/gpu/BUILD | 46 ++++++++++----- .../xla/xla/service/gpu/amdgpu_compiler.cc | 4 +- .../xla/xla/service/gpu/buffer_comparator.cc | 3 +- .../xla/service/gpu/buffer_comparator.cu.cc | 24 ++++++-- .../xla/service/gpu/buffer_comparator_test.cc | 18 ++++-- .../xla/service/gpu/gemm_algorithm_picker.cc | 55 +++++++++++------- .../xla/service/gpu/gemm_algorithm_picker.h | 3 +- .../service/gpu/gemm_algorithm_picker_test.cc | 57 ++++++++++++++----- .../xla/xla/service/gpu/gemm_rewriter.cc | 2 +- .../xla/xla/service/gpu/ir_emitter_context.h | 11 ++-- .../xla/xla/service/gpu/matmul_utils.cc | 1 + third_party/xla/xla/service/gpu/runtime/BUILD | 14 +++-- .../xla/xla/service/gpu/runtime/gemm.cc | 8 +-- .../service/gpu/runtime/gpublas_lt_matmul.cc | 7 +++ .../xla/stream_executor/cuda/cuda_blas_lt.cc | 6 +- third_party/xla/xla/stream_executor/gpu/BUILD | 4 +- .../stream_executor/gpu/redzone_allocator.cc | 2 +- .../xla/xla/stream_executor/rocm/BUILD | 1 + .../xla/stream_executor/rocm/hip_blas_lt.cc | 6 +- .../xla/stream_executor/rocm/rocm_driver.cc | 31 +++++++--- .../rocm/rocm_driver_wrapper.h | 2 + .../stream_executor/rocm/rocm_gpu_executor.cc | 39 ++++++++----- 22 files changed, 241 insertions(+), 103 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index c6615c857e1dca..1721e06572edb7 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -18,6 +18,7 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", "if_rocm_is_configured", + "rocm_copts", ) load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library") load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") @@ -29,6 +30,7 @@ load( "@local_tsl//tsl/platform:build_config_root.bzl", "if_static", "tf_cuda_tests_tags", + "tf_gpu_tests_tags", ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", @@ -1458,11 +1460,13 @@ cc_library( cc_library( name = "gemm_algorithm_picker", - srcs = if_cuda_is_configured(["gemm_algorithm_picker.cc"]), - hdrs = if_cuda_is_configured(["gemm_algorithm_picker.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]), + hdrs = if_gpu_is_configured(["gemm_algorithm_picker.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = if_gpu_is_configured([ ":backend_configs_cc", ":buffer_comparator", ":gemm_thunk", @@ -1479,8 +1483,7 @@ cc_library( "//xla:status_macros", "//xla/stream_executor", "//xla/stream_executor:blas", - "//xla/stream_executor/cuda:cublas_lt_header", - "//xla/stream_executor/cuda:cublas_plugin", + "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/gpu:redzone_allocator", @@ -3244,6 +3247,7 @@ cc_library( ":autotuner_util", ":cusolver_rewriter", ":gemm_rewriter", + ":gemm_algorithm_picker", ":gpu_compiler", ":conv_algorithm_picker", ":gpu_conv_padding_legalization", @@ -3611,10 +3615,13 @@ xla_cc_test( cc_library( name = "buffer_comparator", - srcs = if_cuda_is_configured(["buffer_comparator.cc"]), - hdrs = if_cuda_is_configured(["buffer_comparator.h"]), + srcs = if_gpu_is_configured(["buffer_comparator.cc"]), + hdrs = if_gpu_is_configured(["buffer_comparator.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = if_gpu_is_configured([ ":launch_dimensions", ":buffer_comparator_kernel", ":gpu_asm_opts_util", @@ -3634,22 +3641,33 @@ cc_library( cuda_library( name = "buffer_comparator_kernel", - srcs = if_cuda_is_configured(["buffer_comparator.cu.cc"]), + srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]), + copts = rocm_copts(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], - deps = ["@local_config_cuda//cuda:cuda_headers"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_cc_test( name = "buffer_comparator_test", - srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]), - tags = tf_cuda_tests_tags(), + srcs = if_gpu_is_configured(["buffer_comparator_test.cc"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), deps = [ ":stream_executor_util", "//xla:shape_util", "//xla:types", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - ] + if_cuda_is_configured([ + ] + if_gpu_is_configured([ ":buffer_comparator", "//xla/stream_executor:device_memory", ]), diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index c3b5da52df569f..b6280e62a86bf7 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cusolver_rewriter.h" +#include "xla/service/gpu/gemm_algorithm_picker.h" #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/service/gpu/gpu_conv_rewriter.h" @@ -153,8 +154,7 @@ Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline->AddPass(autotune_config); } - // TODO: - // pipeline->AddPass(autotune_config); + pipeline->AddPass(autotune_config); return OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cc index eeb8ea6509f3c1..c5dbe883f28919 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cc @@ -168,7 +168,6 @@ static StatusOr CompareEqualParameterized(se::Stream* stream, stream, current, expected))); CHECK_EQ(host_return, result) << "Host comparison succeeded even though GPU comparison failed."; - return false; } @@ -176,6 +175,7 @@ StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, se::DeviceMemoryBase expected) const { switch (shape_.element_type()) { +#if GOOGLE_CUDA // not available for ROCm yet.. case xla::F8E4M3FN: return CompareEqualParameterized( stream, current, expected, shape_, config_, "fp8_e4m3fn_comparison", @@ -184,6 +184,7 @@ StatusOr BufferComparator::CompareEqual( return CompareEqualParameterized( stream, current, expected, shape_, config_, "fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison()); +#endif // GOOGLE_CUDA case xla::F16: return CompareEqualParameterized( stream, current, expected, shape_, config_, "fp16_comparison", diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc index d42e772eeb4cd9..08d99184f096b3 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc @@ -13,10 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if GOOGLE_CUDA #include #include #include +using bfloat16 = __nv_bfloat16; +#define BF16_TO_F32 __bfloat162float + +#elif TENSORFLOW_USE_ROCM +#include +#include + +using bfloat16 = hip_bfloat16; +#define BF16_TO_F32 float + +#endif + #include namespace xla::gpu::buffer_comparator { @@ -36,6 +49,7 @@ __device__ __inline__ float Canonicalize(float input) { return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); } +#if GOOGLE_CUDA __global__ void xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t* buffer_a, __nv_fp8_storage_t* buffer_b, float rel_error_threshold, @@ -81,6 +95,7 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); } +#endif // GOOGLE_CUDA __global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, float rel_error_threshold, @@ -134,15 +149,14 @@ __global__ void xla_fp64_comparison(double* buffer_a, double* buffer_b, atomicAdd(mismatch_count, 1); } -__global__ void xla_bf16_comparison(__nv_bfloat16* buffer_a, - __nv_bfloat16* buffer_b, +__global__ void xla_bf16_comparison(bfloat16* buffer_a, bfloat16* buffer_b, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; - float elem_a = __bfloat162float(buffer_a[idx]); - float elem_b = __bfloat162float(buffer_b[idx]); + float elem_a = BF16_TO_F32(buffer_a[idx]); + float elem_b = BF16_TO_F32(buffer_b[idx]); elem_a = Canonicalize(elem_a); elem_b = Canonicalize(elem_b); if (isnan(elem_a) && isnan(elem_b)) return; @@ -182,6 +196,7 @@ __global__ void xla_int32_comparison(int* buffer_a, int* buffer_b, } // namespace +#if GOOGLE_CUDA void* fp8_e4m3fn_comparison() { return reinterpret_cast(&xla_fp8_e4m3fn_comparison); } @@ -189,6 +204,7 @@ void* fp8_e4m3fn_comparison() { void* fp8_e5m2_comparison() { return reinterpret_cast(&xla_fp8_e5m2_comparison); } +#endif void* fp16_comparison() { return reinterpret_cast(&xla_fp16_comparison); diff --git a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc index 540f3591362bc0..839e5038419c0f 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc @@ -34,8 +34,13 @@ namespace { class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() - : platform_(se::MultiPlatformManager::PlatformWithName("cuda").value()), - stream_exec_(platform_->ExecutorForDevice(0).value()) {} +#if GOOGLE_CUDA + : platform_(se::MultiPlatformManager::PlatformWithName("CUDA").value()), +#elif TENSORFLOW_USE_ROCM + : platform_(se::MultiPlatformManager::PlatformWithName("ROCM").value()), +#endif + stream_exec_(platform_->ExecutorForDevice(0).value()) { + } // Take floats only for convenience. Still uses ElementType internally. template @@ -162,7 +167,7 @@ TEST_F(BufferComparatorTest, TestInfs) { EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); - +#if GOOGLE_CUDA EXPECT_TRUE( CompareEqualFloatBuffers({inf}, {std::nanf("")})); EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); @@ -182,6 +187,7 @@ TEST_F(BufferComparatorTest, TestInfs) { EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +#endif // GOOGLE_CUDA } TEST_F(BufferComparatorTest, TestNumbers) { @@ -209,7 +215,7 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); EXPECT_FALSE(CompareEqualFloatBuffers({-128}, {127})); - +#if GOOGLE_CUDA EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); @@ -221,6 +227,7 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); EXPECT_TRUE(CompareEqualFloatBuffers({11}, {12})); EXPECT_TRUE(CompareEqualFloatBuffers({12}, {11})); +#endif // GOOGLE_CUDA } TEST_F(BufferComparatorTest, TestMultiple) { @@ -291,7 +298,7 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } - +#if GOOGLE_CUDA { EXPECT_TRUE(CompareEqualFloatBuffers( {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); @@ -325,6 +332,7 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } +#endif // GOOGLE_CUDA } TEST_F(BufferComparatorTest, BF16) { diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 20f9e709f0961a..2823e16fcd0373 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -47,10 +47,8 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/util/proto/proto_utils.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/buffer_comparator.h" -#include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" #endif namespace xla { @@ -113,7 +111,7 @@ StatusOr GetBestAlgorithm( if (!autotune_config.should_check_correctness()) { continue; } - +#if GOOGLE_CUDA // redzone check is not yet available on ROCm TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, allocator.CheckRedzones()); @@ -126,6 +124,7 @@ StatusOr GetBestAlgorithm( CHECK(!autotune_config.should_crash_on_check_failure()); continue; } +#endif // GOOGLE_CUDA if (!reference_algorithm) { stream->ThenMemcpy(&reference_buffer, output_buffer, @@ -195,31 +194,33 @@ StatusOr GetBestBlasAlgorithm( namespace { -StatusOr AsBlasLtEpilogue( +using se::gpu::BlasLt; + +StatusOr AsBlasLtEpilogue( GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: - return se::gpu::BlasLt::Epilogue::kDefault; + return BlasLt::Epilogue::kDefault; case GemmBackendConfig::RELU: - return se::gpu::BlasLt::Epilogue::kReLU; + return BlasLt::Epilogue::kReLU; case GemmBackendConfig::GELU: - return se::gpu::BlasLt::Epilogue::kGELU; + return BlasLt::Epilogue::kGELU; case GemmBackendConfig::GELU_AUX: - return se::gpu::BlasLt::Epilogue::kGELUWithAux; + return BlasLt::Epilogue::kGELUWithAux; case GemmBackendConfig::BIAS: - return se::gpu::BlasLt::Epilogue::kBias; + return BlasLt::Epilogue::kBias; case GemmBackendConfig::BIAS_RELU: - return se::gpu::BlasLt::Epilogue::kBiasThenReLU; + return BlasLt::Epilogue::kBiasThenReLU; case GemmBackendConfig::BIAS_GELU: - return se::gpu::BlasLt::Epilogue::kBiasThenGELU; + return BlasLt::Epilogue::kBiasThenGELU; case GemmBackendConfig::BIAS_GELU_AUX: - return se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux; + return BlasLt::Epilogue::kBiasThenGELUWithAux; default: return InternalError("Unsupported Epilogue."); } } -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM StatusOr DoGemmAutotuneNoCache( const HloInstruction* gemm, const AutotuneCacheKey& key, @@ -309,18 +310,18 @@ StatusOr DoGemmAutotuneNoCache( autotune_config, rng_state)); } - TF_ASSIGN_OR_RETURN( - auto plan, se::gpu::BlasLt::GetMatmulPlan(stream, config, epilogue)); + TF_ASSIGN_OR_RETURN(auto plan, + BlasLt::GetMatmulPlan(stream, config, epilogue)); TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); TF_ASSIGN_OR_RETURN( best_algorithm, - GetBestAlgorithm( + GetBestAlgorithm( stream, buffer_allocator, gemm->ToString(), autotune_config, lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, hlo_module_config, gemm_config.beta(), - [&](const se::gpu::BlasLt::MatmulAlgorithm& algorithm) + [&](const BlasLt::MatmulAlgorithm& algorithm) -> StatusOr { se::OwningScratchAllocator<> scratch_allocator( stream->parent()->device_ordinal(), allocator); @@ -336,6 +337,14 @@ StatusOr DoGemmAutotuneNoCache( std::vector algorithms; TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)); +#if TENSORFLOW_USE_ROCM // Blas gemm algorithms are not yet supported + if (algorithms.empty()) { // nothing to autotune + VLOG(1) << "Skipping autotuning for ROCm.."; + best_algorithm.mutable_gemm()->set_algorithm(se::blas::kDefaultAlgorithm); + return best_algorithm; + } +#endif + TF_ASSIGN_OR_RETURN( best_algorithm, GetBestBlasAlgorithm( @@ -366,7 +375,7 @@ StatusOr DoGemmAutotuneNoCache( return best_algorithm; } -#endif // (defined(GOOGLE_CUDA) && GOOGLE_CUDA) +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Do Gemm Autotune without stream executor. Use results from autotune cache // only. @@ -390,12 +399,16 @@ StatusOr RunOnInstruction(HloInstruction* gemm, return DoGemmAutotuneNoCache(gemm, key, config); })); - se::CudaComputeCapability capability = config.GetCudaComputeCapability(); GemmBackendConfig updated_config = gemm_config; // We only set the 'algorithm' field on non-Ampere architectures, as for // Ampere it's ignored in any case. - if (!capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { + bool update_algorithm = true; +#if GOOGLE_CUDA + auto capability = config.GetCudaComputeCapability(); + update_algorithm = !capability.IsAtLeast(se::CudaComputeCapability::AMPERE); +#endif + if (update_algorithm) { if (algorithm.has_gemm()) { updated_config.set_selected_algorithm(algorithm.gemm().algorithm()); } else { diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h index 9675d8968bfebf..99cae9f796f496 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h @@ -32,9 +32,8 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/stream_executor/cuda/cuda_blas_lt.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #endif diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc index 07a901fd781226..77bcb13661d7b0 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -33,12 +33,20 @@ namespace { namespace m = ::xla::match; -class GemmAlgorithmPickerTest : public HloTestBase { +class GemmAlgorithmPickerTest : public HloTestBase, + public ::testing::WithParamInterface { public: GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); } + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cublaslt(GetParam()); + debug_options.set_xla_gpu_enable_triton_gemm(false); + return debug_options; + } }; -TEST_F(GemmAlgorithmPickerTest, SetAlgorithm) { +TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) { constexpr absl::string_view kHlo = R"( HloModule module @@ -47,7 +55,10 @@ ENTRY main { %arg1 = f32[100,100]{1,0} parameter(1) ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + + auto module_cfg = GetModuleConfigForTest(); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kHlo, module_cfg)); se::Platform* platform = PlatformUtil::GetDefaultPlatform().value(); TF_ASSERT_OK_AND_ASSIGN(std::vector executors, @@ -57,7 +68,7 @@ ENTRY main { bool changed = false; TF_ASSERT_OK_AND_ASSIGN( changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .cuda_compute_capability()), + .gpu_compute_capability()), m.get())); changed = false; DebugOptions opts; @@ -79,11 +90,11 @@ ENTRY main { // Now send the same module through GemmAlgorithmPicker again. The dot should // have the new algorithm. - TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); changed = false; TF_ASSERT_OK_AND_ASSIGN( changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .cuda_compute_capability()), + .gpu_compute_capability()), m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, @@ -92,15 +103,20 @@ ENTRY main { SCOPED_TRACE(m->ToString()); HloInstruction* dot; - ASSERT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); + if (module_cfg.debug_options().xla_gpu_enable_cublaslt()) { + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall(&dot))); + } else { + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); + } TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config, dot->backend_config()); EXPECT_EQ(config.selected_algorithm(), new_algo_id); } -TEST_F(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) { +TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) { constexpr absl::string_view kHlo = R"( HloModule module @@ -109,7 +125,8 @@ ENTRY main { %arg1 = f32[100,100]{1,0} parameter(1) ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN( + auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest())); se::Platform* platform = PlatformUtil::GetDefaultPlatform().value(); TF_ASSERT_OK_AND_ASSIGN(std::vector executors, @@ -120,7 +137,7 @@ ENTRY main { bool changed = false; TF_ASSERT_OK_AND_ASSIGN( changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .cuda_compute_capability()), + .gpu_compute_capability()), m.get())); changed = false; @@ -142,9 +159,10 @@ ENTRY main { AutotunerUtil::ClearAutotuneResults(); TF_ASSERT_OK(AutotunerUtil::LoadAutotuneResults(results)); + auto module_cfg = GetModuleConfigForTest(); // Now send the same module through GemmAlgorithmPicker again. The dot should // have the new algorithm. - TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); changed = false; DevicelessConfig deviceless_config{ @@ -153,7 +171,7 @@ ENTRY main { AutotuneConfig deviceless_cfg{deviceless_config, opts}; TF_ASSERT_OK_AND_ASSIGN( changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .cuda_compute_capability()), + .gpu_compute_capability()), m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN( @@ -162,13 +180,22 @@ ENTRY main { SCOPED_TRACE(m->ToString()); HloInstruction* dot; - ASSERT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); + + if (module_cfg.debug_options().xla_gpu_enable_cublaslt()) { + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall(&dot))); + } else { + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); + } TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config, dot->backend_config()); EXPECT_EQ(config.selected_algorithm(), new_algo_id); } +INSTANTIATE_TEST_SUITE_P(GemmAlgorithmPickerTestSuite, GemmAlgorithmPickerTest, + ::testing::Bool()); + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc index 0e41cfdcbdd1e0..919573f08cd02a 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc @@ -1979,7 +1979,7 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { // Pass a user-managed workspace to legacy cuBLAS operations, as // otherwise cuBLAS will use its own internal pool which will be competing // with XLA allocator for device memory. - int64_t workspace = cuda_cc == nullptr ? 0 + int64_t workspace = cuda_cc == nullptr ? GemmConfig::kDefaultWorkspace : cuda_cc->IsAtLeastHopper() ? GemmConfig::kHopperWorkspace : GemmConfig::kDefaultWorkspace; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h index faa4379bc8e7d7..90614c715ee828 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h @@ -64,14 +64,17 @@ class IrEmitterContext { const se::DeviceDescription& gpu_device_info() const { return gpu_device_info_; } + const se::GpuComputeCapability& gpu_compute_capability() const { + return gpu_device_info_.gpu_compute_capability(); + } se::CudaComputeCapability cuda_compute_capability() const { - auto* cc = std::get_if( - &gpu_device_info_.gpu_compute_capability()); + auto* cc = + std::get_if(&gpu_compute_capability()); return cc != nullptr ? *cc : se::CudaComputeCapability(); } se::RocmComputeCapability rocm_compute_capability() const { - auto* cc = std::get_if( - &gpu_device_info_.gpu_compute_capability()); + auto* cc = + std::get_if(&gpu_compute_capability()); return cc != nullptr ? *cc : se::RocmComputeCapability(); } mlir::MLIRContext* mlir_context() { return mlir_context_; } diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 8136ac0a48a3c5..3482c7edeefcf5 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -571,6 +571,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, se::blas::BlasSupport::ScopedWorkspace scoped_workspace( stream->parent()->AsBlas(), &workspace); +// TODO: enable DoGemmWithAlgorithm for ROCm ! #if GOOGLE_CUDA if (algorithm) { return DoGemmWithAlgorithm( diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 9dc7c482b520b1..c154f5956af650 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -505,7 +505,9 @@ cc_library( name = "gemm", srcs = ["gemm.cc"], hdrs = ["gemm.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], deps = [ ":support", @@ -525,7 +527,7 @@ cc_library( "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", - ] + if_cuda_is_configured([ + ] + if_gpu_is_configured([ "//xla/service/gpu:gemm_algorithm_picker", "//xla/stream_executor/gpu:redzone_allocator", ]), @@ -646,7 +648,9 @@ cc_library( name = "gpublas_lt_matmul", srcs = ["gpublas_lt_matmul.cc"], hdrs = ["gpublas_lt_matmul.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], deps = [ ":support", @@ -701,7 +705,9 @@ cc_library( name = "support", srcs = ["support.cc"], hdrs = ["support.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], deps = [ "//xla:shape_util", diff --git a/third_party/xla/xla/service/gpu/runtime/gemm.cc b/third_party/xla/xla/service/gpu/runtime/gemm.cc index a57e43ed8a8a69..d56bbad214b601 100644 --- a/third_party/xla/xla/service/gpu/runtime/gemm.cc +++ b/third_party/xla/xla/service/gpu/runtime/gemm.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "tsl/platform/errors.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/gemm_algorithm_picker.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #endif @@ -48,7 +48,7 @@ using xla::runtime::CustomCall; using xla::runtime::State; using xla::runtime::StridedMemrefView; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // TODO(ezhulenev): Delete run time auto tuning from XLA. Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, @@ -108,7 +108,7 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, return InternalError("Runtime autotuning failed to select an algorithm"); } } -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, const DebugOptions* debug_options, @@ -144,7 +144,7 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, // outside of state.GetOrCreate() because otherwise it would be a potential // deadlock. if (gemm_config->algorithm == stream_executor::blas::kRuntimeAutotuning) { -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM auto status = DoRuntimeAutotuning(stream, *gemm_config, lhs_data, rhs_data, output_data, output_shape, beta, debug_options, gpu_lock); diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc index c008bf5c3ffdd4..151571c9389523 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc @@ -109,6 +109,13 @@ absl::Status DoMatmul( })); TF_ASSIGN_OR_RETURN(auto algos, (*plan)->GetAlgorithms()); + if (static_cast(algorithm) >= algos.size()) { + return absl::InternalError( + absl::StrFormat("The requested gpublas-lt matmul " + "algorithm is not found. Total algorithms available: " + "%zu; requested: %zu", + algos.size(), static_cast(algorithm))); + } se::DeviceMemoryBase a_data = GetDeviceAddress(a); se::DeviceMemoryBase b_data = GetDeviceAddress(b); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index 36f01e48ab8bf0..af0487f6931e15 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -402,6 +402,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( workspace = gpu::GpuMemoryMutable(&alloc); } + auto palgo = std::any_cast(&algorithm.opaque_algo); { absl::MutexLock lock(&blas_lt_ref_.mu_); TF_RET_CHECK(blas_lt_ref_.blas_lt_ != nullptr); @@ -477,8 +478,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; - if (auto palgo = - std::any_cast(&algorithm.opaque_algo)) { + if (palgo != nullptr) { SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmul( blas_lt_ref_.blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), beta, c.opaque(), @@ -491,6 +491,8 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( if (profile_result != nullptr) { TF_ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); + // set algorithm ID to be unique (otherwise it gets kDefaultAlgorithm ID) + profile_result->set_algorithm(reinterpret_cast(palgo)); profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 8d42981e19ebca..ba8f01430baeb9 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -391,7 +391,9 @@ cc_library( srcs = if_gpu_is_configured(["redzone_allocator.cc"]), hdrs = if_gpu_is_configured(["redzone_allocator.h"]), copts = tsl_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":asm_compiler", diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index c6415812a52ef2..ceee5503e29b73 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -336,7 +336,7 @@ tsl::StatusOr RedzoneAllocator::CheckRedzones() const { (LoadKernelOrGetPtr, uint8_t, uint64_t, DeviceMemory>( executor, "redzone_checker", redzone_checker_ptx, compiled_ptx))); -#else +#elif TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( std::unique_ptr loaded_kernel, (executor->CreateTypedKernel, uint8, uint64_t, diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index c33983328c0fd9..4a6487b9d47947 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -114,6 +114,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_command_buffer", + "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index e25a0a946c7912..262a3a5c3122f9 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -373,6 +373,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( workspace = gpu::GpuMemoryMutable(&alloc); } + auto palgo = std::any_cast(&algorithm.opaque_algo); { absl::MutexLock lock(&blas_lt_ref_.mu_); TF_RET_CHECK(blas_lt_ref_.blas_lt_ != nullptr); @@ -399,8 +400,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_}; - if (auto palgo = - std::any_cast(&algorithm.opaque_algo)) { + if (palgo != nullptr) { SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmul( blas_lt_ref_.blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), beta, c.opaque(), @@ -413,6 +413,8 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( if (profile_result != nullptr) { TF_ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); + // set algorithm ID to be unique (otherwise it gets kDefaultAlgorithm ID) + profile_result->set_algorithm(reinterpret_cast(palgo)); profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 804348050a590e..bf20011441504d 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -432,8 +432,12 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, /* static */ tsl::Status GpuDriver::FuncSetCacheConfig( hipFunction_t function, hipFuncCache_t cache_config) { - RETURN_IF_ROCM_ERROR(wrap::hipFuncSetCacheConfig(function, cache_config), - "Failed to set ROCM kernel cache config."); + // NOTE: this function is only available for in-process GPU kernels: + // https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html#gafdb33ef569eb89808fc5178d04b508ba + // but it is no-op for the current HIP release ! + RETURN_IF_ROCM_ERROR( + wrap::hipFuncSetCacheConfig((const void*)function, cache_config), + "Failed to set ROCM kernel cache config."); return tsl::OkStatus(); } @@ -823,13 +827,26 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes; - RETURN_IF_ROCM_ERROR(wrap::hipModuleLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, - block_dim_x, block_dim_y, block_dim_z, - shared_mem_bytes, stream, kernel_params, extra), - "Failed to launch ROCm kernel: ", kernel_name, + + // for in-process kernel this function returns mangled kernel function name, + // and null otherwise + auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream); + + auto res = hipSuccess; + if (name != nullptr) { + res = wrap::hipLaunchKernel((const void*)function, + dim3(grid_dim_x, grid_dim_y, grid_dim_z), + dim3(block_dim_x, block_dim_y, block_dim_z), + kernel_params, shared_mem_bytes, stream); + } else { + res = wrap::hipModuleLaunchKernel( + function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, + block_dim_z, shared_mem_bytes, stream, kernel_params, extra); + } + RETURN_IF_ROCM_ERROR(res, "Failed to launch ROCm kernel: ", kernel_name, " with block dimensions: ", block_dim_x, "x", block_dim_y, "x", block_dim_z); + VLOG(2) << "successfully launched kernel"; return tsl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 6749e09dc74d0d..36bcfdd33e873e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -120,7 +120,9 @@ namespace wrap { __macro(hipHostRegister) \ __macro(hipHostUnregister) \ __macro(hipInit) \ + __macro(hipKernelNameRefByPtr) \ __macro(hipLaunchHostFunc) \ + __macro(hipLaunchKernel) \ __macro(hipMalloc) \ __macro(hipMemGetAddressRange) \ __macro(hipMemGetInfo) \ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index 8a489bbafea479..82e1c71186bd71 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/gpu/gpu_runtime.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/platform.h" @@ -213,13 +214,7 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, hipModule_t module = nullptr; const string* kernel_name; - const OnDiskKernelLoaderSpec* on_disk_spec = nullptr; - - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - - if (spec.has_cuda_cubin_on_disk()) on_disk_spec = &spec.cuda_cubin_on_disk(); - - if (on_disk_spec != nullptr) { + if (spec.has_cuda_cubin_on_disk()) { return tsl::errors::Internal( "Loading ROCM kernel from disk is not supported"); } else if (spec.has_cuda_cubin_in_memory()) { @@ -233,22 +228,40 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } kernel_to_gpu_binary_[kernel] = hsaco; + } else if (spec.has_in_process_symbol()) { + kernel_name = &spec.in_process_symbol().kernel_name(); + void* symbol = spec.in_process_symbol().symbol(); + + VLOG(1) << "Resolve ROCM kernel " << *kernel_name + << " from symbol pointer: " << symbol; + + *rocm_kernel->gpu_function_ptr() = + static_cast(spec.in_process_symbol().symbol()); } else { return tsl::errors::Internal("No method of loading ROCM kernel provided"); } - VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( - context_, module, kernel_name->c_str(), rocm_kernel->gpu_function_ptr())); + // If we resolved kernel from a symbol pointer, there is no need to load it + // from a module, as ROCm runtime did that automatically for us. + if (!spec.has_in_process_symbol()) { + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_RETURN_IF_ERROR( + GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), + rocm_kernel->gpu_function_ptr())); + } // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the ROCM API. rocm_kernel->set_arity(spec.arity()); - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + // unable to get kernel metadata for in-process kernel + if (!spec.has_in_process_symbol()) { + KernelMetadata kernel_metadata; + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); + kernel->set_metadata(kernel_metadata); + } kernel->set_name(*kernel_name); + kernel->set_kernel_args_packing(spec.kernel_args_packing()); return tsl::OkStatus(); } From 3963ed0c9611f1e135658a648002e78bb9b09cf4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Nov 2023 01:02:07 -0800 Subject: [PATCH 035/381] compat: Update forward compatibility horizon to 2023-11-23 PiperOrigin-RevId: 584821379 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 10b70e78257887..ff4c66e7386d61 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 22) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 23) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 1fe740c02b9e2c5b360838cbbefbf9e9870382a7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Nov 2023 01:02:07 -0800 Subject: [PATCH 036/381] Update GraphDef version to 1689. PiperOrigin-RevId: 584821380 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 617f88a8e533cc..8515e9ca75f65a 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1688 // Updated: 2023/11/22 +#define TF_GRAPH_DEF_VERSION 1689 // Updated: 2023/11/23 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 4c5b5cdcca113e7e2e5666758200717d205b1124 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 23 Nov 2023 02:00:20 -0800 Subject: [PATCH 037/381] Remove dependency of `QuantizationOptions` from `run_calibration` function in py function library. PiperOrigin-RevId: 584831647 --- .../tensorflow/python/py_function_lib.h | 19 ++++++++--- .../tensorflow/python/py_function_lib.py | 31 ++++++++++-------- .../tensorflow/python/pywrap_function_lib.cc | 20 ++++++++---- .../tensorflow/python/pywrap_function_lib.pyi | 5 ++- .../python/pywrap_quantize_model.cc | 7 ++-- .../tensorflow/python/type_casters.h | 32 +++++++++++++++++++ 6 files changed, 86 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h index da7d0a96d3b697..96d937a7e59a9c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" @@ -76,8 +77,12 @@ class PyFunctionLibrary { // Runs calibration on a model saved at `saved_model_path`. `exported_model` // should be the corresponding exported model resulting from the - // pre-calibration step. `representative_dataset` is a python object of type - // `RepresentativeDatasetOrMapping`, which is used to run the calibration. + // pre-calibration step. `signature_keys` is a set of keys that identify a + // SignatureDef to run the calibration on. `tags` is a set of strings that + // identify the `MetaGraphDef`. `calibration_options` provides configurations + // for the calibration behavior. `representative_dataset` is a python object + // of type `RepresentativeDatasetOrMapping`, which is used to run the + // calibration. // // Returns the updated exported model where the collected calibration // statistics are added to `CustomAggregator` nodes at the `min` and `max` @@ -85,10 +90,14 @@ class PyFunctionLibrary { // // If the function signature changes, likely its corresponding .pyi type // hinting and definition should also change. - // LINT.IfChange + // LINT.IfChange(run_calibration) virtual ExportedModel RunCalibration( - absl::string_view saved_model_path, const ExportedModel& exported_model, - const QuantizationOptions& quantization_options, + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const ExportedModel& exported_model, + const CalibrationOptions& calibration_options, + bool force_graph_mode_calibration, pybind11::object representative_dataset) const = 0; // LINT.ThenChange( // pywrap_function_lib.pyi:run_calibration, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index 82f72d0a157920..c55f11c98b23b7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -649,8 +649,11 @@ def save_exported_model( def run_calibration( self, saved_model_path: str, + signature_keys: list[str], + tags: set[str], exported_model_serialized: bytes, - quantization_options_serialized: bytes, + calibration_options_serialized: bytes, + force_graph_mode_calibration: bool, representative_dataset: rd.RepresentativeDatasetOrMapping, ) -> bytes: # LINT.ThenChange(py_function_lib.h:run_calibration) @@ -658,9 +661,13 @@ def run_calibration( Args: saved_model_path: Path to the SavedModel to run calibration. + signature_keys: List of signature keys corresponding to SignatureDefs to + run calibration on. + tags: A set of tags that identify the MetaGraphDef. exported_model_serialized: Serialized `ExportedModel` that corresponds to the SavedModel at `saved_model_path`. - quantization_options_serialized: Serialized `QuantizationOptions`. + calibration_options_serialized: Serialized `CalibrationOptions`. + force_graph_mode_calibration: If True, runs the calibration in graph mode. representative_dataset: Representative dataset to run calibration. Returns: @@ -668,30 +675,26 @@ def run_calibration( statistics are added to `CustomerAggregator` nodes at the `min` and `max` attributes. """ - quantization_options = ( - quantization_options_pb2.QuantizationOptions.FromString( - quantization_options_serialized - ) - ) - # Uses the representative dataset to collect statistics for calibration. # After this operation, min & max values are stored separately in a global # CalibratorSingleton instance. _run_graph_for_calibration( saved_model_path, - quantization_options.signature_keys, - quantization_options.tags, + signature_keys, + tags, representative_dataset, - quantization_options.force_graph_mode_calibration, + force_graph_mode_calibration, ) exported_model = exported_model_pb2.ExportedModel.FromString( exported_model_serialized ) - _add_calibration_statistics( - exported_model.graph_def, - quantization_options.calibration_options, + calibration_options = ( + quantization_options_pb2.CalibrationOptions.FromString( + calibration_options_serialized + ) ) + _add_calibration_statistics(exported_model.graph_def, calibration_options) return exported_model.SerializeToString() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index fdc442cbf97a5c..f850b7effbe0fd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" @@ -34,6 +35,7 @@ namespace { using ::tensorflow::GraphDef; using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::CalibrationOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; @@ -66,12 +68,16 @@ class PyFunctionLibraryTrampoline : public PyFunctionLibrary { ExportedModel RunCalibration( const absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, const ExportedModel& exported_model, - const QuantizationOptions& quantization_options, + const CalibrationOptions& calibration_options, + const bool force_graph_mode_calibration, const py::object representative_dataset) const override { - PYBIND11_OVERRIDE_PURE(ExportedModel, PyFunctionLibrary, run_calibration, - saved_model_path, exported_model, - quantization_options, representative_dataset); + PYBIND11_OVERRIDE_PURE( + ExportedModel, PyFunctionLibrary, run_calibration, saved_model_path, + signature_keys, tags, exported_model, calibration_options, + force_graph_mode_calibration, representative_dataset); } GraphDef EnableDumpTensor(const GraphDef& graph_def) const override { @@ -100,8 +106,10 @@ PYBIND11_MODULE(pywrap_function_lib, m) { py::arg("src_saved_model_path"), py::arg("tags"), py::arg("serialized_signature_def_map")) .def("run_calibration", &PyFunctionLibrary::RunCalibration, - py::arg("saved_model_path"), py::arg("exported_model_serialized"), - py::arg("quantization_options_serialized"), + py::arg("saved_model_path"), py::arg("signature_keys"), + py::arg("tags"), py::arg("exported_model_serialized"), + py::arg("calibration_options_serialized"), + py::arg("force_graph_mode_calibration"), py::arg("representative_dataset")) .def("enable_dump_tensor", &PyFunctionLibrary::EnableDumpTensor, py::arg("graph_def_serialized")) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi index e0bb18cd9a95b4..d464670c5593f0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi @@ -37,8 +37,11 @@ class PyFunctionLibrary: def run_calibration( self, saved_model_path: str, + signature_keys: list[str], + tags: set[str], exported_model_serialized: bytes, - quantization_options_serialized: bytes, + calibration_options_serialized: bytes, + force_graph_mode_calibration: bool, representative_dataset: Any, ) -> bytes: ... # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 9e02aba0f7d346..756b1b0fe94214 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -293,8 +293,11 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { ExportedModel calibrated_exported_model = py_function_library.RunCalibration( - precalibrated_saved_model_dir, exported_model_ids_assigned, - quantization_options, representative_dataset); + precalibrated_saved_model_dir, signature_keys, tags, + exported_model_ids_assigned, + quantization_options.calibration_options(), + quantization_options.force_graph_mode_calibration(), + representative_dataset); if (quantization_options.has_debugger_options()) { calibrated_exported_model = EnableDebugging( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h index 42aafd758d71fe..7c7d1ae46b42f9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h @@ -124,6 +124,38 @@ struct type_caster { } }; +// Handles type conversion for `CalibrationOptions`. +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(tensorflow::quantization::CalibrationOptions, + const_name("CalibrationOptions")); + + // Python -> C++. Converts a serialized protobuf string and deserializes into + // an instance of `CalibrationOptions`. + bool load(handle src, const bool convert) { + auto caster = make_caster(); + // The user should have passed a valid python string. + if (!caster.load(src, convert)) { + return false; + } + + const absl::string_view calibration_opts_serialized = + cast_op(std::move(caster)); + + // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. + return value.ParseFromString(std::string(calibration_opts_serialized)); + } + + // C++ -> Python. Constructs a `bytes` object after serializing `src`. + static handle cast(const tensorflow::quantization::CalibrationOptions& src, + return_value_policy policy, handle parent) { + // release() prevents the reference count from decreasing upon the + // destruction of py::bytes and returns a raw python object handle. + return py::bytes(internal::Serialize(src)).release(); + } +}; + template <> struct type_caster { public: From c7e6871736d594e8296aecf769fa835d8ccab3f1 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Thu, 23 Nov 2023 02:21:46 -0800 Subject: [PATCH 038/381] PR #7016: [NVIDIA XLA GPU] Restore dotdimension numbers for original bmms Imported from GitHub PR https://github.com/openxla/xla/pull/7016 This is trying to fix the error seen in https://github.com/openxla/xla/pull/5782. This was supposed to be in one of the previous bug fixes but got lost when porting from internal PR to github. Copybara import of the project: -- c5a6356bff563d92592348b4b030cd9488a4331a by TJ Xu : Restore dotdimension numbers for original bmms -- 053c1a86b59c8ddad6c6e73e291fe65c05578d5f by TJ Xu : Add a mechanism to restore forward when backward is not matching Add additional checks to make sure we only lower to valid FMHA patterns -- 326de24fe8193772214787c67f1eaf89589ef05f by Cjkkkk : get Q from bmm1 instead of fmha custom call in case of transpose -- 717b0fe5fde0adb3577b5fc5ccc715e83b7a99ea by Cjkkkk : add test for bwd pattern when Q is transposed by fmha Merging this change closes #7016 PiperOrigin-RevId: 584835733 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 14 +- third_party/xla/xla/hlo/ir/hlo_computation.h | 3 +- .../service/gpu/cudnn_fused_mha_rewriter.cc | 92 +++- .../gpu/cudnn_fused_mha_rewriter_test.cc | 395 ++++++++++++++++++ 4 files changed, 492 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 5bd2fa6f7146cb..30af92643a7b1a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -1133,7 +1133,8 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, StatusOr HloComputation::ReplaceInstructionWithDifferentShape( HloInstruction* old_instruction, HloInstruction* new_instruction, - bool preserve_sharding, bool relay_control_dependency) { + bool preserve_sharding, bool relay_control_dependency, + bool remove_unused_operands) { if (preserve_sharding && new_instruction->has_sharding() && old_instruction->has_sharding() && !new_instruction->has_compatible_sharding(old_instruction)) { @@ -1186,10 +1187,13 @@ StatusOr HloComputation::ReplaceInstructionWithDifferentShape( new_instruction->custom_call_target())) { new_instruction->SetAndSanitizeName(old_instruction->name()); } - - TF_RETURN_IF_ERROR(RemoveInstructionAndUnusedOperands( - old_instruction, /*cleanup=*/std::nullopt, - /*ignore_control_dependencies=*/relay_control_dependency)); + if (remove_unused_operands) { + TF_RETURN_IF_ERROR(RemoveInstructionAndUnusedOperands( + old_instruction, /*cleanup=*/std::nullopt, + /*ignore_control_dependencies=*/relay_control_dependency)); + } else { + TF_RETURN_IF_ERROR(RemoveInstruction(old_instruction)); + } return true; } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 6f60f5806e3a3b..92d2e00013bb1a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -514,7 +514,8 @@ class HloComputation { // shape. StatusOr ReplaceInstructionWithDifferentShape( HloInstruction* old_instruction, HloInstruction* new_instruction, - bool preserve_sharding, bool relay_control_dependency = false); + bool preserve_sharding, bool relay_control_dependency = false, + bool remove_unused_operands = true); Status ReplaceInstructionWithDifferentShape(HloInstruction* old_instruction, HloInstruction* new_instruction); diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 099e6776a370d3..c5c1f77995a530 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -671,11 +671,10 @@ bool IsBmm2GradGemm2(HloInstruction* instr) { } MatchBwdResult MatchBmm1GradGemm1(MatchBwdResult previous_result, - HloInstruction* fwd_fmha_call, HloInstruction* bmm_1) { MatchBwdResult match_result = previous_result; match_result.has_match = false; - const HloInstruction* q_tensor = fwd_fmha_call->operand(0); + const HloInstruction* q_tensor = bmm_1->operand(0); for (int64_t i = 0; i < q_tensor->user_count(); i++) { HloInstruction* q_tensor_user_i = q_tensor->users()[i]; if (IsBatchedMatmul(q_tensor_user_i) && q_tensor_user_i != bmm_1) { @@ -972,7 +971,7 @@ MatchBwdResult MatchBackwardBmms(HloInstruction* fwd_fmha_call, return matched_result; } - matched_result = MatchBmm1GradGemm1(matched_result, fwd_fmha_call, bmm_1); + matched_result = MatchBmm1GradGemm1(matched_result, bmm_1); if (!matched_result.has_match) { return matched_result; } @@ -1154,6 +1153,9 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( HloInstruction* lhs_bmm1; HloInstruction* rhs_bmm1; HloInstruction* rhs_bmm2; + DotDimensionNumbers orig_bmm1_dot_dim = bmm_1->dot_dimension_numbers(); + DotDimensionNumbers orig_bmm2_dot_dim = bmm_2->dot_dimension_numbers(); + TF_ASSIGN_OR_RETURN(rhs_bmm1, ChangeCheckedDimToFastest( comp, bmm_1, false /*is_lhs*/, true /*should_contracting_be_fastest*/)); @@ -1176,6 +1178,11 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( bmm_2->dot_dimension_numbers(); TF_RET_CHECK((dropout_rate >= 0.0 && dropout_rate <= 1.0)); + // Restore original DotDimensionNumbers. + *((DynCast(bmm_1))->mutable_dot_dimension_numbers()) = + orig_bmm1_dot_dim; + *((DynCast(bmm_2))->mutable_dot_dimension_numbers()) = + orig_bmm2_dot_dim; // If scale node is assigned, extract value from it. if (scale != nullptr) { @@ -1288,9 +1295,15 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( HloInstruction::CreateGetTupleElement(bmm_2->shape(), fmha_call, 0))); if (activation_output) { - TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( - activation_output, HloInstruction::CreateGetTupleElement( - activation_output->shape(), fmha_call, 2))); + HloInstruction* activation_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + activation_output->shape(), fmha_call, 2)); + TF_RETURN_IF_ERROR(comp->ReplaceInstructionWithDifferentShape( + activation_output, activation_gte, + /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false) + .status()); } if (VLOG_IS_ON(2)) { @@ -1337,6 +1350,14 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( HloInstruction* lhs_bmm2_grad_gemm1; HloInstruction* rhs_bmm2_grad_gemm2; HloInstruction* d_output_grad; + DotDimensionNumbers orig_bmm1_grad1_config = + bmm_1_grad_1->dot_dimension_numbers(); + DotDimensionNumbers orig_bmm1_grad2_config = + bmm_1_grad_2->dot_dimension_numbers(); + DotDimensionNumbers orig_bmm2_grad1_config = + bmm_2_grad_1->dot_dimension_numbers(); + DotDimensionNumbers orig_bmm2_grad2_config = + bmm_2_grad_2->dot_dimension_numbers(); // Q tensor TF_ASSIGN_OR_RETURN( @@ -1420,6 +1441,16 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( *bwd_fmha_config.mutable_bmm2_grad_gemm2_dot_dimension_numbers() = bmm_2_grad_2->dot_dimension_numbers(); + // Restore original DotDimensionNumbers + *((DynCast(bmm_1_grad_1)) + ->mutable_dot_dimension_numbers()) = orig_bmm1_grad1_config; + *((DynCast(bmm_1_grad_2)) + ->mutable_dot_dimension_numbers()) = orig_bmm1_grad2_config; + *((DynCast(bmm_2_grad_1)) + ->mutable_dot_dimension_numbers()) = orig_bmm2_grad1_config; + *((DynCast(bmm_2_grad_2)) + ->mutable_dot_dimension_numbers()) = orig_bmm2_grad2_config; + bwd_fmha_config.set_fmha_scale(fwd_config.fmha_scale()); bwd_fmha_config.set_dropout_rate(fwd_config.dropout_rate()); // Set to an arbitrary seed for now, seed is not exposed to XLA in HLO @@ -1544,6 +1575,28 @@ StatusOr CudnnFusedMHARewriter::Run( matched_result.need_canonicalization, matched_result.is_training, matched_result.matched_custom_call_name, debug_options)); if (!is_mha_module_supported) continue; + + // If we have an activation with more than 1 users in non-training mode, + // we cannot rewrite the graph. So skip processing the rest. + HloInstruction* activation = + matched_result.need_canonicalization + ? matched_result.matched_bmm_2->mutable_operand(1) + : matched_result.matched_bmm_2->mutable_operand(0); + if (!matched_result.is_training && activation->user_count() > 1) { + VLOG(2) + << "Activation: " << activation->ToString() + << " cannot have more than 1 users in non-training mode. Skipping."; + continue; + } + HloInstruction* original_bmm2_producer0 = + matched_result.matched_bmm_2->mutable_operand(0); + HloInstruction* original_bmm2_producer1 = + matched_result.matched_bmm_2->mutable_operand(1); + + std::vector original_activation_producers; + for (HloInstruction* operand : activation->mutable_operands()) { + original_activation_producers.push_back(operand); + } // If we need to canonicalize the bmm, we will assign the newly // canonicalized bmm to bmm_2. if (matched_result.need_canonicalization) { @@ -1578,6 +1631,33 @@ StatusOr CudnnFusedMHARewriter::Run( fwd_fmha_call, matched_result.matched_bmm_1, matched_result.matched_mask, v_transposed); if (!matched_bwd_result.has_match) { + VLOG(2) << "Backward pattern not matching, skipping."; + // If backward pattern is not matched, we need to restore the + // original graph structure. + // Replacing new GTEs added by forward FMHA call with cloned old + // activations and bmm2. + HloInstruction* output_gte = fwd_fmha_call->users()[0]; + HloInstruction* activation_gte = fwd_fmha_call->users()[1]; + std::string suffix = "fmha_no_match_clone"; + HloInstruction* cloned_activation = + comp->AddInstruction(activation->CloneWithNewOperands( + activation->shape(), original_activation_producers, suffix)); + + // Since old activation is detached by forward FMHA rewrite, we need + // to use the newly cloned activation. + HloInstruction* lhs = activation == original_bmm2_producer0 + ? cloned_activation + : original_bmm2_producer1; + HloInstruction* rhs = activation == original_bmm2_producer0 + ? original_bmm2_producer1 + : cloned_activation; + HloInstruction* cloned_bmm2 = comp->AddInstruction( + matched_result.matched_bmm_2->CloneWithNewOperands( + matched_result.matched_bmm_2->shape(), {lhs, rhs}, suffix)); + + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + TF_RETURN_IF_ERROR( + comp->ReplaceInstruction(activation_gte, cloned_activation)); continue; } // check if dbias is the only user of d_intermediate besides diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index 1545b4a0e39e3f..944acedd15319c 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -2893,6 +2893,401 @@ ENTRY main.146 { EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } +TEST_F(CudnnFusedMhaRewriterTestHloTest, + ActivationHasMoreThan1UserShouldNotLower) { + const char* module_str = R"( +HloModule test + +%region_50.2457 (Arg_0.2458: bf16[], Arg_1.2459: bf16[]) -> bf16[] { + %Arg_0.2458 = bf16[] parameter(0) + %Arg_1.2459 = bf16[] parameter(1) + ROOT %maximum.2 = bf16[] maximum(bf16[] %Arg_0.2458, bf16[] %Arg_1.2459) +} + +%region_36.2316 (Arg_0.2317: f32[], Arg_1.2318: f32[]) -> f32[] { + %Arg_0.2317 = f32[] parameter(0) + %Arg_1.2318 = f32[] parameter(1) + ROOT %add.342 = f32[] add(f32[] %Arg_0.2317, f32[] %Arg_1.2318) +} + +ENTRY main { + %transpose.482 = bf16[4,5,64]{2,1,0} parameter(0) + %transpose.484 = bf16[4,64,5]{2,1,0} parameter(1) + %dot.20 = bf16[4,5,5]{2,1,0} dot(bf16[4,5,64]{2,1,0} %transpose.482, bf16[4,64,5]{2,1,0} %transpose.484), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + %constant.2515 = bf16[] constant(0.125) + %broadcast.789 = bf16[4,5,5]{2,1,0} broadcast(bf16[] %constant.2515), dimensions={} + %multiply.267 = bf16[4,5,5]{2,1,0} multiply(bf16[4,5,5]{2,1,0} %dot.20, bf16[4,5,5]{2,1,0} %broadcast.789) + %constant.287 = f32[] constant(-1) + %broadcast.792 = bf16[4,5,5]{2,1,0} parameter(3) + %add.348 = bf16[4,5,5]{2,1,0} add(bf16[4,5,5]{2,1,0} %multiply.267, bf16[4,5,5]{2,1,0} %broadcast.792) + %constant.2510 = bf16[] constant(-inf) + %reduce.2550 = bf16[4,5]{1,0} reduce(bf16[4,5,5]{2,1,0} %add.348, bf16[] %constant.2510), dimensions={2}, to_apply=%region_50.2457 + %broadcast.793 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %reduce.2550), dimensions={0,1} + %subtract.81 = bf16[4,5,5]{2,1,0} subtract(bf16[4,5,5]{2,1,0} %add.348, bf16[4,5,5]{2,1,0} %broadcast.793) + %exponential.21 = bf16[4,5,5]{2,1,0} exponential(bf16[4,5,5]{2,1,0} %subtract.81) + %convert.180 = f32[4,5,5]{2,1,0} convert(bf16[4,5,5]{2,1,0} %exponential.21) + %constant.2509 = f32[] constant(0) + %reduce.2558 = f32[4,5]{1,0} reduce(f32[4,5,5]{2,1,0} %convert.180, f32[] %constant.2509), dimensions={2}, to_apply=%region_36.2316 + %convert.182 = bf16[4,5]{1,0} convert(f32[4,5]{1,0} %reduce.2558) + %broadcast.794 = bf16[4,5,5]{2,1,0} broadcast(bf16[4,5]{1,0} %convert.182), dimensions={0,1} + %divide.25 = bf16[4,5,5]{2,1,0} divide(bf16[4,5,5]{2,1,0} %exponential.21, bf16[4,5,5]{2,1,0} %broadcast.794) + %transpose.481 = bf16[4,64,5]{2,1,0} parameter(2) + %dot.21 = bf16[4,64,5]{2,1,0} dot(bf16[4,64,5]{2,1,0} %transpose.481, bf16[4,5,5]{2,1,0} %divide.25), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2} + ROOT %tuple.2668 = (bf16[4,5,5]{2,1,0}, bf16[4,64,5]{2,1,0}) tuple(bf16[4,5,5]{2,1,0} %divide.25, bf16[4,64,5]{2,1,0} %dot.21) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision*/ true); + ASSERT_IS_OK(verifier.Run(m.get()).status()); + + EXPECT_EQ(CountFusedAttentionCall(m.get()), 0); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = f16[] parameter(0) + Arg_1.23 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = f16[] parameter(0) + Arg_1.57 = f16[] parameter(1) + ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + constant.18 = pred[2,6,128,128]{3,2,1,0} constant({...}) + Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = f16[] constant(2) + broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = f16[] constant(1) + broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={} + add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13) + constant.21 = f16[] constant(0) + broadcast.23 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.21), dimensions={} + select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.18, add.3, broadcast.23) + constant.15 = f16[] constant(-inf) + reduce.25 = f16[2,6,128]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.17) + exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37) + broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.59), dimensions={0,1,2} + add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1) + select.3 = f16[2,6,128,128]{3,2,1,0} select(constant.18, multiply.8, broadcast.23) + multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(select.3, broadcast.24) + dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision*/ true); + ASSERT_IS_OK(verifier.Run(m.get()).status()); + + // The backward pattern in the graph is not a valid fmha pattern, + // we expect no rewrite happening. + EXPECT_EQ(CountFusedAttentionCall(m.get()), 0); + EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.38 { + Arg_0.39 = f16[] parameter(0) + Arg_1.40 = f16[] parameter(1) + ROOT maximum.1 = f16[] maximum(Arg_0.39, Arg_1.40) +} + +region_1.50 { + Arg_0.51 = f32[] parameter(0) + Arg_1.52 = f32[] parameter(1) + ROOT add.2 = f32[] add(Arg_0.51, Arg_1.52) +} + +region_2.99 { + Arg_0.100 = f16[] parameter(0) + Arg_1.101 = f16[] parameter(1) + ROOT add.3 = f16[] add(Arg_0.100, Arg_1.101) +} + +ENTRY main.126 { + constant.6 = u32[1]{0} constant({2718843009}) + constant.8 = u32[1]{0} constant({1272950319}) + constant.10 = u32[1]{0} constant({0}) + constant.12 = u32[1]{0} constant({2711844646}) + custom-call.65 = (u32[1]{0}, u32[1]{0}) custom-call(constant.6, constant.8, constant.10, constant.12), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000" + get-tuple-element.66 = u32[1]{0} get-tuple-element(custom-call.65), index=0 + bitcast.343 = u32[] bitcast(get-tuple-element.66) + broadcast.27 = u32[98304]{0} broadcast(bitcast.343), dimensions={} + get-tuple-element.67 = u32[1]{0} get-tuple-element(custom-call.65), index=1 + bitcast.344 = u32[] bitcast(get-tuple-element.67) + broadcast.28 = u32[98304]{0} broadcast(bitcast.344), dimensions={} + iota.68 = u32[196608]{0} iota(), iota_dimension=0 + slice = u32[98304]{0} slice(iota.68), slice={[0:98304]} + slice.1 = u32[98304]{0} slice(iota.68), slice={[98304:196608]} + custom-call.75 = (u32[98304]{0}, u32[98304]{0}) custom-call(broadcast.27, broadcast.28, slice, slice.1), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[98304]{0}, u32[98304]{0}, u32[98304]{0}, u32[98304]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\001\000\000\000\000\000" + get-tuple-element.76 = u32[98304]{0} get-tuple-element(custom-call.75), index=0 + get-tuple-element.77 = u32[98304]{0} get-tuple-element(custom-call.75), index=1 + concatenate.2 = u32[196608]{0} concatenate(get-tuple-element.76, get-tuple-element.77), dimensions={0} + constant.56 = u32[] constant(9) + broadcast.63 = u32[196608]{0} broadcast(constant.56), dimensions={} + shift-right-logical.3 = u32[196608]{0} shift-right-logical(concatenate.2, broadcast.63) + constant.57 = u32[] constant(1065353216) + broadcast.64 = u32[196608]{0} broadcast(constant.57), dimensions={} + or.3 = u32[196608]{0} or(shift-right-logical.3, broadcast.64) + bitcast-convert.3 = f32[196608]{0} bitcast-convert(or.3) + constant.58 = f32[] constant(-1) + broadcast.65 = f32[196608]{0} broadcast(constant.58), dimensions={} + add.10 = f32[196608]{0} add(bitcast-convert.3, broadcast.65) + constant.48 = f32[] constant(0) + broadcast.66 = f32[196608]{0} broadcast(constant.48), dimensions={} + maximum.4 = f32[196608]{0} maximum(add.10, broadcast.66) + constant.59 = f32[] constant(0.9) + broadcast.67 = f32[196608]{0} broadcast(constant.59), dimensions={} + compare.3 = pred[196608]{0} compare(maximum.4, broadcast.67), direction=LT + bitcast.308 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3) + constant.44 = pred[2,6,128,128]{3,2,1,0} constant({...}) + Arg_0.1 = f16[2,6,128,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated} + dot.34 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.55 = f16[] constant(2) + broadcast.61 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.55), dimensions={} + multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(dot.34, broadcast.61) + constant.52 = f16[] constant(1) + broadcast.39 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.52), dimensions={} + add.6 = f16[2,6,128,128]{3,2,1,0} add(multiply.8, broadcast.39) + constant.54 = f16[] constant(0) + broadcast.52 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.54), dimensions={} + select.1 = f16[2,6,128,128]{3,2,1,0} select(constant.44, add.6, broadcast.52) + constant.41 = f16[] constant(-inf) + reduce.42 = f16[2,6,128]{2,1,0} reduce(select.1, constant.41), dimensions={3}, to_apply=region_0.38 + broadcast.42 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.42), dimensions={0,1,2} + subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(select.1, broadcast.42) + exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1) + reduce.54 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.48), dimensions={3}, to_apply=region_1.50 + convert.9 = f16[2,6,128]{2,1,0} convert(reduce.54) + broadcast.68 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.68) + constant.60 = f16[] constant(1.1113) + broadcast.69 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.60), dimensions={} + multiply.20 = f16[2,6,128,128]{3,2,1,0} multiply(divide.5, broadcast.69) + select.8 = f16[2,6,128,128]{3,2,1,0} select(bitcast.308, multiply.20, broadcast.52) + Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.88 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + bitcast.248 = pred[2,6,128,128]{3,2,1,0} bitcast(compare.3) + Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.91 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + select.6 = f16[2,6,128,128]{3,2,1,0} select(bitcast.248, dot.91, broadcast.52) + multiply.17 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.69) + divide.4 = f16[2,6,128,128]{3,2,1,0} divide(multiply.17, broadcast.68) + broadcast.55 = f16[2,6,128]{2,1,0} broadcast(constant.52), dimensions={} + multiply.11 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.55, multiply.11) + broadcast.56 = f16[2,6,128]{2,1,0} broadcast(constant.60), dimensions={} + multiply.12 = f16[2,6,128]{2,1,0} multiply(divide.3, broadcast.56) + broadcast.58 = f16[2,6,128,128]{3,2,1,0} broadcast(multiply.12), dimensions={0,1,2} + multiply.13 = f16[2,6,128,128]{3,2,1,0} multiply(select.6, broadcast.58) + multiply.14 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.13, exponential.1) + reduce.103 = f16[2,6,128]{2,1,0} reduce(multiply.14, constant.54), dimensions={3}, to_apply=region_2.99 + broadcast.62 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.103), dimensions={0,1,2} + add.9 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.62) + multiply.18 = f16[2,6,128,128]{3,2,1,0} multiply(add.9, exponential.1) + select.7 = f16[2,6,128,128]{3,2,1,0} select(constant.44, multiply.18, broadcast.52) + multiply.19 = f16[2,6,128,128]{3,2,1,0} multiply(select.7, broadcast.61) + dot.124 = f16[2,6,128,64]{3,2,1,0} dot(multiply.19, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.19), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = f16[2,6,128,64]{3,2,1,0} dot(select.8, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.125 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.88, dot.124, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision*/ true); + ASSERT_IS_OK(verifier.Run(m.get()).status()); + + // The backward pattern in the graph is not a valid fmha pattern, + // we expect no rewrite happening. + EXPECT_EQ(CountFusedAttentionCall(m.get()), 0); + EXPECT_EQ(CountFusedAttentionCall(m.get(), /*is_backward*/ true), 0); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) { + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = f16[] parameter(0) + Arg_1.23 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = f16[] parameter(0) + Arg_1.57 = f16[] parameter(1) + ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = f16[] constant(2) + broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = f16[] constant(1) + broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={} + add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13) + constant.21 = f16[] constant(0) + constant.15 = f16[] constant(-inf) + reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17) + exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37) + broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59) + broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24) + dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0) + .WithShape(F16, {2, 6, 128, 64}), + m::GetTupleElement( + m::CustomCall(&fmha, + {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 0) + .WithShape(F16, {2, 6, 128, 64}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 1)) + .WithShape(F16, {2, 6, 64, 128}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2) + .WithShape(F16, {2, 6, 128, 64})))); + TF_ASSERT_OK_AND_ASSIGN(auto config, + fmha->backend_config()); + EXPECT_EQ(fmha->operands().size(), 5); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); +} + } // anonymous namespace } // namespace gpu } // namespace xla From 2092542d47b6ef916815936098c9535846004bb1 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 23 Nov 2023 02:24:18 -0800 Subject: [PATCH 039/381] [XLA] Add `hlo.topk` -> `mhlo.topk` translation. PiperOrigin-RevId: 584836233 --- .../xla/translate/hlo_to_mhlo/hlo_function_importer.cc | 10 ++++++++++ third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc | 2 ++ .../xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt | 9 +++++++++ 3 files changed, 21 insertions(+) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 08d5f49c806286..cc7aa9e9ed6e49 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -1178,6 +1178,16 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ->create(loc, result_type, sort_op.getResults()) .getOperation(); } + case HloOpcode::kTopK: { + auto topk_instruction = Cast(instruction); + auto topk_op = func_builder->create( + loc, result_type.dyn_cast().getTypes(), operands[0], + builder_->getI64IntegerAttr(topk_instruction->k()), + builder_->getBoolAttr(topk_instruction->largest())); + return func_builder + ->create(loc, result_type, topk_op.getResults()) + .getOperation(); + } case HloOpcode::kCopyStart: { auto copy_start_instruction = Cast(instruction); if (auto cross_program_prefetch_index = diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc index 7c2d2c7fd24345..4c2f38a15a7a13 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -289,6 +289,8 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { return xla::HloOpcode::kConvolution; } else if (isa(op)) { return xla::HloOpcode::kSort; + } else if (isa(op)) { + return xla::HloOpcode::kTopK; } else if (isa(op)) { return xla::HloOpcode::kRngBitGenerator; } else if (isa(op)) { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 396fdbc4c78f71..e18fc9926588f6 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -1850,3 +1850,12 @@ add { %Arg_0.1 = f32[?,784] parameter(0) ROOT %abs.2 = f32[?,784] abs(f32[?,784] %Arg_0.1) } + +// Test topk +%test_topk { + x = f32[4,4] parameter(0) + ROOT out = (f32[4,2], s32[4,2]) topk(x), k=2, largest=true +} +// CHECK-LABEL: func private @test_topk +// CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> tuple, tensor<4x2xi32>> +// CHECK: mhlo.topk([[ARG]], k = 2, largest = true) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) From 95991bf74c707db790a7cbe47aff74fd13718e00 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Thu, 23 Nov 2023 02:37:37 -0800 Subject: [PATCH 040/381] Update XLA's Triton pipeline with the missing changes from the latest integrate PiperOrigin-RevId: 584838739 --- third_party/xla/xla/service/gpu/ir_emitter_triton.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index b7ee8986d622d9..acdbd80fb94cd8 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -736,6 +736,7 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createTritonGPUAccelerateMatmulPass(ccAsInt)); pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); pm.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); + pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createTritonGPUPipelinePass(num_stages, num_warps, numCTAs, ccAsInt)); pm.addPass( @@ -754,6 +755,8 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); } pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); + pm.addPass(mlir::createTritonGPUOptimizeThreadLocalityPass()); + pm.addPass(mlir::createCanonicalizerPass()); // Based on translateTritonGPUToLLVMIR() in // @triton//:lib/Target/LLVMIR/LLVMIRTranslation.cpp pm.addPass(mlir::createConvertSCFToCFPass()); From 049a478faac0dfc6fe0dd01879b2850b1391f83b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 23 Nov 2023 02:46:05 -0800 Subject: [PATCH 041/381] [xla:gpu] Allow missing ObjFile / MLIR in `GpuCompiler::Export`. PiperOrigin-RevId: 584840548 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 4eb572aa76a4e9..ceb72769ee7b79 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1925,16 +1925,14 @@ StatusOr> GpuCompiler::Export( auto* gpu_executable = tensorflow::down_cast(executable); if (!gpu_executable) return Internal("GpuExecutable is null"); HloModuleProto module_proto = gpu_executable->module().ToProto(); - TF_ASSIGN_OR_RETURN(auto obj_file, gpu_executable->GetObjFile()); - TF_ASSIGN_OR_RETURN(auto mlir_module, gpu_executable->GetMlirModule()); + auto obj_file = gpu_executable->GetObjFile().value_or(""); + auto mlir_module = gpu_executable->GetMlirModule().value_or(""); auto text = gpu_executable->text(); auto binary = gpu_executable->binary(); - std::unique_ptr result = - std::make_unique( - module_proto, obj_file, mlir_module, text, binary, - gpu_executable->constants()); - return result; + return std::make_unique( + module_proto, obj_file, mlir_module, text, binary, + gpu_executable->constants()); } Status GpuCompiler::RunPostSchedulingPipelines( From 4205319fa7da4cabac625f273053206956a88d8a Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 23 Nov 2023 02:50:17 -0800 Subject: [PATCH 042/381] [XLA] Disable autosharding tests in OSS until it's fully fixed upstream PiperOrigin-RevId: 584841378 --- .../xla/xla/hlo/experimental/auto_sharding/BUILD | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 7dbfa6a114694d..c27a2cc6984e03 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -283,6 +283,11 @@ build_test( xla_cc_test( name = "auto_sharding_test", srcs = ["auto_sharding_test.cc"], + tags = [ + # Disabled until autosharding is fully supported in OSS, + # https://github.com/openxla/xla/issues/7248. + "no_oss", + ], deps = [ ":auto_sharding", ":auto_sharding_option", @@ -302,6 +307,11 @@ xla_cc_test( xla_cc_test( name = "auto_sharding_solver_test", srcs = ["auto_sharding_solver_test.cc"], + tags = [ + # Disabled until autosharding is fully supported in OSS, + # https://github.com/openxla/xla/issues/7248. + "no_oss", + ], deps = [ ":auto_sharding_solver", ":auto_sharding_strategy", From 2fe0e059de523bf9b491c9aee16e77fa3c45bbd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= Date: Thu, 23 Nov 2023 02:54:31 -0800 Subject: [PATCH 043/381] PR #7201: [GPU] Improve NCCL error reporting Imported from GitHub PR https://github.com/openxla/xla/pull/7201 ncclGetLastError return the last log entry generated at the "WARN/ERROR" level. Here is an example of the new error: ``` NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: unhandled cuda error (run with NCCL_DEBUG=INFO for details). Last NCCL warning(error) log entry (may be unrelated) 'Cuda failure 2 'out of memory''.; current tracing scope: all-reduce-start.285; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=25#. ``` The new part is: ``` Last NCCL warning(error) log entry (may be unrelated) 'Cuda failure 2 'out of memory''. ``` Copybara import of the project: -- 348df80802b82a8015dad96ef2b23bf23c352eba by Frederic Bastien : Add extra error information when NCCL error out. Merging this change closes #7201 PiperOrigin-RevId: 584842170 --- third_party/xla/xla/service/gpu/nccl_utils.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/nccl_utils.cc b/third_party/xla/xla/service/gpu/nccl_utils.cc index 88850a032ca679..a1c0beb0b9ab7e 100644 --- a/third_party/xla/xla/service/gpu/nccl_utils.cc +++ b/third_party/xla/xla/service/gpu/nccl_utils.cc @@ -49,9 +49,10 @@ Status ToStatus(ncclResult_t s, const char* file, int64_t line, if (s == ncclSuccess) { return OkStatus(); } - return tsl::errors::Internal( - absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr, - ncclGetErrorString(s))); + return tsl::errors::Internal(absl::StrFormat( + "%s:%d: NCCL operation %s failed: %s." + " Last NCCL warning(error) log entry (may be unrelated) '%s'.", + file, line, expr, ncclGetErrorString(s), ncclGetLastError(NULL))); } ncclRedOp_t ToNcclReduction(ReductionKind kind) { @@ -210,7 +211,9 @@ void CheckNcclAsyncError(NcclComm& lockable_comm) { if (async_err != ncclSuccess) { LOG(ERROR) << "Aborting communicator: " << comm << " due to async NCCL error: " - << ncclGetErrorString(async_err); + << ncclGetErrorString(async_err) + << ". Last NCCL warning(error) log entry (may be unrelated): " + << ncclGetLastError(NULL); XLA_CUDA_RETURN_IF_ERROR(ncclCommAbort(comm)); } return XLA_CUDA_STATUS(async_err); From 08aea74139b16484842cf0449b08a97648a4fe00 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 23 Nov 2023 03:27:16 -0800 Subject: [PATCH 044/381] PR #7237: Update build_from_source.md doc to point to latest-gpu docker image Imported from GitHub PR https://github.com/openxla/xla/pull/7237 The "devel" docker image is not updated and does not seem maintained anymore. We probably should recommend a most up-to-date image. Copybara import of the project: -- 3412f300c0a57cea37c923719aa484b72c04956f by Mehdi Amini : Update build_from_source.md doc to point to latest-gpu docker image The "devel" docker image is not updated and does not seem maintained anymore. We probably should recommend a most up-to-date image. -- 4b7a36b9a7b58ec73a7d696a555ad70ee639d7ea by Mehdi Amini : Update build_from_source.md Merging this change closes #7237 PiperOrigin-RevId: 584848729 --- third_party/xla/docs/build_from_source.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index 9c4cc0e401fd37..f5b2ded3c4cd4e 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -33,7 +33,7 @@ We recommend using a suitable docker container to build/test XLA, such as [TensorFlow's docker container](https://www.tensorflow.org/install/docker): ``` -docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash +docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/tensorflow:latest-gpu bash ``` Using a docker container you can build XLA with CPU support using the following commands: From 17c2fc88376b2784d37a4755aa33189463f1901c Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 23 Nov 2023 03:50:16 -0800 Subject: [PATCH 045/381] Fix a typo in comment. PiperOrigin-RevId: 584852376 --- third_party/xla/xla/service/cpu/onednn_rewriter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/cpu/onednn_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_rewriter.cc index 45b38c59d76a61..4452381ea6f191 100644 --- a/third_party/xla/xla/service/cpu/onednn_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_rewriter.cc @@ -82,7 +82,7 @@ class OneDnnRewriterVisitor : public DfsHloRewriteVisitor { // verifier already does the job. We, however, need to check if contraction // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication // parlance). We also restrict that batch dimensions of the operands - // matches. + // match. if (!IsSupportedType(dot_instr->shape().element_type())) return OkStatus(); auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot_dim_numbers)); From ffc7fa5e28896c017d092810340422b6ea4c9932 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 23 Nov 2023 05:59:10 -0800 Subject: [PATCH 046/381] [TileAnalysis] Use IsOpElementwise instead of IsElementwise. Obviously, an experienced XLA engineer would know that IsElementwise/IsOpElementwise/IsElementwiseImpl/IsElementwiseOnOperand are very different functions and one should be very careful when using them. PiperOrigin-RevId: 584873063 --- third_party/xla/xla/service/gpu/model/tile_analysis.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index d3dfff05d9e568..41a9c73951bcb5 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -101,7 +101,7 @@ StatusOr ComputeReduceOpIndexing( std::vector exprs; for (auto [input_dim_id, input_dim] : llvm::enumerate(input_shape.dimensions())) { - if (reduce_dims_ids.count(input_dim_id)) { + if (reduce_dims_ids.contains(input_dim_id)) { exprs.push_back(getAffineSymbolExpr(reduced_dim_id++, mlir_context)); sizes.push_back(input_dim); continue; @@ -240,7 +240,7 @@ std::string HloInstructionIndexing::ToString() const { StatusOr ComputeInstructionIndexing( const HloInstruction* instr, int output_id, MLIRContext* mlir_context) { - if (instr->IsElementwise()) { + if (HloInstruction::IsOpElementwise(instr->opcode())) { return ComputeCwiseOpIndexing(instr, mlir_context); } if (auto bcast = DynCast(instr)) { From e16a51b514254f21815213717d36d0e2a99a0c0c Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 23 Nov 2023 06:01:58 -0800 Subject: [PATCH 047/381] Add MoveCopyToUsers pass to multi-headed attention fusion pipeline. This is needed once we want to enable it by default. PiperOrigin-RevId: 584873521 --- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/nvptx_compiler.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1721e06572edb7..fd65d44c92c4b3 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3104,6 +3104,7 @@ cc_library( ":gpu_sort_rewriter", ":ir_emission_utils", ":metrics", + ":move_copy_to_users", ":target_constants", ":triangular_solve_rewriter", ":triton_autotuner", diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index fa6dca5b2eafb6..2c1ca2b4400349 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/metrics.h" +#include "xla/service/gpu/move_copy_to_users.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/triangular_solve_rewriter.h" #include "xla/service/gpu/triton_autotuner.h" @@ -225,6 +226,7 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( false); if (debug_options.xla_gpu_normalize_layouts()) { mha_fusion_pipeline.AddPass(); + mha_fusion_pipeline.AddPass>(); mha_fusion_pipeline.AddPass(); } From 58db9ff66c201aaf576b04064e267f37f35833cd Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 23 Nov 2023 06:32:05 -0800 Subject: [PATCH 048/381] [XLA:GPU] Tiled fusion: fix nested slicing. PiperOrigin-RevId: 584878714 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 10 ++++-- .../service/gpu/gemm_rewriter_triton_test.cc | 35 +++++++++++++++++++ .../xla/service/gpu/ir_emitter_triton_test.cc | 22 ++++++++++++ 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 815b9cf0c3e87c..0eb2cc6cb0b542 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -893,9 +893,13 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( if (dst_logical[dim].size() > 1) { return FusionDecision("Slicing of fragmented dimension."); } - dst_logical[dim].front()->set_size(dst->shape().dimensions(dim)); - dst_logical[dim].front()->set_slice(slice->slice_starts(dim), - slice->slice_limits(dim)); + auto fragment = dst_logical[dim].front(); + fragment->set_size(dst->shape().dimensions(dim)); + // Slicing of an already sliced dimension means adding offsets. + fragment->set_slice( + fragment->slice_start() + slice->slice_starts(dim), + fragment->slice_start() + slice->slice_starts(dim) + + fragment->sliced_size()); } } } else { diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 6a4d94268935a7..ab19d7809e98e9 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -1373,6 +1373,41 @@ ENTRY e { m::Negate())))); } +TEST_F(GemmRewriterTritonLevel2Test, NestedSlicingIsAnalyzedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm_d_computation { + p0 = f32[6,24]{1,0} parameter(0) + s1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]} + n1 = f32[5,20]{1,0} negate(s1) + s2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]} + p1 = f32[7,37]{1,0} parameter(1) + ROOT d = f32[3,37]{1,0} dot(s2, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[7,37]{1,0} parameter(0) + p1 = f32[6,24]{1,0} parameter(1) + ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom, + calls=triton_gemm_d_computation +})")); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, + computation->parameter_instruction(0), 0), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6, + /*slice_start=*/2, /*slice_limit=*/5, + /*subfragments=*/ElementsAre(3)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, + computation->parameter_instruction(0), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*slice_start=*/16, /*slice_limit=*/23, + /*subfragments=*/ElementsAre(7)))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 77b75ae5ecb261..534305831c2a7c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -2003,6 +2003,28 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +TEST_F(TritonGemmLevel2Test, NestedSlicingWorks) { + const std::string kHloText = R"( +ENTRY e { + p1 = f32[6,24] parameter(1) + s1 = f32[5,20] slice(p1), slice={[1:6], [3:23]} + n1 = f32[5,20] negate(s1) + s2 = f32[3,7] slice(n1), slice={[1:4], [13:20]} + p0 = f32[7,37] parameter(0) + ROOT d = f32[3,37] dot(s2, p0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-3})); +} + TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { const std::string kHloText = R"( ENTRY e { From b8a51b397213fc8dc317395cf0c4d7218b16a2f4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Nov 2023 09:24:43 -0800 Subject: [PATCH 049/381] Import openai/triton from GitHub. PiperOrigin-RevId: 584907007 --- third_party/triton/b311157761.patch | 64 ------------------- third_party/triton/cl577379396.patch | 33 ---------- third_party/triton/cl582925648.patch | 24 ------- third_party/triton/workspace.bzl | 5 +- .../xla/third_party/triton/b311157761.patch | 64 ------------------- .../xla/third_party/triton/cl577379396.patch | 33 ---------- .../xla/third_party/triton/cl582925648.patch | 24 ------- .../xla/third_party/triton/workspace.bzl | 5 +- 8 files changed, 4 insertions(+), 248 deletions(-) delete mode 100644 third_party/triton/b311157761.patch delete mode 100644 third_party/triton/cl577379396.patch delete mode 100644 third_party/triton/cl582925648.patch delete mode 100644 third_party/xla/third_party/triton/b311157761.patch delete mode 100644 third_party/xla/third_party/triton/cl577379396.patch delete mode 100644 third_party/xla/third_party/triton/cl582925648.patch diff --git a/third_party/triton/b311157761.patch b/third_party/triton/b311157761.patch deleted file mode 100644 index b03fa04c142a42..00000000000000 --- a/third_party/triton/b311157761.patch +++ /dev/null @@ -1,64 +0,0 @@ -diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp ---- a/include/triton/Tools/Sys/GetEnv.hpp -+++ b/include/triton/Tools/Sys/GetEnv.hpp -@@ -30,6 +30,7 @@ - namespace triton { - - const std::set ENV_VARS = { -+ "ENABLE_MMA_V3", - "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", - "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", - "AMDGCN_ENABLE_DUMP"}; -diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp ---- a/lib/Analysis/Utility.cpp -+++ b/lib/Analysis/Utility.cpp -@@ -394,7 +394,8 @@ bool supportMMA(triton::DotOp op, int version) { - auto aElemTy = op.getA().getType().cast().getElementType(); - auto bElemTy = op.getB().getType().cast().getElementType(); - if (version == 3) { -- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) -+ // TODO(b/311157761): enable mma_v3 -+ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) - return false; - auto retType = op.getResult().getType().cast(); - auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); -diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp ---- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp -+++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp -@@ -40,7 +40,8 @@ public: - // Only insert fences for compute capability 9.0 - if (computeCapability < 90) - return; -- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) -+ // TODO(b/311157761): enable mma_v3 -+ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) - return; - ModuleOp mod = getOperation(); - mod.walk([&](Operation *op) { -diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir ---- a/test/Conversion/tritongpu_to_llvm_hopper.mlir -+++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s - - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> - #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> -diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir ---- a/test/TritonGPU/accelerate-matmul.mlir -+++ b/test/TritonGPU/accelerate-matmul.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s - - // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> - // CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir ---- a/test/TritonGPU/fence-inserstion.mlir -+++ b/test/TritonGPU/fence-inserstion.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> - #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> - #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> \ No newline at end of file diff --git a/third_party/triton/cl577379396.patch b/third_party/triton/cl577379396.patch deleted file mode 100644 index ee569f9b8f55c3..00000000000000 --- a/third_party/triton/cl577379396.patch +++ /dev/null @@ -1,33 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -246,7 +246,7 @@ SmallVector LayoutPropagation::pr - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { - Value arg = forOp.getTiedLoopRegionIterArg(&use); -- Value result = forOp.getResultForOpOperand(use); -+ Value result = forOp.getTiedLoopResult(&use); - setEncoding({arg, result}, info, changed, user); - continue; - } -@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector()) { - auto result = value.cast(); -- OpOperand &forOperand = nestedFor.getOpOperandForResult(result); -+ OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); - markLive(forOperand.get()); - auto nestedYieldOp = - cast(nestedFor.getBody()->getTerminator()); diff --git a/third_party/triton/cl582925648.patch b/third_party/triton/cl582925648.patch deleted file mode 100644 index 9d86a1e9a21d32..00000000000000 --- a/third_party/triton/cl582925648.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -787,7 +787,7 @@ static void rewriteSlice(SetVector(op)) { - auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); - for (Value operand : yieldOp.getOperands()) { -diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir ---- a/test/TritonGPU/combine.mlir -+++ b/test/TritonGPU/combine.mlir -@@ -53,7 +53,7 @@ tt.func @remat(%arg0: i32) -> tensor<102 - // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> - // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]> - // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]> -- // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]> -+ // CHECK: %6 = arith.addi %5, %4 : tensor<1024xi32, [[$target_layout]]> - // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]> - } - diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index a9797e30c231e4..834668f112a38d 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl580208989" - TRITON_SHA256 = "bcf6e99a73c8797720325b0f2e48447cdae7f68c53c68bfe04c39104db542562" + TRITON_COMMIT = "cl584018112" + TRITON_SHA256 = "a0f2461af9fbcf576cef08e0b83ab7a1caa3cfe2041c60b2809cbd495ff14f08" tf_http_archive( name = "triton", @@ -16,7 +16,6 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:b304456327.patch", - "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:cl584230333.patch", ], diff --git a/third_party/xla/third_party/triton/b311157761.patch b/third_party/xla/third_party/triton/b311157761.patch deleted file mode 100644 index b03fa04c142a42..00000000000000 --- a/third_party/xla/third_party/triton/b311157761.patch +++ /dev/null @@ -1,64 +0,0 @@ -diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp ---- a/include/triton/Tools/Sys/GetEnv.hpp -+++ b/include/triton/Tools/Sys/GetEnv.hpp -@@ -30,6 +30,7 @@ - namespace triton { - - const std::set ENV_VARS = { -+ "ENABLE_MMA_V3", - "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", - "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", - "AMDGCN_ENABLE_DUMP"}; -diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp ---- a/lib/Analysis/Utility.cpp -+++ b/lib/Analysis/Utility.cpp -@@ -394,7 +394,8 @@ bool supportMMA(triton::DotOp op, int version) { - auto aElemTy = op.getA().getType().cast().getElementType(); - auto bElemTy = op.getB().getType().cast().getElementType(); - if (version == 3) { -- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) -+ // TODO(b/311157761): enable mma_v3 -+ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) - return false; - auto retType = op.getResult().getType().cast(); - auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); -diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp ---- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp -+++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp -@@ -40,7 +40,8 @@ public: - // Only insert fences for compute capability 9.0 - if (computeCapability < 90) - return; -- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) -+ // TODO(b/311157761): enable mma_v3 -+ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) - return; - ModuleOp mod = getOperation(); - mod.walk([&](Operation *op) { -diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir ---- a/test/Conversion/tritongpu_to_llvm_hopper.mlir -+++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s - - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> - #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> -diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir ---- a/test/TritonGPU/accelerate-matmul.mlir -+++ b/test/TritonGPU/accelerate-matmul.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s - - // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> - // CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir ---- a/test/TritonGPU/fence-inserstion.mlir -+++ b/test/TritonGPU/fence-inserstion.mlir -@@ -1,4 +1,4 @@ --// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s -+// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> - #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> - #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> \ No newline at end of file diff --git a/third_party/xla/third_party/triton/cl577379396.patch b/third_party/xla/third_party/triton/cl577379396.patch deleted file mode 100644 index ee569f9b8f55c3..00000000000000 --- a/third_party/xla/third_party/triton/cl577379396.patch +++ /dev/null @@ -1,33 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -246,7 +246,7 @@ SmallVector LayoutPropagation::pr - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { - Value arg = forOp.getTiedLoopRegionIterArg(&use); -- Value result = forOp.getResultForOpOperand(use); -+ Value result = forOp.getTiedLoopResult(&use); - setEncoding({arg, result}, info, changed, user); - continue; - } -@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector()) { - auto result = value.cast(); -- OpOperand &forOperand = nestedFor.getOpOperandForResult(result); -+ OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); - markLive(forOperand.get()); - auto nestedYieldOp = - cast(nestedFor.getBody()->getTerminator()); diff --git a/third_party/xla/third_party/triton/cl582925648.patch b/third_party/xla/third_party/triton/cl582925648.patch deleted file mode 100644 index 9d86a1e9a21d32..00000000000000 --- a/third_party/xla/third_party/triton/cl582925648.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -787,7 +787,7 @@ static void rewriteSlice(SetVector(op)) { - auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); - for (Value operand : yieldOp.getOperands()) { -diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir ---- a/test/TritonGPU/combine.mlir -+++ b/test/TritonGPU/combine.mlir -@@ -53,7 +53,7 @@ tt.func @remat(%arg0: i32) -> tensor<102 - // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> - // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]> - // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]> -- // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]> -+ // CHECK: %6 = arith.addi %5, %4 : tensor<1024xi32, [[$target_layout]]> - // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]> - } - diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index a9797e30c231e4..834668f112a38d 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl580208989" - TRITON_SHA256 = "bcf6e99a73c8797720325b0f2e48447cdae7f68c53c68bfe04c39104db542562" + TRITON_COMMIT = "cl584018112" + TRITON_SHA256 = "a0f2461af9fbcf576cef08e0b83ab7a1caa3cfe2041c60b2809cbd495ff14f08" tf_http_archive( name = "triton", @@ -16,7 +16,6 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:b304456327.patch", - "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:cl584230333.patch", ], From 1c4a834715c5159771a8ad844dd23b096570da7b Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 23 Nov 2023 09:33:33 -0800 Subject: [PATCH 050/381] [XLA:GPU] Do not error out in triton autotuner in deviceless mode when --xla_gpu_autotuner_level=0 is set Instead, pick the first tiling available. This is consistent with autotuner_level=0 behavior in non-deviceless mode, and allows for better QOL while developing without a (matching) GPU. PiperOrigin-RevId: 584908154 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/triton_autotuner.cc | 51 ++++--- .../xla/service/gpu/triton_autotuner_test.cc | 138 ++++++++++-------- 3 files changed, 110 insertions(+), 80 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index fd65d44c92c4b3..606ed0c969e019 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -707,6 +707,7 @@ xla_test( "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.cc b/third_party/xla/xla/service/gpu/triton_autotuner.cc index fb6286487794da..5909fe611e419b 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner.cc @@ -102,6 +102,9 @@ constexpr int kMinTileSize = 16; // Not a hard limit, just an assumption that should stay valid. constexpr int kMaxTileSize = 512; +// Default tiling when autotuning is disabled. +constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4}; + class TritonAutotunerVisitor : public DfsHloRewriteVisitor { public: explicit TritonAutotunerVisitor(const AutotuneConfig& config) @@ -126,7 +129,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { "compilation (HLO: ", hlo->ToString())); } - return InternalError("Expect autotune result cache hit."); + return absl::InternalError("Expect autotune result cache hit."); })); VLOG(4) << "Result: " << autotune_result.ShortDebugString(); @@ -226,12 +229,11 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { GemmConfigSet GetGemmConfigSet(const HloFusionInstruction* fusion) { const DebugOptions& debug_options = fusion->GetModule()->config().debug_options(); - se::StreamExecutor* stream_exec = config_.GetExecutor(); return {GetPossibleMatmulAutotuneConfigs( *Cast(hlo_query::GetFirstInstructionWithOpcode( *fusion->called_computations().at(0), HloOpcode::kDot)), - stream_exec->GetDeviceDescription().cuda_compute_capability(), - debug_options, config_.ExhaustiveTilingSearch())}; + config_.GetCudaComputeCapability(), debug_options, + config_.ExhaustiveTilingSearch())}; } AutotuneConfig config_; @@ -884,7 +886,7 @@ std::vector GetPossibleMatmulAutotuneConfigs( constexpr int kMinGemmElements = 32 * 32; if (ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements && ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements) { - return ReduceTileSizes(dot, {TritonGemmConfig(32, 32, 32, 1, 1, 4)}); + return ReduceTileSizes(dot, {kDefaultGemmTiling}); } // Split-K optimization enables more even utilization of a GPU in cases // where tiling just the non-contracting dimensions of a GEMM does not create @@ -914,22 +916,29 @@ StatusOr TritonAutotuner::Run( const absl::flat_hash_set& execution_threads) { XLA_SCOPED_LOGGING_TIMER("Triton autotuner"); const DebugOptions& debug_options = module->config().debug_options(); - if (debug_options.xla_gpu_autotune_level() == 0) { - return false; - } + TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, + AutotunerCompileUtil::Create(config_, debug_options)); - if (!config_.IsDeviceless()) { - TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, - AutotunerCompileUtil::Create(config_, debug_options)); + GemmConfigSetCollector gemm_config_set_collector(config_); + absl::flat_hash_map + gemm_config_sets; + TF_ASSIGN_OR_RETURN(gemm_config_sets, + gemm_config_set_collector.CollectGemmConfigSets( + module, execution_threads)); + + if (debug_options.xla_gpu_autotune_level() == 0 || + debug_options.xla_gpu_deterministic_ops()) { + // Pick the first option for each gemm instead of autotuning.. + for (const auto& [fusion, tilings] : gemm_config_sets) { + const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_); + AutotuneResult res; + *res.mutable_triton() = kDefaultGemmTiling.ToProto(); + *res.mutable_run_time() = + tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); + AutotunerUtil::AddResult(key, res); + } + } else if (!config_.IsDeviceless()) { TF_RET_CHECK(opt_compile_util.has_value()); - AutotunerCompileUtil& compile_util = opt_compile_util.value(); - - GemmConfigSetCollector gemm_config_set_collector(config_); - absl::flat_hash_map - gemm_config_sets; - TF_ASSIGN_OR_RETURN(gemm_config_sets, - gemm_config_set_collector.CollectGemmConfigSets( - module, execution_threads)); if (!gemm_config_sets.empty()) { std::string correctness_check_str = config_.should_check_correctness() ? "(with correctness check)" @@ -940,7 +949,7 @@ StatusOr TritonAutotuner::Run( int fusion_id_for_dump = 0; if (debug_options.xla_gpu_single_wave_autotuning()) { // Tune all fusions at once to save time. - TF_RETURN_IF_ERROR(Autotune(config_, compile_util, thread_pool_, + TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, debug_options, gemm_config_sets, fusion_id_for_dump)); } else { @@ -948,7 +957,7 @@ StatusOr TritonAutotuner::Run( for (const auto& key_value : gemm_config_sets) { absl::flat_hash_map single_element_map({key_value}); - TF_RETURN_IF_ERROR(Autotune(config_, compile_util, thread_pool_, + TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, debug_options, single_element_map, fusion_id_for_dump)); } diff --git a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc index 3bdc62a3691e9a..0edf9c86a987d8 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" @@ -144,14 +145,28 @@ ENTRY entry { /*allow_mixed_precision=*/false)); } -class TritonAutotunerTest : public HloTestBase { +class StatelessAutotunerTest : public HloTestBase { public: - TritonAutotunerTest() + StatelessAutotunerTest() : HloTestBase(/*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false) {} + void SetUp() override { + AutotunerUtil::ClearAutotuneResults(); + HloTestBase::SetUp(); + } + + void TearDown() override { + AutotunerUtil::ClearAutotuneResults(); + HloTestBase::TearDown(); + } +}; + +class TritonAutotunerTest : public StatelessAutotunerTest { + public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + DebugOptions debug_options = + StatelessAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(true); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; @@ -196,36 +211,6 @@ class TritonAutotunerTest : public HloTestBase { 0); }); } - - void CheckTritonAutotuningDeviceless(absl::string_view hlo) { - HloPassPipeline pipeline("gemm_rewrite_deviceless"); - pipeline.AddPass(backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()); - tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", - tsl::port::MaxParallelism()); - DebugOptions opts; - pipeline.AddPass( - AutotuneConfig{DevicelessConfig{backend() - .default_stream_executor() - ->GetDeviceDescription() - .model_str(), - backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()}, - opts}, - &thread_pool); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()), - tsl::testing::StatusIs( - tsl::error::INTERNAL, - ::testing::HasSubstr( - "Expect autotune result cache hit for deviceless"))); - } }; class TritonAutotunerTestWithMorePreciseReduction : public TritonAutotunerTest { @@ -568,11 +553,12 @@ ENTRY %e { // TODO(b/281489442): Write a testcase called // `SkipConfigsProducingDeviantResults` or similar. -class TritonAutotunerLevelTest : public HloTestBase, +class TritonAutotunerLevelTest : public StatelessAutotunerTest, public ::testing::WithParamInterface { public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + DebugOptions debug_options = + StatelessAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_autotune_level(GetParam()); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; @@ -591,23 +577,71 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - AutotunerUtil::ClearAutotuneResults(); - - if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) { - MatchOptimizedHlo(kHloText, R"( -; CHECK: kind=kCustom -; CHECK-NOT: block_m - )"); - } else { - MatchOptimizedHlo(kHloText, R"( + MatchOptimizedHlo(kHloText, R"( ; CHECK: kind=kCustom ; CHECK-SAME: block_m )"); - } EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +TEST_P(TritonAutotunerLevelTest, Deviceless) { + const std::string hlo = R"( +HloModule module + +ENTRY e { + x = s8[16,16] parameter(0) + c = f16[16,16] convert(x) + y = f16[16,16] parameter(1) + ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + HloPassPipeline pipeline("gemm_rewrite_deviceless"); + pipeline.AddPass(backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability()); + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", + tsl::port::MaxParallelism()); + DebugOptions opts; + pipeline.AddPass( + AutotuneConfig{DevicelessConfig{backend() + .default_stream_executor() + ->GetDeviceDescription() + .model_str(), + backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability()}, + opts}, + &thread_pool); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloTestBase::RunHloPass(&pipeline, module.get())); + EXPECT_TRUE(changed); + + // Check default configuration. + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), + R"( +// CHECK: backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4"}} + )")); + EXPECT_TRUE(filecheck_matches); + } else { + EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()), + tsl::testing::StatusIs( + tsl::error::INTERNAL, + ::testing::HasSubstr( + "Expect autotune result cache hit for deviceless"))); + } +} + INSTANTIATE_TEST_SUITE_P(TritonAutotunerLevelSweep, TritonAutotunerLevelTest, ::testing::Range(0, 5)); @@ -640,20 +674,6 @@ ENTRY e { )"); } -TEST_F(TritonAutotunerExhaustiveTest, Deviceless_CompileOnly) { - const std::string hlo = R"( -HloModule module - -ENTRY e { - x = s8[16,16] parameter(0) - c = f16[16,16] convert(x) - y = f16[16,16] parameter(1) - ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - CheckTritonAutotuningDeviceless(hlo); -} class TritonAutotunerDisableSplitK : public TritonAutotunerTest { public: From 26b4d8f8b50538f9ad7263e906630666f9112a42 Mon Sep 17 00:00:00 2001 From: Joyce Date: Thu, 23 Nov 2023 17:04:38 -0300 Subject: [PATCH 051/381] Feat: hash pin actions using Step Security Signed-off-by: Joyce --- .github/workflows/stale-issues.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index 84118acca683fd..e439c0f180ed44 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -31,7 +31,7 @@ jobs: pull-requests: write steps: - name: Awaiting response issues - uses: actions/stale@v7 + uses: actions/stale@6f05e4244c9a0b2ed3401882b05d701dd0a7289b # v7.0.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' @@ -59,7 +59,7 @@ jobs: close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further." repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Contribution issues - uses: actions/stale@v7 + uses: actions/stale@6f05e4244c9a0b2ed3401882b05d701dd0a7289b # v7.0.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' From 18bb144a7a1661c23da678e20be2cd604a755e94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Thu, 23 Nov 2023 12:54:38 -0800 Subject: [PATCH 052/381] [XLA:GPU] Fix occasional nullptr dereference in IrEmitterUnnested::EmitTritonFusion xla/tests:dot_operation_test_autotune_disabled_gpu_a100 was flaky because of this. PiperOrigin-RevId: 584935941 --- .../xla/service/gpu/ir_emitter_unnested.cc | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index ac55cd06cbbd42..4015b728ae60b7 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -2059,8 +2059,15 @@ StatusOr IrEmitterUnnested::EmitTritonFusion( // because we only get the launch dimensions after code generation. So we // implement kernel reuse using lower level APIs, such as // `BuildKernelThunkImpl`. - - VLOG(3) << llvm_ir::DumpToString(op); + CHECK_NE(fusion, nullptr); + if (!ir_emitter_context_->emit_ir_from_hlo()) { + CHECK_NE(op, nullptr); + } + if (ir_emitter_context_->emit_ir_from_hlo()) { + VLOG(3) << fusion->ToString(); + } else { + VLOG(3) << llvm_ir::DumpToString(op); + } std::string suggested_kernel_name = std::string(fusion->name()); TF_ASSIGN_OR_RETURN( auto kernel_arguments, @@ -2109,8 +2116,13 @@ StatusOr IrEmitterUnnested::EmitTritonFusion( } else { // Must be a MatMul CHECK_EQ(fusion_kind, kTritonGemmFusionKind); if (!backend_config.has_triton_gemm_config()) { - LOG(WARNING) << "Using fallback triton GEMM config for op " - << GetIrNameFromLoc(op->getLoc()); + if (ir_emitter_context_->emit_ir_from_hlo()) { + LOG(WARNING) << "Using fallback triton GEMM config for op " + << fusion->name(); + } else { + LOG(WARNING) << "Using fallback triton GEMM config for op " + << GetIrNameFromLoc(op->getLoc()); + } auto& triton_config = *backend_config.mutable_triton_gemm_config(); triton_config.set_block_m(64); triton_config.set_block_k(64); From 4407e247cfb843978aa703ffe7b965e42ea879d9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 23 Nov 2023 13:31:17 -0800 Subject: [PATCH 053/381] [xla:ffi] Sketched basic diagnostics infrastructure and plumbed it through decoding APIs PiperOrigin-RevId: 584940021 --- third_party/xla/xla/ffi/api/api.h | 164 ++++++++++++++++++++++-------- third_party/xla/xla/ffi/api/ffi.h | 5 +- third_party/xla/xla/ffi/ffi.h | 6 +- 3 files changed, 130 insertions(+), 45 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index fac396aa709a08..9ebc01805d4b74 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -296,7 +296,8 @@ struct ArgDecoding; // // template <> // struct AttrDecoding { -// static std::optional Decode(XLA_FFI_AttrType type, void* attr); +// static std::optional Decode(XLA_FFI_AttrType type, void* attr, +// DiagnosticEngine&); // } // template @@ -341,6 +342,64 @@ struct CtxDecoding; template struct ResultEncoding; +//===----------------------------------------------------------------------===// +// Diagnostics +//===----------------------------------------------------------------------===// + +class DiagnosticEngine; + +// RAII wrapper around constructed, but but not yet emitted diagnostic. In +// flight diagnostic gives an opportunity to build a diagnostic before reporting +// it to the engine, similar to the builder pattern. +class InFlightDiagnostic { + public: + explicit InFlightDiagnostic(DiagnosticEngine* engine, std::string s) + : engine_(engine), stream_(std::move(s)) {} + + ~InFlightDiagnostic(); + + template + InFlightDiagnostic& operator<<(Arg&& arg) { + stream_ << std::forward(arg); + return *this; + } + + operator std::nullopt_t() const { // NOLINT + return std::nullopt; + } + + private: + InFlightDiagnostic& operator=(const InFlightDiagnostic&) = delete; + InFlightDiagnostic& operator=(InFlightDiagnostic&&) = delete; + + DiagnosticEngine* engine_; + std::stringstream stream_; +}; + +class DiagnosticEngine { + public: + DiagnosticEngine() = default; + DiagnosticEngine(const DiagnosticEngine&) = delete; + DiagnosticEngine& operator=(const DiagnosticEngine&) = delete; + + InFlightDiagnostic Emit(std::string message) { + return InFlightDiagnostic(this, std::move(message)); + } + + std::string Result() const { return s_; } + + private: + friend class InFlightDiagnostic; + + void append(std::string s) { s_.append(std::move(s)); } + + std::string s_; +}; + +inline InFlightDiagnostic::~InFlightDiagnostic() { + engine_->append(stream_.str()); +} + //===----------------------------------------------------------------------===// // Decoding arguments and attributes //===----------------------------------------------------------------------===// @@ -363,10 +422,11 @@ struct DecodingContext { template struct Decode { - static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx) { + static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { int64_t idx = offsets.args++; return ArgDecoding::Decode(ctx.call_frame->args.types[idx], - ctx.call_frame->args.args[idx]); + ctx.call_frame->args.args[idx], diagnostic); } }; @@ -374,7 +434,8 @@ struct Decode { template struct internal::Decode> { - static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx) { + static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { // Find decoded attribute corresponding to the given attribute index. int64_t i = offsets.attrs++; @@ -394,7 +455,7 @@ struct internal::Decode> { std::string_view attr_name_view = {attr_name->ptr, attr_name->len}; if (attr_name_view != ctx.attrs_names[i]) return std::nullopt; - return AttrDecoding::Decode(attr_type, attr); + return AttrDecoding::Decode(attr_type, attr, diagnostic); } }; @@ -402,8 +463,10 @@ template struct internal::Decode> { using R = typename CtxDecoding::Type; - static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx) { - return CtxDecoding::Decode(ctx.call_frame->api, ctx.call_frame->ctx); + static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return CtxDecoding::Decode(ctx.call_frame->api, ctx.call_frame->ctx, + diagnostic); } }; @@ -426,7 +489,10 @@ class RemainingArgs { size_t idx = offset_ + index; if (idx >= args_->num_args) return std::nullopt; - return ArgDecoding::Decode(args_->types[idx], args_->args[idx]); + // TODO(slebedev): Expose the collected diagnostic to the caller. + DiagnosticEngine diagnostic; + return ArgDecoding::Decode(args_->types[idx], args_->args[idx], + diagnostic); } private: @@ -437,7 +503,8 @@ class RemainingArgs { template <> struct internal::Decode { static std::optional call(DecodingOffsets& offsets, - DecodingContext& ctx) { + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { return RemainingArgs(&ctx.call_frame->args, offsets.args); } }; @@ -464,7 +531,9 @@ class Dictionary { XLA_FFI_AttrType attr_type = attrs_->types[idx]; void* attr = attrs_->attrs[idx]; - return AttrDecoding::Decode(attr_type, attr); + // TODO(slebedev): Expose the collected diagnostic to the caller. + DiagnosticEngine diagnostic; + return AttrDecoding::Decode(attr_type, attr, diagnostic); } private: @@ -489,7 +558,8 @@ class Dictionary { template <> struct internal::Decode> { static std::optional call(DecodingOffsets& offsets, - DecodingContext& ctx) { + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { return Dictionary(&ctx.call_frame->attrs); } }; @@ -497,10 +567,11 @@ struct internal::Decode> { // Decode `AttrsTag` into a type `T` relying on struct decoding defined below. template struct internal::Decode> { - static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx) { + static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { return AttrDecoding::Decode( XLA_FFI_AttrType_DICTIONARY, - const_cast(&ctx.call_frame->attrs)); + const_cast(&ctx.call_frame->attrs), diagnostic); } }; @@ -648,12 +719,15 @@ class Handler : public Ffi { internal::DecodingContext ctx = {call_frame, attrs_.data(), attrs_idx_.data()}; + DiagnosticEngine diagnostic; + std::tuple>...> args = { - internal::Decode::call(offsets, ctx)...}; + internal::Decode::call(offsets, ctx, diagnostic)...}; bool all_decoded = (std::get(args).has_value() && ...); if (!all_decoded) { - return FailedDecodeError(call_frame, {std::get(args).has_value()...}); + return FailedDecodeError(call_frame, {std::get(args).has_value()...}, + diagnostic); } auto result = fn_(std::move(*std::get(args))...); @@ -662,7 +736,8 @@ class Handler : public Ffi { } XLA_FFI_Error* FailedDecodeError(const XLA_FFI_CallFrame* call_frame, - std::array decoded) const { + std::array decoded, + const DiagnosticEngine& diagnostic) const { std::string message = "Failed to decode all FFI handler operands (bad operands at: "; for (size_t cnt = 0, idx = 0; idx < kSize; ++idx) { @@ -672,6 +747,10 @@ class Handler : public Ffi { } } message.append(")"); + if (auto s = std::move(diagnostic).Result(); !s.empty()) { + message.append("\nDiagnostics:\n"); + message.append(s); + } return InvalidArgument(call_frame->api, message); } @@ -712,16 +791,17 @@ class Handler : public Ffi { // Builtin attributes decoding //===----------------------------------------------------------------------===// -#define XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(T, TYPE) \ - template <> \ - struct AttrDecoding { \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr) { \ - if (type != TYPE) { \ - return std::nullopt; \ - } \ - \ - return *reinterpret_cast(attr); \ - } \ +#define XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(T, TYPE) \ + template <> \ + struct AttrDecoding { \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine&) { \ + if (type != TYPE) { \ + return std::nullopt; \ + } \ + \ + return *reinterpret_cast(attr); \ + } \ } XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int32_t, XLA_FFI_AttrType_I32); @@ -733,7 +813,7 @@ XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_AttrType_F32); template <> struct AttrDecoding { static std::optional Decode(XLA_FFI_AttrType type, - void* attr) { + void* attr, DiagnosticEngine&) { if (type != XLA_FFI_AttrType_STRING) { return std::nullopt; } @@ -745,7 +825,8 @@ struct AttrDecoding { template <> struct AttrDecoding { - static std::optional Decode(XLA_FFI_AttrType type, void* attr) { + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine&) { if (type != XLA_FFI_AttrType_DICTIONARY) { return std::nullopt; } @@ -825,19 +906,20 @@ auto DictionaryDecoder(Members... m) { // StructMember("a"), // StructMember("b")); // -#define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ - template <> \ - struct AttrDecoding { \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr) { \ - if (type != XLA_FFI_AttrType_DICTIONARY) { \ - return std::nullopt; \ - } \ - \ - auto decoder = internal::DictionaryDecoder(__VA_ARGS__); \ - return decltype(decoder)::Decode( \ - reinterpret_cast(attr), \ - internal::StructMemberNames(__VA_ARGS__)); \ - } \ +#define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ + template <> \ + struct AttrDecoding { \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine&) { \ + if (type != XLA_FFI_AttrType_DICTIONARY) { \ + return std::nullopt; \ + } \ + \ + auto decoder = internal::DictionaryDecoder(__VA_ARGS__); \ + return decltype(decoder)::Decode( \ + reinterpret_cast(attr), \ + internal::StructMemberNames(__VA_ARGS__)); \ + } \ } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 2c485dc7a7e1ba..ed168f91488b5c 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -144,7 +144,7 @@ struct BufferBase { template struct ArgDecoding> { static std::optional> Decode(XLA_FFI_ArgType type, - void* arg) { + void* arg, DiagnosticEngine&) { if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; auto* buf = reinterpret_cast(arg); // TODO(slebedev): Emit a user-friendly error instead. @@ -187,7 +187,8 @@ struct CtxDecoding> { static_assert(std::is_pointer_v, "stream type must be a pointer"); static std::optional Decode(const XLA_FFI_Api* api, - XLA_FFI_ExecutionContext* ctx) { + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { XLA_FFI_Stream_Get_Args args; args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; args.priv = nullptr; diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index b1502ce1b306fb..83878d2b085956 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -61,7 +61,8 @@ struct Buffer { template <> struct ArgDecoding { - static std::optional Decode(XLA_FFI_ArgType type, void* arg) { + static std::optional Decode(XLA_FFI_ArgType type, void* arg, + DiagnosticEngine&) { if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; auto* buf = reinterpret_cast(arg); @@ -82,7 +83,8 @@ struct CtxDecoding { using Type = const ServiceExecutableRunOptions*; static std::optional Decode(const XLA_FFI_Api* api, - XLA_FFI_ExecutionContext* ctx) { + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { void* ptr = api->internal_api->XLA_FFI_ServiceExecutableRunOptions_Get(ctx); return reinterpret_cast(ptr); } From 50949edcb78a5d89b3994d1822adadc121fe63ff Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 23 Nov 2023 13:41:45 -0800 Subject: [PATCH 054/381] [stream_executor] Add If-Else conditional commands to CommandBuffer PiperOrigin-RevId: 584941044 --- .../xla/xla/stream_executor/command_buffer.cc | 7 + .../xla/xla/stream_executor/command_buffer.h | 18 +- .../cuda/cuda_command_buffer_test.cc | 105 +++++++++ .../cuda/cuda_conditional_kernels.cu.cc | 46 +++- .../xla/stream_executor/cuda/cuda_driver.cc | 11 + .../stream_executor/gpu/gpu_command_buffer.cc | 223 +++++++++++++----- .../stream_executor/gpu/gpu_command_buffer.h | 41 +++- .../xla/xla/stream_executor/gpu/gpu_driver.h | 6 + .../rocm/hip_conditional_kernels.cu.cc | 5 +- .../stream_executor_internal.h | 8 + 10 files changed, 390 insertions(+), 80 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 6ac06e67f1e9a0..94ce1f67a4d846 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -120,6 +120,13 @@ tsl::Status CommandBuffer::If(StreamExecutor* executor, DeviceMemory pred, return implementation_->If(executor, pred, std::move(then_builder)); } +tsl::Status CommandBuffer::IfElse(StreamExecutor* executor, + DeviceMemory pred, Builder then_builder, + Builder else_builder) { + return implementation_->IfElse(executor, pred, std::move(then_builder), + std::move(else_builder)); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 0c12cc82f8ffcf..2698293af7261a 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -122,13 +122,23 @@ class CommandBuffer { tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size); - // Adds a conditional operation that will execute a command buffer constructed - // by `then_builder` if predicate is true. Builder should not call `Update` or - // `Finalize` on command buffer argument, parent command buffer is responsible - // for updating and finalizing conditional command buffers. + //--------------------------------------------------------------------------// + // Command buffer condtitional commands API + //--------------------------------------------------------------------------// + + // Adds a conditional operation that will run a command buffer constructed by + // `then_builder` if `predicate` value is `true`. tsl::Status If(StreamExecutor* executor, DeviceMemory pred, Builder then_builder); + // Adds a conditional operation that will run a command buffer constructed by + // `then_builder` if `predicate` value is `true`, or a command buffer + // constructed by `else_builder` if `predicate` is `false`. + tsl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, + Builder then_builder, Builder else_builder); + + //--------------------------------------------------------------------------// + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. tsl::Status Finalize(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 981187b854f637..cc1f823e3608db 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -36,6 +36,8 @@ namespace stream_executor::cuda { using AddI32Kernel = TypedKernel, DeviceMemory, DeviceMemory>; +using MulI32Kernel = TypedKernel, DeviceMemory, + DeviceMemory>; using AddI32Ptrs3 = TypedKernel>; @@ -309,6 +311,109 @@ TEST(CudaCommandBufferTest, ConditionalIf) { ASSERT_EQ(dst, expected); } +TEST(CudaCommandBufferTest, ConditionalIfElse) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + if (!CommandBuffer::SupportsConditionalCommands(platform)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + AddI32Kernel add(executor); + MulI32Kernel mul(executor); + + { // Load addition kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); + TF_ASSERT_OK(executor->GetKernel(spec, &add)); + } + + { // Load multiplication kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetMulI32CudaKernel(), "mul"); + TF_ASSERT_OK(executor->GetKernel(spec, &mul)); + } + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=2, b=3, c=0, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + stream.ThenMemcpy(&pred, &kTrue, 1); + stream.ThenMemset32(&a, 2, byte_length); + stream.ThenMemset32(&b, 3, byte_length); + stream.ThenMemZero(&c, byte_length); + + // if (pred == true) c = a + b + CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // if (pred == false) c = a * b + CommandBuffer::Builder else_builder = [&](CommandBuffer* else_cmd) { + return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `c` data back to host. + std::vector dst(4, 42); + stream.ThenMemcpy(dst.data(), c, byte_length); + + std::vector expected_add = {5, 5, 5, 5}; + ASSERT_EQ(dst, expected_add); + + // Reset predicate to false. + constexpr bool kFalse = false; + stream.ThenMemcpy(&pred, &kFalse, 1); + + // Submit the same command buffer, but this time it should execute `else` + // branch and multiply inputs. + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), c, byte_length); + std::vector expected_mul = {6, 6, 6, 6}; + ASSERT_EQ(dst, expected_mul); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + stream.ThenMemZero(&d, byte_length); + + // if (pred == false) d = a * b (write to a new location). + else_builder = [&](CommandBuffer* else_cmd) { + return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, d); + }; + + // Update command buffer with a conditional to use new `else` builder. + TF_ASSERT_OK(cmd_buffer.Update()); + TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + stream.ThenMemcpy(dst.data(), d, byte_length); + ASSERT_EQ(dst, expected_mul); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index 3c6f492615e170..ad49ecade31c8e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,30 +19,50 @@ namespace stream_executor { namespace cuda { namespace { -#if CUDA_VERSION >= 12030 +#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) && \ + CUDA_VERSION >= 12030 -__global__ void SetCondition(cudaGraphConditionalHandle handle, - bool* predicate) { -#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) +__global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, + bool* predicate) { if (*predicate) { - cudaGraphSetConditional(handle, 1); + cudaGraphSetConditional(then_handle, 1); } else { - cudaGraphSetConditional(handle, 0); + cudaGraphSetConditional(then_handle, 0); } -#endif // defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) } -#else -__global__ void SetCondition() {} -#endif // CUDA_VERSION >= 12030 +__global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, + cudaGraphConditionalHandle else_handle, + bool* predicate) { + if (*predicate) { + cudaGraphSetConditional(then_handle, 1); + cudaGraphSetConditional(else_handle, 0); + } else { + cudaGraphSetConditional(then_handle, 0); + cudaGraphSetConditional(else_handle, 1); + } +} + +#else // CUDA graph conditionals are not available + +__global__ void SetIfCondition() {} +__global__ void SetIfElseCondition() {} + +#endif } // namespace } // namespace cuda namespace gpu { -void* GetSetConditionKernel() { - return reinterpret_cast(&cuda::SetCondition); + +void* GetSetIfConditionKernel() { + return reinterpret_cast(&cuda::SetIfCondition); } + +void* GetSetIfElseConditionKernel() { + return reinterpret_cast(&cuda::SetIfElseCondition); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 3a32a8f8df1801..79ee95668a6f5c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -775,6 +775,17 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, return absl::UnimplementedError("unsupported node type"); } +/* static */ tsl::Status GpuDriver::GraphAddEmptyNode( + CUgraphNode* node, CUgraph graph, absl::Span deps) { + VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); + + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddEmptyNode(node, graph, deps.data(), deps.size()), + "Failed to add empty node to a CUDA graph"); + + return tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::GraphAddKernelNode( CUgraphNode* node, CUgraph graph, absl::Span deps, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index a9d0e7fc907b17..09143ffbb5fcad 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -15,7 +15,9 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include #include +#include #include #include #include @@ -40,6 +42,7 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace stream_executor::gpu { @@ -186,6 +189,16 @@ tsl::Status GpuCommandBuffer::CheckPrimary() { return tsl::OkStatus(); } +tsl::Status GpuCommandBuffer::CheckNumCommandBuffers( + const ConditionalCommandBuffers& cmd_buffers, size_t num_cmd_buffers) { + if (cmd_buffers.handles.size() != num_cmd_buffers) { + return absl::InternalError(absl::StrCat( + "Expected to have ", num_cmd_buffers, + " conditional command buffers, got ", cmd_buffers.handles.size())); + } + return tsl::OkStatus(); +} + tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, @@ -265,38 +278,29 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder) { - DCHECK(executor->implementation() == parent_); // NOLINT - - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `If`. - SetConditionKernel set_condition(executor); +//--------------------------------------------------------------------------// +// Command buffer condtitional commands API +//--------------------------------------------------------------------------// - { // Load kernels that update condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/1); - spec.AddInProcessSymbol(gpu::GetSetConditionKernel(), "set_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_condition)); +tsl::StatusOr> +GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { + std::vector handles; + for (size_t i = 0; i < num_handles; ++i) { + TF_RETURN_IF_ERROR(GpuDriver::GraphConditionalHandleCreate( + &handles.emplace_back(), graph_, parent_->gpu_context(), 0, 0)); } + return handles; +} + +tsl::StatusOr> +GpuCommandBuffer::CreateConditionalNodes( + absl::Span handles) { + std::vector conditional_graphs; using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; - // Conditional command buffers always created in nested mode. - CommandBuffer::Mode nested = CommandBuffer::Mode::kNested; - - if (state_ == State::kCreate) { - // Create a handle for a conditional node. - GpuGraphConditionalHandle handle; - TF_RETURN_IF_ERROR(GpuDriver::GraphConditionalHandleCreate( - &handle, graph_, parent_->gpu_context(), 0, 0)); - - // Add a kernel to update conditional handle value based on a predicate. - TF_RETURN_IF_ERROR( - Launch(set_condition, ThreadDim(), BlockDim(), handle, predicate)); - - // Add conditional node to the graph. + for (GpuGraphConditionalHandle handle : handles) { Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); @@ -309,19 +313,82 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, GpuDriver::GpuGraphNodeResult result, GpuDriver::GraphAddNode(node, graph_, absl::MakeSpan(deps), params)); - // Set up conditional command buffer. - GpuGraphHandle then_graph = std::get(result).graph; + conditional_graphs.push_back(std::get(result).graph); + } + + return conditional_graphs; +} + +tsl::StatusOr +GpuCommandBuffer::CreateConditionalCommandBuffers( + absl::Span handles, + absl::Span graphs, + absl::Span builders) { + ConditionalCommandBuffers cond_cmd_buffers; + + // Conditional command buffers always created in nested mode and with + // underlying graphs owned by a conditional node. + CommandBuffer::Mode nested = CommandBuffer::Mode::kNested; + bool is_owned_graph = false; - // Wrap conditional graph into command buffer and pass it to the builder. - auto then_cmd_buffer = - CommandBuffer::Wrap(parent_->GetCommandBufferImplementation( - nested, then_graph, /*is_owned_graph=*/false)); - TF_RETURN_IF_ERROR(then_builder(&then_cmd_buffer)); - TF_RETURN_IF_ERROR(then_cmd_buffer.Finalize()); + for (size_t i = 0; i < handles.size(); ++i) { + auto command_buffer_impl = parent_->GetCommandBufferImplementation( + nested, graphs[i], is_owned_graph); + + auto command_buffer = CommandBuffer::Wrap(std::move(command_buffer_impl)); + TF_RETURN_IF_ERROR(builders[i](&command_buffer)); + TF_RETURN_IF_ERROR(command_buffer.Finalize()); + + cond_cmd_buffers.Add(handles[i], std::move(command_buffer)); + } + + return cond_cmd_buffers; +} - // Keep track of conditional handle and command buffers for update. - auto& command_buffers = conditional_command_buffers_.emplace_back(); - command_buffers.Add(handle, std::move(then_cmd_buffer)); +tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( + absl::Span command_buffers, + absl::Span builders) { + for (size_t i = 0; i < command_buffers.size(); ++i) { + // Use parent graph executable for conditional command buffer update. + ScopedGpuGraphExec scoped_exec(Cast(&command_buffers[i]), exec_); + + // Update command buffer using user-provided builder callback. + TF_RETURN_IF_ERROR(command_buffers[i].Update()); + TF_RETURN_IF_ERROR(builders[i](&command_buffers[i])); + TF_RETURN_IF_ERROR(command_buffers[i].Finalize()); + } + return tsl::OkStatus(); +} + +tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, + DeviceMemory predicate, + CommandBuffer::Builder then_builder) { + DCHECK(executor->implementation() == parent_); // NOLINT + + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `If`. + SetIfConditionKernel set_if_condition(executor); + + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddInProcessSymbol(gpu::GetSetIfConditionKernel(), "set_if_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_condition)); + } + + std::array builders = {std::move(then_builder)}; + + if (state_ == State::kCreate) { + TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(1)); + + // Add a kernel to update conditional handle value based on a predicate. + TF_RETURN_IF_ERROR(Launch(set_if_condition, ThreadDim(), BlockDim(), + handles[0], predicate)); + + // Create conditional command buffer for then branch. + TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(handles)); + TF_ASSIGN_OR_RETURN( + conditional_command_buffers_.emplace_back(), + CreateConditionalCommandBuffers(handles, graphs, builders)); return tsl::OkStatus(); } @@ -331,35 +398,77 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, conditional_command_buffers_[update_state_.conditional_idx++]; // Sanity check that we got the correct conditional command buffers. - if (cond_cmd_buffers.handles.size() != 1 || - cond_cmd_buffers.command_buffers.size() != 1) { - return absl::InternalError( - "`If` command expected one conditional command buffer"); - } - - GpuGraphConditionalHandle& handle = cond_cmd_buffers.handles[0]; - CommandBuffer& then_cmd_buffer = cond_cmd_buffers.command_buffers[0]; + TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, 1)); // Update a kernel that updates conditional handle based on a predicate. - TF_RETURN_IF_ERROR( - Launch(set_condition, ThreadDim(), BlockDim(), handle, predicate)); + TF_RETURN_IF_ERROR(Launch(set_if_condition, ThreadDim(), BlockDim(), + cond_cmd_buffers.handles[0], predicate)); - // Conditional handle created only when we add conditional node first time - // and then owned by a `graph_`. We also don't need to update conditional - // node itself, as it reuses the same handle. - update_state_.node_idx++; // skip conditional node + // Skip updating conditional nodes. + update_state_.node_idx += cond_cmd_buffers.handles.size(); - // Use parent graph executable for conditional command buffer update. - ScopedGpuGraphExec scoped_exec(Cast(&then_cmd_buffer), exec_); + return UpdateConditionalCommandBuffers( + absl::MakeSpan(cond_cmd_buffers.command_buffers), builders); + } + + return UnsupportedStateError(state_); +} - // Update `then` command buffer using user-provided builder callback. - TF_RETURN_IF_ERROR(then_cmd_buffer.Update()); - TF_RETURN_IF_ERROR(then_builder(&then_cmd_buffer)); - TF_RETURN_IF_ERROR(then_cmd_buffer.Finalize()); +tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, + DeviceMemory predicate, + CommandBuffer::Builder then_builder, + CommandBuffer::Builder else_builder) { + DCHECK(executor->implementation() == parent_); // NOLINT + + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `If`. + SetIfElseConditionKernel set_if_else_condition(executor); + + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(gpu::GetSetIfElseConditionKernel(), + "set_if_else_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_else_condition)); + } + + std::array builders = {std::move(then_builder), + std::move(else_builder)}; + + if (state_ == State::kCreate) { + TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(2)); + + // Add a kernel to update conditional handle value based on a predicate. + TF_RETURN_IF_ERROR(Launch(set_if_else_condition, ThreadDim(), BlockDim(), + handles[0], handles[1], predicate)); + + // Create conditional command buffers for then/else branches. + TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(handles)); + TF_ASSIGN_OR_RETURN( + conditional_command_buffers_.emplace_back(), + CreateConditionalCommandBuffers(handles, graphs, builders)); return tsl::OkStatus(); } + if (state_ == State::kUpdate) { + ConditionalCommandBuffers& cond_cmd_buffers = + conditional_command_buffers_[update_state_.conditional_idx++]; + + // Sanity check that we got the correct conditional command buffers. + TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, 2)); + + // Update a kernel that updates conditional handles based on a predicate. + TF_RETURN_IF_ERROR(Launch(set_if_else_condition, ThreadDim(), BlockDim(), + cond_cmd_buffers.handles[0], + cond_cmd_buffers.handles[0], predicate)); + + // Skip updating conditional nodes. + update_state_.node_idx += cond_cmd_buffers.handles.size(); + + return UpdateConditionalCommandBuffers( + absl::MakeSpan(cond_cmd_buffers.command_buffers), builders); + } + return UnsupportedStateError(state_); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 4a2b46dda07e0c..ed048d24f6675a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_COMMAND_BUFFER_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_COMMAND_BUFFER_H_ +#include #include #include #include #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_executor.h" @@ -31,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace stream_executor::gpu { @@ -57,6 +60,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) override; + tsl::Status IfElse(StreamExecutor* executor, DeviceMemory predicate, + CommandBuffer::Builder then_builder, + CommandBuffer::Builder else_builder) override; + tsl::Status Finalize() override; tsl::Status Update() override; @@ -97,9 +104,12 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { private: using Dependencies = absl::InlinedVector; - // A signature of a device kernels updating conditional handle. - using SetConditionKernel = + // A signature of a device kernels updating conditional handle(s). + using SetIfConditionKernel = TypedKernel>; + using SetIfElseConditionKernel = + TypedKernel>; // Overwrites the `exec_` handle in a Gpu command buffer by `exec`, and // restores to the original handle when destroyed. This allows us updating @@ -124,9 +134,24 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { std::vector command_buffers; }; + tsl::StatusOr> + CreateConditionalHandles(size_t num_handles); + + tsl::StatusOr> CreateConditionalNodes( + absl::Span handles); + + tsl::StatusOr CreateConditionalCommandBuffers( + absl::Span handles, + absl::Span graphs, + absl::Span builders); + + tsl::Status UpdateConditionalCommandBuffers( + absl::Span command_buffers, + absl::Span builders); + // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a - // dependency between all nodes added to a command buffer. We need a concept - // of a barrier at a command buffer level. + // dependency between all nodes added to a command buffer. We need a + // concept of a barrier at a command buffer level. Dependencies GetDependencies(); // Returns OK status if command buffer is not finalized and it is still @@ -137,6 +162,11 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // error. tsl::Status CheckPrimary(); + // Returns OK status if the number of command buffers is equal to the expected + // one, otherwise returns internal error. + tsl::Status CheckNumCommandBuffers( + const ConditionalCommandBuffers& cmd_buffers, size_t num_cmd_buffers); + static_assert(std::is_pointer_v, "GpuGraphHandle must be a pointer"); static_assert(std::is_pointer_v, @@ -197,7 +227,8 @@ inline tsl::Status GpuCommandBuffer::Launch( // values, and allow implementing on-device control flow via conditional command // buffers. -void* GetSetConditionKernel(); +void* GetSetIfConditionKernel(); +void* GetSetIfElseConditionKernel(); } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 58592442908112..888d00f218dc96 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -462,6 +462,12 @@ class GpuDriver { GpuGraphNodeHandle* node, GpuGraphHandle graph, absl::Span deps, const GpuGraphNodeParams& params); + // Creates an empty node and adds it to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g14b625984430cb2d574c63f29c9b9223 + static tsl::Status GraphAddEmptyNode(GpuGraphNodeHandle* node, + GpuGraphHandle graph, + absl::Span deps); + // Creates a kernel execution node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management diff --git a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index c22a051f2bb87c..a31990bf61b989 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -25,7 +25,10 @@ __global__ void SetCondition() {} } // namespace rocm namespace gpu { -void* GetSetConditionKernel() { +void* GetSetIfConditionKernel() { + return reinterpret_cast(&rocm::SetCondition); +} +void* GetSetIfElseConditionKernel() { return reinterpret_cast(&rocm::SetCondition); } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index b2fb31023ecb7f..522186c4bad7e9 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -150,6 +150,14 @@ class CommandBufferInterface { virtual tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) = 0; + // Adds a conditional operation that will run a command buffer constructed by + // `then_builder` if `predicate` value is `true`, or a command buffer + // constructed by `else_builder` if `predicate` is `false`. + virtual tsl::Status IfElse(StreamExecutor* executor, + DeviceMemory predicate, + CommandBuffer::Builder then_builder, + CommandBuffer::Builder else_builder) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From 837382acc92e8c51316aef6612c9d1501e825e5c Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 23 Nov 2023 17:49:48 -0800 Subject: [PATCH 055/381] Allow unbounded dynamism representation for AddOp. This adds the necessary changes in XlaBuilder API, verifier, and shape inference following StableHLO rules for unbounded dynamism. Implicit broadcasting support in XlaBuilder API will be addressed in a follow up CL. PiperOrigin-RevId: 584967526 --- third_party/xla/xla/client/BUILD | 3 + third_party/xla/xla/client/xla_builder.cc | 17 +-- .../xla/xla/client/xla_builder_test.cc | 47 ++++++++ third_party/xla/xla/service/BUILD | 3 + .../xla/xla/service/shape_inference.cc | 114 ++++++++++++------ .../xla/xla/service/shape_inference_test.cc | 50 ++++++++ third_party/xla/xla/shape.h | 11 ++ third_party/xla/xla/shape_util.cc | 5 + third_party/xla/xla/shape_util.h | 5 +- third_party/xla/xla/shape_util_test.cc | 10 ++ .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 2 +- 11 files changed, 220 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index ca9d8a952a26b9..7f95b2ade5060c 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -301,10 +301,13 @@ xla_cc_test( ":xla_computation", "//xla:debug_options_flags", "//xla:shape_util", + "//xla:statusor", + "//xla:test", "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index 5b70a8e2e91466..ccb353f052989b 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -1002,15 +1002,18 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape, GetShapePtr(updated_lhs)); - if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(shape, updated_lhs)); - } TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape, GetShapePtr(updated_rhs)); - if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(shape, updated_rhs)); + if (!updated_lhs_shape->is_unbounded_dynamic() && + !updated_rhs_shape->is_unbounded_dynamic()) { + if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(shape, updated_lhs)); + } + if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(shape, updated_rhs)); + } } if (binop == HloOpcode::kCompare) { diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 018ddb30f782fd..2fc57142837115 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -31,9 +31,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/test.h" #include "xla/test_helpers.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -197,6 +201,16 @@ TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { GmockMatch(m::Add(m::Parameter(), m::Broadcast(m::Constant())))); } +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcastReversed) { + XlaBuilder b(TestName()); + XlaOp x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + Add(ConstantR0(&b, 1.0), x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(m::Add(m::Broadcast(m::Constant()), m::Parameter()))); +} + TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); @@ -1543,5 +1557,38 @@ TEST_F(XlaBuilderTest, TopKDimensions) { EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(0), 6); EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(1), k); } + +TEST_F(XlaBuilderTest, UnboundedAdd) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Add(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Add(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 6f3cb3499b63e2..0f615afd3fad88 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -613,6 +613,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -627,8 +628,10 @@ xla_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ + ":hlo_parser", ":shape_inference", "//xla:shape_util", + "//xla:statusor", "//xla:test", "//xla:test_helpers", "//xla:types", diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index dcb5d7eded4741..bf324e880c3963 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -55,6 +56,7 @@ limitations under the License. namespace xla { namespace { +using absl::InvalidArgumentError; using absl::StrFormat; using absl::StrJoin; @@ -813,12 +815,47 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector output_dimensions(lhs.rank()); std::vector output_dimensions_is_dynamic(lhs.rank()); for (int64_t i = 0; i < lhs.rank(); ++i) { - if (lhs.dimensions(i) == rhs.dimensions(i)) { - output_dimensions[i] = lhs.dimensions(i); - } else if (lhs.dimensions(i) == 1) { - output_dimensions[i] = rhs.dimensions(i); - } else if (rhs.dimensions(i) == 1) { + if (lhs.dimensions(i) == 1 || rhs.dimensions(i) == 1) { + // For the unbounded case, the operand with 1 should be broadcasted to the + // unbounded size which can be > 1. + // LHS | RHS | Result + // 1 | X | X + // 1 | <=X | <=X + // 1 | ? | ? + // X | 1 | X + // <=X | 1 | <=X + // ? | 1 | ? + output_dimensions[i] = + lhs.dimensions(i) == 1 ? rhs.dimensions(i) : lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.dimensions(i) == 1 + ? rhs.is_dynamic_dimension(i) + : lhs.is_dynamic_dimension(i); + } else if (lhs.dimensions(i) == rhs.dimensions(i)) { + // LHS | RHS | Result + // X | X | X + // X | <=X | <=X + // <=X | X | <=X + // <=X | <=X | <=X + // ? | ? | ? output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = + lhs.is_dynamic_dimension(i) || rhs.is_dynamic_dimension(i); + } else if (lhs.is_unbounded_dynamic_dimension(i) || + rhs.is_unbounded_dynamic_dimension(i)) { + // For the last two rows, consider when <=X turns out to be 1 and ? turns + // out to be 5. It would be wrong to infer <=1 as this is a degenerate + // dimension that should be broadcasted to 5. + // LHS | RHS | Result + // X | ? | X + // ? | X | X + // <=X | ? | ? + // ? | <=X | ? + output_dimensions[i] = lhs.is_unbounded_dynamic_dimension(i) + ? rhs.dimensions(i) + : lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_unbounded_dynamic_dimension(i) + ? rhs.is_dynamic_dimension(i) + : lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -827,13 +864,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - // Merge dynamic dimensions from two shapes. - for (int64_t i = 0; i < rhs.rank(); ++i) { - if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) { - output_dimensions_is_dynamic[i] = true; - } - } - return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), output_dimensions, output_dimensions_is_dynamic); } @@ -841,20 +871,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, absl::Span broadcast_dimensions) { + if (smaller_shape.is_unbounded_dynamic() || + larger_shape.is_unbounded_dynamic()) { + return InvalidArgumentError(StrFormat( + "Unbounded dynamic shapes not supported, but we have %s and %s", + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape))); + } + if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. - return InvalidArgument("Shapes must be equal rank, but are %s and %s", - ShapeUtil::HumanString(smaller_shape), - ShapeUtil::HumanString(larger_shape)); + return InvalidArgumentError( + StrFormat("Shapes must be equal rank, but are %s and %s", + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape))); } else if (broadcast_dimensions.size() != smaller_shape.rank()) { - return InvalidArgument( + return InvalidArgumentError(StrFormat( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %d, size of broadcast_dimensions is " "%u.", - smaller_shape.rank(), broadcast_dimensions.size()); + smaller_shape.rank(), broadcast_dimensions.size())); } // broadcast_dimensions is a sequence of dimensions; its length is equal to @@ -902,15 +941,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { int64_t dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { - return InvalidArgument( - "Broadcast dimension number (%d) cannot be negative.", - dimension_to_match); + return InvalidArgumentError( + StrFormat("Broadcast dimension number (%d) cannot be negative.", + dimension_to_match)); } if (dimension_to_match >= larger_shape.dimensions_size()) { - return InvalidArgument( - "Broadcast dimension number (%d) too large; higher-rank " - "operand has rank %d.", - dimension_to_match, larger_shape.dimensions_size()); + return InvalidArgumentError( + StrFormat("Broadcast dimension number (%d) too large; higher-rank " + "operand has rank %d.", + dimension_to_match, larger_shape.dimensions_size())); } int64_t small_dimension_size = smaller_shape.dimensions(i); int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match); @@ -922,11 +961,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // InDim broadcasting). if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { - return InvalidArgument( - "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, - small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape), - ShapeUtil::HumanString(larger_shape)); + return InvalidArgumentError( + StrFormat("Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, + small_dimension_size, large_dimension_size, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape))); } if (small_is_dynamic != large_is_dynamic) { if (small_dimension_size == large_dimension_size || @@ -934,18 +973,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, (large_dimension_size == 1 && !large_is_dynamic)) { // Do nothing. It's OK when the size-1 dimension is not static. } else { - return InvalidArgument( - "Broadcast dimension %d dynamism mismatch: %s and %s.", i, - ShapeUtil::HumanString(smaller_shape), - ShapeUtil::HumanString(larger_shape)); + return InvalidArgumentError( + StrFormat("Broadcast dimension %d dynamism mismatch: %s and %s.", i, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape))); } } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { - return InvalidArgument( - "Broadcast dimensions order is wrong: %d comes after %d.", - dimension_to_match, broadcast_dimensions.at(i - 1)); + return InvalidArgumentError( + StrFormat("Broadcast dimensions order is wrong: %d comes after %d.", + dimension_to_match, broadcast_dimensions.at(i - 1))); } output_shape.set_dimensions(dimension_to_match, small_dimension_size); @@ -979,7 +1018,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs) && + !lhs.is_unbounded_dynamic() && !rhs.is_unbounded_dynamic()) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). Shape result = ShapeUtil::ChangeElementType( diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 6a58bd382477b6..229a1a02736747 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/strings/string_view.h" @@ -24,7 +25,11 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/padding.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/types.h" @@ -105,6 +110,10 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { ProgramShape scatter_program_shape_; }; +// Subclass for testing unbounded dynamic binary ops +class UnboundedBinaryOpShapeInferenceTest + : public ::testing::TestWithParam> {}; + TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = @@ -3733,5 +3742,46 @@ INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest, BF16}), ScatterTestName()); +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { + StatusOr lhs = ParseShape(GetParam()[0]); + StatusOr rhs = ParseShape(GetParam()[1]); + StatusOr expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + StatusOr inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kAdd, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op add with incompatible shapes")); + } +} + +INSTANTIATE_TEST_SUITE_P( + UnboundedDynamism, UnboundedBinaryOpShapeInferenceTest, + ::testing::Values( + // LHS | RHS | Result + // 1 | ? | ? + std::vector({"f32[1]", "f32[?]", "f32[?]"}), + // ? | 1 | ? + std::vector({"f32[?]", "f32[1]", "f32[?]"}), + // 2 | ? | 2 + std::vector({"f32[2]", "f32[?]", "f32[2]"}), + // ? | 2 | 2 + std::vector({"f32[?]", "f32[2]", "f32[2]"}), + // <=2 | ? | <=2 + std::vector({"f32[<=2]", "f32[?]", "f32[<=2]"}), + // ? | <=2 | <=2 + std::vector({"f32[?]", "f32[<=2]", "f32[<=2]"}), + // ? | ? | ? + std::vector({"f32[?]", "f32[?]", "f32[?]"}), + // ?,2 | ?,3 | error + std::vector({"f32[?,2]", "f32[?,3]", ""}))); + } // namespace } // namespace xla diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 9386fc18043ac7..b2828b429e586c 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -102,6 +102,17 @@ class Shape { // Tuple shapes are traversed recursively. bool is_unbounded_dynamic() const; + // Returns true if the given dimension is unbounded dynamic. + bool is_unbounded_dynamic_dimension(int dimension) const { + return dimensions_.at(dimension) == kUnboundedSize; + } + + // Sets a given dimension as unbounded dynamic. + void set_unbounded_dynamic_dimension(int dimension) { + dynamic_dimensions_[dimension] = true; + dimensions_.at(dimension) = kUnboundedSize; + } + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 004fd94ea5fcdc..628b8011dee630 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -331,6 +331,11 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { } for (int i = 0, n = dimensions.size(); i < n; i++) { shape.set_dynamic_dimension(i, dynamic_dimensions[i]); + if (shape.dimensions(i) == Shape::kUnboundedSize && + !dynamic_dimensions[i]) { + return InvalidArgument( + "Cannot mark a dynamic dimension at dim=%d as static", i); + } } return shape; } diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 760fc42894894e..897aa29d525fd3 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -372,8 +372,9 @@ class ShapeUtil { const std::vector& dynamic_dimensions); // Constructs a new shape with the given element type and sequence of - // dimensions. Method checks if the element type is valid and the shape's - // size fits in std::numeric_limits::max(). + // dimensions. Method checks if the element type is valid, the shape's + // size fits in std::numeric_limits::max(), and dynamic size is not + // marked static. static StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions); static StatusOr MakeValidatedShape( diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index b63be84136e5dc..f18eb0bbcc0945 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -966,6 +966,16 @@ TEST(ShapeUtilTest, UpdateDynamicDimensions) { EXPECT_TRUE(ShapeUtil::GetSubshape(tuple_shape, {0}).is_dynamic_dimension(1)); } +TEST(ShapeUtilTest, InvalidDynamicDimension) { + StatusOr error_status = ShapeUtil::MakeValidatedShape( + F32, {Shape::kUnboundedSize, Shape::kUnboundedSize}, {true, false}); + + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT(error_status.status().message(), + ::testing::HasSubstr( + "Cannot mark a dynamic dimension at dim=1 as static")); +} + TEST(ShapeUtilTest, PermuteDynamicDimensions) { Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}, diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index 8dd290e0c2d307..098a682c82045c 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -63,7 +63,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, for (int64_t dim = 0; dim < rank; ++dim) { int64_t dim_size = xla_ty.dimensions(dim); if (xla_ty.is_dynamic_dimension(dim)) { - if (dim_size != Shape::kUnboundedSize) { + if (!xla_ty.is_unbounded_dynamic_dimension(dim)) { bounds[dim] = dim_size; is_bounded_dynamic = true; } From fbcea480ab1e10b44a71e2d0ed214b2a139484bd Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Thu, 23 Nov 2023 19:47:29 -0800 Subject: [PATCH 056/381] Match uniform_quantized_types functions with UniformQuantizedType dtype. PiperOrigin-RevId: 584980524 --- .../compose_uniform_quantized_type_pass.cc | 17 +++++++------ .../stablehlo/uniform_quantized_types.cc | 12 +++++----- .../stablehlo/uniform_quantized_types.h | 12 +++++----- .../stablehlo/uniform_quantized_types_test.cc | 24 +++++++++---------- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index a5286025463a52..587c971cdffaef 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -731,17 +731,18 @@ class ComposeUniformQuantizedConvolutionOp auto combined_scale_constant_op = cast( scale_combined_broadcast_in_dim_op.getOperand().getDefiningOp()); - SmallVector filter_scale_values; + SmallVector filter_scale_values; for (const auto combined_scale_value : combined_scale_constant_op.getValue() .cast() .getValues()) { - const float filter_scale_value = - combined_scale_value * input_inverse_scales_value; + // UniformQuantizedPerAxisType requires scales to have double dtype. + const double filter_scale_value = static_cast( + combined_scale_value * input_inverse_scales_value); filter_scale_values.emplace_back(filter_scale_value); } // Assumes it is symmetric. - SmallVector filter_zero_point_values( + SmallVector filter_zero_point_values( /*Size=*/filter_scale_values.size(), /*Value=*/0); // Use quantization dimension = 3 that corresponds to the output channel @@ -1083,15 +1084,17 @@ class ComposeUniformQuantizedDotGeneralOp // s1 * s2 auto merged_scale_constant_op = cast(multiply_op_second_operand.getDefiningOp()); - SmallVector filter_scale_values; + SmallVector filter_scale_values; for (const auto merged_scale : merged_scale_constant_op.getValue() .cast() .getValues()) { // (s1 * s2) * (1 / s1) = s2 - filter_scale_values.push_back(merged_scale * input_inverse_scale_value); + // UniformQuantizedPerAxisType requires scales to have double dtype. + filter_scale_values.push_back( + static_cast(merged_scale * input_inverse_scale_value)); } - SmallVector filter_zero_point_values( + SmallVector filter_zero_point_values( /*Size=*/filter_scale_values.size(), /*Value=*/0); const int quantization_dimension = GetFilterQuantizationDimension( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc index f8064220786442..eecc96b04be9eb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc @@ -32,8 +32,8 @@ namespace quant { UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, MLIRContext& context, - const float scale, - const int8_t zero_point) { + const double scale, + const int64_t zero_point) { return UniformQuantizedType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), @@ -42,8 +42,8 @@ UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, } UniformQuantizedType CreateI32F32UniformQuantizedType( - const Location loc, MLIRContext& context, const float scale, - const int32_t zero_point) { + const Location loc, MLIRContext& context, const double scale, + const int64_t zero_point) { return UniformQuantizedType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/32), @@ -53,8 +53,8 @@ UniformQuantizedType CreateI32F32UniformQuantizedType( } UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( - const Location loc, MLIRContext& context, const ArrayRef scales, - const ArrayRef zero_points, const int quantization_dimension) { + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension) { return UniformQuantizedPerAxisType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h index c422439e8472dc..d04dc5a5761b8f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h @@ -34,8 +34,8 @@ namespace quant { // values can be non-zero values. UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, MLIRContext& context, - float scale, - int8_t zero_point); + double scale, + int64_t zero_point); // Creates a `UniformQuantizedType` with the given `scale` and `zero_point` // values. The produced type has f32 as its expressed type and i32 as its @@ -44,8 +44,8 @@ UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, // non-zero values. UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, MLIRContext& context, - float scale, - int32_t zero_point); + double scale, + int64_t zero_point); // Creates a `UniformQuantizedPerAxisType` with the given `scales` and // `zero_points` values. The produced type has f32 as its expressed type and @@ -53,8 +53,8 @@ UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, // storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the // zero point values can be non-zero values. UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( - Location loc, MLIRContext& context, ArrayRef scales, - ArrayRef zero_points, int quantization_dimension); + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension); bool IsStorageTypeI8(QuantizedType quantized_type); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc index 43b78f505564fb..ab1ca261a4075f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -151,8 +151,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasI8StorageType) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); @@ -162,8 +162,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasF32ExpressedType) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.getExpressedType().isF32()); @@ -173,8 +173,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, IsSigned) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.isSigned()); @@ -185,8 +185,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_EQ(quantized_type.getStorageTypeMin(), -128); @@ -198,8 +198,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/3); EXPECT_EQ(quantized_type.getQuantizedDimension(), 3); @@ -210,8 +210,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{8.0, 9.0}, - /*zero_points=*/SmallVector{98, 99}, + /*scales=*/SmallVector{8.0, 9.0}, + /*zero_points=*/SmallVector{98, 99}, /*quantization_dimension=*/0); EXPECT_THAT(quantized_type.getScales(), ElementsAreArray({8.0, 9.0})); From a2ee21765553ae782439d833b7130dcba5897b6e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Nov 2023 23:06:22 -0800 Subject: [PATCH 057/381] Adding fingerprint to module's profile information PiperOrigin-RevId: 585006293 --- third_party/xla/xla/hlo/ir/hlo_module.cc | 1 + third_party/xla/xla/service/hlo.proto | 2 ++ 2 files changed, 3 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index d505039c06c65c..149fec7c3f1d45 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -480,6 +480,7 @@ HloModuleProto HloModule::ToProto() const { profile_info_proto.set_relative_speedup(profile_info.relative_speedup()); profile_info_proto.set_profile_source(profile_info.profile_source()); profile_info_proto.set_compilation_event(profile_info.compilation_event()); + profile_info_proto.set_fingerprint(profile_info.fingerprint()); } if (config_.get().has_static_device_assignment()) { DeviceAssignmentProto device_assignment; diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 54a74e742756bf..ea63665e83e85b 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -588,6 +588,8 @@ message HloModuleProto { xla.ProfileSource profile_source = 3; // The compilation event that triggered the use of the profile. xla.CompilationEvent compilation_event = 4; + // The fingerprint of the unoptimized module this profile was applied to. + string fingerprint = 5; } // Profile information for the HLO module. From ac5fb9369fe1d85218722a6d985db91ec73b18ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Nov 2023 01:04:38 -0800 Subject: [PATCH 058/381] Update GraphDef version to 1690. PiperOrigin-RevId: 585023817 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8515e9ca75f65a..7f560250ce2e94 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1689 // Updated: 2023/11/23 +#define TF_GRAPH_DEF_VERSION 1690 // Updated: 2023/11/24 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From f76a6f45183ae2dfea7815ac44585331d243fa89 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Nov 2023 01:04:47 -0800 Subject: [PATCH 059/381] compat: Update forward compatibility horizon to 2023-11-24 PiperOrigin-RevId: 585023854 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index ff4c66e7386d61..85ae86d88d993d 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 24) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 2f91fbbf66f26dff40611f3be6bc076888387b8a Mon Sep 17 00:00:00 2001 From: Krasimir Georgiev Date: Fri, 24 Nov 2023 05:37:30 -0800 Subject: [PATCH 060/381] Integrate LLVM at llvm/llvm-project@af7a1453526a Updates LLVM usage to match [af7a1453526a](https://github.com/llvm/llvm-project/commit/af7a1453526a) PiperOrigin-RevId: 585072558 --- third_party/llvm/workspace.bzl | 4 ++-- .../translate/mhlo_to_hlo/type_to_shape.cc | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 6634a12658711c..d7f3a809369277 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "2b0b0ad7e760e30b66cebfc994d0eb64de8846ad" - LLVM_SHA256 = "6257e7d8f95a42c248304a1f7ecf7a00ba9504cf6d8e901750781e22589580fe" + LLVM_COMMIT = "af7a1453526a88a0e242baf156244aa4ae42ae4b" + LLVM_SHA256 = "f9f75e4823c2f09a8141ab4db40ee2c79aef96017782a9338e26621ee547d3d5" tf_http_archive( name = name, diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc index 27fbcb2ad60e85..d5709b91b80378 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -82,12 +82,12 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { } std::optional> ConvertDimLevelType( - mlir::sparse_tensor::DimLevelType dlt) { - auto f = mlir::sparse_tensor::getLevelFormat(dlt); + mlir::sparse_tensor::DimLevelType lt) { + auto f = mlir::sparse_tensor::getLevelFormat(lt); if (!f) return std::nullopt; - bool unique = mlir::sparse_tensor::isUniqueDLT(dlt); - bool ordered = mlir::sparse_tensor::isOrderedDLT(dlt); + bool unique = mlir::sparse_tensor::isUniqueLT(lt); + bool ordered = mlir::sparse_tensor::isOrderedLT(lt); switch (*f) { case mlir::sparse_tensor::LevelFormat::Singleton: return std::make_tuple(DimLevelType::DIM_SINGLETON, unique, ordered); @@ -206,12 +206,12 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector lvl_types; llvm::SmallVector level_unique; llvm::SmallVector level_ordered; - for (auto dlt : sparse.getLvlTypes()) { - auto new_dlt = ConvertDimLevelType(dlt); - if (!new_dlt) return {}; - lvl_types.push_back(std::get<0>(*new_dlt)); - level_unique.push_back(std::get<1>(*new_dlt)); - level_ordered.push_back(std::get<2>(*new_dlt)); + for (auto lt : sparse.getLvlTypes()) { + auto new_lt = ConvertDimLevelType(lt); + if (!new_lt) return {}; + lvl_types.push_back(std::get<0>(*new_lt)); + level_unique.push_back(std::get<1>(*new_lt)); + level_ordered.push_back(std::get<2>(*new_lt)); } std::vector ordering(rank); From 07bcefc289a57e33acc0bdf2196a4f39464bce02 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Nov 2023 06:08:01 -0800 Subject: [PATCH 061/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4347953799d066962cb1897814de77c8e195499d. PiperOrigin-RevId: 585077428 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 7e03e86883ed38..ada8843278851c 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "16b6347e3e522ba322a2bcd90f1311018f20b47d" - TFRT_SHA256 = "4a8bf225e64641999b1290b7554270be0066e63886d713d77dcbe9317923ef9f" + TFRT_COMMIT = "4347953799d066962cb1897814de77c8e195499d" + TFRT_SHA256 = "26af1f500eab6aa22f47e05a36253faeee8786208d18e4f0ee385f9ac04f21bf" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 7e03e86883ed38..ada8843278851c 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "16b6347e3e522ba322a2bcd90f1311018f20b47d" - TFRT_SHA256 = "4a8bf225e64641999b1290b7554270be0066e63886d713d77dcbe9317923ef9f" + TFRT_COMMIT = "4347953799d066962cb1897814de77c8e195499d" + TFRT_SHA256 = "26af1f500eab6aa22f47e05a36253faeee8786208d18e4f0ee385f9ac04f21bf" tf_http_archive( name = "tf_runtime", From 790adb93923201082fea1c91cde9218276d5e12c Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 24 Nov 2023 06:41:08 -0800 Subject: [PATCH 062/381] [TileAnalysis] Add indexing computation for fusionOp. PiperOrigin-RevId: 585083299 --- third_party/xla/xla/service/gpu/model/BUILD | 2 + .../xla/service/gpu/model/tile_analysis.cc | 174 +++++++++- .../xla/xla/service/gpu/model/tile_analysis.h | 45 +-- .../service/gpu/model/tile_analysis_test.cc | 296 +++++++++++++++--- 4 files changed, 442 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index c0b2bf8ec14937..1a3189df9eb272 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -221,10 +221,12 @@ cc_library( "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 41a9c73951bcb5..7005c0c89081a6 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -17,14 +17,17 @@ limitations under the License. #include #include +#include #include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project @@ -35,6 +38,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -51,10 +55,9 @@ using mlir::MLIRContext; StatusOr ComputeCwiseOpIndexing( const HloInstruction* instr, MLIRContext* mlir_context) { auto dims = instr->shape().dimensions(); - IndexingMap identity_map{ - .affine_map = - AffineMap::getMultiDimIdentityMap(dims.size(), mlir_context), - .sizes = std::vector{dims.begin(), dims.end()}}; + IndexingMap identity_map{.affine_map = AffineMap::getMultiDimIdentityMap( + dims.size(), mlir_context), + .input_dims_sizes = {}}; std::vector operand_indexing_maps; int64_t operand_count = instr->operand_count(); @@ -73,16 +76,150 @@ StatusOr ComputeBroadcastOpIndexing( for (int64_t bcast_dim : bcast->dimensions()) { exprs.push_back(getAffineDimExpr(bcast_dim, mlir_context)); } - IndexingMap indexing_map{ .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), - .sizes = std::vector{output_dims.begin(), output_dims.end()}}; + .input_dims_sizes = {}}; return HloInstructionIndexing{{HloOperandIndexing{ .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; } +// Composes affine maps, i.e. consumer_map ∘ producer_map. +IndexingMap ComposeIndexingMaps(const IndexingMap& producer_map, + const IndexingMap& consumer_map) { + // AffineMap::compose(some_affine_map) actually computes some_affine_map ∘ + // this. + AffineMap composed_map = mlir::simplifyAffineMap( + producer_map.affine_map.compose(consumer_map.affine_map)); + + // After the composition some of the symbols might become unused, e.g. when a + // dimension was added by broadcasting as then reduced. We should remove these + // dimensions from the composed affine map and also from the resulting + // `input_dim_sizes`. + // + // For example, if there is a reduction(broadcast): + // + // param = f32[15] parameter(0) + // bcast = f32[15, 20] broadcast(p0), dimensions={0} + // reduce = f32[15, 20] reduce(bcast, init) dimensions={1} + // + // then `reduce` has (d0)[s0] -> (d0, s0) with size(s0) = 20 + // and `bcast` has (d0, d1) -> (d0) indexing map. + // + // The composition of there two maps yields (d0)[s0] -> (d0) with size(s0), + // although `s0` is not used in the mapping. In order to remove such symbols, + // we get the indices of unused symbols and remove them from the composed + // affine map and the `input_dim_sizes`. + auto unused_symbols_bit_vector = + mlir::getUnusedSymbolsBitVector({composed_map}); + composed_map = mlir::compressSymbols(composed_map, unused_symbols_bit_vector); + + // The input dims symbols in the composed map, i.e. combined + // producer_map.compose(consumer_map) are packed as [symbols(producer_map) | + // symbols(consumer_map)]. In that order we are adding the sizes for the input + // dims while skipping the symbols that are unused. + std::vector combined_sizes; + int64_t symbol_id = 0; + for (int64_t dim : llvm::concat( + producer_map.input_dims_sizes, consumer_map.input_dims_sizes)) { + if (unused_symbols_bit_vector[symbol_id++]) continue; + combined_sizes.push_back(dim); + } + return IndexingMap{.affine_map = std::move(composed_map), + .input_dims_sizes = std::move(combined_sizes)}; +} + +// Computes HloInstructionIndexing that maps the iteration space of the +// consumer's output tensor to the iteration space of the producer's inputs and +// the remaining outputs of the consumer as if the producer was fused. +// +// Example: +// +// operand1 operand2 +// | | # producer_instr_indexing edges +// producer_instr +// | # consumer_operand_indexing edge +// consumer +// +// The function has two inputs: +// +// 1. `producer_instr_indexing` is the producer's HloInstructionIndexing +// that maps the iteration space of its output tensor to the inputs of +// producers. +// 2. `consumer_operand_indexing` is the consumer's HloOperandIndexing for the +// operand that corresponds to the provided producer. +HloInstructionIndexing ComputeFusedProducerConsumerIndexing( + const HloInstructionIndexing& producer_instr_indexing, + const HloOperandIndexing& consumer_operand_indexing) { + HloInstructionIndexing fused_instr_indexing; + + // Every operand can be read 1 or more times by the consumer which also can + // have 1 or more read accesses to its operands. So, to get the composed + // indexing maps we have to compute a "cross product" here. + for (const HloOperandIndexing& producer_operand_indexing : + producer_instr_indexing.operand_indexing_maps) { + auto& composed_operand_indexing = + fused_instr_indexing.operand_indexing_maps.emplace_back(); + composed_operand_indexing.operand_id = producer_operand_indexing.operand_id; + for (const IndexingMap& producer_map : + producer_operand_indexing.indexing_maps) { + for (const IndexingMap& consumer_map : + consumer_operand_indexing.indexing_maps) { + composed_operand_indexing.indexing_maps.insert( + ComposeIndexingMaps(producer_map, consumer_map)); + } + } + fused_instr_indexing.operand_indexing_maps.push_back( + std::move(composed_operand_indexing)); + } + return fused_instr_indexing; +} + +// Composes instruction indexing maps starting at the root instruction +// until the HloParameterInstruction is found. +StatusOr ComputeFusionOpIndexing( + const HloFusionInstruction* fusion, int output_id, + MLIRContext* mlir_context) { + const HloInstruction* root = + fusion->shape().IsTuple() + ? fusion->fused_expression_root()->operand(output_id) + : fusion->fused_expression_root(); + std::queue> bfs; + TF_ASSIGN_OR_RETURN(auto root_indexing, ComputeInstructionIndexing( + root, output_id, mlir_context)); + + bfs.push(std::make_pair(root, root_indexing)); + absl::flat_hash_map> + parameter_indexing_maps; + while (!bfs.empty()) { + const auto& [instr, instr_indexing] = bfs.front(); + for (const auto& operand_indexing : instr_indexing.operand_indexing_maps) { + const HloInstruction* producer_instr = + instr->operand(operand_indexing.operand_id); + // If the producer is a fusion op parameter, store the result. + if (auto parameter = DynCast(producer_instr)) { + parameter_indexing_maps[parameter->parameter_number()].insert( + operand_indexing.indexing_maps.begin(), + operand_indexing.indexing_maps.end()); + continue; + } + TF_ASSIGN_OR_RETURN(auto producer_instr_indexing, + ComputeInstructionIndexing( + producer_instr, /*output_id=*/0, mlir_context)); + bfs.push(std::make_pair(producer_instr, + ComputeFusedProducerConsumerIndexing( + producer_instr_indexing, operand_indexing))); + } + bfs.pop(); + } + HloInstructionIndexing fusion_indexing; + for (const auto& [operand_id, maps] : parameter_indexing_maps) { + fusion_indexing.operand_indexing_maps.push_back({maps, operand_id}); + } + return fusion_indexing; +} + StatusOr ComputeReduceOpIndexing( const HloReduceInstruction* reduce, int output_id, MLIRContext* mlir_context) { @@ -94,8 +231,7 @@ StatusOr ComputeReduceOpIndexing( ? ShapeUtil::GetSubshape(reduce->shape(), {0}) : reduce->shape(); - std::vector sizes(output_shape.dimensions().begin(), - output_shape.dimensions().end()); + std::vector input_dims_sizes; int64_t reduced_dim_id = 0; int64_t output_dim_id = 0; std::vector exprs; @@ -103,7 +239,7 @@ StatusOr ComputeReduceOpIndexing( llvm::enumerate(input_shape.dimensions())) { if (reduce_dims_ids.contains(input_dim_id)) { exprs.push_back(getAffineSymbolExpr(reduced_dim_id++, mlir_context)); - sizes.push_back(input_dim); + input_dims_sizes.push_back(input_dim); continue; } exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); @@ -111,7 +247,7 @@ StatusOr ComputeReduceOpIndexing( IndexingMap indexing_map{ .affine_map = AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), exprs, mlir_context), - .sizes = std::vector{sizes.begin(), sizes.end()}}; + .input_dims_sizes = std::move(input_dims_sizes)}; std::vector operand_indexing_maps; int64_t input_count = reduce->input_count(); @@ -145,7 +281,7 @@ StatusOr ComputeReverseOpIndexing( IndexingMap indexing_map{ .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), - .sizes = std::vector{output_dims.begin(), output_dims.end()}}; + .input_dims_sizes = {}}; return HloInstructionIndexing{{HloOperandIndexing{ .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; @@ -170,21 +306,19 @@ StatusOr ComputeSliceOpIndexing( IndexingMap indexing_map{ .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), - .sizes = std::vector{output_dims.begin(), output_dims.end()}}; + .input_dims_sizes = {}}; return HloInstructionIndexing{{HloOperandIndexing{ .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; } StatusOr ComputeTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - auto output_dims = transpose->shape().dimensions(); std::vector permutation(transpose->dimensions().begin(), transpose->dimensions().end()); - IndexingMap permutation_map{ .affine_map = mlir::inversePermutation( AffineMap::getPermutationMap(permutation, mlir_context)), - .sizes = std::vector{output_dims.begin(), output_dims.end()}}; + .input_dims_sizes = {}}; return HloInstructionIndexing{{HloOperandIndexing{ .indexing_maps = {std::move(permutation_map)}, .operand_id = 0}}}; @@ -207,9 +341,14 @@ std::string ToString(const AffineMap& affine_map) { return s; } +bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { + return lhs.affine_map == rhs.affine_map && + lhs.input_dims_sizes == rhs.input_dims_sizes; +} + std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { out << ToString(indexing_map.affine_map) << " with sizes " - << absl::StrJoin(indexing_map.sizes, ", ") << "\n"; + << absl::StrJoin(indexing_map.input_dims_sizes, ", ") << "\n"; return out; } @@ -246,6 +385,9 @@ StatusOr ComputeInstructionIndexing( if (auto bcast = DynCast(instr)) { return ComputeBroadcastOpIndexing(bcast, mlir_context); } + if (auto fusion = DynCast(instr)) { + return ComputeFusionOpIndexing(fusion, output_id, mlir_context); + } if (auto reduce = DynCast(instr)) { return ComputeReduceOpIndexing(reduce, output_id, mlir_context); } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.h b/third_party/xla/xla/service/gpu/model/tile_analysis.h index 6e4c9b354a7788..601272b1f68360 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_TILE_ANALYSIS_H_ #define XLA_SERVICE_GPU_MODEL_TILE_ANALYSIS_H_ +#include #include #include #include #include +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/Hashing.h" #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" @@ -36,13 +39,11 @@ namespace gpu { // of dimensions of output. For example, for broadcasts and cwise ops all // dimensions of the inputs are covered by the output dimensions. // Symbols s_j correspond to the dimensions that are present ONLY in inputs. -// `sizes` is an array that holds the upper bounds for the iteration sizes for -// every output/input dimension -// i.e. elements size_0, ..., size_{N - 1} correspond to the output dimensions -// d0, ..., d_{N - 1} and elements size_N, ..., size_{N + M - 1} correspond -// to the input-only dimensions s_0, ..., s_{M - 1}. -// Note, that the sizes have upper bounds only and the lower bounds are always -// 0, since we can encode the offsets in the affine map. +// `input_dims_sizes` is an array that holds the upper bounds for the iteration +// sizes for every input-only dimension. Note, that the sizes have upper +// bounds only and the lower bounds are always 0, since we can encode the +// offsets in the affine map. The sizes for the output dimensions can be deduced +// from the shape of the output tensor. // // Example: // @@ -51,8 +52,8 @@ namespace gpu { // p0 = f32[150, 20, 10, 50] parameter(0) // reduce = f32[150, 10] reduce(p0, p0_init), dimensions={3, 1} // ``` -// can be written as `(d0, d1)[s0, s1] -> (d0, s0, d1, s1)` with the sizes -// `[/*d0 size=*/150, /*d1 size=*/10, /*s0 size=*/20, /*s1 size=*/50]`. +// can be written as `(d0, d1)[s0, s1] -> (d0, s0, d1, s1)` with the input +// dimensions sizes `[/*s0 size=*/20, /*s1 size=*/50]`. // // 2. Indexing map for the input of the reverse op // ``` @@ -60,23 +61,31 @@ namespace gpu { // reverse = f32[1, 17, 9, 9] reverse(%p0), dimensions={1, 2} // ``` // can be written as `(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)` with the -// sizes `[/*d0 size=*/1, /*d1 size=*/17, /*d2 size=*/9, /*d3 size=*/9]`. +// empty 'input_dims_sizes`, because there are no dimensions in the input that +// could not be expressed via dimensions of the output. struct IndexingMap { - mlir::AffineMap affine_map; - std::vector sizes; - std::string ToString() const; + + mlir::AffineMap affine_map; + std::vector input_dims_sizes; }; std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); +bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); + +template +H AbslHashValue(H h, const IndexingMap& indexing_map) { + llvm::hash_code affine_map_hash = llvm::hash_combine(indexing_map.affine_map); + return H::combine(std::move(h), static_cast(affine_map_hash)); +} // Contains 1 or more indexing maps for the `operand_id`. There are cases, when // the same input operand is read multiple times in various ways. Especially, it // happens a lot in fusion ops. struct HloOperandIndexing { - std::vector indexing_maps; - int64_t operand_id; - std::string ToString() const; + + absl::flat_hash_set indexing_maps; + int64_t operand_id; }; std::ostream& operator<<(std::ostream& out, const HloOperandIndexing& operand_indexing); @@ -84,9 +93,9 @@ std::ostream& operator<<(std::ostream& out, // Contains indexing maps for all N-dimensional tensor input operands that // correspond to a particular output. struct HloInstructionIndexing { - std::vector operand_indexing_maps; - std::string ToString() const; + + std::vector operand_indexing_maps; }; std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing); diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 3a764d39aae31f..452f0424802d09 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -38,15 +38,16 @@ using ::testing::Eq; using ::testing::ExplainMatchResult; using ::testing::HasSubstr; using ::testing::PrintToString; +using ::testing::UnorderedElementsAre; -MATCHER_P2(MatchIndexingMap, affine_map_string, sizes, +MATCHER_P2(MatchIndexingMap, affine_map_string, input_dims_sizes, absl::StrCat(negation ? "equals " : "doesn't equal ", "affine map ", - affine_map_string, " with sizes ", - PrintToString(sizes))) { + affine_map_string, " with input dim sizes ", + PrintToString(input_dims_sizes))) { return ExplainMatchResult(HasSubstr(affine_map_string), ToString(arg.affine_map), result_listener) && - ExplainMatchResult(ElementsAreArray(sizes), arg.sizes, - result_listener); + ExplainMatchResult(ElementsAreArray(input_dims_sizes), + arg.input_dims_sizes, result_listener); } MATCHER_P2(MatchOperandIndexing, operand_id, indexing_map_matchers, "") { @@ -79,13 +80,12 @@ TEST_F(TileAnalysisTest, ElementwiseOp) { ASSERT_IS_OK(input_indexing_or); EXPECT_THAT( input_indexing_or->operand_indexing_maps, - ElementsAre( - MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{10, 20}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{10, 20}))))); + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))), + MatchOperandIndexing( + 1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BroadcastOp) { @@ -97,11 +97,227 @@ TEST_F(TileAnalysisTest, BroadcastOp) { } )"); ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT(input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT a0 = f32[100] add(p0, p1) + } + ENTRY e { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = f32[100] fusion(p0, p1), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT( + input_indexing_or->operand_indexing_maps, + UnorderedElementsAre( + MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))), + MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[1000, 1000] parameter(0) + transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) + } + ENTRY e { + p0 = f32[1000,1000] parameter(0) + ROOT fusion = f32[1000,1000] fusion(p0), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); EXPECT_THAT( input_indexing_or->operand_indexing_maps, ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", - std::vector{10, 20, 30}))))); + 0, + UnorderedElementsAre( + MatchIndexingMap("(d0, d1) -> (d1, d0)", std::vector{}), + MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionExponentialDuplication) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + add0 = f32[4] add(p0, p1) + slice1.0 = f32[3] slice(add0), slice={[0:3]} + slice1.1 = f32[3] slice(add0), slice={[1:4]} + add1 = f32[3]{0} add(slice1.0, slice1.1) + slice2.0 = f32[2] slice(add1), slice={[0:2]} + slice2.1 = f32[2] slice(add1), slice={[1:3]} + ROOT add2 = f32[2] add(slice2.0, slice2.1) + })"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT( + input_indexing_or->operand_indexing_maps, + ElementsAre( + MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))), + MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[150, 20, 10, 50] parameter(0) + p0_init = f32[] parameter(1) + reduce_1 = f32[20, 10] reduce(p0, p0_init), + dimensions={0, 3}, to_apply=max + ROOT reduce_2 = f32[10] reduce(reduce_1, p0_init), + dimensions={0}, to_apply=max + } + ENTRY e { + p0 = f32[150, 20, 10, 50] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[10] fusion(p0, p0_init), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT(input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", + std::vector{150, 50, 20}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] parameter(1) + p0_bcast = f32[15, 32, 20, 64] broadcast(p0), dimensions={0, 2} + + ROOT reduce_2 = f32[15, 64] reduce(p0_bcast, p0_init), + dimensions={1, 2}, to_apply=max + } + ENTRY e { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT(input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1)[s0] -> (d0, s0)", + std::vector{20}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[20, 10, 50] parameter(0) + + lhs_transpose_1 = f32[10, 20, 50] + transpose(p0), dimensions={1, 0, 2} + lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1) + lhs_transpose_2 = f32[10, 50, 20] + transpose(lhs_e), dimensions={0, 2, 1} + + rhs_transpose_1 = f32[50, 10, 20] + transpose(p0), dimensions={2, 1, 0} + rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1) + rhs_transpose_2 = f32[10, 50, 20] + transpose(rhs_log), dimensions={1, 0, 2} + + ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2) + } + ENTRY e { + p0 = f32[20, 10, 50] parameter(0) + ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT( + input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d2, d0, d1)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[150, 64, 1024] parameter(0) + p0_init = f32[] parameter(1) + p0_slice = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), + slice={[5:21:1], [0:64:2], [50:434:3]} + ROOT reduce = f32[32] reduce(p0_slice, p0_init), + dimensions={0, 2}, to_apply=max + } + ENTRY e { + p0 = f32[150, 64, 1024] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[32] fusion(p0, p0_init), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT(input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", + std::vector{16, 128}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { + auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[150, 64, 1024] parameter(0) + p0_slice_1 = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), + slice={[5:21:1], [0:64:2], [50:434:3]} + ROOT p0_slice_2 = f32[7, 9, 24] slice(f32[16, 32, 128] p0_slice_1), + slice={[3:16:2], [4:30:3], [5:100:4]} + } + ENTRY e { + p0 = f32[150, 64, 1024] parameter(0) + ROOT fusion = f32[7, 9, 24] fusion(p0), kind=kLoop, calls=f + } + )"); + ASSERT_IS_OK(input_indexing_or); + EXPECT_THAT( + input_indexing_or->operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReduceOp) { @@ -124,7 +340,7 @@ TEST_F(TileAnalysisTest, ReduceOp) { ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", - std::vector{150, 10, 20, 50}))))); + std::vector{20, 50}))))); } TEST_F(TileAnalysisTest, VariadicReduceOp) { @@ -157,33 +373,31 @@ TEST_F(TileAnalysisTest, VariadicReduceOp) { ASSERT_IS_OK(input_indexing_0); EXPECT_THAT( input_indexing_0->operand_indexing_maps, - ElementsAre( - MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{10, 256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{10, 256}))))); + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))), + MatchOperandIndexing( + 1, ElementsAre(MatchIndexingMap( + "(d0)[s0] -> (s0, d0)", std::vector{256}))))); auto input_indexing_1 = ComputeInstructionIndexing(root, 1, &mlir_context_); ASSERT_IS_OK(input_indexing_1); EXPECT_THAT( input_indexing_1->operand_indexing_maps, - ElementsAre( - MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{10, 256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{10, 256}))))); + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))), + MatchOperandIndexing( + 1, ElementsAre(MatchIndexingMap( + "(d0)[s0] -> (s0, d0)", std::vector{256}))))); } TEST_F(TileAnalysisTest, ReverseOp) { auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { - %p0 = f32[1, 17, 9, 9] parameter(0) - ROOT reverse = f32[1, 17, 9, 9] reverse(%p0), dimensions={1, 2} + p0 = f32[1, 17, 9, 9] parameter(0) + ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} } )"); ASSERT_IS_OK(input_indexing_or); @@ -191,15 +405,15 @@ TEST_F(TileAnalysisTest, ReverseOp) { ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)", - std::vector{1, 17, 9, 9}))))); + std::vector{}))))); } TEST_F(TileAnalysisTest, SliceOp) { auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { - %p0 = f32[10, 20, 50] parameter(0) - ROOT %slice = f32[5, 3, 25] slice(f32[10, 20, 50] %p0), + p0 = f32[10, 20, 50] parameter(0) + ROOT slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0), slice={[5:10:1], [3:20:7], [0:50:2]} } )"); @@ -208,16 +422,16 @@ TEST_F(TileAnalysisTest, SliceOp) { ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", - std::vector{5, 3, 25}))))); + std::vector{}))))); } TEST_F(TileAnalysisTest, TransposeOp) { auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { - %p0 = f16[1, 8, 1536, 512] parameter(0) + p0 = f16[1, 8, 1536, 512] parameter(0) ROOT transpose = f16[1, 8, 512, 1536]{2, 3, 1, 0} - transpose(%p0), dimensions={0, 1, 3, 2} + transpose(p0), dimensions={0, 1, 3, 2} } )"); ASSERT_IS_OK(input_indexing_or); @@ -225,16 +439,16 @@ TEST_F(TileAnalysisTest, TransposeOp) { ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, d1, d3, d2)", - std::vector{1, 8, 512, 1536}))))); + std::vector{}))))); } TEST_F(TileAnalysisTest, UnsupportedOp) { auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { - %p0 = f32[1, 17, 9, 9] parameter(0) - %p1 = f32[5, 17, 9, 9] parameter(1) - ROOT %concat = f32[6, 17, 9, 9] concatenate(%p0, %p1) + p0 = f32[1, 17, 9, 9] parameter(0) + p1 = f32[5, 17, 9, 9] parameter(1) + ROOT concat = f32[6, 17, 9, 9] concatenate(p0, p1) } )"); ASSERT_IS_NOT_OK(input_indexing_or); From e4914abb5ff8ffb7bb08b589cc1c1cc08e9aadd9 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 24 Nov 2023 07:04:25 -0800 Subject: [PATCH 063/381] Replace boundary functions with a fusion adaptor class. Boundary functions seemed like a nice and easy abstraction for fusions, but they turned out to be too difficult to use in practice. The main problem is that everything is still based on HloInstructions, whose users and operands are difficult to traverse in general. The solution introduced here is to introduce an HloFusionAdaptor class with a simple interface, and an HloInstructionAdaptor which always behaves as if the HLO was completely unfused. If I had more time, I would have made smaller change. PiperOrigin-RevId: 585087631 --- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/service/gpu/hlo_fusion_analysis.cc | 304 +++++++++--------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 12 +- .../service/gpu/hlo_fusion_analysis_test.cc | 79 +++-- .../xla/xla/service/gpu/hlo_traversal.cc | 293 ++++++++++------- .../xla/xla/service/gpu/hlo_traversal.h | 120 ++++--- .../xla/xla/service/gpu/hlo_traversal_test.cc | 293 ++++++++--------- .../xla/xla/service/gpu/ir_emission_utils.cc | 74 ++--- .../xla/xla/service/gpu/ir_emission_utils.h | 15 +- .../xla/service/gpu/ir_emission_utils_test.cc | 125 +++---- .../xla/xla/service/gpu/ir_emitter_triton.cc | 38 ++- .../xla/xla/service/gpu/ir_emitter_triton.h | 12 +- .../xla/service/gpu/ir_emitter_unnested.cc | 8 +- .../xla/xla/service/gpu/priority_fusion.cc | 4 +- 14 files changed, 701 insertions(+), 678 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 606ed0c969e019..f07b8878c36b6f 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3592,6 +3592,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -4761,7 +4762,6 @@ xla_cc_test( name = "hlo_traversal_test", srcs = ["hlo_traversal_test.cc"], deps = [ - ":gpu_fusible", ":hlo_traversal", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 71fdc475eb7013..05182f8e660f1c 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/node_hash_map.h" #include "absl/log/check.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -110,32 +111,30 @@ bool AllSliceInputsAreCompatible( }); } -bool MayPreventVectorization( - const std::vector& fusion_roots, - const FusionBoundaryFn& fusion_boundary_fn) { +bool MayPreventVectorization(const HloFusionAdaptor& fusion) { // An empirically chosen constant: unrolling concat with a large amount of // arguments causes excessive register spilling. static constexpr int kMaxConcatArgumentsForUnrolling = 10; - return HloAnyOf( - fusion_roots, fusion_boundary_fn, [&](const HloInstruction& node) { - switch (node.opcode()) { - case HloOpcode::kReduceWindow: - case HloOpcode::kSort: - case HloOpcode::kDot: - case HloOpcode::kSin: - case HloOpcode::kCos: - case HloOpcode::kTan: - case HloOpcode::kPower: - case HloOpcode::kAtan2: - return true; - case HloOpcode::kConcatenate: - return node.operand_count() > kMaxConcatArgumentsForUnrolling; - case HloOpcode::kReduce: - return node.shape().tuple_shapes_size() > 1; - default: - return false; - } - }); + return HloAnyOf(fusion.GetRoots(), fusion, [&](auto node) { + switch (node.opcode()) { + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + case HloOpcode::kSin: + case HloOpcode::kCos: + case HloOpcode::kTan: + case HloOpcode::kPower: + case HloOpcode::kAtan2: + return true; + case HloOpcode::kConcatenate: + return node.instruction().operand_count() > + kMaxConcatArgumentsForUnrolling; + case HloOpcode::kReduce: + return node.instruction().shape().tuple_shapes_size() > 1; + default: + return false; + } + }); } // Determines if we enable the row optimized codegen. When we have a fusion with @@ -144,15 +143,15 @@ bool MayPreventVectorization( // particular on A100. The int is the number of inputs with rank `out_rank`. Its // value is only defined if row vectorization is enabled. std::pair RowVectorizationEnabled( - const std::vector& fusion_roots, int64_t out_rank) { - const auto is_row_major = [](const HloInstruction* instr) { + const HloFusionAdaptor& fusion, int64_t out_rank) { + auto roots = fusion.GetRoots(); + const auto is_row_major = [](auto instr) { // Only tested when the inputs are row-major. So only enable that case. // Maybe it would work if only the inner dimensions is contiguous. - return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()); + return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout()); }; - bool row_vectorized = fusion_roots.size() == 1 && - !fusion_roots[0]->shape().IsTuple() && - is_row_major(fusion_roots[0]); + bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && + is_row_major(roots[0]); if (!row_vectorized) { return {false, 0}; } @@ -168,12 +167,13 @@ std::pair RowVectorizationEnabled( int num_big_inputs = 0; bool some_row_broadcasting = false; HloBfsConsumersFirstTraversal( - {fusion_roots.front()}, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kParameter; - }, - [&](const HloInstruction& node) -> TraversalResult { - if (node.IsElementwise()) { + roots, fusion, + [&](auto node) -> TraversalResult { + if (!row_vectorized) { + return TraversalResult::kAbortTraversal; + } + + if (node.instruction().IsElementwise()) { return TraversalResult::kVisitOperands; } @@ -181,33 +181,33 @@ std::pair RowVectorizationEnabled( case HloOpcode::kConstant: return TraversalResult::kDoNotVisitOperands; case HloOpcode::kParameter: - if (node.shape().rank() == out_rank) { - ++num_big_inputs; - } - // TODO(jreiffers): When extending this to work with unfused HLO, - // move this check to the boundary function. - if (!is_row_major(&node)) { - row_vectorized = false; - return TraversalResult::kAbortTraversal; - } return TraversalResult::kVisitOperands; - case HloOpcode::kBroadcast: - if (node.dimensions().empty()) { + case HloOpcode::kBroadcast: { + auto dims = node.instruction().dimensions(); + if (dims.empty()) { return TraversalResult::kVisitOperands; } - if (node.dimensions().size() == 1 && - node.dimensions().front() == node.shape().rank() - 1) { + if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) { some_row_broadcasting = true; return TraversalResult::kVisitOperands; } TF_FALLTHROUGH_INTENDED; + } default: VLOG(2) << "Row vectorization not enabled due to: " << node.ToString(); row_vectorized = false; return TraversalResult::kAbortTraversal; } + }, + [&](auto argument) { + if (argument.shape().rank() == out_rank) { + ++num_big_inputs; + } + if (!is_row_major(argument)) { + row_vectorized = false; + } }); // Trigger only when there is a row broadcasting. return std::make_pair(row_vectorized && some_row_broadcasting, @@ -299,14 +299,14 @@ int SmallestInputDtypeBits(const std::vector& args) { HloFusionAnalysis::HloFusionAnalysis( FusionBackendConfig fusion_backend_config, std::vector fusion_roots, - FusionBoundaryFn fusion_boundary_fn, + std::unique_ptr fusion, std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, HloFusionAnalysis::InputOutputInfo input_output_info) : fusion_backend_config_(std::move(fusion_backend_config)), fusion_roots_(std::move(fusion_roots)), - fusion_boundary_fn_(std::move(fusion_boundary_fn)), + fusion_(std::move(fusion)), fusion_heroes_(std::move(fusion_heroes)), device_info_(device_info), tiled_transpose_(tiled_transpose), @@ -318,19 +318,19 @@ HloFusionAnalysis::HloFusionAnalysis( // static StatusOr HloFusionAnalysis::Create( FusionBackendConfig backend_config, - std::vector hlo_roots, FusionBoundaryFn boundary_fn, + std::unique_ptr fusion, const se::DeviceDescription* device_info) { + std::vector roots; std::vector heroes; - heroes.reserve(hlo_roots.size()); - for (auto* root : hlo_roots) { - heroes.push_back(&FindNonTrivialHero(*root, boundary_fn)); + for (auto root : fusion->GetRoots()) { + roots.push_back(&root.instruction()); + heroes.push_back(&FindNonTrivialHero(*roots.back(), *fusion)); } std::vector fusion_arguments; - FindFusionArguments(hlo_roots, boundary_fn, - [&](const HloInstruction& argument) { - fusion_arguments.push_back(&argument); - }); + FindFusionArguments(*fusion, [&](auto argument) { + fusion_arguments.push_back(&argument.instruction()); + }); auto is_4bit = [](const HloInstruction* arg) { return primitive_util::Is4BitType(arg->shape().element_type()); @@ -338,17 +338,16 @@ StatusOr HloFusionAnalysis::Create( InputOutputInfo input_output_info{ .has_4_bit_input = absl::c_any_of(fusion_arguments, is_4bit), - .has_4_bit_output = absl::c_any_of(hlo_roots, is_4bit), + .has_4_bit_output = absl::c_any_of(roots, is_4bit), .smallest_input_dtype_bits = SmallestInputDtypeBits(fusion_arguments), }; std::optional tiled_transpose_hero = - FindConsistentTransposeHero(hlo_roots, heroes); + FindConsistentTransposeHero(roots, heroes); - return HloFusionAnalysis(std::move(backend_config), std::move(hlo_roots), - std::move(boundary_fn), std::move(heroes), - device_info, tiled_transpose_hero, - std::move(input_output_info)); + return HloFusionAnalysis(std::move(backend_config), std::move(roots), + std::move(fusion), std::move(heroes), device_info, + tiled_transpose_hero, std::move(input_output_info)); } // static @@ -358,10 +357,8 @@ StatusOr HloFusionAnalysis::Create( CHECK(device_info != nullptr); TF_ASSIGN_OR_RETURN(auto backend_config, fusion->backend_config()); - - auto hlo_roots = GetFusionRoots(*fusion->fused_instructions_computation()); - return Create(std::move(backend_config), std::move(hlo_roots), - DefaultFusionBoundaryFn, device_info); + return Create(std::move(backend_config), + HloFusionAdaptor::ForInstruction(fusion), device_info); } // Returns true if the fusion has consistent transpose heros. @@ -495,8 +492,7 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { int64_t num_elements = ShapeUtil::ElementsIn(GetElementShape()); int64_t n_threads_max = device_info_->threads_per_core_limit() * device_info_->core_count(); - if (num_elements >= n_threads_max && - !MayPreventVectorization(fusion_roots_, fusion_boundary_fn_)) { + if (num_elements >= n_threads_max && !MayPreventVectorization(*fusion_)) { unroll_factor = ComputeMaxUnrollFactor(num_elements); } // CHECK that unroll_factor is a power-of-2, as needed by the logic below. @@ -523,24 +519,25 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { bool row_vectorized; int num_big_inputs; std::tie(row_vectorized, num_big_inputs) = - RowVectorizationEnabled(fusion_roots(), GetElementShape().rank()); - bool few_waves = !HloAnyOf( - fusion_roots_, fusion_boundary_fn_, [&](const HloInstruction& instr) { - if (instr.opcode() == HloOpcode::kParameter || - instr.opcode() == HloOpcode::kConstant || - HloInstruction::IsOpElementwise(instr.opcode())) { - return false; - } - if (auto broadcast = DynCast(&instr)) { - if (broadcast->dimensions().empty() || - // More than 3 big inputs cause a speed regression. - (row_vectorized && num_big_inputs <= 3)) { - return false; - } - } - VLOG(2) << "few_waves not enabled due to: " << instr.ToString(); - return true; - }); + RowVectorizationEnabled(*fusion_, GetElementShape().rank()); + bool few_waves = !HloAnyOf(fusion_->GetRoots(), *fusion_, [&](auto instr) { + if (instr.opcode() == HloOpcode::kParameter || + instr.opcode() == HloOpcode::kConstant || + HloInstruction::IsOpElementwise(instr.opcode())) { + return false; + } + if (auto broadcast = + DynCast(&instr.instruction())) { + if (broadcast->dimensions().empty() || + // More than 3 big inputs cause a speed regression. + (row_vectorized && num_big_inputs <= 3)) { + return false; + } + } + VLOG(2) << "few_waves not enabled due to: " + << instr.instruction().ToString(); + return true; + }); LaunchDimensionsConfig launch_config{unroll_factor, few_waves, row_vectorized}; @@ -609,64 +606,68 @@ HloFusionAnalysis::GroupDisjointReductions() const { return {{fusion_roots()[0]}}; } - ConstHloInstructionMap> + absl::node_hash_map> disjoint_sets; // TODO(b/249976438): we currently do not treat properly // aliasing between inputs and outputs of the fusion, so for now put all // non-reduction roots into one group to avoid read-after-write conflicts. - const HloInstruction* first_non_reduction_root = nullptr; + std::optional first_non_reduction_root = std::nullopt; - ConstHloInstructionMap> + absl::node_hash_map> reachable_outputs; - absl::flat_hash_set roots_with_reduction; - for (auto [root, hero] : llvm::zip(fusion_roots(), fusion_heroes_)) { + absl::flat_hash_set roots_with_reduction; + auto roots = fusion_->GetRoots(); + for (auto [root, hero] : llvm::zip(roots, fusion_heroes_)) { disjoint_sets[root].Get() = root; reachable_outputs[root].insert(root); - if (IsRealReductionHero(*root, *hero)) { + if (IsRealReductionHero(root.instruction(), *hero)) { roots_with_reduction.insert(root); } else if (first_non_reduction_root) { - disjoint_sets[first_non_reduction_root].Merge(&disjoint_sets[root]); + disjoint_sets[*first_non_reduction_root].Merge(&disjoint_sets[root]); } else { first_non_reduction_root = root; } } - std::vector instructions; + std::vector instructions; HloBfsConsumersFirstTraversal( - fusion_roots_, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - auto& producer_reachable = reachable_outputs[&producer]; - for (auto* instruction : reachable_outputs[&consumer]) { - producer_reachable.insert(instruction); + roots, *fusion_, + [&](HloInstructionAdaptor consumer) { + auto& consumer_reachable = reachable_outputs[consumer]; + for (auto producer : consumer.GetOperands()) { + reachable_outputs[producer].insert(consumer_reachable.begin(), + consumer_reachable.end()); } - return fusion_boundary_fn_(producer, consumer); - }, - [&](const HloInstruction& node) { - instructions.push_back(&node); + instructions.push_back(consumer); return TraversalResult::kVisitOperands; + }, + [&](HloInstructionAdaptor argument) { + instructions.push_back(argument); }); - for (const HloInstruction* instr : instructions) { + for (auto instr : instructions) { const auto& reachable = reachable_outputs[instr]; - std::vector reached_output_ids; + std::vector reached_output_ids; bool added_to_reduce = false; - for (const HloInstruction* output : fusion_roots()) { + for (auto output : roots) { bool has_real_hero = roots_with_reduction.contains(output); - if (has_real_hero && (hlo_query::IsBroadcastedConstantOrScalar(*instr))) { + if (has_real_hero && + (hlo_query::IsBroadcastedConstantOrScalar(instr.instruction()))) { if (added_to_reduce) { // Do not group more than one output reduce instructions through // broadcasted constants or scalars, as the recomputation should be // acceptable. - VLOG(3) << "Skip broadcasted constant or scalar " - << instr->ToString(); + VLOG(3) << "Skip broadcasted constant or scalar " << instr.ToString(); continue; } } // Now group output instructions if they have common predecessors. if (reachable.contains(output)) { - VLOG(3) << "Reaching " << output->ToString() << " from " - << instr->ToString(); + VLOG(3) << "Reaching " << output.ToString() << " from " + << instr.ToString(); reached_output_ids.push_back(output); if (has_real_hero) { added_to_reduce = true; @@ -681,8 +682,9 @@ HloFusionAnalysis::GroupDisjointReductions() const { // Place output instructions in the same set into the same group. ConstHloInstructionMap> groups; - for (const HloInstruction* root : fusion_roots()) { - groups[disjoint_sets[root].Get()].push_back(root); + for (auto root : roots) { + groups[&disjoint_sets[root].Get().instruction()].push_back( + &root.instruction()); } std::vector> ret; @@ -719,10 +721,9 @@ bool HloFusionAnalysis::IsUnrollingColumnReductionBeneficial( // Fusion inputs that have the same dimension as the reduce input and // only involve in element-wise operations can be vectorized. - absl::flat_hash_set reachable_through_non_elementwise; + absl::flat_hash_set reachable_through_non_elementwise; HloBfsConsumersFirstTraversal( - fusion_roots_, - [&](const HloInstruction& producer, const HloInstruction& consumer) { + fusion_->GetRoots(), *fusion_, [&](auto consumer) { // We check if the consumer is elementwise, unless this edge is a // virtual edge that only exists in partially fused HLO. There are two // types of such edges: @@ -730,38 +731,33 @@ bool HloFusionAnalysis::IsUnrollingColumnReductionBeneficial( // within a fusion. Here, the producer is a parameter of the fusion // instruction. // 2. Edges from fusion roots to fusion nodes. - if (reachable_through_non_elementwise.contains(&consumer) || - (!(consumer.opcode() == HloOpcode::kParameter || - consumer.opcode() == HloOpcode::kFusion || - consumer.IsElementwise()) && - !use_chain_endings.contains(&consumer))) { - reachable_through_non_elementwise.insert(&producer); + if (reachable_through_non_elementwise.contains(consumer) || + (!consumer.instruction().IsElementwise() && + !use_chain_endings.contains(&consumer.instruction()))) { + for (auto producer : consumer.GetOperands()) { + reachable_through_non_elementwise.insert(producer); + } } - - return fusion_boundary_fn_(producer, consumer); - }, - [&](const HloInstruction& node) { return TraversalResult::kVisitOperands; }); int64_t num_elements = ShapeUtil::ElementsIn(input_shape); - FindFusionArguments( - fusion_roots_, fusion_boundary_fn_, [&](const HloInstruction& arg) { - if (!reachable_through_non_elementwise.contains(&arg) && - ShapeUtil::SameDimensions(input_shape, arg.shape())) { - ++can_be_vectorized; - } + FindFusionArguments(*fusion_, [&](auto arg) { + if (!reachable_through_non_elementwise.contains(arg) && + ShapeUtil::SameDimensions(input_shape, arg.shape())) { + ++can_be_vectorized; + } - // Fusion inputs with more elements than the reduce op input must - // participate in non-elementwise operations and we assume that they are - // not vectorizable for the purpose of estimating the benefit of - // unrolling. If the kernel is unrolled even with such an assumption, - // and the accesses to those inputs turn out to be vectorizable, the - // compiler will still vectorize them. - if (ShapeUtil::ElementsIn(arg.shape()) > num_elements) { - ++cannot_be_vectorized; - } - }); + // Fusion inputs with more elements than the reduce op input must + // participate in non-elementwise operations and we assume that they are + // not vectorizable for the purpose of estimating the benefit of + // unrolling. If the kernel is unrolled even with such an assumption, + // and the accesses to those inputs turn out to be vectorizable, the + // compiler will still vectorize them. + if (ShapeUtil::ElementsIn(arg.shape()) > num_elements) { + ++cannot_be_vectorized; + } + }); if (can_be_vectorized < cannot_be_vectorized) { return false; @@ -781,7 +777,7 @@ bool HloFusionAnalysis::CanVectorizeReduction( } if (reduction_dimensions.dimensions[kDimX] % 2 != 0 || - MayPreventVectorization(fusion_roots_, fusion_boundary_fn_)) { + MayPreventVectorization(*fusion_)) { return false; } @@ -922,19 +918,15 @@ HloFusionAnalysis::ComputeReductionCodegenInfo( reduction_is_race_free, std::move(instr_index_groups), hero_reduction); } -static std::vector GetRoots( - const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kFusion - ? GetFusionRoots(*consumer.fused_instructions_computation()) - : std::vector{&consumer}; -} - std::optional AnalyzeProducerConsumerFusion( const HloInstruction& producer, const HloInstruction& consumer, const se::DeviceDescription& device_info) { auto ret = HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), GetRoots(consumer), - MakeProducerConsumerFusion(producer, consumer), &device_info); + FusionBackendConfig::default_instance(), + std::make_unique( + HloFusionAdaptor::ForInstruction(&producer), + HloFusionAdaptor::ForInstruction(&consumer)), + &device_info); if (!ret.ok()) return std::nullopt; return {std::move(*ret)}; } @@ -942,8 +934,8 @@ std::optional AnalyzeProducerConsumerFusion( std::optional AnalyzeFusion( const HloInstruction& consumer, const se::DeviceDescription& device_info) { auto ret = HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), GetRoots(consumer), - MakeSingleInstructionFusion(consumer), &device_info); + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction(&consumer), &device_info); if (!ret.ok()) return std::nullopt; return {std::move(*ret)}; } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index 38066ff9234f12..786fe3d8a86d9d 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -50,8 +50,8 @@ class HloFusionAnalysis { static StatusOr Create( FusionBackendConfig backend_config, - std::vector hlo_roots, - FusionBoundaryFn boundary_fn, const se::DeviceDescription* device_info); + std::unique_ptr fusion, + const se::DeviceDescription* device_info); static StatusOr Create( const HloFusionInstruction* fusion, const se::DeviceDescription* device_info); @@ -59,9 +59,7 @@ class HloFusionAnalysis { const std::vector& fusion_roots() const { return fusion_roots_; } - const FusionBoundaryFn& fusion_boundary() const { - return fusion_boundary_fn_; - } + const HloFusionAdaptor& fusion() const { return *fusion_; } // Determines the fusion type for the emitter. EmitterFusionKind GetEmitterFusionKind() const; @@ -104,7 +102,7 @@ class HloFusionAnalysis { HloFusionAnalysis(FusionBackendConfig fusion_backend_config, std::vector fusion_roots, - FusionBoundaryFn fusion_boundary_fn, + std::unique_ptr fusion, std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, @@ -130,7 +128,7 @@ class HloFusionAnalysis { FusionBackendConfig fusion_backend_config_; std::vector fusion_roots_; - FusionBoundaryFn fusion_boundary_fn_; + std::unique_ptr fusion_; std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc index 1c39371d4148bf..0a02760077547a 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -47,18 +47,15 @@ TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), {root}, - MakeSingleInstructionFusion(*root), &device_info)); - EXPECT_EQ(analysis.GetEmitterFusionKind(), + auto analysis = AnalyzeFusion(*root, device_info); + ASSERT_NE(analysis, std::nullopt); + EXPECT_EQ(analysis->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis_fused, - HloFusionAnalysis::Create(FusionBackendConfig::default_instance(), {root}, - DefaultFusionBoundaryFn, &device_info)); - EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), + auto analysis_fused = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + ASSERT_NE(analysis_fused, std::nullopt); + EXPECT_EQ(analysis_fused->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -84,11 +81,12 @@ TEST_F(HloFusionAnalysisTest, ReductionWithMultipleUsers) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - auto* root = module->entry_computation()->root_instruction(); TF_ASSERT_OK_AND_ASSIGN( auto analysis, - HloFusionAnalysis::Create(FusionBackendConfig::default_instance(), {root}, - DefaultFusionBoundaryFn, &device_info)); + HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForComputation(module->entry_computation()), + &device_info)); // This fusion cannot use the reduction emitter because the reduce has two // users. EXPECT_EQ(analysis.GetEmitterFusionKind(), @@ -105,11 +103,17 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusion) { ROOT add = f32[] add(p0, p1) } - ENTRY main { + fused_computation { %p0 = f32[1024] parameter(0) %p1 = f32[] parameter(1) %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add ROOT %negate = f32[] negate(%reduce) + } + + ENTRY main { + %p0 = f32[1024] parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = f32[] fusion(%p0, %p1), kind=kInput, calls=fused_computation })") .value(); @@ -117,9 +121,9 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusion) { auto* root = module->entry_computation()->root_instruction(); TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create(FusionBackendConfig::default_instance(), {root}, - DefaultFusionBoundaryFn, &device_info)); + auto analysis, HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction(root), &device_info)); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -151,12 +155,11 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), {root}, - MakeProducerConsumerFusion(*root->operand(0), *root), &device_info)); - EXPECT_EQ(analysis.GetEmitterFusionKind(), + + auto analysis = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + ASSERT_NE(analysis, std::nullopt); + EXPECT_EQ(analysis->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -186,13 +189,10 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInConsumer) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), - {root->fused_expression_root()}, - MakeProducerConsumerFusion(*root->operand(0), *root), &device_info)); - EXPECT_EQ(analysis.GetEmitterFusionKind(), + auto analysis = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + ASSERT_NE(analysis, std::nullopt); + EXPECT_EQ(analysis->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -228,12 +228,10 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInBoth) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), {root}, - MakeProducerConsumerFusion(*root->operand(0), *root), &device_info)); - EXPECT_EQ(analysis.GetEmitterFusionKind(), + auto analysis = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + ASSERT_NE(analysis, std::nullopt); + EXPECT_EQ(analysis->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -262,11 +260,10 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { stream_executor::DeviceDescription device_info(device_info_proto); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis_fused, - HloFusionAnalysis::Create(FusionBackendConfig::default_instance(), {root}, - DefaultFusionBoundaryFn, &device_info)); - EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), + auto analysis_fused = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + ASSERT_NE(analysis_fused, std::nullopt); + EXPECT_EQ(analysis_fused->GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.cc b/third_party/xla/xla/service/gpu/hlo_traversal.cc index e729c2467e16e9..dc0b1dcb57a10b 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal.cc @@ -15,9 +15,11 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include +#include #include #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -25,92 +27,181 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +template +void ResolveUsers(const HloInstruction* value, const HloInstruction* user, + F&& fn) { + if (user->opcode() == HloOpcode::kFusion) { + auto* param = user->fused_parameter(user->operand_index(value)); + for (const auto* param_user : param->users()) { + fn(param_user); + } + } else { + fn(user); + } +} -bool DefaultFusionBoundaryFn(const HloInstruction&, - const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kParameter; +const HloInstruction* ResolveOperand(const HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kFusion) { + return operand->fused_expression_root(); + } + if (operand->opcode() == HloOpcode::kParameter) { + if (auto* fusion = operand->parent()->FusionInstruction()) { + return ResolveOperand(fusion->operand(operand->parameter_number())); + } + } + return operand; } -FusionBoundaryFn MakeProducerConsumerFusion( - const HloInstruction& fused_producer, - const HloInstruction& fused_consumer) { - if (fused_consumer.opcode() == HloOpcode::kFusion && - fused_producer.opcode() == HloOpcode::kFusion) { - // fusion -> fusion. - return [&](const HloInstruction& producer, const HloInstruction& consumer) { - return DefaultFusionBoundaryFn(producer, consumer) && - &producer != &fused_producer; - }; +class SingleInstructionFusion : public HloFusionAdaptor { + public: + explicit SingleInstructionFusion(const HloInstruction* instruction) + : instruction_(*instruction) { + CHECK_NE(instruction->opcode(), HloOpcode::kFusion) + << "Use HloFusionFusion"; } - if (fused_consumer.opcode() == HloOpcode::kFusion) { - // non-fusion -> fusion. - return [&](const HloInstruction& producer, const HloInstruction& consumer) { - if (DefaultFusionBoundaryFn(producer, consumer)) { - return &producer != &fused_producer; + + bool ContainsInstruction(HloInstructionAdaptor instruction) const override { + return instruction == instruction_; + } + + absl::InlinedVector GetRoots() const override { + return {instruction_}; + } + + private: + HloInstructionAdaptor instruction_; +}; + +class HloComputationFusion : public HloFusionAdaptor { + public: + explicit HloComputationFusion(const HloComputation* computation) + : computation_(computation) { + std::function get_roots; + absl::flat_hash_set roots_set; + get_roots = [&](const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kTuple) { + for (const auto* operand : instr->operands()) { + get_roots(operand); + } + } else { + HloInstructionAdaptor wrapped{*instr}; + if (roots_set.insert(wrapped).second) { + roots_.push_back(wrapped); + } } - // Otherwise, don't follow edges above the fused producer. - return &consumer == &fused_producer; }; + get_roots(computation->root_instruction()); } - if (fused_producer.opcode() == HloOpcode::kFusion) { - // fusion -> non-fusion. - return [&](const HloInstruction& producer, const HloInstruction& consumer) { - if (&consumer == &fused_consumer) { - // If the consumer is the fused user, only follow edges to the fused - // producer. - return &fused_producer != &producer; - } - if (&producer == &fused_consumer) { - return true; - } - // Otherwise, fall back to the default; we're already in the fused - // producer. - return DefaultFusionBoundaryFn(producer, consumer); - }; + bool ContainsInstruction(HloInstructionAdaptor instruction) const override { + return instruction.instruction().parent() == computation_; } - // non-fusion -> non-fusion. - return [&](const HloInstruction& producer, const HloInstruction& consumer) { - if (&consumer == &fused_consumer) { - // If the consumer is the fused user, only follow edges to the fused - // producer. - return &fused_producer != &producer; + absl::InlinedVector GetRoots() const override { + return roots_; + } + + private: + const HloComputation* computation_; + absl::InlinedVector roots_; +}; + +} // namespace + +std::unique_ptr HloFusionAdaptor::ForInstruction( + const HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kFusion) { + return ForComputation(instruction->fused_instructions_computation()); + } + return std::make_unique(instruction); +} + +std::unique_ptr HloFusionAdaptor::ForComputation( + const HloComputation* computation) { + return std::make_unique(computation); +} + +absl::InlinedVector +HloInstructionAdaptor::GetOperands() const { + absl::InlinedVector operands; + if (instruction_->opcode() == HloOpcode::kParameter) { + // The only time this should happen is when a fusion has a parameter + // that is also a root. This probably never makes sense, but it technically + // is valid HLO, so we support it by treating the parameter as an identity + // function in this context. + auto operand = ResolveOperand(instruction_); + if (operand != instruction_) { + operands.emplace_back(*operand); } - return &producer == &fused_consumer || &consumer == &fused_producer; - }; + } else { + for (const auto* operand : instruction_->operands()) { + operands.emplace_back(*ResolveOperand(operand)); + } + } + return operands; } -FusionBoundaryFn MakeSingleInstructionFusion(const HloInstruction& root) { - if (root.opcode() == HloOpcode::kFusion) { - return DefaultFusionBoundaryFn; +HloInstructionAdaptor HloInstructionAdaptor::GetOperand(int index) const { + return HloInstructionAdaptor{*ResolveOperand(instruction_->operand(index))}; +} + +absl::InlinedVector HloInstructionAdaptor::GetUsers() + const { + absl::InlinedVector users; + auto add_user = [&](const HloInstruction* instr) { + users.emplace_back(*instr); + }; + + if (instruction_->IsRoot()) { + if (auto* fusion = instruction_->parent()->FusionInstruction()) { + for (auto* user : fusion->users()) { + ResolveUsers(fusion, user, add_user); + } + } + } + + for (auto* user : instruction_->users()) { + ResolveUsers(instruction_, user, add_user); } - return [](const HloInstruction&, const HloInstruction&) { return true; }; + + return users; +} + +bool operator==(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs) { + return lhs.instruction_->GetModule() == rhs.instruction_->GetModule() && + lhs.instruction_->unique_id() == rhs.instruction_->unique_id(); } void HloBfsConsumersFirstTraversal( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit) { - absl::flat_hash_set visited; - std::queue q; - auto enqueue_operands = [&](const HloInstruction& node) { - for (const auto* predecessor : FindPredecessors(node, boundary)) { - if (visited.insert(predecessor).second) { - q.push(predecessor); + absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit, + const std::function& visit_arg) { + absl::flat_hash_set visited; + std::queue q; + auto enqueue_operands = [&](const HloInstructionAdaptor& node) { + for (auto operand : node.GetOperands()) { + if (visited.insert(operand).second) { + if (fusion.ContainsInstruction(operand)) { + q.push(operand); + } else { + visit_arg(operand); + } } } }; - - for (auto* root : roots) { + for (auto root : roots) { q.push(root); } while (!q.empty()) { - const HloInstruction* node = q.front(); + HloInstructionAdaptor node = q.front(); q.pop(); - switch (visit(*node)) { + switch (visit(node)) { case TraversalResult::kVisitOperands: - enqueue_operands(*node); + enqueue_operands(node); break; case TraversalResult::kAbortTraversal: return; @@ -121,71 +212,33 @@ void HloBfsConsumersFirstTraversal( } void FindFusionArguments( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit) { - absl::flat_hash_set visited; + const HloFusionAdaptor& fusion, + const std::function& visit) { HloBfsConsumersFirstTraversal( - roots, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - auto is_boundary = boundary(producer, consumer); - if (is_boundary) { - if (visited.insert(&producer).second) { - visit(producer); - } - } - return is_boundary; - }, - [&](const HloInstruction&) { return TraversalResult::kVisitOperands; }); + fusion.GetRoots(), fusion, + [&](HloInstructionAdaptor) { return TraversalResult::kVisitOperands; }, + visit); } -bool HloAnyOf(absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit) { - return HloFindIf(roots, boundary, visit) != nullptr; -} - -const HloInstruction* HloFindIf( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit) { - const HloInstruction* result = nullptr; - HloBfsConsumersFirstTraversal(roots, boundary, - [&](const HloInstruction& node) { - if (visit(node)) { - result = &node; - return TraversalResult::kAbortTraversal; - } - return TraversalResult::kVisitOperands; - }); - return result; +bool HloAnyOf(absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit) { + return HloFindIf(roots, fusion, visit).has_value(); } -absl::InlinedVector FindPredecessors( - const HloInstruction& node, const FusionBoundaryFn& boundary) { - absl::InlinedVector predecessors; - auto visit = [&](const HloInstruction& predecessor) { - if (!boundary(predecessor, node)) { - predecessors.push_back(&predecessor); +std::optional HloFindIf( + absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit) { + std::optional result = std::nullopt; + HloBfsConsumersFirstTraversal(roots, fusion, [&](HloInstructionAdaptor node) { + if (visit(node)) { + result = node; + return TraversalResult::kAbortTraversal; } - }; - - switch (node.opcode()) { - case HloOpcode::kParameter: - if (auto* fusion = node.parent()->FusionInstruction()) { - // If the parent is the entry computation, there's no predecessor. - visit(*fusion->operand(node.parameter_number())); - } - break; - case HloOpcode::kFusion: - visit(*node.fused_expression_root()); - break; - default: - for (HloInstruction* operand : node.operands()) { - visit(*operand); - } - } - return predecessors; + return TraversalResult::kVisitOperands; + }); + return result; } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.h b/third_party/xla/xla/service/gpu/hlo_traversal.h index 97a71ff39f8901..9a4f29d621f922 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.h +++ b/third_party/xla/xla/service/gpu/hlo_traversal.h @@ -20,10 +20,77 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" namespace xla { namespace gpu { +// Treats HloInstructions as if they were unfused. +class HloInstructionAdaptor { + public: + HloInstructionAdaptor() = default; + explicit HloInstructionAdaptor(const HloInstruction& instruction) + : instruction_(&instruction) {} + + HloOpcode opcode() const { return instruction_->opcode(); } + absl::string_view name() const { return instruction_->name(); } + + HloInstructionAdaptor GetOperand(int index) const; + absl::InlinedVector GetOperands() const; + absl::InlinedVector GetUsers() const; + const xla::Shape& shape() const { return instruction_->shape(); } + std::string ToString() const { return instruction_->ToString(); } + + friend bool operator==(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs); + template + friend H AbslHashValue(H h, const HloInstructionAdaptor& m); + + // Use sparingly; prefer extending the interface. + const HloInstruction& instruction() const { return *instruction_; } + + private: + const HloInstruction* instruction_; +}; + +template +H AbslHashValue(H h, const HloInstructionAdaptor& m) { + return H::combine(std::move(h), m.instruction_->GetModule(), + m.instruction_->unique_id()); +} + +class HloFusionAdaptor { + public: + virtual ~HloFusionAdaptor() = default; + virtual bool ContainsInstruction(HloInstructionAdaptor instruction) const = 0; + virtual absl::InlinedVector GetRoots() const = 0; + + static std::unique_ptr ForInstruction( + const HloInstruction* instruction); + static std::unique_ptr ForComputation( + const HloComputation* computation); +}; + +class ProducerConsumerFusion : public HloFusionAdaptor { + public: + ProducerConsumerFusion(std::unique_ptr producer, + std::unique_ptr consumer) + : producer_(std::move(producer)), consumer_(std::move(consumer)) {} + + bool ContainsInstruction(HloInstructionAdaptor instruction) const override { + return producer_->ContainsInstruction(instruction) || + consumer_->ContainsInstruction(instruction); + } + + absl::InlinedVector GetRoots() const override { + return consumer_->GetRoots(); + } + + private: + std::unique_ptr producer_; + std::unique_ptr consumer_; +}; + enum class TraversalResult { // Visit the operands of this node. kVisitOperands, @@ -35,56 +102,35 @@ enum class TraversalResult { kDoNotVisitOperands, }; -using FusionBoundaryFn = std::function; - -// Boundary function for HloFusionInstructions. -bool DefaultFusionBoundaryFn(const HloInstruction& producer, - const HloInstruction& consumer); - -// Creates a fusion boundary function for fusing the given producer and -// consumer. `fused_consumer` must be a consumer of `fused_producer`. -FusionBoundaryFn MakeProducerConsumerFusion( - const HloInstruction& fused_producer, const HloInstruction& fused_consumer); - -// Creates a fusion boundary function for a fusion consisting only of `root`. If -// `root` is a fusion, the result is the same as `DefaultFusionBuondaryFn`. If -// `root` is the root of a fusion, the result is just that root, not the entire -// computation. -FusionBoundaryFn MakeSingleInstructionFusion(const HloInstruction& root); - // Visit the HLO nodes starting from `roots` in BFS order (consumers before -// producers). Each node will be visited exactly once. The graph is not -// traversed along edges for which `boundary` returns true. +// producers). Each node will be visited exactly once. void HloBfsConsumersFirstTraversal( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit); + absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& + visit_node, + const std::function& visit_arg = + [](HloInstructionAdaptor) {}); // Visit the HLO nodes starting from `roots`, returning true if the return value // of `visit` for any of nodes is true. Uses the same order as // `HloBfsConsumersFirstTraversal`. -bool HloAnyOf(absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit); +bool HloAnyOf(absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit); // Visit the HLO nodes stating from `roots`, returning the first // node for which `visit` returns true, or `nullptr` if no node matches. Uses // the same order as `HloBfsConsumersFirstTraversal`. -const HloInstruction* HloFindIf( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit); +std::optional HloFindIf( + absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit); // Visit the producers of all parameters that are needed by the fusion. void FindFusionArguments( - absl::Span roots, - const FusionBoundaryFn& boundary, - const std::function& visit); - -// Returns all predecessors of node that lie within the boundary. -absl::InlinedVector FindPredecessors( - const HloInstruction& node, const FusionBoundaryFn& boundary); + const HloFusionAdaptor& fusion, + const std::function& visit); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc index ca29d9b06906ef..d6f18c17a30f55 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/gpu_fusible.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -58,153 +57,133 @@ const char kTestModule[] = R"( ROOT difference = f32[] subtract(fusion, p0) })"; -TEST_F(HloTraversalTest, TraverseFusion) { +TEST_F(HloTraversalTest, AdaptorOperands) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); - std::vector visited_nodes; - HloBfsConsumersFirstTraversal( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, [&](const HloInstruction& node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; - }); - EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul", "p0.1", "p1.1")); + HloInstructionAdaptor instr{ + *module->entry_computation()->GetInstructionWithName("difference")}; + + auto operands = instr.GetOperands(); + ASSERT_EQ(operands.size(), 2); + EXPECT_EQ(operands[0].name(), "reduce.1"); + EXPECT_EQ(operands[1].name(), "p0"); } -TEST_F(HloTraversalTest, TraverseFusionPartially) { - auto module = ParseAndReturnVerifiedModule(kTestModule).value(); - std::vector visited_nodes; - HloBfsConsumersFirstTraversal( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, [&](const HloInstruction& node) { - visited_nodes.emplace_back(node.name()); - return node.opcode() == HloOpcode::kReduce - ? TraversalResult::kVisitOperands - : TraversalResult::kDoNotVisitOperands; - }); +TEST_F(HloTraversalTest, AdaptorUsers) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test - EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul", "p0.1")); + fused_computation { + p0 = f32[] parameter(0) + neg = f32[] negate(p0) + add = f32[] add(p0, neg) + ROOT t = (f32[], f32[]) tuple(neg, add) + } + + ENTRY entry { + p0 = f32[] parameter(0) + fusion = (f32[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + ROOT gte = f32[] get-tuple-element(fusion), index=0 + } + )") + .value(); + + auto get_single_user = [](auto instr) { + auto users = instr.GetUsers(); + EXPECT_EQ(users.size(), 1); + return users[0]; + }; + + HloInstructionAdaptor add{*module->GetComputationWithName("fused_computation") + ->GetInstructionWithName("add")}; + EXPECT_EQ(get_single_user(add).name(), "t"); + EXPECT_EQ(get_single_user(get_single_user(add)).name(), "gte"); } -TEST_F(HloTraversalTest, AbortTraversal) { +TEST_F(HloTraversalTest, TraverseFusion) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; + std::vector visited_args; + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); HloBfsConsumersFirstTraversal( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, [&](const HloInstruction& node) { + fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { visited_nodes.emplace_back(node.name()); - return node.opcode() == HloOpcode::kReduce - ? TraversalResult::kVisitOperands - : TraversalResult::kAbortTraversal; + return TraversalResult::kVisitOperands; + }, + [&](HloInstructionAdaptor arg) { + visited_args.emplace_back(arg.name()); }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); + EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } -TEST_F(HloTraversalTest, TraversePartialFusion) { - // Verifies that we correctly traverse the fusion that would result if we - // fused the negation into fused_computation. +TEST_F(HloTraversalTest, AbortTraversal) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); std::vector visited_nodes; + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return node.opcode() == HloOpcode::kReduce + ? TraversalResult::kVisitOperands + : TraversalResult::kAbortTraversal; + }); - auto* fused_computation = module->GetComputationWithName("fused_computation"); - HloBfsConsumersFirstTraversal( - {fused_computation->root_instruction()}, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - return &consumer == fused_computation->parameter_instruction(0) || - consumer.opcode() == HloOpcode::kNegate; - }, - [&](const HloInstruction& node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; - }); - - EXPECT_THAT(visited_nodes, - ElementsAre("reduce.1", "mul", "p0.1", "p1.1", "negate")); + EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); } TEST_F(HloTraversalTest, FindArguments) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); std::vector producers; - FindFusionArguments( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, [&](const HloInstruction& producer) { - producers.emplace_back(producer.name()); - }); + FindFusionArguments(*fusion, [&](HloInstructionAdaptor producer) { + producers.emplace_back(producer.name()); + }); EXPECT_THAT(producers, ElementsAre("p0", "negate")); } TEST_F(HloTraversalTest, FindArgumentsAfterFusion) { // Verifies that we correctly find the arguments after fusing the negation. auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto producer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("negate")); + auto consumer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); std::vector producers; - auto* fused_computation = module->GetComputationWithName("fused_computation"); FindFusionArguments( - {fused_computation->root_instruction()}, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - return &consumer == fused_computation->parameter_instruction(0) || - consumer.opcode() == HloOpcode::kNegate; - }, - [&](const HloInstruction& producer) { + ProducerConsumerFusion(std::move(producer), std::move(consumer)), + [&](HloInstructionAdaptor producer) { producers.emplace_back(producer.name()); }); EXPECT_THAT(producers, ElementsAre("p0", "log")); } -TEST_F(HloTraversalTest, FuseEverything) { - auto module = ParseAndReturnVerifiedModule(kTestModule).value(); - std::vector producers; - auto* fused_computation = module->GetComputationWithName("fused_computation"); - FindFusionArguments( - {fused_computation->root_instruction()}, - [&](const HloInstruction& producer, const HloInstruction& consumer) { - return producer.opcode() == HloOpcode::kParameter && - producer.parent()->IsEntryComputation(); - }, - [&](const HloInstruction& producer) { - producers.emplace_back(producer.name()); - }); - EXPECT_THAT(producers, ElementsAre("p0", "p1")); -} - -TEST_F(HloTraversalTest, FuseConsumer) { - auto module = ParseAndReturnVerifiedModule(kTestModule).value(); - std::vector visited_nodes; - HloBfsConsumersFirstTraversal( - {module->entry_computation()->root_instruction()}, - [](const HloInstruction& producer, const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kParameter || - (producer.opcode() == HloOpcode::kParameter && - consumer.opcode() == HloOpcode::kSubtract); - }, - [&](const HloInstruction& node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; - }); - EXPECT_THAT(visited_nodes, ElementsAre("difference", "fusion", "reduce.1", - "mul", "p0.1", "p1.1")); -} - TEST_F(HloTraversalTest, FindIf) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); std::vector visited_nodes; - auto* result = HloFindIf( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, [&](const HloInstruction& node) { + auto result = + HloFindIf(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { return node.opcode() == HloOpcode::kMultiply; }); - ASSERT_NE(result, nullptr); + ASSERT_NE(result, std::nullopt); ASSERT_EQ(result->name(), "mul"); } TEST_F(HloTraversalTest, NotFound) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); std::vector visited_nodes; - auto* result = HloFindIf( - {module->GetComputationWithName("fused_computation")->root_instruction()}, - DefaultFusionBoundaryFn, - [&](const HloInstruction& node) { return false; }); - ASSERT_EQ(result, nullptr); + auto result = HloFindIf(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { return false; }); + ASSERT_EQ(result, std::nullopt); } const char kTwoFusions[] = R"( @@ -241,87 +220,85 @@ const char kTwoFusions[] = R"( TEST_F(HloTraversalTest, FuseFusionConsumer) { auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); - auto* producer = - module->entry_computation()->GetInstructionWithName("negate"); - auto* consumer = - module->entry_computation()->GetInstructionWithName("fusion.1"); - auto roots = GetFusionRoots(*consumer->fused_instructions_computation()); - auto boundary = MakeProducerConsumerFusion(*producer, *consumer); + auto producer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("negate")); + auto consumer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion.1")); + ProducerConsumerFusion fusion(std::move(producer), std::move(consumer)); + std::vector nodes; - HloBfsConsumersFirstTraversal(roots, boundary, - [&](const HloInstruction& node) { - nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; - }); std::vector params; - FindFusionArguments(roots, boundary, [&](const HloInstruction& param) { - params.emplace_back(param.name()); - }); + HloBfsConsumersFirstTraversal( + fusion.GetRoots(), fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kVisitOperands; + }, + [&](HloInstructionAdaptor param) { params.emplace_back(param.name()); }); - EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul", "p0.1", "p1.1", "negate")); + EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul", "negate")); EXPECT_THAT(params, ElementsAre("p0", "sum")); } TEST_F(HloTraversalTest, FuseFusionProducer) { auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); - auto* producer = - module->entry_computation()->GetInstructionWithName("fusion.2"); - auto* consumer = - module->entry_computation()->GetInstructionWithName("difference"); - auto boundary = MakeProducerConsumerFusion(*producer, *consumer); + auto producer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion.2")); + auto consumer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("difference")); + ProducerConsumerFusion fusion(std::move(producer), std::move(consumer)); + std::vector nodes; - HloBfsConsumersFirstTraversal({consumer}, boundary, - [&](const HloInstruction& node) { - nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; - }); std::vector params; - FindFusionArguments({consumer}, boundary, [&](const HloInstruction& param) { - params.emplace_back(param.name()); - }); + HloBfsConsumersFirstTraversal( + fusion.GetRoots(), fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kVisitOperands; + }, + [&](HloInstructionAdaptor arg) { params.emplace_back(arg.name()); }); - EXPECT_THAT( - nodes, ElementsAre("difference", "fusion.2", "reduce.2", "p1.2", "p0.2")); - EXPECT_THAT(params, ElementsAre("p0", "negate", "fusion.1")); + EXPECT_THAT(nodes, ElementsAre("difference", "reduce.2")); + EXPECT_THAT(params, ElementsAre("p0", "negate", "reduce.1")); } TEST_F(HloTraversalTest, FuseFusionConsumerAndProducer) { auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); - auto* producer = - module->entry_computation()->GetInstructionWithName("fusion.1"); - auto* consumer = - module->entry_computation()->GetInstructionWithName("fusion.2"); + auto producer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion.1")); + auto consumer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion.2")); + ProducerConsumerFusion fusion(std::move(producer), std::move(consumer)); - auto roots = GetFusionRoots(*consumer->fused_instructions_computation()); - auto boundary = MakeProducerConsumerFusion(*producer, *consumer); std::vector nodes; - HloBfsConsumersFirstTraversal(roots, boundary, - [&](const HloInstruction& node) { + HloBfsConsumersFirstTraversal(fusion.GetRoots(), fusion, + [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); return TraversalResult::kVisitOperands; }); std::vector params; - FindFusionArguments(roots, boundary, [&](const HloInstruction& param) { + FindFusionArguments(fusion, [&](const HloInstructionAdaptor& param) { params.emplace_back(param.name()); }); - EXPECT_THAT(nodes, ElementsAre("reduce.2", "p1.2", "p0.2", "fusion.1", - "reduce.1", "mul", "p0.1", "p1.1")); + EXPECT_THAT(nodes, ElementsAre("reduce.2", "reduce.1", "mul")); EXPECT_THAT(params, ElementsAre("negate", "p0")); } TEST_F(HloTraversalTest, FuseNonFusionConsumerAndProducer) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); - auto* producer = module->entry_computation()->GetInstructionWithName("log"); - auto* consumer = - module->entry_computation()->GetInstructionWithName("negate"); - auto boundary = MakeProducerConsumerFusion(*producer, *consumer); + auto producer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("log")); + auto consumer = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("negate")); + ProducerConsumerFusion fusion(std::move(producer), std::move(consumer)); + std::vector nodes; - HloBfsConsumersFirstTraversal({consumer}, boundary, - [&](const HloInstruction& node) { + HloBfsConsumersFirstTraversal(fusion.GetRoots(), fusion, + [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); return TraversalResult::kVisitOperands; }); @@ -331,29 +308,27 @@ TEST_F(HloTraversalTest, FuseNonFusionConsumerAndProducer) { TEST_F(HloTraversalTest, SingleInstructionFusionOfFusion) { auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); - auto* fusion = - module->entry_computation()->GetInstructionWithName("fusion.1"); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion.1")); - auto boundary = MakeSingleInstructionFusion(*fusion); std::vector nodes; - HloBfsConsumersFirstTraversal({fusion}, boundary, - [&](const HloInstruction& node) { + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); return TraversalResult::kVisitOperands; }); - EXPECT_THAT(nodes, - ElementsAre("fusion.1", "reduce.1", "mul", "p0.1", "p1.1")); + EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul")); } TEST_F(HloTraversalTest, SingleInstructionFusionOfInstruction) { auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); - auto* negate = module->entry_computation()->GetInstructionWithName("negate"); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("negate")); - auto boundary = MakeSingleInstructionFusion(*negate); std::vector nodes; - HloBfsConsumersFirstTraversal({negate}, boundary, - [&](const HloInstruction& node) { + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); return TraversalResult::kVisitOperands; }); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 3fc8f0c302a481..be7a9438d2f3c2 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -765,7 +765,7 @@ std::optional GetDescriptionForTiledTransposeEmitter( } bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, - FusionBoundaryFn boundary) { + const HloFusionAdaptor* fusion) { // Number of operands should be in range [1, allowed_operand_count]. if (instr->operand_count() == 0 || instr->operand_count() > allowed_operand_count) { @@ -774,14 +774,13 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, // Intermediate `instr` can't have multiple users. // If we have a boundary function, only consider users within the - // boundary. This isn't really correct, since the real users aren't - // necessarily the instruction's users at this point. + // boundary. // TODO(jreiffers): Figure out the point of this check. int64_t num_users = - boundary ? absl::c_count_if( - instr->users(), - [&](const auto* user) { return !boundary(*instr, *user); }) - : instr->user_count(); + fusion ? absl::c_count_if( + HloInstructionAdaptor{*instr}.GetUsers(), + [&](auto user) { return fusion->ContainsInstruction(user); }) + : instr->user_count(); if (num_users > 1) { return false; } @@ -815,57 +814,59 @@ static bool IsParameter(const HloInstruction& instr) { return instr.opcode() == HloOpcode::kParameter; } -const HloInstruction& FindNonTrivialHero( - const HloInstruction& instr, - const std::function& is_boundary) { - const HloInstruction* idx = &instr; +const HloInstruction& FindNonTrivialHero(const HloInstruction& instr, + const HloFusionAdaptor& fusion) { + HloInstructionAdaptor idx{instr}; // Go up the chain of trivial element-wise(+bitcast, -copy) operations. Such // chains are bound to be quite small, as we restrict the number of users as // well. Note that no memoization is needed due to user number constraints: we // never have to revisit same nodes. - auto get_intermediate_arg = [&](const HloInstruction* node) { - if (node->opcode() == HloOpcode::kFusion || - node->opcode() == HloOpcode::kParameter) { - auto preds = FindPredecessors(*node, is_boundary); - return preds.size() == 1 ? preds.front() : nullptr; + auto get_intermediate_arg = + [&](HloInstructionAdaptor node) -> std::optional { + if (IsIntermediate(&node.instruction(), 1, &fusion) && + fusion.ContainsInstruction(node.GetOperand(0))) { + return node.GetOperand(0); } - return IsIntermediate(node, 1, is_boundary) && - !is_boundary(*node->operand(0), *node) - ? node->operand(0) - : nullptr; + return std::nullopt; }; - while (auto* arg = get_intermediate_arg(idx)) { - idx = arg; + while (auto arg = get_intermediate_arg(idx)) { + idx = *arg; + } + + // The reduction emitter can't handle multiple users. + if (idx.opcode() == HloOpcode::kReduce && + absl::c_count_if(idx.GetUsers(), [&](auto user) { + return fusion.ContainsInstruction(user); + }) > 1) { + return instr; } - const HloInstruction* transpose = nullptr; + std::optional transpose = std::nullopt; // Try a bit harder to find a transpose hero. The shared memory transpose // emitter also works if there are ops with more than 1 operand on the path // between root and the transpose op, we still want the restriction though // that each op on the path is elementwise and has only 1 user. - auto visit = [&transpose](const HloInstruction& node) { - if (FindTiledLogicalTranspose(node)) { + auto visit = [&transpose](HloInstructionAdaptor node) { + if (FindTiledLogicalTranspose(node.instruction())) { // If we do not find a unique transpose op, use the original non-trivial // hero. if (transpose) { - transpose = nullptr; + transpose = std::nullopt; return TraversalResult::kAbortTraversal; } - transpose = &node; + transpose = node; return TraversalResult::kDoNotVisitOperands; } - if (node.opcode() != HloOpcode::kParameter && - node.opcode() != HloOpcode::kFusion && - !IsIntermediate(&node, /*allowed_operand_count=*/3)) { + if (!IsIntermediate(&node.instruction(), /*allowed_operand_count=*/3)) { return TraversalResult::kDoNotVisitOperands; } return TraversalResult::kVisitOperands; }; - HloBfsConsumersFirstTraversal({idx}, is_boundary, visit); - return transpose ? *transpose : *idx; + HloBfsConsumersFirstTraversal({idx}, fusion, visit); + + return transpose ? transpose->instruction() : idx.instruction(); } const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) { @@ -873,10 +874,9 @@ const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) { // happens. Return the fusion itself for historical reasons. // TODO(jreiffers): Clean this up. if (instr.opcode() == HloOpcode::kFusion) return instr; - return FindNonTrivialHero(instr, [](const HloInstruction& producer, - const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kParameter; - }); + + return FindNonTrivialHero(instr, + *HloFusionAdaptor::ForComputation(instr.parent())); } void VLogModule(int level, const llvm::Module& module) { diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 172d405086f947..f3349762868c9c 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -156,12 +156,11 @@ Shape GetShape(mlir::Value value); // or vice versa. // Note: when this is called with a fusion instruction, it will traverse into // the fusion (unless the boundary function stops it). -const HloInstruction& FindNonTrivialHero( - const HloInstruction& instr, - const std::function& is_boundary); -// Like above, with the default boundary function. Additionally, this will not -// traverse into `instr`'s computation if it is a fusion. +const HloInstruction& FindNonTrivialHero(const HloInstruction& instr, + const HloFusionAdaptor& fusion); + +// Like above, but assumes the instruction is inside an HloFusionInstruction. +// Returns the instruction itself if it is an HloFusionInstruction. const HloInstruction& FindNonTrivialHero(const HloInstruction& instr); /// Description of how to emit a given transposition. @@ -206,8 +205,10 @@ std::optional FindTiledLogicalTranspose( std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& root, const HloInstruction& hero); +// Checks if the instruction is elementwise and only has a single user. If +// a fusion adaptor is provided, only checks for users within the fusion. bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1, - FusionBoundaryFn boundary = nullptr); + const HloFusionAdaptor* fusion = nullptr); // Log the given module if the VLOG level is >= level. void VLogModule(int level, const llvm::Module& module); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 21a05ae26a757d..1b3e323a58db76 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -184,7 +184,42 @@ ENTRY entry { // emitter is fast for S8 output. EXPECT_FALSE( GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)).has_value()); - EXPECT_EQ(&FindNonTrivialHero(*r), r->operand(0)); + EXPECT_EQ(FindNonTrivialHero(*r).name(), "t"); +} + +TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) { + const char* hlo = R"( + HloModule module + + %add { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) + } + + %fused_computation (param_0.4: f32[128,64], param_1.4: bf16[]) -> bf16[64] { + %param_0 = f32[128,64]{1,0} parameter(0) + %param_1 = bf16[] parameter(1) + %convert.0 = f32[] convert(bf16[] %param_1) + %reduce.0 = f32[64]{0} reduce(f32[128,64]{1,0} %param_0, f32[] %convert.0), dimensions={0}, to_apply=%add + ROOT %convert.1 = bf16[64]{0} convert(f32[64]{0} %reduce.0) + } + + ENTRY %main { + %param_0 = f32[128,64]{1,0} parameter(0) + %param_1 = bf16[] parameter(1) + ROOT fusion = bf16[64]{0} fusion(%param_0, %param_1), kind=kInput, calls=fused_computation + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* r = module->entry_computation()->root_instruction(); + auto fusion = HloFusionAdaptor::ForInstruction(r); + const auto& result = + FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion); + EXPECT_EQ(result.name(), "reduce.0"); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateBinaryOp) { @@ -280,46 +315,13 @@ ENTRY entry { HloInstruction* r = module->GetComputationWithName("f")->root_instruction(); HloInstruction* transpose = - module->entry_computation()->parameter_instruction(0)->users().front(); + module->entry_computation()->GetInstructionWithName("t"); + HloInstruction* fusion = + module->entry_computation()->GetInstructionWithName("fusion"); EXPECT_EQ( - &FindNonTrivialHero( - *r, - [](const HloInstruction& producer, const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kTranspose; - }), - transpose); -} - -TEST_F(IrEmissionUtilsTest, FindNonTrivialHeroThroughFusion) { - const char* hlo = R"( -HloModule module - -f { - p0 = f32[100,200,300]{2,1,0} parameter(0) - ROOT add = f32[100,200,300]{2,1,0} add(p0, p0) -} - -ENTRY entry { - p0 = f32[300,200,100]{2,1,0} parameter(0) - p1 = f32[100,200,300]{2,1,0} parameter(1) - t = f32[100,200,300]{2,1,0} transpose(p0), dimensions={2,1,0} - fusion = f32[100,200,300]{2,1,0} fusion(t), kind=kLoop, calls=f - ROOT add = f32[100,200,300]{2,1,0} add(p1, fusion) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - - HloInstruction* r = module->entry_computation()->root_instruction(); - HloInstruction* transpose = - module->entry_computation()->parameter_instruction(0)->users().front(); - EXPECT_EQ( - &FindNonTrivialHero( - *r, - [](const HloInstruction& producer, const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kTranspose; - }), + &FindNonTrivialHero(*r, ProducerConsumerFusion( + HloFusionAdaptor::ForInstruction(transpose), + HloFusionAdaptor::ForInstruction(fusion))), transpose); } @@ -349,50 +351,15 @@ ENTRY entry { ->parameter_instruction(0) ->users() .front(); + HloInstruction* fusion = + module->entry_computation()->GetInstructionWithName("fusion"); EXPECT_EQ( &FindNonTrivialHero( - *r, - [](const HloInstruction& producer, const HloInstruction& consumer) { - return consumer.opcode() == HloOpcode::kParameter; - }), + *r, ProducerConsumerFusion(HloFusionAdaptor::ForInstruction(fusion), + HloFusionAdaptor::ForInstruction(r))), transpose); } -TEST_F(IrEmissionUtilsTest, FindNonTrivialHeroSomeOperandsInFusion) { - const char* hlo = R"( -HloModule module - -ENTRY entry { - p0 = f32[300,200,100]{2,1,0} parameter(0) - p1 = f32[100,200,300]{2,1,0} parameter(1) - - transpose = f32[100,200,300]{2,1,0} transpose(p0), dimensions={2,1,0} - subtract = f32[100,200,300]{2,1,0} subtract(transpose, p1) - ROOT add = f32[100,200,300]{2,1,0} add(subtract, p1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - - HloInstruction* r = module->entry_computation()->root_instruction(); - HloInstruction* transpose = - module->entry_computation()->parameter_instruction(0)->users().front(); - // The transpose is the hero if everything is on one fusion. - EXPECT_EQ(&FindNonTrivialHero( - *r, [](const HloInstruction& producer, - const HloInstruction& consumer) { return false; }), - transpose); - // The transpose isn't the hero if we cut the fusion at the subtraction. - EXPECT_EQ( - &FindNonTrivialHero( - *r, - [](const HloInstruction& producer, const HloInstruction& consumer) { - return producer.opcode() == HloOpcode::kSubtract; - }), - r); -} - TEST_F(IrEmissionUtilsTest, FindTiledTransposeOneSwapDimIsSmall) { const char* hlo = R"( HloModule module diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index acdbd80fb94cd8..ab8d1c8cf59b6c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -1281,17 +1281,17 @@ class MatMulEmitterHelper { } // namespace -LaunchDimensions GetMatMulLaunchDimensions( - const TritonFusionAnalysis& analysis, - absl::Span roots, - const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config) { - const auto* dot = static_cast( - HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) { - return node.opcode() == HloOpcode::kDot; - })); - CHECK_NE(dot, nullptr); - const MatMulDims dims(config, *dot, analysis); - const MatMulLaunchConfig launch_config(config, *dot, dims); +LaunchDimensions GetMatMulLaunchDimensions(const TritonFusionAnalysis& analysis, + const HloFusionAdaptor& fusion, + const TritonGemmConfig& config) { + auto dot = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { + return node.opcode() == HloOpcode::kDot; + }); + CHECK(dot != std::nullopt); + const auto& dot_instr = + *static_cast(&dot->instruction()); + MatMulDims dims(config, dot_instr, analysis); + MatMulLaunchConfig launch_config(config, dot_instr, dims); return launch_config.launch_dims; } @@ -1560,15 +1560,13 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, return OkStatus(); } -LaunchDimensions GetSoftMaxLaunchDimensions( - absl::Span roots, - const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config) { - const HloInstruction* reduce = - HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) { - return node.opcode() == HloOpcode::kReduce; - }); - CHECK_NE(reduce, nullptr); - const Shape& reduce_input_shape = reduce->operand(0)->shape(); +LaunchDimensions GetSoftMaxLaunchDimensions(const HloFusionAdaptor& fusion, + const TritonGemmConfig& config) { + auto reduce = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { + return node.opcode() == HloOpcode::kReduce; + }); + CHECK(reduce != std::nullopt); + const Shape& reduce_input_shape = reduce->instruction().operand(0)->shape(); int num_rows = 1; for (int minor_axis = 1; minor_axis < reduce_input_shape.rank(); ++minor_axis) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.h b/third_party/xla/xla/service/gpu/ir_emitter_triton.h index 52d66dd9c8c9a0..dbb7b160f06504 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.h @@ -42,10 +42,9 @@ struct TritonWrapperResult { }; // Compute the launch dimensions for the given Triton MatMul. -LaunchDimensions GetMatMulLaunchDimensions( - const TritonFusionAnalysis& analysis, - absl::Span roots, - const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config); +LaunchDimensions GetMatMulLaunchDimensions(const TritonFusionAnalysis& analysis, + const HloFusionAdaptor& fusion, + const TritonGemmConfig& config); // Use tiling and execution parameters from 'config'. Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, const TritonFusionAnalysis& analysis, @@ -53,9 +52,8 @@ Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, const TritonGemmConfig& config, int shmem_budget); // Compute the launch dimensions for the given Triton SoftMax. -LaunchDimensions GetSoftMaxLaunchDimensions( - absl::Span roots, - const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config); +LaunchDimensions GetSoftMaxLaunchDimensions(const HloFusionAdaptor& fusion, + const TritonGemmConfig& config); // Generate Softmax in Triton IR inside 'fn'. // Use execution parameters from 'config'. Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 4015b728ae60b7..cd3c05a63d369d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -2110,9 +2110,8 @@ StatusOr IrEmitterUnnested::EmitTritonFusion( ir_emitter_context_->cuda_compute_capability(), ir_emitter_context_->gpu_device_info(), config, module_, &EmitSoftMax, *ir_emitter_context_->mlir_context())); - launch_dimensions = GetSoftMaxLaunchDimensions( - hlo_fusion_analysis.fusion_roots(), - hlo_fusion_analysis.fusion_boundary(), config); + launch_dimensions = + GetSoftMaxLaunchDimensions(hlo_fusion_analysis.fusion(), config); } else { // Must be a MatMul CHECK_EQ(fusion_kind, kTritonGemmFusionKind); if (!backend_config.has_triton_gemm_config()) { @@ -2144,8 +2143,7 @@ StatusOr IrEmitterUnnested::EmitTritonFusion( ir_emitter_context_->gpu_device_info(), config, module_, &EmitMatMul, *ir_emitter_context_->mlir_context())); launch_dimensions = GetMatMulLaunchDimensions( - analysis, hlo_fusion_analysis.fusion_roots(), - hlo_fusion_analysis.fusion_boundary(), config); + analysis, hlo_fusion_analysis.fusion(), config); } llvm::Function* impl_fn = module_->getFunction(impl_fn_name); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 6350d127ad01a1..491170ff67df00 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -481,8 +481,8 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, // understand this case due to a lack of tiling analysis. // TODO(b/312200883): Remove this. auto contains_reduce = [&](const HloInstruction* instr) { - return HloAnyOf({instr}, MakeSingleInstructionFusion(*instr), - [](const HloInstruction& node) { + return HloAnyOf({HloInstructionAdaptor{*instr}}, + *HloFusionAdaptor::ForInstruction(instr), [](auto node) { return node.opcode() == HloOpcode::kReduce; }); }; From cbe5ea316354f78a3fab2fb0da1d3f62ae846a74 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Nov 2023 07:37:36 -0800 Subject: [PATCH 064/381] Fix documentation of `tensorflow::test::TensorEq` Matchers need to be used with `EXPECT_THAT` not `EXPECT_EQ`. PiperOrigin-RevId: 585093103 --- tensorflow/core/framework/tensor_matcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/framework/tensor_matcher.h b/tensorflow/core/framework/tensor_matcher.h index 094d66f81f72f3..e89cfc15cd1f2a 100644 --- a/tensorflow/core/framework/tensor_matcher.h +++ b/tensorflow/core/framework/tensor_matcher.h @@ -34,7 +34,7 @@ namespace test { // // Use this like: // -// EXPECT_EQ(lhs, TensorEq(rhs)); +// EXPECT_THAT(lhs, TensorEq(rhs)); // // All POD types and DT_STRING type tensors are supported. Note that this // utility requires Tensors to point to CPU memory. From 8c0a58a22054ece262aa5e5968c3bc1fa808cc43 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 24 Nov 2023 07:52:50 -0800 Subject: [PATCH 065/381] [XLA:GPU][TileAnalysis] Add support for dot in tile analysis. This is another step towards tile analysis being able to tile all HLOs. Tiling dot requires some care to ensure that output dimensions are mapped to the appropriate dimensions. The [StableHLO specification for dot_general](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general) describes how output dimensions are constructed from the input dimensions and the operation's attributes. PiperOrigin-RevId: 585095167 --- .../xla/xla/service/gpu/matmul_utils.h | 1 + third_party/xla/xla/service/gpu/model/BUILD | 4 + .../xla/service/gpu/model/tile_analysis.cc | 92 +++++++++++++++++++ .../service/gpu/model/tile_analysis_test.cc | 25 +++++ 4 files changed, 122 insertions(+) diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 07ba566ad128cb..4d329fea522a78 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -43,6 +43,7 @@ limitations under the License. namespace xla { namespace gpu { +// Ordered non-contracting dimensions for a dot instruction operand. StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims); diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 1a3189df9eb272..7ef915f9fcaf6f 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -220,10 +220,14 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu:matmul_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 7005c0c89081a6..e36317645ce7fd 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -25,25 +25,31 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/matmul_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { +using llvm::SmallVector; using mlir::AffineExpr; using mlir::AffineExprKind; using mlir::AffineMap; @@ -220,6 +226,89 @@ StatusOr ComputeFusionOpIndexing( return fusion_indexing; } +StatusOr ComputeDotOpIndexing( + const HloDotInstruction* dot, MLIRContext* mlir_context) { + CHECK_NE(dot, nullptr); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + absl::Span lhs_contracting_dims( + dim_numbers.lhs_contracting_dimensions()); + absl::Span rhs_contracting_dims = + dim_numbers.rhs_contracting_dimensions(); + + absl::Span lhs_batch_dims = dim_numbers.lhs_batch_dimensions(); + absl::Span rhs_batch_dims = dim_numbers.rhs_batch_dimensions(); + + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + // According to the StableHLO specification, the dimensions of the output + // shape are ordered as follows: + // lhs_batch_dims | lhs_non_contracting_dims | rhs_non_contracting_dims + SmallVector lhs_exprs(lhs_shape.rank()); + SmallVector rhs_exprs(rhs_shape.rank()); + int64_t output_dim_id = 0; + + // lhs_batch_dims + for (auto [lhs_batch_dim, rhs_batch_dim] : + llvm::zip(lhs_batch_dims, rhs_batch_dims)) { + AffineExpr output_dim_expr = getAffineDimExpr(output_dim_id, mlir_context); + lhs_exprs[lhs_batch_dim] = output_dim_expr; + rhs_exprs[rhs_batch_dim] = output_dim_expr; + ++output_dim_id; + } + + // lhs_non_contracting_dims + TF_ASSIGN_OR_RETURN( + std::vector lhs_non_contracting_dims, + GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_contracting_dims)); + + for (int64_t lhs_non_contracting_dim : lhs_non_contracting_dims) { + lhs_exprs[lhs_non_contracting_dim] = + getAffineDimExpr(output_dim_id++, mlir_context); + } + + // rhs_non_contracting_dims + TF_ASSIGN_OR_RETURN( + std::vector rhs_non_contracting_dims, + GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_contracting_dims)); + + for (int64_t rhs_non_contracting_dim : rhs_non_contracting_dims) { + rhs_exprs[rhs_non_contracting_dim] = + getAffineDimExpr(output_dim_id++, mlir_context); + } + + int64_t input_dim_id = 0; + std::vector input_dim_sizes; + input_dim_sizes.reserve(lhs_contracting_dims.size()); + + for (auto [lhs_contracting_dim, rhs_contracting_dim] : + llvm::zip(lhs_contracting_dims, rhs_contracting_dims)) { + AffineExpr input_dim_expr = getAffineSymbolExpr(input_dim_id, mlir_context); + lhs_exprs[lhs_contracting_dim] = input_dim_expr; + rhs_exprs[rhs_contracting_dim] = input_dim_expr; + ++input_dim_id; + + // LHS and RHS contracting dimensions must match pairwise, and we therefore + // need only populate a single input_dim_sizes vector. + input_dim_sizes.push_back(lhs_shape.dimensions(lhs_contracting_dim)); + } + + IndexingMap lhs_indexing_map{ + .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), + lhs_exprs, mlir_context), + .input_dims_sizes = input_dim_sizes}; + + IndexingMap rhs_indexing_map{ + .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), + rhs_exprs, mlir_context), + .input_dims_sizes = input_dim_sizes}; + + return HloInstructionIndexing{ + {HloOperandIndexing{.indexing_maps = {std::move(lhs_indexing_map)}, + .operand_id = 0}, + HloOperandIndexing{.indexing_maps = {std::move(rhs_indexing_map)}, + .operand_id = 1}}}; +} + StatusOr ComputeReduceOpIndexing( const HloReduceInstruction* reduce, int output_id, MLIRContext* mlir_context) { @@ -385,6 +474,9 @@ StatusOr ComputeInstructionIndexing( if (auto bcast = DynCast(instr)) { return ComputeBroadcastOpIndexing(bcast, mlir_context); } + if (auto dot = DynCast(instr)) { + return ComputeDotOpIndexing(dot, mlir_context); + } if (auto fusion = DynCast(instr)) { return ComputeFusionOpIndexing(fusion, output_id, mlir_context); } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 452f0424802d09..a93ecb75cf94c6 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -442,6 +442,31 @@ TEST_F(TileAnalysisTest, TransposeOp) { std::vector{}))))); } +TEST_F(TileAnalysisTest, DotOp) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing_or, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4, 38, 17, 11, 18, 10] parameter(0) + p1 = f32[17, 10, 16, 18, 22, 38] parameter(1) + ROOT dot = f32[10, 38, 4, 11, 16, 22] dot(p0, p1), + lhs_batch_dims={5,1}, rhs_batch_dims={1,5}, + lhs_contracting_dims={4,2}, rhs_contracting_dims={3,0} + } + )")); + EXPECT_THAT( + input_indexing_or.operand_indexing_maps, + ElementsAre( + MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " + "(d2, d1, s1, d3, s0, d0)", + std::vector{18, 17}))), + MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " + "(s1, d0, d4, s0, d5, d1)", + std::vector{18, 17}))))); +} + TEST_F(TileAnalysisTest, UnsupportedOp) { auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( HloModule m From f52cc58a7a0f3ffadda87fbb311cdd8cf702b28a Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 24 Nov 2023 08:19:10 -0800 Subject: [PATCH 066/381] Move vectorize_copy.cc and copy_removal.cc passes out of gml_st. They don't use gml_st dialect. PiperOrigin-RevId: 585099478 --- third_party/xla/xla/mlir_hlo/BUILD | 5 +++-- .../mlir_hlo/gml_st/transforms/CMakeLists.txt | 2 -- .../xla/mlir_hlo/gml_st/transforms/passes.h | 7 ------- .../xla/mlir_hlo/gml_st/transforms/passes.td | 21 ------------------- ...dead_copy.mlir => naive_copy_removal.mlir} | 0 .../{Dialect/gml_st => }/vectorize_copy.mlir | 2 +- .../xla/mlir_hlo/transforms/CMakeLists.txt | 2 ++ .../naive_copy_removal.cc} | 12 ++++++----- .../xla/xla/mlir_hlo/transforms/passes.h | 7 +++++++ .../xla/xla/mlir_hlo/transforms/passes.td | 12 +++++++++++ .../vectorize_copy.cc | 19 ++++++----------- .../service/cpu/hlo_xla_runtime_pipeline.cc | 6 ++---- 12 files changed, 40 insertions(+), 55 deletions(-) rename third_party/xla/xla/mlir_hlo/tests/{Dialect/gml_st/simplify_dead_copy.mlir => naive_copy_removal.mlir} (100%) rename third_party/xla/xla/mlir_hlo/tests/{Dialect/gml_st => }/vectorize_copy.mlir (97%) rename third_party/xla/xla/mlir_hlo/{gml_st/transforms/copy_removal/copy_removal.cc => transforms/naive_copy_removal.cc} (93%) rename third_party/xla/xla/mlir_hlo/{gml_st/transforms/vectorization => transforms}/vectorize_copy.cc (95%) diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index abbd1dafbb860d..a1c6c941e9c6bd 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -1289,11 +1289,13 @@ cc_library( "transforms/detensorize_scf_ops.cc", "transforms/generic_host_to_llvm.cc", "transforms/lower_index_cast_pass.cc", + "transforms/naive_copy_removal.cc", "transforms/propagate_static_shapes_to_kernel.cc", "transforms/test_hlo_transform_dialect_interpreter.cc", "transforms/tile_loops_pass.cc", "transforms/unbufferize_pass.cc", "transforms/unroll_loops.cc", + "transforms/vectorize_copy.cc", ], hdrs = [ "transforms/passes.h", @@ -1351,6 +1353,7 @@ cc_library( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:PDLDialect", "@llvm-project//mlir:Pass", @@ -1554,7 +1557,6 @@ cc_library( "gml_st/transforms/collapse_shape/collapse_shape.cc", "gml_st/transforms/collect_stats/collect_stats.cc", "gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc", - "gml_st/transforms/copy_removal/copy_removal.cc", "gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc", "gml_st/transforms/cpu_tiling/fusion_outlining.cc", "gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc", @@ -1578,7 +1580,6 @@ cc_library( "gml_st/transforms/transforms.h", "gml_st/transforms/vectorization/lower_vectors.cc", "gml_st/transforms/vectorization/vectorization.cc", - "gml_st/transforms/vectorization/vectorize_copy.cc", "gml_st/transforms/vectorization/vectorize_for_cpu.cc", "gml_st/utils/linalg_utils.cc", "gml_st/utils/tensor_utils.cc", diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt index 379af7e56a1d9e..5a96a1a77827f0 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt @@ -32,7 +32,6 @@ add_mlir_library(GmlStPasses collapse_shape/collapse_shape.cc collect_stats/collect_stats.cc compose_extract_insert_slice/compose_extract_insert_slice.cc - copy_removal/copy_removal.cc cpu_tiling/cpu_tiling_pipeline.cc cpu_tiling/fusion_outlining.cc cpu_tiling/fusion_planning_for_cpu.cc @@ -54,7 +53,6 @@ add_mlir_library(GmlStPasses tiling_softmax/tiling_softmax.cc vectorization/lower_vectors.cc vectorization/vectorization.cc - vectorization/vectorize_copy.cc vectorization/vectorize_for_cpu.cc DEPENDS diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h index 12b02debc705fa..b9562e722177d3 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h @@ -61,13 +61,6 @@ createComposeExtractInsertSlicePass(); std::unique_ptr> createVectorizeForCPUPass( int64_t numElementsThreshold = 1024); -/// Pass to vectorize `memref.copy`. -std::unique_ptr> createVectorizeCopyPass( - int64_t numElementsThreshold = 8); - -/// Pass to remove redundant `memref.copy` ops. -std::unique_ptr> createNaiveCopyRemovalPass(); - /// Pass to gradually lower vector ops to SCF. std::unique_ptr> createLowerVectorsPass( bool enableAVX2 = true, bool flatten = false); diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td index ce462a3adcd850..cb41946b07eb80 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td @@ -70,27 +70,6 @@ def VectorizeForCPUPass : Pass<"vectorize-for-cpu", "mlir::func::FuncOp"> { ]; } -def VectorizeCopyPass : Pass<"vectorize-copy", "mlir::func::FuncOp"> { - let summary = "Pass to vectorize `memref.copy`."; - let constructor = "::mlir::gml_st::createVectorizeCopyPass()"; - let dependentDialects = [ - "scf::SCFDialect", - "vector::VectorDialect", - ]; - let options = [ - Option<"numElementsThreshold", "num-elements-threshold", "int64_t", - /*default=*/"8", - "Max number of elements in src and dst memref for a copy to be " - "vectorized.">, - ]; -} - -def NaiveCopyRemovalPass : Pass<"naive-copy-removal", "mlir::func::FuncOp"> { - let summary = "Pass to remove redundant `memref.copy` ops."; - let constructor = "::mlir::gml_st::createNaiveCopyRemovalPass()"; - let dependentDialects = ["::mlir::memref::MemRefDialect"]; -} - def LowerVectorsPass : Pass<"lower-vectors", "mlir::func::FuncOp"> { let summary = "Pass to lower vector operations progressively."; let constructor = "::mlir::gml_st::createLowerVectorsPass()"; diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir b/third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir similarity index 100% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir rename to third_party/xla/xla/mlir_hlo/tests/naive_copy_removal.mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir b/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir similarity index 97% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir rename to third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir index b2dff59e53cbce..8c57281a7041c9 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/vectorize_copy.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --vectorize-copy="num-elements-threshold=8" --split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s --vectorize-copy --split-input-file | FileCheck %s func.func @vectorize_copy(%arg: memref<2x2xf32>) -> memref<2x2xf32> { %subview = memref.subview %arg[0, 0] [2, 2] [1, 1] : memref<2x2xf32> to memref<2x2xf32, strided<[16, 1]>> diff --git a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt index 4d61e07d8e6cb1..5187d9707ea74e 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt @@ -31,9 +31,11 @@ add_mlir_library(MLIRBufferTransforms detensorize_scf_ops.cc generic_host_to_llvm.cc lower_index_cast_pass.cc + naive_copy_removal.cc propagate_static_shapes_to_kernel.cc test_hlo_transform_dialect_interpreter.cc tile_loops_pass.cc + vectorize_copy.cc unbufferize_pass.cc unroll_loops.cc diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/copy_removal/copy_removal.cc b/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc similarity index 93% rename from third_party/xla/xla/mlir_hlo/gml_st/transforms/copy_removal/copy_removal.cc rename to third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc index a5ebefdef58b1f..ddd6e6916971f6 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/copy_removal/copy_removal.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,18 @@ limitations under the License. #include #include -#include "gml_st/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "transforms/passes.h" -namespace mlir::gml_st { +namespace mlir { namespace { #define GEN_PASS_DEF_NAIVECOPYREMOVALPASS -#include "gml_st/transforms/passes.h.inc" +#include "transforms/passes.h.inc" /// Remove memref::CopyOp whose target (can be either a memref::SubViewOp or /// memref::AllocOp) has no other users. @@ -88,4 +90,4 @@ std::unique_ptr> createNaiveCopyRemovalPass() { return std::make_unique(); } -} // namespace mlir::gml_st +} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/transforms/passes.h index d18c60320525fa..ac322b01aac4a5 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.h @@ -49,6 +49,7 @@ using BufferizePatternsCallback = std::function> createTileLoopsPass( // and scf.if. std::unique_ptr> createDetensorizeScfOpsPass(); +/// Pass to remove redundant `memref.copy` ops. +std::unique_ptr> createNaiveCopyRemovalPass(); + +/// Pass to vectorize `memref.copy`. +std::unique_ptr> createVectorizeCopyPass(); + /// Registers the test pass for erasing transform dialect ops. void registerTestHloTransformDialectEraseSchedulePass(); diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.td b/third_party/xla/xla/mlir_hlo/transforms/passes.td index a75696b9de2fed..d9966b11f764db 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.td @@ -172,4 +172,16 @@ def AllocToArgPass : Pass<"alloc-to-arg", "mlir::func::FuncOp"> { let constructor = "hlo::createAllocToArgPass()"; } +def NaiveCopyRemovalPass : Pass<"naive-copy-removal", "mlir::func::FuncOp"> { + let summary = "Pass to remove redundant `memref.copy` ops."; + let constructor = "createNaiveCopyRemovalPass()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + +def VectorizeCopyPass : Pass<"vectorize-copy", "mlir::func::FuncOp"> { + let summary = "Pass to vectorize `memref.copy`."; + let constructor = "createVectorizeCopyPass()"; + let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"]; +} + #endif // TENSORFLOW_COMPILER_MLIR_HLO_TRANSFORMS_PASSES diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc b/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc similarity index 95% rename from third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc rename to third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc index 3fc5e99cce5286..4d1a9fa213e0b0 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc @@ -14,24 +14,21 @@ limitations under the License. ==============================================================================*/ #include -#include #include -#include #include -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/vectorization/vectorization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { -namespace gml_st { namespace { #define GEN_PASS_DEF_VECTORIZECOPYPASS -#include "gml_st/transforms/passes.h.inc" +#include "transforms/passes.h.inc" /// Transforms a big non-contiguous `memref.copy` into a loop over smaller /// copies that are either contiguous or can be vectorized. @@ -217,7 +214,7 @@ struct VectorizeCopyPass RewritePatternSet patterns(ctx); patterns.add( - ctx, numElementsThreshold); + ctx, /*numElementsThreshold = */ 8); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { return signalPassFailure(); } @@ -226,12 +223,8 @@ struct VectorizeCopyPass } // namespace -std::unique_ptr> createVectorizeCopyPass( - int64_t numElementsThreshold) { - VectorizeCopyPassOptions opts; - opts.numElementsThreshold = numElementsThreshold; - return std::make_unique(opts); +std::unique_ptr> createVectorizeCopyPass() { + return std::make_unique(); } -} // namespace gml_st } // namespace mlir diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index 535730e5dc41e7..fe402498d76d1e 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -267,10 +267,8 @@ static Status CreateHloXlaPipeline( } pm.addNestedPass(mlir::gml_st::createInlineFusionClustersPass()); - if (options.enable_tiling_and_fusion) { - pm.addNestedPass(mlir::gml_st::createVectorizeCopyPass()); - pm.addNestedPass(mlir::gml_st::createNaiveCopyRemovalPass()); - } + pm.addNestedPass(mlir::createVectorizeCopyPass()); + pm.addNestedPass(mlir::createNaiveCopyRemovalPass()); // Handle framework specific requirements for buffers and then insert // deallocations for temporary buffers. pm.addNestedPass(mlir::createConvertLinalgToLoopsPass()); From 47f46ceb9e26268d71ff8ce0511bc81b3c200a05 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 24 Nov 2023 11:13:12 -0800 Subject: [PATCH 067/381] [TileAnalysis] Add a test for dynamic-update-slice. This op is unsupported by tile analysis. Adding a test so that we don't shoot ourselves in the foot by using `isElementwise` method, for example. PiperOrigin-RevId: 585123986 --- .../xla/service/gpu/model/tile_analysis_test.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index a93ecb75cf94c6..917bea1851f3dc 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -467,16 +467,25 @@ TEST_F(TileAnalysisTest, DotOp) { std::vector{18, 17}))))); } -TEST_F(TileAnalysisTest, UnsupportedOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( +TEST_F(TileAnalysisTest, UnsupportedOps) { + ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[1, 17, 9, 9] parameter(0) p1 = f32[5, 17, 9, 9] parameter(1) ROOT concat = f32[6, 17, 9, 9] concatenate(p0, p1) } - )"); - ASSERT_IS_NOT_OK(input_indexing_or); + )")); + ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + input = s32[1,1,25,1] parameter(0) + update = s32[1,1,2,1] parameter(1) + start_indices = s32[4] parameter(2) + ROOT dyn-update = s32[1,1,25,1] dynamic-update-slice( + s32[1,1,25,1] input, s32[1,1,2,1] update, s32[4] start_indices) + } + )")); } } // namespace From 3f8022ead6ec3eb04073b7f1dff697593485bc69 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 24 Nov 2023 11:21:01 -0800 Subject: [PATCH 068/381] Put thlo & gml_st on ice. We can revert this when/if we need this. PiperOrigin-RevId: 585124862 --- tensorflow/compiler/aot/tests/BUILD | 41 +- tensorflow/compiler/aot/tfcompile.bzl | 36 +- tensorflow/compiler/mlir/tfrt/BUILD | 2 - tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc | 4 - .../compiler/mlir/tools/kernel_gen/BUILD | 1 - .../tools/kernel-gen-opt/kernel-gen-opt.cc | 4 +- .../mlir/tools/kernel_gen/transforms/BUILD | 12 +- third_party/xla/xla/mlir/backends/cpu/BUILD | 6 +- .../xla/xla/mlir/backends/cpu/xla-cpu-opt.cc | 7 - third_party/xla/xla/mlir_hlo/BUILD | 418 ----- third_party/xla/xla/mlir_hlo/CMakeLists.txt | 2 - .../xla/mlir_hlo/bindings/c/CMakeLists.txt | 2 - .../xla/xla/mlir_hlo/gml_st/CMakeLists.txt | 18 - .../xla/xla/mlir_hlo/gml_st/IR/CMakeLists.txt | 43 - .../xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc | 186 --- .../xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.h | 36 - .../xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.td | 77 - .../xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td | 38 - third_party/xla/xla/mlir_hlo/gml_st/README.md | 173 --- .../mlir_hlo/gml_st/interfaces/CMakeLists.txt | 33 - .../bufferizable_op_interface_impl.cc | 216 --- .../bufferizable_op_interface_impl.h | 29 - .../mlir_hlo/gml_st/transforms/CMakeLists.txt | 111 -- .../add_debug_info/add_debug_info.cc | 74 - .../canonicalization/optimize_linalg_ops.cc | 217 --- .../collapse_shape/collapse_shape.cc | 351 ----- .../transforms/collect_stats/collect_stats.cc | 123 -- .../compose_extract_insert_slice.cc | 87 -- .../cpu_tiling/cpu_tiling_pipeline.cc | 113 -- .../transforms/cpu_tiling/fusion_outlining.cc | 119 -- .../cpu_tiling/fusion_planning_for_cpu.cc | 269 ---- .../transforms/cpu_tiling/pack_matmul.cc | 331 ---- .../transforms/cpu_tiling/remove_label.cc | 43 - .../cpu_tiling/transform_dot_for_cpu.cc | 681 -------- .../transform_elementwise_for_cpu.cc | 401 ----- .../cpu_tiling/transform_mmt4d_for_cpu.cc | 148 -- .../cpu_tiling/transform_pack_for_cpu.cc | 143 -- .../cpu_tiling/transform_reduce_for_cpu.cc | 605 -------- .../cpu_tiling/transform_scatter_for_cpu.cc | 105 -- .../gml_st/transforms/fusion/fusion.cc | 796 ---------- .../gml_st/transforms/fusion/fusion.h | 83 - .../xla/mlir_hlo/gml_st/transforms/passes.h | 244 --- .../xla/mlir_hlo/gml_st/transforms/passes.td | 243 --- .../gml_st/transforms/peeling/peeling.cc | 191 --- .../gml_st/transforms/peeling/peeling.h | 51 - .../rewrite_from_elements_op.cc | 65 - .../rewrite_scf_forall/rewrite_scf_forall.cc | 112 -- .../transforms/scalarization/scalarization.cc | 670 -------- .../transforms/scalarization/scalarization.h | 58 - .../mlir_hlo/gml_st/transforms/test_passes.cc | 95 -- .../mlir_hlo/gml_st/transforms/test_passes.h | 39 - .../mlir_hlo/gml_st/transforms/test_passes.td | 21 - .../gml_st/transforms/tiling/tile_by_one.cc | 128 -- .../gml_st/transforms/tiling/tiling.cc | 253 --- .../gml_st/transforms/tiling/tiling.h | 50 - .../tiling_softmax/tiling_softmax.cc | 292 ---- .../mlir_hlo/gml_st/transforms/transforms.cc | 42 - .../mlir_hlo/gml_st/transforms/transforms.h | 47 - .../transforms/vectorization/lower_vectors.cc | 270 ---- .../transforms/vectorization/vectorization.cc | 41 - .../transforms/vectorization/vectorization.h | 53 - .../vectorization/vectorize_for_cpu.cc | 421 ----- .../xla/mlir_hlo/gml_st/utils/CMakeLists.txt | 24 - .../xla/mlir_hlo/gml_st/utils/linalg_utils.cc | 238 --- .../xla/mlir_hlo/gml_st/utils/linalg_utils.h | 77 - .../xla/mlir_hlo/gml_st/utils/tensor_utils.cc | 34 - .../xla/mlir_hlo/gml_st/utils/tensor_utils.h | 57 - .../mlir_hlo/mhlo/transforms/CMakeLists.txt | 34 - .../legalize_mhlo_to_thlo.cc | 465 ------ .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 16 - .../tests/Dialect/gml_st/add_debug_info.mlir | 22 - .../tests/Dialect/gml_st/bufferization.mlir | 303 ---- .../tests/Dialect/gml_st/collapse-shape.mlir | 288 ---- .../tests/Dialect/gml_st/collect_stats.mlir | 77 - .../gml_st/compose_extract_insert_slice.mlir | 45 - .../gml_st/cpu_tiling/batch_matmul.mlir | 35 - .../gml_st/cpu_tiling/conv_2d_nhwc_hwcf.mlir | 52 - .../tests/Dialect/gml_st/cpu_tiling/dot.mlir | 260 ---- .../gml_st/cpu_tiling/duplicate_fusions.mlir | 45 - .../Dialect/gml_st/cpu_tiling/fibonacci.mlir | 60 - .../gml_st/cpu_tiling/fusion_outlining.mlir | 130 -- .../cpu_tiling/fusion_planning_for_cpu.mlir | 480 ------ .../cpu_tiling/inline_fusion_clusters.mlir | 74 - .../gml_st/cpu_tiling/map_bcast_map.mlir | 37 - .../Dialect/gml_st/cpu_tiling/map_matmul.mlir | 46 - .../gml_st/cpu_tiling/map_reduce_map.mlir | 113 -- .../gml_st/cpu_tiling/map_reshape_map.mlir | 92 -- .../Dialect/gml_st/cpu_tiling/matmul.mlir | 180 --- .../Dialect/gml_st/cpu_tiling/reduce_1d.mlir | 51 - .../gml_st/cpu_tiling/reduce_1d_map.mlir | 35 - .../Dialect/gml_st/cpu_tiling/reduce_2d.mlir | 92 -- .../gml_st/cpu_tiling/reduce_window.mlir | 50 - .../Dialect/gml_st/cpu_tiling/reverse.mlir | 61 - .../Dialect/gml_st/cpu_tiling/scatter.mlir | 70 - .../tests/Dialect/gml_st/cpu_tiling/sort.mlir | 24 - .../Dialect/gml_st/cpu_tiling/transpose.mlir | 41 - .../tests/Dialect/gml_st/greedy_fusion.mlir | 490 ------ .../tests/Dialect/gml_st/invalid.mlir | 18 - .../tests/Dialect/gml_st/lower_vectors.mlir | 219 --- .../Dialect/gml_st/nested_tiling_softmax.mlir | 110 -- .../mlir_hlo/tests/Dialect/gml_st/ops.mlir | 26 - .../Dialect/gml_st/optimize_linalg_ops.mlir | 183 --- .../Dialect/gml_st/rewrite_forall_to_for.mlir | 62 - .../tests/Dialect/gml_st/tile_by_one.mlir | 77 - .../tests/Dialect/gml_st/tiling_softmax.mlir | 172 --- .../Dialect/gml_st/vectorize_for_cpu.mlir | 395 ----- .../Dialect/mhlo/legalize-mhlo-to-thlo.mlir | 314 ---- .../tests/Dialect/thlo/bufferize.mlir | 41 - .../tests/Dialect/thlo/canonicalize.mlir | 15 - .../mlir_hlo/tests/Dialect/thlo/invalid.mlir | 352 ----- .../tests/Dialect/thlo/legalize_sort.mlir | 203 --- .../xla/mlir_hlo/tests/Dialect/thlo/ops.mlir | 179 --- .../mlir_hlo/tests/Dialect/thlo/tiling.mlir | 416 ----- .../xla/xla/mlir_hlo/tests/scalarization.mlir | 586 ------- .../xla/xla/mlir_hlo/thlo/CMakeLists.txt | 17 - .../xla/xla/mlir_hlo/thlo/IR/CMakeLists.txt | 43 - .../xla/xla/mlir_hlo/thlo/IR/thlo_ops.cc | 1363 ----------------- .../xla/xla/mlir_hlo/thlo/IR/thlo_ops.h | 37 - .../xla/xla/mlir_hlo/thlo/IR/thlo_ops.td | 346 ----- .../mlir_hlo/thlo/interfaces/CMakeLists.txt | 27 - .../bufferizable_op_interface_impl.cc | 151 -- .../bufferizable_op_interface_impl.h | 29 - .../mlir_hlo/thlo/transforms/CMakeLists.txt | 38 - .../transforms/legalize_sort/legalize_sort.cc | 561 ------- .../xla/xla/mlir_hlo/thlo/transforms/passes.h | 46 - .../mlir_hlo/thlo/transforms/thlo_passes.td | 24 - .../tools/mlir-hlo-opt/CMakeLists.txt | 5 - .../tools/mlir-hlo-opt/mlir-hlo-opt.cc | 26 +- .../xla/mlir_hlo/transforms/CMakeLists.txt | 5 - .../xla/mlir_hlo/transforms/bufferize_pass.cc | 50 +- third_party/xla/xla/service/cpu/BUILD | 3 - .../xla/xla/service/cpu/cpu_compiler.cc | 3 +- .../service/cpu/hlo_xla_runtime_pipeline.cc | 55 +- 133 files changed, 44 insertions(+), 19921 deletions(-) delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/IR/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.td delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/README.md delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/canonicalization/optimize_linalg_ops.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/collect_stats/collect_stats.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_outlining.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/pack_matmul.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/remove_label.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_dot_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_mmt4d_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_pack_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_from_elements_op/rewrite_from_elements_op.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_scf_forall/rewrite_scf_forall.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.td delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tile_by_one.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/lower_vectors.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/utils/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.h delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.cc delete mode 100644 third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.h delete mode 100644 third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collect_stats.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/batch_matmul.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/conv_2d_nhwc_hwcf.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/dot.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/duplicate_fusions.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fibonacci.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_outlining.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_planning_for_cpu.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/inline_fusion_clusters.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reshape_map.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d_map.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_window.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/optimize_linalg_ops.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_forall_to_for.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tile_by_one.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/tiling.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/scalarization.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/IR/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.cc delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.h delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.td delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/transforms/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/transforms/passes.h delete mode 100644 third_party/xla/xla/mlir_hlo/thlo/transforms/thlo_passes.td diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 6c276dbedef1f2..92d62b34be8bf9 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,8 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup", "genrule") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -131,7 +131,6 @@ genrule( tfcompile_test_dep_configs = [ ("", "None"), ("_mlir_bridge", "Bridge"), - ("_mhlo_lowering", "HloLowering"), ] [ @@ -473,42 +472,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "tfcompile_test_mhlo_lowering", - srcs = ["tfcompile_test.cc"], - extra_copts = ["-DMHLO_LOWERING_TEST"], - tags = [ - "manual", - "no_mac", # TODO(b/228273415) - ], - deps = [ - ":test_graph_tfadd_mhlo_lowering", - ":test_graph_tfadd_with_ckpt_mhlo_lowering", - ":test_graph_tfadd_with_ckpt_saver_mhlo_lowering", - ":test_graph_tfassert_eq_mhlo_lowering", - ":test_graph_tfcond_mhlo_lowering", - ":test_graph_tffunction_mhlo_lowering", - ":test_graph_tfgather_mhlo_lowering", - ":test_graph_tfmatmul_mhlo_lowering", - ":test_graph_tfmatmulandadd_mhlo_lowering", - ":test_graph_tfsplits_mhlo_lowering", - ":test_graph_tftop_k_mhlo_lowering", - ":test_graph_tfvariable_mhlo_lowering", - ":test_graph_tfvariable_readonly_mhlo_lowering", - ":test_graph_tfvariable_sequential_updates_mhlo_lowering", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:regexp", - "@com_google_absl//absl/strings", - "@eigen_archive//:eigen3", - "@local_xla//xla:shape_util", - "@local_xla//xla:test", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/service:hlo_profile_printer", - ], -) - tf_cc_test( name = "tfcompile_test_mlir_bridge", srcs = ["tfcompile_test.cc"], diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index f056533d1b21e6..a543aae5b92997 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -329,16 +329,7 @@ def _tf_library( "@local_xla//xla/service/cpu:runtime_single_threaded_conv2d", "@local_xla//xla/service/cpu:runtime_single_threaded_matmul", "@eigen_archive//:eigen3", - ] or []) + ( - mlir_components.count("HloLowering") > 0 and [ - "@local_xla//xla/runtime:aot_ffi_c_symbols", - "@local_xla//xla/service/cpu:runtime_mlir_utils", - ] or [] - ) + ( - include_standard_runtime_deps and mlir_components == "HloLowering" and [ - "@local_xla//xla/service/cpu/runtime:retain", - ] or [] - ) + (deps or []), + ] or []) + (deps or []), tags = tags, copts = copts, ) @@ -559,31 +550,6 @@ def tf_library( copts, xla_flags, ) - if mlir_components == "None": - _tf_library( - name + "_mlir", - graph, - config, - debug_info, - freeze_checkpoint, - freeze_saver, - cpp_class, - gen_test, - gen_benchmark, - gen_compiler_log, - visibility, - testonly, - tfcompile_flags, - tfcompile_tool, - include_standard_runtime_deps, - enable_xla_hlo_profiling, - enable_tracemes, - "HloLowering", - deps, - tags + ["notap", "local", "manual"], - copts, - xla_flags, - ) def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index ab43fc214e039d..4abd2607d6b183 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -521,8 +521,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:gml_st", - "@local_xla//xla/mlir_hlo:gml_st_passes", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:print_stream_pass", ], diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index 1ae3e8f1c54d31..a07558bac45f77 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -33,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" -#include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h" -#include "xla/mlir_hlo/gml_st/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime @@ -46,7 +44,6 @@ int main(int argc, char **argv) { mlir::registerTensorFlowPasses(); - mlir::gml_st::registerGmlStPasses(); tensorflow::mlrt_compiler::RegisterMlrtPasses(); tensorflow::ifrt_serving::RegisterTfIfrtPasses(); @@ -54,7 +51,6 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); - registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index e6ce181074de7f..d391f35e9adf77 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -142,7 +142,6 @@ tf_cc_binary( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@local_xla//xla/mlir_hlo:all_passes", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@stablehlo//:register", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index 681896f2a235a7..178e899cb33a72 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h" #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -37,8 +36,7 @@ int main(int argc, char **argv) { mlir::stablehlo::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); - registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 7cf5ef8522bb23..7c2e9d45d12db9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -1,13 +1,13 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -155,7 +155,6 @@ cc_library( "@local_xla//xla:debug_options_flags", "@local_xla//xla:xla_proto_cc", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:lhlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:type_conversion", @@ -218,7 +217,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:lhlo", "@local_xla//xla/mlir_hlo:transforms_passes", ], diff --git a/third_party/xla/xla/mlir/backends/cpu/BUILD b/third_party/xla/xla/mlir/backends/cpu/BUILD index 7478d4ab70e0a7..b4623938eb2cba 100644 --- a/third_party/xla/xla/mlir/backends/cpu/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/BUILD @@ -1,5 +1,5 @@ -load("//xla:xla.bzl", "xla_cc_binary") load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("//xla:xla.bzl", "xla_cc_binary") package( default_visibility = ["//visibility:public"], @@ -20,12 +20,8 @@ xla_cc_binary( "//xla/mlir/backends/cpu/transforms:passes", "//xla/mlir/xla_cpu/ir:xla_cpu", "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:gml_st", - "//xla/mlir_hlo:gml_st_passes", - "//xla/mlir_hlo:gml_st_test_passes", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:thlo", "//xla/service/cpu:cpu_compiler", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:FuncExtensions", diff --git a/third_party/xla/xla/mlir/backends/cpu/xla-cpu-opt.cc b/third_party/xla/xla/mlir/backends/cpu/xla-cpu-opt.cc index dd939396fdfda3..88a8e5c8663b9a 100644 --- a/third_party/xla/xla/mlir/backends/cpu/xla-cpu-opt.cc +++ b/third_party/xla/xla/mlir/backends/cpu/xla-cpu-opt.cc @@ -24,20 +24,14 @@ limitations under the License. #include "stablehlo/dialect/Register.h" // from @stablehlo #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" -#include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h" -#include "xla/mlir_hlo/gml_st/transforms/passes.h" -#include "xla/mlir_hlo/gml_st/transforms/test_passes.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/mlir_hlo/thlo/IR/thlo_ops.h" int main(int argc, char **argv) { mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); - mlir::gml_st::registerGmlStPasses(); - mlir::gml_st::registerGmlStTestPasses(); mlir::bufferization::registerBufferizationPasses(); mlir::DialectRegistry registry; @@ -45,7 +39,6 @@ int main(int argc, char **argv) { mlir::stablehlo::registerAllDialects(registry); registry.insert(); mlir::func::registerAllExtensions(registry); diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index a1c6c941e9c6bd..f9ca7a0a2b0c69 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -750,7 +750,6 @@ cc_library( "mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc", "mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc", "mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc", - "mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc", "mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc", "mhlo/transforms/legalize_sort/legalize_sort.cc", "mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc", @@ -812,8 +811,6 @@ cc_library( ":mlir_hlo", ":shape_component_analysis", ":stablehlo_legalize_to_hlo", - ":thlo", - ":thlo_bufferizable_op_interface", ":type_conversion", ":unfuse_batch_norm", "@llvm-project//llvm:Support", @@ -1236,16 +1233,12 @@ cc_library( # shouldn't be. Ideally, this entire target should be removed. "deallocation/transforms/passes.h.inc", "lhlo/transforms/lmhlo_passes.h.inc", - "gml_st/transforms/passes.h.inc", - "thlo/transforms/thlo_passes.h.inc", "transforms/passes.h.inc", ], hdrs = [ "deallocation/transforms/passes.h", - "gml_st/transforms/passes.h", "lhlo/transforms/passes.h", "mhlo/transforms/passes.h", - "thlo/transforms/passes.h", "transforms/passes.h", ], strip_include_prefix = ".", @@ -1254,16 +1247,12 @@ cc_library( ":chlo_legalize_to_hlo", ":deallocation_passes", ":deallocation_passes_inc_gen", - ":gml_st_passes", - ":gml_st_passes_inc_gen", ":lhlo", ":lmhlo_pass_inc_gen", ":lmhlo_passes", ":mhlo_pass_inc_gen", ":mhlo_passes", ":stablehlo_legalize_to_hlo", - ":thlo_passes", - ":thlo_passes_inc_gen", ":transforms_passes", ":transforms_passes_inc_gen", ":userange_analysis", @@ -1307,15 +1296,10 @@ cc_library( deps = [ ":deallocation", ":deallocation_passes", - ":gml_st", - ":gml_st_bufferizable_op_interface", - ":gml_st_passes", ":lhlo", ":mhlo_passes", ":mlir_hlo", ":shape_component_analysis", - ":thlo", - ":thlo_bufferizable_op_interface", ":transforms_passes_inc_gen", ":type_conversion", ":userange_analysis", @@ -1393,7 +1377,6 @@ cc_library( strip_include_prefix = ".", visibility = ["//visibility:public"], deps = [ - ":gml_st_passes", ":gpu_transforms_passes_inc_gen", ":lhlo", ":mhlo_passes", @@ -1438,48 +1421,6 @@ cc_library( ], ) -gentbl_cc_library( - name = "gml_st_test_passes_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GmlStTest", - ], - "gml_st/transforms/test_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "gml_st/transforms/test_passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "gml_st_test_passes", - srcs = [ - "gml_st/transforms/test_passes.cc", - "gml_st/transforms/test_passes.h.inc", - ], - hdrs = ["gml_st/transforms/test_passes.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st_passes", - ":gml_st_test_passes_inc_gen", - ":gml_st_transforms", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - ], -) - gentbl_cc_library( name = "transforms_passes_inc_gen", compatible_with = get_compatible_with_portable(), @@ -1549,108 +1490,6 @@ cc_library( ], ) -cc_library( - name = "gml_st_passes", - srcs = [ - "gml_st/transforms/add_debug_info/add_debug_info.cc", - "gml_st/transforms/canonicalization/optimize_linalg_ops.cc", - "gml_st/transforms/collapse_shape/collapse_shape.cc", - "gml_st/transforms/collect_stats/collect_stats.cc", - "gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc", - "gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc", - "gml_st/transforms/cpu_tiling/fusion_outlining.cc", - "gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc", - "gml_st/transforms/cpu_tiling/pack_matmul.cc", - "gml_st/transforms/cpu_tiling/remove_label.cc", - "gml_st/transforms/cpu_tiling/transform_dot_for_cpu.cc", - "gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc", - "gml_st/transforms/cpu_tiling/transform_mmt4d_for_cpu.cc", - "gml_st/transforms/cpu_tiling/transform_pack_for_cpu.cc", - "gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc", - "gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc", - "gml_st/transforms/fusion/fusion.cc", - "gml_st/transforms/passes.h.inc", - "gml_st/transforms/peeling/peeling.cc", - "gml_st/transforms/rewrite_from_elements_op/rewrite_from_elements_op.cc", - "gml_st/transforms/rewrite_scf_forall/rewrite_scf_forall.cc", - "gml_st/transforms/scalarization/scalarization.cc", - "gml_st/transforms/tiling/tile_by_one.cc", - "gml_st/transforms/tiling/tiling.cc", - "gml_st/transforms/tiling_softmax/tiling_softmax.cc", - "gml_st/transforms/transforms.h", - "gml_st/transforms/vectorization/lower_vectors.cc", - "gml_st/transforms/vectorization/vectorization.cc", - "gml_st/transforms/vectorization/vectorize_for_cpu.cc", - "gml_st/utils/linalg_utils.cc", - "gml_st/utils/tensor_utils.cc", - ], - hdrs = [ - "gml_st/transforms/fusion/fusion.h", - "gml_st/transforms/passes.h", - "gml_st/transforms/peeling/peeling.h", - "gml_st/transforms/scalarization/scalarization.h", - "gml_st/transforms/tiling/tiling.h", - "gml_st/transforms/vectorization/vectorization.h", - "gml_st/utils/linalg_utils.h", - "gml_st/utils/tensor_utils.h", - ], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st", - ":gml_st_passes_inc_gen", - ":gml_st_transforms", - ":lhlo", - ":mlir_hlo", - ":thlo", - ":type_conversion", - "@llvm-project//llvm:BinaryFormat", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LinalgUtils", - "@llvm-project//mlir:LoopLikeInterface", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MemRefUtils", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", - "@llvm-project//mlir:TensorTilingInterfaceImpl", - "@llvm-project//mlir:TensorTransforms", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:VectorToSCF", - "@llvm-project//mlir:VectorTransforms", - "@llvm-project//mlir:X86VectorTransforms", - "@stablehlo//:chlo_ops", - ], -) - CAPI_HEADERS = [ "bindings/c/Attributes.h", "bindings/c/Dialects.h", @@ -1715,13 +1554,9 @@ cc_binary( deps = [ ":all_passes", ":deallocation", - ":gml_st", - ":gml_st_passes", - ":gml_st_test_passes", ":hlo_dialect_registration", ":lhlo", ":lhlo_gpu", - ":thlo", ":transforms_gpu_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllExtensions", @@ -1764,259 +1599,6 @@ filegroup( visibility = ["//visibility:public"], ) -td_library( - name = "gml_st_ops_td_files", - srcs = glob(["gml_st/IR/*.td"]), - compatible_with = get_compatible_with_portable(), - includes = ["."], - deps = [ - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", - "@llvm-project//mlir:DialectUtilsTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "gml_st_ops_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "gml_st/IR/gml_st_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "gml_st/IR/gml_st_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "gml_st/IR/gml_st_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "gml_st/IR/gml_st_dialect.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "gml_st/IR/gml_st_ops.td", - deps = [":gml_st_ops_td_files"], -) - -cc_library( - name = "gml_st", - srcs = ["gml_st/IR/gml_st_ops.cc"], - hdrs = ["gml_st/IR/gml_st_ops.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st_ops_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:LoopLikeInterface", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:ViewLikeInterface", - ], -) - -cc_library( - name = "gml_st_bufferizable_op_interface", - srcs = ["gml_st/interfaces/bufferizable_op_interface_impl.cc"], - hdrs = ["gml_st/interfaces/bufferizable_op_interface_impl.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Support", - ], -) - -gentbl_cc_library( - name = "gml_st_passes_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GmlSt", - ], - "gml_st/transforms/passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "gml_st/transforms/passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "gml_st_transforms", - srcs = ["gml_st/transforms/transforms.cc"], - hdrs = ["gml_st/transforms/transforms.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - ], -) - -td_library( - name = "thlo_ops_td_files", - srcs = glob(["thlo/IR/*.td"]), - compatible_with = get_compatible_with_portable(), - includes = ["."], - deps = [ - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "thlo_ops_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "thlo/IR/thlo_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "thlo/IR/thlo_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "thlo/IR/thlo_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "thlo/IR/thlo_dialect.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "thlo/IR/thlo_ops.td", - deps = [ - ":thlo_ops_td_files", - "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", - "@llvm-project//mlir:TilingInterfaceTdFiles", - ], -) - -cc_library( - name = "thlo", - srcs = ["thlo/IR/thlo_ops.cc"], - hdrs = ["thlo/IR/thlo_ops.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":gml_st", - ":thlo_ops_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgUtils", - "@llvm-project//mlir:LoopLikeInterface", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:ViewLikeInterface", - ], -) - -cc_library( - name = "thlo_bufferizable_op_interface", - srcs = ["thlo/interfaces/bufferizable_op_interface_impl.cc"], - hdrs = ["thlo/interfaces/bufferizable_op_interface_impl.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":thlo", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:DestinationStyleOpInterface", - ], -) - -gentbl_cc_library( - name = "thlo_passes_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=AllThlo", - ], - "thlo/transforms/thlo_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "thlo/transforms/thlo_passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "thlo_passes", - srcs = [ - "thlo/transforms/legalize_sort/legalize_sort.cc", - "thlo/transforms/thlo_passes.h.inc", - ], - hdrs = [ - "thlo/transforms/passes.h", - ], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":thlo", - ":thlo_passes_inc_gen", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Transforms", - ], -) - # A light-weight runtime support library, used by MLIR code that results # after lowering some ops in the vector and sparse tensor dialects. cc_binary( diff --git a/third_party/xla/xla/mlir_hlo/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/CMakeLists.txt index cbbaa42f9c16bf..9bfdc58b3a3eb2 100644 --- a/third_party/xla/xla/mlir_hlo/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/CMakeLists.txt @@ -161,13 +161,11 @@ add_custom_target(check-mlir-hlo) add_subdirectory(analysis) add_subdirectory(bindings) add_subdirectory(deallocation) -add_subdirectory(gml_st) add_subdirectory(lhlo) add_subdirectory(lhlo_gpu) add_subdirectory(mhlo) add_subdirectory(stablehlo) add_subdirectory(tests) -add_subdirectory(thlo) add_subdirectory(tools) add_subdirectory(transforms) add_subdirectory(utils) diff --git a/third_party/xla/xla/mlir_hlo/bindings/c/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/bindings/c/CMakeLists.txt index 858b2b63dc9b28..d3f4158293adfc 100644 --- a/third_party/xla/xla/mlir_hlo/bindings/c/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/bindings/c/CMakeLists.txt @@ -5,14 +5,12 @@ add_mlir_public_c_api_library(MLIRHLOCAPIDialects Passes.cc LINK_LIBS PUBLIC MhloDialect - THLODialect # For AllMhLoPasses: ChloPasses MhloPasses MhloToArithmeticConversion MhloToMemrefConversion MhloToStandard - MhloToThloConversion MhloToLinalg MhloToStablehlo MhloShapeOpsToStandard diff --git a/third_party/xla/xla/mlir_hlo/gml_st/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/CMakeLists.txt deleted file mode 100644 index 47c038050ca6b2..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_subdirectory(IR) -add_subdirectory(interfaces) -add_subdirectory(transforms) -add_subdirectory(utils) diff --git a/third_party/xla/xla/mlir_hlo/gml_st/IR/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/IR/CMakeLists.txt deleted file mode 100644 index b7b4af6c0a3d55..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/IR/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(LLVM_TARGET_DEFINITIONS gml_st_ops.td) -mlir_tablegen(gml_st_ops.h.inc -gen-op-decls) -mlir_tablegen(gml_st_ops.cc.inc -gen-op-defs) -mlir_tablegen(gml_st_dialect.h.inc -gen-dialect-decls) -mlir_tablegen(gml_st_dialect.cc.inc -gen-dialect-defs) - -add_public_tablegen_target(MLIRgml_st_opsIncGen) -add_dependencies(mlir-headers MLIRgml_st_opsIncGen) - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_dialect_library(GmlStDialect - gml_st_ops.cc - - DEPENDS - MLIRgml_st_opsIncGen - - LINK_LIBS PUBLIC - MLIRArithUtils - MLIRControlFlowInterfaces - MLIRIR - MLIRMemRefDialect - MLIRSideEffectInterfaces - MLIRSupport - MLIRTensorDialect - MLIRVectorDialect -) diff --git a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc b/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc deleted file mode 100644 index ff65f6841aea3f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/IR/gml_st_ops.h" - -#include -#include - -#include "llvm/ADT/SetVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/RegionUtils.h" - -// Generated dialect definitions. -#include "gml_st/IR/gml_st_dialect.cc.inc" - -namespace mlir { -namespace gml_st { -namespace { - -//===----------------------------------------------------------------------===// -// GmlSt Dialect Interfaces -//===----------------------------------------------------------------------===// - -struct GmlStInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - // Operations in GmlSt dialect are always legal to inline since they are - // pure. - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - // Handle the given inlined terminator by replacing it with a new operation - // as necessary. Required when the region has only one block. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - auto yieldOp = dyn_cast(op); - if (!yieldOp) return; - - for (auto [valueToRepl, operand] : - llvm::zip(valuesToRepl, yieldOp.getOperands())) { - valueToRepl.replaceAllUsesWith(operand); - } - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// GmlStDialect -//===----------------------------------------------------------------------===// - -void GmlStDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "gml_st/IR/gml_st_ops.cc.inc" - >(); - addInterfaces(); -} - -//===----------------------------------------------------------------------===// -// FusionOp -//===----------------------------------------------------------------------===// - -YieldOp FusionOp::getTerminator() { - return cast(getBody()->getTerminator()); -} - -void FusionOp::print(OpAsmPrinter &p) { - p << " "; - if (!getInputs().empty()) { - p << "ins("; - llvm::interleaveComma( - llvm::zip(getBody()->getArguments(), getInputs()), p, [&](auto it) { - Value inputRegionArg, input; - std::tie(inputRegionArg, input) = it; - p << inputRegionArg << " = " << input << ": " << input.getType(); - }); - p << ") "; - } - - if (!getInits().empty()) { - p << "inits("; - llvm::interleaveComma( - llvm::zip(getBody()->getArguments().drop_front(getInputs().size()), - getInits()), - p, [&](auto it) { - Value inputRegionArg, input; - std::tie(inputRegionArg, input) = it; - p << inputRegionArg << " = " << input << ": " << input.getType(); - }); - p << ") "; - } - - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); - - p.printOptionalAttrDict(getOperation()->getAttrs(), - {getOperandSegmentSizesAttrName()}); - - if (!getResultTypes().empty()) { - p << " : "; - llvm::interleave(getResultTypes(), p, ", "); - } -} - -ParseResult FusionOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operands, regionOperands; - SmallVector operandTypes; - - auto parseElt = [&]() -> ParseResult { - if (parser.parseOperand(regionOperands.emplace_back(), - /*allowResultNumber=*/false) || - parser.parseEqual()) { - return failure(); - } - if (parser.parseOperand(operands.emplace_back()) || parser.parseColon() || - parser.parseType(operandTypes.emplace_back())) { - return failure(); - } - return success(); - }; - - size_t numInputs = 0, numInits = 0; - if (succeeded(parser.parseOptionalKeyword("ins"))) { - if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt)) - return failure(); - } - numInputs = operands.size(); - - if (succeeded(parser.parseOptionalKeyword("inits"))) { - if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt)) - return failure(); - } - numInits = operands.size() - numInputs; - - SMLoc loc = parser.getCurrentLocation(); - if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) - return failure(); - - // Parse region. - SmallVector regionArgs; - for (auto argAndType : llvm::zip(regionOperands, operandTypes)) { - auto &arg = regionArgs.emplace_back(); - std::tie(arg.ssaName, arg.type) = argAndType; - } - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) return failure(); - - // Parse attributes. - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - // Parser result types. - if (parser.parseOptionalColonTypeList(result.types)) return failure(); - - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {static_cast(numInputs), static_cast(numInits)})); - - return success(); -} - -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -LogicalResult YieldOp::verify() { return success(); } - -} // namespace gml_st -} // namespace mlir - -// Generated op classes. -#define GET_OP_CLASSES -#include "gml_st/IR/gml_st_ops.cc.inc" diff --git a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.h b/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.h deleted file mode 100644 index 55ade3e4c4a76a..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the GML ST dialect. - -#ifndef MLIR_HLO_GML_ST_IR_GML_ST_OPS_H -#define MLIR_HLO_GML_ST_IR_GML_ST_OPS_H - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -// Generated dialect declarations. -#include "gml_st/IR/gml_st_dialect.h.inc" - -// Generated operation classes. -#define GET_OP_CLASSES -#include "gml_st/IR/gml_st_ops.h.inc" - -#endif // MLIR_HLO_GML_ST_IR_GML_ST_OPS_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.td b/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.td deleted file mode 100644 index b19d067027d752..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops.td +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the operation definition file for ST ops. - -#ifndef GML_ST_OPS -#define GML_ST_OPS - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "gml_st/IR/gml_st_ops_base.td" - -/////////////////////////////////////////////////////////////////////////////// -// Ops -/////////////////////////////////////////////////////////////////////////////// - -def GMLST_FusionOp : GMLST_Op<"fusion", [ - AttrSizedOperandSegments, - DestinationStyleOpInterface, - IsolatedFromAbove, - SingleBlockImplicitTerminator<"gml_st::YieldOp"> - ]> { - let summary = "A cluster of operations to be tiled and fused."; - - let arguments = (ins Variadic:$inputs, - Variadic:$inits, - OptionalAttr:$parallel_tile_sizes, - OptionalAttr:$reduction_tile_sizes); - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$region); - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 0; - - code extraClassDeclaration = [{ - /// Return terminator of the region body. - YieldOp getTerminator(); - - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitsMutable(); - } - }]; -} - -def GMLST_YieldOp : GMLST_Op<"yield", [Pure, ReturnLike, Terminator, - HasParent<"::mlir::gml_st::FusionOp">]>, - Arguments<(ins Variadic:$values)> { - let summary = "Yield operation"; - let description = [{ - `gml_st.yield` is a special terminator operation for accumulator regions of - `gml_st.set_yield` and `gml_st.fusion` region. - - Example: - - ```mlir - gml_st.yield %f0: tensor - ``` - }]; - let assemblyFormat = "attr-dict $values `:` type($values)"; -} - -#endif // GML_ST_OPS diff --git a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td b/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td deleted file mode 100644 index 64364a36f77361..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef GML_ST_OPS_BASE -#define GML_ST_OPS_BASE - -include "mlir/Dialect/Utils/StructuredOpsUtils.td" -include "mlir/IR/EnumAttr.td" -include "mlir/IR/OpBase.td" - -def GmlSt_Dialect : Dialect { - let name = "gml_st"; - let cppNamespace = "::mlir::gml_st"; - let description = [{ - The GmlSt (Google ML Structured) dialect is intended to hold operations, - types and transformations to assist structured code generation. - }]; - let usePropertiesForAttributes = 0; -} - -class GMLST_Op traits> : - Op { - let hasVerifier = 1; -} - -#endif // GML_ST_OPS_BASE diff --git a/third_party/xla/xla/mlir_hlo/gml_st/README.md b/third_party/xla/xla/mlir_hlo/gml_st/README.md deleted file mode 100644 index e121d7c7e86fa5..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/README.md +++ /dev/null @@ -1,173 +0,0 @@ -# Google ML Structured Dialect - -The `gml_st` dialect will contain a loop-like construct and subset operations -that should allow support for fusion beyond rectangular tiles. This is necessary -for operations like `gather`, `scatter`, `concat` and more. - -## Overview -### Tiling and fusion - -Tiling of an op is performed by creating a loop that computes subsets of the -result. Usually the tiling is needed to enable vectorization or distribution. - -Before tiling - -``` -%0 = op(%input) -``` - -After tiling - -``` -loop (%ivs) - %1 = subset(%input, %ivs) - %2 = op (%1) -``` - -Fusion of a producer op into a tiled consumer consists of two main parts: -computing subsets of producer's operands and moving the producer op into the -loop body so that it operates on the subsets of its original operands. - -After consumer tiling -``` -%0 = producer (%input) -loop (%ivs) - %1 = subset(%0, %ivs) - %2 = consumer(%1) -``` - -After producer fusion - -``` -loop (%ivs) - %0 = subset(%input, %ivs) - %1 = producer(%0) - %2 = consumer (%1) -``` - -There is some duality between tiling and fusion. One can consider tiling as -fusion of the op into a loop that partitions the iteration space and just -returns identity for every subset. On the other hand, fusion can be seen as -tiling of the producer and then merging of the loop bodies. - -### Subset operations - -Linalg has support for hyperrectangular subsets (tiles) of tensor/memref -operands. Currently, Linalg's fusion assumes that the tiling is performed only -using `tensor.extract_slice/tensor.insert_slice` and `memref.subview` -operations. -There are several disadvantages to that approach: - -If some of the operands are not affected by tiling, i.e. the tiling was -performed along dimensions that are not present in the operand, then we cannot -fuse anymore the producer of the operand. That can happen when `linalg.generic` -broadcasts one of the operands or when the output is tiled, but not the -reduction dimensions - -Support for fusion with ops like `gather`, `scatter`, `concat` for some of the -cases can only be done via `TilingInterface` -([RFC](https://llvm.discourse.group/t/rfc-for-tilinginterface-for-tiling-operations-that-dont-fit-into-linalg-structured-operation-definition/3897/7)). - -**Example of a tiled op** - -``` -%sum = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c80, %c60) step (%c4, %c4) - ins (%in_ = %in: tensor<80x60xf32>, %cst_ = %cst: f32) - outs (%out_ = %out: tensor<80xf32>) - iterators["parallel", "reduction"] { - %in_sub = tensor.extract_slice %in_[%i, %j] [4, 4] [1, 1] - : tensor<80x60xf32> to tensor<4x4xf32> - %out_sub = tensor.extract_slice %out_[%i] [4] [1] - : tensor<80xf32> to tensor<4xf32> - %reduction = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%in_sub : tensor<4x4xf32>) - outs(%out_sub : tensor<4xf32>) { - ^bb0(%a: f32, %b: f32): - %0 = arith.addf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor<4xf32> - %update = tensor.insert_slice %reduction into %out_[%i] [4] [1] - : tensor<4xf32> into tensor<80xf32> - linalg.yield %update : tensor<80xf32> -} -``` - -The body of this loop models read-modify-write of the output tensor. The tile -that we extract from `%out_` should have the same sizes/offsets/strides as the -destination of `tensor.insert_slice`. The arguments of `tensor.extract_slice` -and `tensor.insert_slice` are currently not required to encode the same tile. - -We introduce new operations that define subsets on tensors/memrefs - - * `subset.full %tensor` - the subset spans the original tensor fully - * `subset.tile %tensor [%offsets][%sizes][%strides]` - defines a rectangular - tile - * `subset.filter %tensor[%indices]` - the subset has the same shape as the - original tensor, but only the values at %indices are populated. This can be a - sparse tensor. - * `subset.point %tensor[%index]` - the subset contains a single element - -### Structured loop - -We introduce `gml_st.loop` that keeps the subset definition separately from the -materialization. - -`linalg.generic` has `AffineMap` attributes that specify the indexing maps and a -region that models the computation on the element types of the operand -tensors/memrefs. The region ends with `linalg.yield` terminator that yields the -element of the output. The load and store ops in that case are implicit, so -are extraction/insertion in `gml_st.loop`. - -`gml_st.loop` has one region that contains subset operations to define the -dense/sparse ranges that we are working with and also `gml_st.materialize` ops -to convert subset spec to a tensor or memref. - -`gml_st.yield` is the terminator for `gml_st.loop` that takes computed tensors -and a subset specification for which the computation was done. Note that this -way we don't have to explicitly write a destructive update with -`tensor.insert_slice` and then yield a full tensor. Here, we yield values for a -subset. - - -``` -%sum = gml_st.loop (%i, %j) = (%c0, %c0) to (%c80, %c60) step (%c4, %c4) - ins (%in_ = %in: tensor<80x60xf32>, %cst_ = %cst: f32) - outs (%out_ = %out: tensor<80xf32>) - iterators["parallel", "sequential"] { - %in_tile = gml_st.tile %in_[%i, %j] [4, 4] [1, 1] - : tensor<80x60xf32> to !gml_st.subset<4x4xf32> - %out_tile = gml_st.tile %out_[%i] [4] [1] - : tensor<80xf32> to !gml_st.subset<4xf32> - - %in_sub = gml_st.materialize %in_tile - : !gml_st.subset<4x4xf32> to tensor<4x4xf32> - %out_sub = gml_st.materialize %in_tile - : !gml_st.subset<4xf32> to tensor<4xf32> - %reduction = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%in_sub : tensor<4x4xf32>) - outs(%out_sub : tensor<4xf32>) { - ^bb0(%a: f32, %b: f32): - %0 = arith.addf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor<4xf32> - gml_st.yield %reduction to %out_tile - : tensor<4xf32> to !gml_st.subset<4xf32> -} -``` - -Currently, tiling of the consumer and fusion of its producers are tightly -coupled. If the fusion is happening not in the same pass, then some analysis is -required to find the [consumer - `tensor.extract_slice` - producer] triple to -perform the fusion. Keeping the subset computations separately from the -"compute" ops not only improves readability but also simplifies fusion, since we -have a subset computation per operand and we can just specify what argument of -the loop we want to fuse. - -It also simplifies the bufferization, since we don't need to introduce the -additional operations in MemRef dialect for every subset operation in TensorOps. diff --git a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt deleted file mode 100644 index da038529bba406..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -set(LLVM_OPTIONAL_SOURCES - bufferizable_op_interface_impl.cc -) - -add_mlir_library(GmlStBufferizableOpInterface - bufferizable_op_interface_impl.cc - - LINK_LIBS PUBLIC - GmlStDialect - MLIRBufferizationDialect - MLIRBufferizationTransforms - MLIRDialectUtils - MLIRIR - MLIRSupport -) diff --git a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc b/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc deleted file mode 100644 index 398d78c9cd9961..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc +++ /dev/null @@ -1,216 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/interfaces/bufferizable_op_interface_impl.h" - -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace gml_st { -namespace { - -using mlir::bufferization::AliasingOpOperandList; -using mlir::bufferization::AliasingValueList; -using mlir::bufferization::AnalysisState; -using mlir::bufferization::BufferizableOpInterface; -using mlir::bufferization::BufferizationOptions; -using mlir::bufferization::BufferRelation; - -struct FusionOpBufferizationInterface - : public BufferizableOpInterface::ExternalModel< - FusionOpBufferizationInterface, FusionOp> { - bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/, - const AnalysisState & /*state*/) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState & /*state*/) const { - return cast(op).isDpsInit(&opOperand); - } - - AliasingOpOperandList getAliasingOpOperands( - Operation *op, Value value, const AnalysisState & /*state*/) const { - auto fusionOp = cast(op); - auto opResult = value.dyn_cast(); - if (!opResult) return {}; - - // The i-th OpResult aliases with the i-th "out" tensor. - return {{fusionOp.getDpsInitOperand(opResult.getResultNumber()), - BufferRelation::Equivalent}}; - } - - AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, - const AnalysisState & /*state*/) const { - auto fusionOp = cast(op); - - // The i-th "out" tensor aliases with the i-th OpResult. - if (fusionOp.isDpsInit(&opOperand)) { - return { - {fusionOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}; - } - return {}; - } - - bool isWritable(Operation * /*op*/, Value /*value*/, - const AnalysisState & /*state*/) const { - return true; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - - auto loc = op->getLoc(); - FusionOp fusionOp = cast(op); - - // Nothing to do. This op is already bufferized. - if (fusionOp.hasBufferSemantics()) return success(); - - if (!fusionOp.hasTensorSemantics()) { - return op->emitError() << "expected either buffer or tensor semantics"; - } - - size_t numOutputs = fusionOp.getNumDpsInits(); - - // New operands for the cloned op. - SmallVector newOperands; - newOperands.reserve(fusionOp.getNumDpsInputs() + numOutputs); - - for (OpOperand *opOperand : fusionOp.getDpsInputOperands()) { - if (fusionOp.isScalar(opOperand)) { - newOperands.push_back(opOperand->get()); - continue; - } - FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); - if (failed(buffer)) return failure(); - newOperands.push_back(*buffer); - } - - // New output operands for the cloned op. - SmallVector newOutputs; - newOutputs.reserve(numOutputs); - - for (OpResult opResult : fusionOp->getOpResults()) { - OpOperand *opOperand = - fusionOp.getDpsInitOperand(opResult.getResultNumber()); - FailureOr resultBuffer = - getBuffer(rewriter, opOperand->get(), options); - if (failed(resultBuffer)) return failure(); - newOutputs.push_back(*resultBuffer); - } - - newOperands.append(newOutputs.begin(), newOutputs.end()); - - // Set insertion point now that potential alloc/dealloc are introduced. - rewriter.setInsertionPoint(op); - - // Clone the op, but use the new operands. Move the existing block into the - // new op. Since the new op does not have any tensor results, it does not - // return anything. - auto newFusionOp = cast(cloneWithoutRegions( - rewriter, op, /*resultTypes=*/TypeRange{}, newOperands)); - - // Create empty region in the new bufferized op. - Region ®ion = newFusionOp.getRegion(); - SmallVector blockArgTypes = - llvm::to_vector(TypeRange(ValueRange(newOperands))); - SmallVector blockArgLocs(blockArgTypes.size(), loc); - rewriter.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - - ArrayRef bbArgs = - newFusionOp.getRegion().front().getArguments(); - SmallVector bbArgsToTensors; - for (auto buf : bbArgs) { - if (isa(buf.getType())) { - Value tensor = rewriter.create(loc, buf); - bbArgsToTensors.push_back(tensor); - } else { - bbArgsToTensors.push_back(buf); - } - } - - // Move old body into new fusion op. - rewriter.mergeBlocks(fusionOp.getBody(), newFusionOp.getBody(), - bbArgsToTensors); - - // Copy results to output memrefs. In most of the cases it's not necessary, - // because clusters are constructed in a way that the result is produced by - // an dst-style op that already put everything in the output memrefs, but - // there are corner cases when it doesn't happen. For example, tiled 1d - // linalg.reduce. - rewriter.setInsertionPoint(newFusionOp.getTerminator()); - for (auto [bbArg, resultValue] : - llvm::zip(bbArgs.take_back(numOutputs), - newFusionOp.getTerminator().getValues())) { - if (auto toTensorOp = - resultValue.getDefiningOp()) { - rewriter.create(loc, toTensorOp.getMemref(), bbArg); - } - } - - // Replace gml_st.yield values with output buffers. - rewriter.replaceOpWithNewOp(newFusionOp.getTerminator(), - bbArgs.take_back(numOutputs)); - - // Replace the results of the old op with the new output buffers. - bufferization::replaceOpWithBufferizedValues(rewriter, op, newOutputs); - - return success(); - } - - FailureOr getBufferType( - Operation *op, Value value, const BufferizationOptions &options, - SmallVector &invocationStack) const { - auto fusionOp = cast(op); - - if (auto bbArg = value.dyn_cast()) { - // A tensor block argument has the same bufferized type as the - // corresponding output operand. - return bufferization::getBufferType( - fusionOp->getOpOperand(bbArg.getArgNumber()).get(), options, - invocationStack); - } - - // The bufferized result type is the same as the bufferized type of the - // corresponding output operand. - return bufferization::getBufferType( - fusionOp.getDpsInitOperand(value.cast().getResultNumber()) - ->get(), - options, invocationStack); - } -}; - -} // namespace -} // namespace gml_st -} // namespace mlir - -void mlir::gml_st::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addExtension( - +[](MLIRContext *ctx, gml_st::GmlStDialect * /*dialect*/) { - FusionOp::attachInterface(*ctx); - }); -} diff --git a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h b/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h deleted file mode 100644 index 54739c5b3acc76..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H -#define MLIR_HLO_GML_ST_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H - -namespace mlir { -class DialectRegistry; - -namespace gml_st { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace gml_st -} // namespace mlir - -#endif diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt deleted file mode 100644 index 5a96a1a77827f0..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt +++ /dev/null @@ -1,111 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set(LLVM_TARGET_DEFINITIONS passes.td) -mlir_tablegen(passes.h.inc -gen-pass-decls -name GmlSt) -add_public_tablegen_target(MLIRGmlStPassIncGen) - -set(LLVM_TARGET_DEFINITIONS test_passes.td) -mlir_tablegen(test_passes.h.inc -gen-pass-decls -name GmlStTest) -add_public_tablegen_target(MLIRGmlStTestPassIncGen) - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_library(GmlStPasses - add_debug_info/add_debug_info.cc - canonicalization/optimize_linalg_ops.cc - collapse_shape/collapse_shape.cc - collect_stats/collect_stats.cc - compose_extract_insert_slice/compose_extract_insert_slice.cc - cpu_tiling/cpu_tiling_pipeline.cc - cpu_tiling/fusion_outlining.cc - cpu_tiling/fusion_planning_for_cpu.cc - cpu_tiling/pack_matmul.cc - cpu_tiling/remove_label.cc - cpu_tiling/transform_dot_for_cpu.cc - cpu_tiling/transform_elementwise_for_cpu.cc - cpu_tiling/transform_mmt4d_for_cpu.cc - cpu_tiling/transform_pack_for_cpu.cc - cpu_tiling/transform_reduce_for_cpu.cc - cpu_tiling/transform_scatter_for_cpu.cc - fusion/fusion.cc - peeling/peeling.cc - rewrite_from_elements_op/rewrite_from_elements_op.cc - rewrite_scf_forall/rewrite_scf_forall.cc - scalarization/scalarization.cc - tiling/tile_by_one.cc - tiling/tiling.cc - tiling_softmax/tiling_softmax.cc - vectorization/lower_vectors.cc - vectorization/vectorization.cc - vectorization/vectorize_for_cpu.cc - - DEPENDS - MLIRGmlStPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRAffineDialect - MLIRArithDialect - MLIRDestinationStyleOpInterface - MLIRDialectUtils - MLIRFuncDialect - MLIRGmlStUtils - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRMemRefDialect - MLIRPass - MLIRSCFUtils - MLIRSupport - MLIRVectorDialect - MLIRVectorToSCF - MLIRX86VectorTransforms - MhloDialect -) - -add_mlir_library(GmlStTransforms - transforms.cc - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - GmlStDialect - MLIRAffineDialect - MLIRDialectUtils - MLIRIR -) - -add_mlir_library(GmlStTestPasses - test_passes.cc - - DEPENDS - MLIRGmlStTestPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - GmlStDialect - GmlStTransforms - MLIRPass - MLIRTransforms -) diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc deleted file mode 100644 index 8cf648cb9b2b58..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "llvm/BinaryFormat/Dwarf.h" -#include "llvm/Support/Path.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_ADDDEBUGINFOPASS -#include "gml_st/transforms/passes.h.inc" - -struct AddDebugInfoPass : public impl::AddDebugInfoPassBase { - void runOnOperation() override { - auto module = getOperation(); - auto *context = &getContext(); - OpBuilder builder(context); - std::string inputFilePath("-"); - - if (auto fileLoc = module.getLoc().dyn_cast()) - inputFilePath = fileLoc.getFilename().getValue(); - - auto fileAttr = - LLVM::DIFileAttr::get(context, llvm::sys::path::filename(inputFilePath), - llvm::sys::path::parent_path(inputFilePath)); - - auto producer = StringAttr::get(context, "XLA CPU"); - auto cuAttr = LLVM::DICompileUnitAttr::get( - context, llvm::dwarf::DW_LANG_C_plus_plus_17, fileAttr, producer, - /*isOptimized=*/false, LLVM::DIEmissionKind::LineTablesOnly); - module.walk([&](func::FuncOp funcOp) { - StringAttr funcName = StringAttr::get(context, funcOp.getName()); - auto bT = LLVM::DIBasicTypeAttr::get( - context, llvm::dwarf::DW_TAG_base_type, "void", /*sizeInBits=*/0, - /*encoding=*/1); - auto subTypeAttr = LLVM::DISubroutineTypeAttr::get( - context, llvm::dwarf::DW_CC_normal, {bT}); - auto spAttr = LLVM::DISubprogramAttr::get( - context, cuAttr, fileAttr, funcName, funcName, fileAttr, /*line=*/1, - /*scopeline=*/1, LLVM::DISubprogramFlags::Definition, subTypeAttr); - funcOp->setLoc(builder.getFusedLoc({funcOp->getLoc()}, spAttr)); - }); - } -}; -} // namespace - -std::unique_ptr> createAddDebugInfoPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/canonicalization/optimize_linalg_ops.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/canonicalization/optimize_linalg_ops.cc deleted file mode 100644 index f7b29a7cb1f1bf..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/canonicalization/optimize_linalg_ops.cc +++ /dev/null @@ -1,217 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_OPTIMIZELINALGOPSPASS -#include "gml_st/transforms/passes.h.inc" - -std::optional getSplatValue(PatternRewriter& rewriter, Location loc, - Value value) { - auto* definingOp = value.getDefiningOp(); - if (!definingOp) return std::nullopt; - - if (auto constantOp = dyn_cast_or_null(definingOp)) { - auto denseElementsAttr = - constantOp.getValue().dyn_cast(); - - if (!denseElementsAttr.isSplat()) return std::nullopt; - - auto splatAttr = denseElementsAttr.getSplatValue(); - auto splatType = denseElementsAttr.getElementType(); - - if (complex::ConstantOp::isBuildableWith(splatAttr, splatType)) - return rewriter.create(loc, splatType, - splatAttr.cast()); - - return rewriter.create(loc, cast(splatAttr)); - } - - if (auto fillOp = dyn_cast_or_null(definingOp)) - return fillOp.getInputs()[0]; - return std::nullopt; -} - -LogicalResult foldConstantOperandsIntoMap(linalg::MapOp op, - PatternRewriter& rewriter) { - auto loc = op->getLoc(); - SmallVector newInputs; - IRMapping mapping; - - for (auto [operand, bbArg] : - llvm::zip(op.getDpsInputOperands(), op.getBody()->getArguments())) { - auto constantValue = getSplatValue(rewriter, loc, operand->get()); - if (constantValue.has_value()) { - mapping.map(bbArg, *constantValue); - } else { - newInputs.push_back(operand->get()); - } - } - - // No constant operands found. - if (newInputs.size() == op.getInputs().size()) return failure(); - - auto newMapOp = rewriter.create(loc, op.getResultTypes(), - /*inputs=*/newInputs, - /*init=*/op.getInit()); - rewriter.cloneRegionBefore(op.getRegion(), newMapOp.getRegion(), - newMapOp.getRegion().begin(), mapping); - rewriter.replaceOp(op, newMapOp.getResults()); - - return success(); -} - -// Replace linalg.map with no inputs with an linalg.fill. -LogicalResult replaceConstantMapWithFill(linalg::MapOp op, - PatternRewriter& rewriter) { - // Only replace linalg.map that has no inputs. - if (!op.getInputs().empty()) return failure(); - - // linalg.index indicates that region result is not constant. - if (!op.getBody()->getOps().empty()) return failure(); - - // Move all ops outside of the region. It's safe, because this linalg.map has - // only implicit arguments. - for (Operation& regionOp : - llvm::make_early_inc_range(op.getBody()->without_terminator())) { - regionOp.moveBefore(op); - } - - // Get fill value from gml_st.yield operand. - auto yieldValue = op.getBody()->getTerminator()->getOperand(0); - - rewriter.replaceOpWithNewOp(op, yieldValue, op.getInit()); - return success(); -} - -// Replace linalg.broadcast(single_element_tensor) with linalg.fill. -LogicalResult replaceBroadcastWithFill(linalg::BroadcastOp op, - PatternRewriter& rewriter) { - Value input = op.getInput(); - auto inputType = dyn_cast(input.getType()); - if (!inputType) return failure(); - - Location loc = op.getLoc(); - Value scalar; - if (auto splatValue = getSplatValue(rewriter, loc, input)) { - scalar = *splatValue; - } else if (hasSingleElement(inputType)) { - SmallVector indicesInput( - inputType.getRank(), rewriter.create(loc, 0)); - scalar = rewriter.create(loc, input, indicesInput); - } - if (!scalar) return failure(); - rewriter.replaceOpWithNewOp(op, scalar, op.getInit()); - return success(); -} - -// Rewrite `tensor.extract_slice(op(arg1, ...))` into -// `op(tensor.extract_slice(arg1, ...))`. -LogicalResult rewriteExtractSliceOfTileableOp(Operation* op, - PatternRewriter& rewriter) { - auto tileableOp = dyn_cast(op); - if (!tileableOp) return failure(); - - // Support only ops with a single result for now. - if (op->getNumResults() != 1) return failure(); - auto result = op->getResult(0); - - // If the op has several uses, then it is not always beneficial to rewrite. - if (!result.hasOneUse()) return failure(); - auto sliceOp = dyn_cast(*result.getUsers().begin()); - // Check if the defining op and the slice op are located in the same block. - // Cases when they are not are covered by fusion. - if (!sliceOp || sliceOp->getBlock() != op->getBlock()) return failure(); - - rewriter.setInsertionPointAfter(sliceOp); - FailureOr tilingResult = - tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp, result); - - if (failed(tilingResult)) return failure(); - rewriter.replaceOp(sliceOp, tilingResult->tiledValues); - - return success(); -} - -LogicalResult rewriteExtractSliceOfReverseOp(thlo::ReverseOp reverseOp, - PatternRewriter& rewriter) { - return rewriteExtractSliceOfTileableOp(reverseOp, rewriter); -} - -struct RewriteExtractSliceOfLinalgOpPattern - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter& rewriter) const override { - return rewriteExtractSliceOfTileableOp(linalgOp, rewriter); - } -}; - -struct OptimizeLinalgOpsPass - : public impl::OptimizeLinalgOpsPassBase { - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = &getContext(); - - // Populate patterns. - RewritePatternSet patterns(ctx); - patterns.add(ctx); - patterns.add(foldConstantOperandsIntoMap); - patterns.add(replaceBroadcastWithFill); - patterns.add(replaceConstantMapWithFill); - patterns.add(rewriteExtractSliceOfReverseOp); - tensor::populateFoldTensorEmptyPatterns(patterns); - tensor::populateReassociativeReshapeFoldingPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -createOptimizeLinalgOpsPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc deleted file mode 100644 index 7da420ef6333c2..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc +++ /dev/null @@ -1,351 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/utils/linalg_utils.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_COLLAPSESHAPEPASS -#include "gml_st/transforms/passes.h.inc" - -// Creates reassociation indices for `shape_collapse` and `shape_expand` ops. -// Given `rank`(N) and `retainTrailingDims`(M), returns the following -// reassociation: -// [[0, 1, ..., N-M-1], [N-M], [N-M+1], ..., [N-1]] -// |--- retainTrailingDims ---| -// |-------------------- rank --------------------| -SmallVector getCollapsingReassociationIndices( - int64_t rank, int64_t retainTrailingDims) { - SmallVector reassociation; - reassociation.reserve(retainTrailingDims + 1); - if (rank > retainTrailingDims) { - auto seq = llvm::seq(0, rank - retainTrailingDims); - reassociation.emplace_back(seq.begin(), seq.end()); - } - for (int64_t i = rank - retainTrailingDims; i < rank; ++i) - reassociation.push_back({i}); - return reassociation; -} - -struct CollapseBcastPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - CollapseBcastPattern(MLIRContext* ctx, int64_t retainTrailingDims) - : OpRewritePattern(ctx), - retainTrailingDims(retainTrailingDims) {} - - LogicalResult matchAndRewrite(linalg::BroadcastOp op, - PatternRewriter& rewriter) const override { - Value init = op.getInit(); - auto initTy = init.getType().cast(); - int64_t initRank = initTy.getRank(); - int64_t numCollapsedDims = initRank - retainTrailingDims; - - if (numCollapsedDims < 2) { - return rewriter.notifyMatchFailure(op, "no dimension to collapse"); - } - - // Dimensions to be collapsed must either be all broadcasted or not - // broadcasted. - llvm::ArrayRef nonBroadcastedDims = op.getDimensions(); - - bool firstDimsBroadcasted = true; - if (!nonBroadcastedDims.empty()) { - int64_t i = 0; - while (i < (int64_t)nonBroadcastedDims.size() && - nonBroadcastedDims[i] == i && i < numCollapsedDims) { - ++i; - } - if (i >= numCollapsedDims) { - firstDimsBroadcasted = false; - } else if (llvm::any_of(nonBroadcastedDims, - [numCollapsedDims](unsigned dim) { - return dim < numCollapsedDims; - })) { - return rewriter.notifyMatchFailure( - op, "collapsed dims are not broadcasted in order"); - } - } - - Value operand = op.getInput(); - auto operandTy = operand.getType().cast(); - int64_t operandRank = operandTy.getRank(); - llvm::DenseSet nonBroadcastedDimsSet(nonBroadcastedDims.begin(), - nonBroadcastedDims.end()); - llvm::SmallVector collapsedNonBroadcastedDims; - collapsedNonBroadcastedDims.reserve(numCollapsedDims + - (firstDimsBroadcasted ? 1 : 0)); - for (int64_t dim = numCollapsedDims; dim < initRank; ++dim) { - if (nonBroadcastedDimsSet.contains(dim)) { - collapsedNonBroadcastedDims.push_back(dim - numCollapsedDims + 1); - } - } - int64_t operandRetainTrailingDims = - retainTrailingDims - collapsedNonBroadcastedDims.size(); - - // Collapse operand and init tensor. - // For bcasts, this retains the last `retainTrailingDims` dimensions of the - // *result* and collapses all others. - Location loc = op.getLoc(); - Value collapsedOperand = operand; - if (operandRank > operandRetainTrailingDims + 1) { - SmallVector operandReassociation = - getCollapsingReassociationIndices(operandRank, - operandRetainTrailingDims); - collapsedOperand = rewriter.createOrFold( - loc, operand, operandReassociation); - } - SmallVector initReassociation = - getCollapsingReassociationIndices(initRank, retainTrailingDims); - Value collapsedInit = - rewriter.create(loc, init, initReassociation); - - // Create collapsed bcast op. - if (!firstDimsBroadcasted) { - collapsedNonBroadcastedDims.push_back(0); - } - Value collapsedBcastOp = - rewriter - .create( - loc, collapsedOperand, collapsedInit, - ArrayRef(collapsedNonBroadcastedDims)) - .getResult() - .front(); - - // Re-expand broadcast op and replace the original. - auto reexpandedBcastOp = rewriter.create( - loc, initTy, collapsedBcastOp, initReassociation); - rewriter.replaceOp(op, reexpandedBcastOp.getResult()); - return success(); - } - - private: - int64_t retainTrailingDims; -}; - -struct CollapseReductionPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - CollapseReductionPattern(MLIRContext* ctx, int64_t retainTrailingDims) - : OpRewritePattern(ctx), - retainTrailingDims(retainTrailingDims) {} - - LogicalResult matchAndRewrite(linalg::ReduceOp op, - PatternRewriter& rewriter) const override { - if (op.getNumDpsInits() != 1 || op.getDimensions().empty()) - return failure(); - int64_t reductionDim = op.getDimensions()[0]; - - Value operand = op.getInputs().front(); - auto operandTy = operand.getType().cast(); - int64_t operandRank = operandTy.getRank(); - - if (operandRank <= retainTrailingDims + 1) { - return rewriter.notifyMatchFailure(op, "no dimension to collapse"); - } - - if (operandRank - 1 - reductionDim >= retainTrailingDims) { - return rewriter.notifyMatchFailure( - op, "reduction dimension must be retained"); - } - - Value init = op.getInits().front(); - auto initTy = init.getType().cast(); - int64_t initRank = initTy.getRank(); - - // Collapse operand and init tensor. - // For reductions, this retains the last `retainTrailingDims` dimensions of - // the *operand* and collapses all others. - Location loc = op.getLoc(); - SmallVector operandReassociation = - getCollapsingReassociationIndices(operandRank, retainTrailingDims); - Value collapsedOperand = rewriter.create( - loc, operand, operandReassociation); - SmallVector initReassociation = - getCollapsingReassociationIndices(initRank, retainTrailingDims - 1); - Value collapsedInit = - rewriter.create(loc, init, initReassociation); - - auto collapsedOperandTy = - collapsedOperand.getType().cast(); - int64_t collapsedOperandRank = collapsedOperandTy.getRank(); - auto collapsedInitTy = collapsedInit.getType().cast(); - - // Create collapsed reduction op. - int64_t collapsedReductionDim = - reductionDim - operandRank + collapsedOperandRank; - SmallVector collapsedIteratorTypes( - collapsedOperandRank, utils::IteratorType::parallel); - collapsedIteratorTypes[collapsedReductionDim] = - utils::IteratorType::reduction; - auto collapsedReductionOp = rewriter.create( - loc, collapsedInitTy, collapsedOperand, collapsedInit, - ArrayRef({collapsedReductionDim})); - collapsedReductionOp.getRegion().takeBody(op.getBodyRegion()); - - // Re-expand reduction op and replace the original. - auto reexpandedReductionOp = rewriter.create( - loc, initTy, collapsedReductionOp.getResults().front(), - initReassociation); - rewriter.replaceOp(op, reexpandedReductionOp.getResult()); - return success(); - } - - private: - int64_t retainTrailingDims; -}; - -linalg::MapOp createCollapsedMapOp( - linalg::MapOp mapOp, PatternRewriter& rewriter, - const SmallVector& reassociation) { - // Collapsed operands and init tensor. - Location loc = mapOp.getLoc(); - SmallVector collapsedOperands = llvm::to_vector( - llvm::map_range(mapOp.getInputs(), [&](Value it) -> Value { - return rewriter.create(loc, it, reassociation); - })); - Value init = mapOp.getInit(); - Value collapsedInit = - rewriter.create(loc, init, reassociation); - - // Create collapsed map op. - auto collapsedInitTy = collapsedInit.getType().cast(); - auto collapsedMapOp = rewriter.create( - loc, collapsedInitTy, collapsedOperands, collapsedInit); - IRMapping bvm; - mapOp.getBodyRegion().cloneInto(&collapsedMapOp.getRegion(), bvm); - return collapsedMapOp; -} - -struct CollapseMapPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - CollapseMapPattern(MLIRContext* ctx, int64_t retainTrailingDims) - : OpRewritePattern(ctx), - retainTrailingDims(retainTrailingDims) {} - - LogicalResult matchAndRewrite(linalg::MapOp op, - PatternRewriter& rewriter) const override { - Value init = op.getInit(); - auto initTy = init.getType().cast(); - int64_t rank = initTy.getRank(); - - if (rank <= retainTrailingDims + 1) { - return rewriter.notifyMatchFailure(op, "no dimension to collapse"); - } - - SmallVector reassociation = - getCollapsingReassociationIndices(rank, retainTrailingDims); - auto collapsedMapOp = createCollapsedMapOp(op, rewriter, reassociation); - - // Re-expand map op and replace the original. - auto reexpandedMapOp = rewriter.create( - op.getLoc(), initTy, collapsedMapOp.getResult().front(), reassociation); - rewriter.replaceOp(op, reexpandedMapOp.getResult()); - return success(); - } - - private: - int64_t retainTrailingDims; -}; - -struct MoveCollapseBeforeMapPattern - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - explicit MoveCollapseBeforeMapPattern(MLIRContext* ctx) - : OpRewritePattern(ctx) {} - - LogicalResult matchAndRewrite(tensor::CollapseShapeOp op, - PatternRewriter& rewriter) const override { - auto mapOp = op.getSrc().getDefiningOp(); - if (!mapOp) return failure(); - auto collapsedMapOp = - createCollapsedMapOp(mapOp, rewriter, op.getReassociationIndices()); - rewriter.replaceOp(op, collapsedMapOp.getResult()); - return success(); - } -}; - -struct CollapseShapePass - : public impl::CollapseShapePassBase { - using CollapseShapePassBase::CollapseShapePassBase; - - void getDependentDialects(DialectRegistry& registry) const override { - CollapseShapePassBase::getDependentDialects(registry); - - // TODO(frgossen): Move these iface implementations into the tensor dialect. - // Some of its canonicalizations depend on it. Until then, we have to - // register them explicitly. - tensor::registerInferTypeOpInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = &getContext(); - - // Populate shape-collapsing patterns for cwise ops, reductions, and bcasts. - RewritePatternSet patterns(ctx); - patterns.add(ctx, retainTrailingDims); - // By moving CollapseShapeOp before MapOp, we can potentially remove it if - // it cancels out with an ExpandShapeOp. - patterns.add(ctx); - - // Collect some related canonicalization patterns. - linalg::BroadcastOp::getCanonicalizationPatterns(patterns, ctx); - linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); - linalg::MapOp::getCanonicalizationPatterns(patterns, ctx); - linalg::ReduceOp::getCanonicalizationPatterns(patterns, ctx); - tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx); - tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); - tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, ctx); - tensor::populateFoldTensorEmptyPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> createCollapseShapePass() { - return std::make_unique(); -} - -std::unique_ptr> createCollapseShapePass( - const CollapseShapePassOptions& options) { - return std::make_unique(options); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/collect_stats/collect_stats.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/collect_stats/collect_stats.cc deleted file mode 100644 index 8c75eb08cc88b1..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/collect_stats/collect_stats.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/utils/tensor_utils.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_COLLECTSTATSPASS -#include "gml_st/transforms/passes.h.inc" - -using NameToOpMap = - std::unordered_map>; - -struct CollectStatsPass : public impl::CollectStatsPassBase { - using CollectStatsPassBase::CollectStatsPassBase; - - explicit CollectStatsPass(int64_t level) { detailLevel = level; } - - void runOnOperation() override { - if (detailLevel <= 0) return; - func::FuncOp func = getOperation(); - - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - - std::string key = op->getName().getStringRef().str(); - if (auto collapseShapeOp = dyn_cast(op)) { - key += isDegenerateReshapeOp(collapseShapeOp) ? " (degenerate)" - : " (non-degenerate)"; - } - map[key].push_back(op); - return WalkResult::advance(); - }); - - printStats(); - } - - private: - void printStats() { - llvm::outs() << "*** Tileable ops stats (detail level " << detailLevel - << ") ***\n"; - for (const auto &it : map) { - auto name = it.first; - auto ops = it.second; - llvm::outs() << ops.size() << "x " << name << "\n"; - // If we want the op name only, stop here. - if (detailLevel == 1) continue; - for (size_t i = 0; i < ops.size(); ++i) { - auto *op = ops[i]; - llvm::outs().indent(2) << i + 1 << ". "; - op->print(llvm::outs()); - llvm::outs() << '\n'; - // If we want the full op string only, stop here. - if (detailLevel == 2) continue; - // Otherwise print info about the producers and consumers of the op. - llvm::outs().indent(4) << "Producers:\n"; - for (auto operand : op->getOperands()) { - if (auto loopLikeProducer = - operand.getDefiningOp()) { - llvm::outs().indent(6) - << loopLikeProducer->getName().getStringRef() << '\n'; - } else { - operand.print(llvm::outs().indent(6)); - llvm::outs() << '\n'; - } - } - llvm::outs().indent(4) << "Consumers:\n"; - for (auto user : op->getUsers()) { - user->print(llvm::outs().indent(6)); - llvm::outs() << '\n'; - } - } - llvm::outs() << '\n'; - } - } - - int64_t detailLevel; - NameToOpMap map; -}; -} // namespace - -std::unique_ptr> createCollectStatsPass() { - return std::make_unique(); -} - -std::unique_ptr> createCollectStatsPass( - int64_t level) { - return std::make_unique(level); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc deleted file mode 100644 index 1d2e1f0b1e8381..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "gml_st/transforms/passes.h" -#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -using tensor::ExtractOp; -using tensor::ExtractSliceOp; - -#define GEN_PASS_DEF_COMPOSEEXTRACTINSERTSLICEPASS -#include "gml_st/transforms/passes.h.inc" - -LogicalResult composeExtractOfExtractSlice(ExtractOp extractOp, - PatternRewriter& rewriter) { - auto sliceOp = extractOp.getTensor().getDefiningOp(); - if (!sliceOp) return failure(); - - Location loc = extractOp.getLoc(); - SmallVector combinedOffsets, combinedSizes, combinedStrides; - - // ExtractOp can be viewed as ExtractSliceOp as extracts 1x...x1 slice. - int64_t rank = extractOp.getTensor().getType().getRank(); - SmallVector consumerOffsets( - getAsOpFoldResult(extractOp.getIndices())); - SmallVector consumerSizes(rank, rewriter.getIndexAttr(1)); - SmallVector consumerStrides(rank, rewriter.getIndexAttr(1)); - - if (failed(affine::mergeOffsetsSizesAndStrides( - rewriter, loc, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), - sliceOp.getMixedStrides(), sliceOp.getDroppedDims(), consumerOffsets, - consumerSizes, consumerStrides, combinedOffsets, combinedSizes, - combinedStrides))) - return failure(); - - rewriter.replaceOpWithNewOp( - extractOp, sliceOp.getSource(), - getValueOrCreateConstantIndexOp(rewriter, loc, combinedOffsets)); - return success(); -} - -struct ComposeExtractInsertSlicePass - : public impl::ComposeExtractInsertSlicePassBase< - ComposeExtractInsertSlicePass> { - void runOnOperation() override { - MLIRContext* ctx = &getContext(); - RewritePatternSet patterns(ctx); - patterns.add(composeExtractOfExtractSlice); - tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createComposeExtractInsertSlicePass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc deleted file mode 100644 index 5e2649621f1314..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "gml_st/transforms/passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { -namespace gml_st { - -GmlStCPUTilingOptions getDefaultCPUPipelineOptions(StringRef cpuName, - int64_t statsDetailLevel) { - GmlStCPUTilingOptions opts; - opts.vectorSize = 8; - opts.reductionEnableHeuristic = false; - opts.reduction1DSplitRatio = 8; - opts.reduction1DTileSize = 8; - opts.reduction2DParallelDimTileSize = 4; - opts.reduction2DReductionDimTileSize = 4; - opts.matmulTileSizes = {}; - // TODO(vuson): Re-enable or remove this: - opts.vectorizationSizeThreshold = 0; - opts.vectorizationTiledSizeThreshold = 1024; - opts.lowerToMmt4d = false; - opts.cpuName = cpuName; - opts.statsDetailLevel = statsDetailLevel; - opts.fuseDegenerateReshapes = false; - opts.inlineFusionClusters = true; - return opts; -} - -void addCPUTilingPipeline(OpPassManager& pm, - const GmlStCPUTilingOptions& options) { - using func::FuncOp; - - pm.addNestedPass(createCollectStatsPass(options.statsDetailLevel)); - pm.addNestedPass(createScalarizationPass(false)); - pm.addNestedPass( - createVectorizeForCPUPass(options.vectorizationSizeThreshold)); - - if (options.lowerToMmt4d) pm.addNestedPass(createPackMatmulPass()); - - pm.addNestedPass(createTransformScatterForCpuPass()); - - pm.addNestedPass( - createTransformDotForCpuPass(options.matmulTileSizes, options.cpuName)); - TransformReduceForCpuPassOptions reductionOpts; - reductionOpts.enableHeuristic = options.reductionEnableHeuristic; - reductionOpts.tileSize1D = options.reduction1DTileSize; - reductionOpts.splitRatio1D = options.reduction1DSplitRatio; - reductionOpts.parallelDimTileSize2D = options.reduction2DParallelDimTileSize; - reductionOpts.reductionDimTileSize2D = - options.reduction2DReductionDimTileSize; - pm.addNestedPass(createTransformReduceForCpuPass(reductionOpts)); - - // Upstream generalization of tensor.pack/unpack (i.e. tensor.pack/unpack -> - // tensor.pad + linalg.transpose + tensor.insert_slice) does not transfer - // transformed labels from tensor.pack/unpack to linalg.transpose and thus - // makes the latter being tiled again. - // Hence, elementwise ops transformation needs to be run before pack/unpack - // transformation. - pm.addNestedPass(createTransformElementwiseForCpuPass( - options.vectorSize, options.fuseDegenerateReshapes)); - pm.addNestedPass(createTransformMmt4DForCpuPass()); - pm.addNestedPass(createTransformPackForCpuPass()); - - if (options.inlineFusionClusters) - pm.addNestedPass(createInlineFusionClustersPass()); - - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - - pm.addNestedPass(createRewriteForallOpPass()); - pm.addNestedPass(createComposeExtractInsertSlicePass()); - pm.addNestedPass( - createVectorizeForCPUPass(options.vectorizationTiledSizeThreshold)); - - // Tile remaining ops by size one and scalarize what we can. - pm.addNestedPass(createTileByOnePass()); - pm.addNestedPass(createScalarizationPass()); - pm.addNestedPass(createComposeExtractInsertSlicePass()); - - pm.addPass(createCanonicalizerPass()); - - // Remove transformed labels after tiling all ops. - pm.addNestedPass(createRemoveLabelPass()); -} - -void addDefaultCPUTilingPipeline(OpPassManager& pm, StringRef cpuName, - int64_t statsDetailLevel) { - addCPUTilingPipeline(pm, - getDefaultCPUPipelineOptions(cpuName, statsDetailLevel)); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_outlining.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_outlining.cc deleted file mode 100644 index dd912a13c58af8..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_outlining.cc +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/RegionUtils.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_FUSIONOUTLININGPASS -#include "gml_st/transforms/passes.h.inc" - -constexpr llvm::StringRef kFusionFunctionLabel = "fusion"; - -void outlineFusionOp(func::FuncOp parentFuncOp, gml_st::FusionOp fusionOp, - int64_t localFusionId, PatternRewriter& rewriter) { - Location loc = fusionOp.getLoc(); - MLIRContext* ctx = fusionOp.getContext(); - - // Generate outlined fusion func ops right before the parent func op. - rewriter.setInsertionPoint(parentFuncOp); - std::string funcName = - llvm::formatv("{0}_fusion_{1}", parentFuncOp.getName(), localFusionId) - .str(); - TypeRange funcArgTypes = fusionOp->getOperandTypes(); - TypeRange funcResultTypes = fusionOp.getResultTypes(); - auto funcTy = FunctionType::get(ctx, funcArgTypes, funcResultTypes); - auto funcOp = - rewriter.create(fusionOp.getLoc(), funcName, funcTy); - setLabel(funcOp, kFusionFunctionLabel); - - // Generate entry block. - Region& funcRegion = funcOp.getBody(); - Block* funcBlock = - rewriter.createBlock(&funcRegion, funcRegion.begin(), funcArgTypes, - SmallVector(funcArgTypes.size(), loc)); - rewriter.setInsertionPointToStart(funcBlock); - - // Generate new fusion op and steal body. - auto newFusionOp = rewriter.create( - loc, funcResultTypes, funcBlock->getArguments(), fusionOp->getAttrs()); - newFusionOp.getRegion().takeBody(fusionOp.getRegion()); - - // Forward fusion op results. - rewriter.create(loc, newFusionOp->getResults()); - - // Replace fusion op with a call to the newly outlined function. - rewriter.setInsertionPoint(fusionOp); - rewriter.replaceOpWithNewOp(fusionOp, funcOp, - fusionOp->getOperands()); -} - -LogicalResult outlineFusionOpPattern(func::FuncOp funcOp, - PatternRewriter& rewriter) { - // Only apply to functions that are not the result of outlining. - if (hasLabel(funcOp, kFusionFunctionLabel)) return failure(); - - // Outline fusion ops one by one. - int64_t numOutlinedFusions = 0; - funcOp.walk([&](gml_st::FusionOp fusionOp) { - // Outline only outermost cluster. - if (fusionOp->getParentOfType()) return; - - outlineFusionOp(funcOp, fusionOp, numOutlinedFusions++, rewriter); - }); - - // Successfully applied pattern if at least one fusion was outlined. - if (numOutlinedFusions > 0) return success(); - return failure(); -} - -struct FusionOutliningPass - : public impl::FusionOutliningPassBase { - void runOnOperation() override { - ModuleOp moduleOp = getOperation(); - MLIRContext* ctx = &getContext(); - - // Populate patterns. - RewritePatternSet patterns(ctx); - patterns.add(outlineFusionOpPattern); - - if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> createFusionOutliningPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc deleted file mode 100644 index a6d127c6eb0de4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/fusion_planning_for_cpu.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_FUSIONPLANNINGFORCPUPASS -#define GEN_PASS_DEF_INLINEFUSIONCLUSTERSPASS -#include "gml_st/transforms/passes.h.inc" - -static constexpr llvm::StringRef kFusionPlanningLabel = - "__fusion_planning_label__"; - -// Returns true if the op is linalg.reduce or one of the variations of matmul. -bool isReducingOp(Operation* op) { - return isa(op); -} - -// Returns true if the op is either a map (linalg.map or linalg.fill) or the op -// has only parallel tiling dimensions and doesn't perform any computations -// (linalg.broadcast, linalg.transpose, thlo.reverse). -bool isElementwiseOp(Operation* op) { - return isa(op); -} - -// Returns true is consumer and producer should be fused and tiled together. -bool allowedToFuse(Operation* consumerOp, Operation* producerOp) { - // Verify that only known ops are fused. - if (!isa( - producerOp->getDialect())) - return false; - - if (isa(producerOp)) return false; - - if (isa(producerOp)) { - auto dstStyleOp = dyn_cast(consumerOp); - if (!dstStyleOp) return false; - - if (llvm::any_of(dstStyleOp.getDpsInits(), [&](Value operand) { - return operand.getDefiningOp() == producerOp; - })) - return true; - } - - if (isElementwiseOp(consumerOp) && isElementwiseOp(producerOp)) return true; - - if (isa(consumerOp)) return true; - if (isa(consumerOp)) return false; - - if (isa(consumerOp)) - return isa(producerOp); - if (isa(consumerOp)) - return isa(producerOp); - if (isa(consumerOp)) return isa(producerOp); - return false; -} - -// Runs graph search to find ops that can be fused together. -template -LogicalResult fusionPattern(OpTy op, PatternRewriter& rewriter) { - // The op is already in a fusion cluster. - if (isa(op.getOperation()->getParentOp())) return failure(); - - // The op was already processed. - if (hasLabel(op, kFusionPlanningLabel)) return failure(); - - for (auto& use : op->getUses()) { - auto* useOp = use.getOwner(); - // This op can be potentially fused into one of the consumers. Wait until - // that other op is processed. - if (useOp && allowedToFuse(useOp, op.getOperation())) return failure(); - } - - SetVector resultOps; - SmallVector remainingProducers; - bool hasReducingOp = isReducingOp(op); - resultOps.insert(op.getOperation()); - for (auto operand : op.getOperands()) - remainingProducers.push_back(operand.getDefiningOp()); - - while (!remainingProducers.empty()) { - Operation* curOp = remainingProducers.pop_back_val(); - if (!curOp) continue; - - if (llvm::is_contained(resultOps, curOp)) continue; - - if (!llvm::all_of(curOp->getUses(), [&](mlir::OpOperand& use) { - auto* consumerOp = use.getOwner(); - // Check that curOp is allowed to fused with all consumers. - if (!allowedToFuse(consumerOp, curOp)) return false; - // Check that all consumers are already in the fusion cluster. - if (!llvm::is_contained(resultOps, consumerOp)) return false; - return true; - })) - continue; - - // Only one reducing op should be added to the cluster. - if (isReducingOp(curOp)) { - if (hasReducingOp) continue; - hasReducingOp = true; - } - - resultOps.insert(curOp); - - for (auto operand : curOp->getOperands()) - remainingProducers.push_back(operand.getDefiningOp()); - } - - FusionCluster fusionCluster; - fusionCluster.root = op; - fusionCluster.operations = resultOps; - if (failed(wrapFusionCluster(rewriter, fusionCluster))) return failure(); - - // Mark all ops as processed. - for (auto* op : resultOps) setLabel(op, kFusionPlanningLabel); - - return success(); -} - -// Add attributes with tile sizes for parallel and reduction dimensions. -// Attribute is empty if there is nothing to tile across respective dimensions. -struct ComputeTileSizesPattern : public OpRewritePattern { - ComputeTileSizesPattern(MLIRContext* context, int64_t vectorSize, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - vectorSize(vectorSize) {} - - LogicalResult matchAndRewrite(gml_st::FusionOp fusionOp, - PatternRewriter& rewriter) const override { - if (fusionOp.getParallelTileSizes().has_value()) return failure(); - - if (!llvm::all_of(fusionOp.getRegion().getOps(), [](Operation& op) { - return isa(op); - })) - return failure(); - - auto rootOp = dyn_cast_or_null( - fusionOp.getTerminator().getOperand(0).getDefiningOp()); - if (!rootOp) return failure(); - - const int64_t numLoops = rootOp.getLoopIteratorTypes().size(); - - fusionOp.setParallelTileSizes(getParallelTileSizes(numLoops)); - fusionOp.setReductionTileSizes(SmallVector(numLoops, 0)); - - return success(); - }; - - private: - SmallVector getParallelTileSizes(int64_t numLoops) const { - SmallVector result(numLoops, 1); - if (!result.empty()) result.back() = vectorSize; - return result; - } - - int64_t vectorSize; -}; - -struct FusionPlanningForCpuPass - : public impl::FusionPlanningForCpuPassBase { - explicit FusionPlanningForCpuPass(int64_t vs = 8) { vectorSize = vs; } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = &getContext(); - - // Cleanup passes to prepare ops for better clustering. - { - RewritePatternSet patterns(ctx); - populateDuplicateInitOpsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - // Move ops to gml_st.fusion clusters. - { - RewritePatternSet patterns(ctx); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - patterns.add(fusionPattern); - - GreedyRewriteConfig config = GreedyRewriteConfig(); - // TODO(shyshkov): Refactor the fusion pattern so it doesn't visit all ops - // too many times. Currently pattern might need O(N^2) iterations to - // create fusion clusters for N ops. - config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed( - applyPatternsAndFoldGreedily(f, std::move(patterns), config))) { - return signalPassFailure(); - } - } - - // Add attributes with tile sizes. - { - RewritePatternSet patterns(ctx); - patterns.add(ctx, vectorSize); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - } -}; - -struct InlineFusionClustersPass - : public impl::InlineFusionClustersPassBase { - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(inlineFusionCluster); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -createFusionPlanningForCpuPass(int64_t vectorSize) { - return std::make_unique(vectorSize); -} - -std::unique_ptr> -createInlineFusionClustersPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/pack_matmul.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/pack_matmul.cc deleted file mode 100644 index e119d88dbb291f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/pack_matmul.cc +++ /dev/null @@ -1,331 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_PACKMATMULPASS -#include "gml_st/transforms/passes.h.inc" - -// Helper to pick the tile shapes to use as the 2 inner dimensions of the -// 4D shapes appearing in a Mmt4D. -class Mmt4DTileParams { - public: - Mmt4DTileParams(ArrayRef m0k0n0, const llvm::StringRef comment) - : m0(m0k0n0[0]), k0(m0k0n0[1]), n0(m0k0n0[2]), comment(comment) {} - std::array lhs() const { return {m0, k0}; } - std::array rhs() const { return {k0, n0}; } - std::array acc() const { return {m0, n0}; } - std::array rhsTranspose() const { return {n0, k0}; } - const std::string &getComment() const { return comment; } - - private: - const int64_t m0; - const int64_t k0; - const int64_t n0; - const std::string comment; -}; - -std::optional getPaddingValue(Value &source) { - auto padOp = source.getDefiningOp(); - if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) - return std::nullopt; - - Value constantPaddingValue = padOp.getConstantPaddingValue(); - if (!constantPaddingValue) return std::nullopt; - - source = padOp.getSource(); - return constantPaddingValue; -} - -// Returns a tiled and packed value of |source|, the data layout is described by -// |innerDimsPos|, |innerTileSizes| and |outerDimsPerm|. -Value pack(Location loc, PatternRewriter &rewriter, Value source, - ArrayRef innerDimsPos, ArrayRef innerTileSizes, - ArrayRef outerDimsPerm) { - SmallVector innerTileSizesOfr = - getAsOpFoldResult(rewriter.getI64ArrayAttr(innerTileSizes)); - auto empty = tensor::PackOp::createDestinationTensor( - rewriter, loc, source, innerTileSizesOfr, innerDimsPos, outerDimsPerm); - std::optional paddingValue = getPaddingValue(source); - return rewriter.create(loc, source, empty, innerDimsPos, - innerTileSizesOfr, paddingValue, - outerDimsPerm); -} - -// Returns an unpacked value of |source|, the data layout is described by -// |innerDimsPos|, |innerTileSizes| and |outerDimsPerm|. |resultShapeValue| is -// used to create the destination tensor for the resulting unpacked value. -Value unpack(Location loc, PatternRewriter &rewriter, Value source, - Value resultShapeValue, ArrayRef innerDimsPos, - ArrayRef innerTileSizes, - ArrayRef outerDimsPerm) { - SmallVector resultDims = - tensor::getMixedSizes(rewriter, loc, resultShapeValue); - auto empty = rewriter.create( - loc, resultDims, - source.getType().cast().getElementType()); - - SmallVector innerTileSizesOfr = - getAsOpFoldResult(rewriter.getI64ArrayAttr(innerTileSizes)); - - return rewriter.create(loc, source, empty, innerDimsPos, - innerTileSizesOfr, outerDimsPerm); -} - -bool haveEqualShapeDim(Value x, Value y, int i) { - return x.getType().cast().getDimSize(i) == - y.getType().cast().getDimSize(i); -} - -// Returns a top-left slice from |input| shaped like |likeWhat|. -Value extractSliceLike(Location loc, PatternRewriter &rewriter, Value input, - Value likeWhat) { - SmallVector offsets, dims, strides; - auto resultType = likeWhat.getType().cast(); - int64_t rank = resultType.getRank(); - auto resultShape = likeWhat.getType().cast().getShape(); - for (int i = 0; i < rank; ++i) { - offsets.push_back(rewriter.getIndexAttr(0)); - strides.push_back(rewriter.getIndexAttr(1)); - if (resultShape[i] == ShapedType::kDynamic) { - dims.emplace_back(rewriter.create(loc, likeWhat, i)); - } else { - dims.push_back(rewriter.getIndexAttr(resultShape[i])); - } - } - return rewriter.create(loc, resultType, input, - offsets, dims, strides); -} - -// Returns true if an input of the given |inputShape| needs padding to -// ensure that its shape will be a multiple of |tileShape|. That's always true -// in the dynamic shape case. -bool needsPadding(ArrayRef inputShape, ArrayRef tileShape) { - assert(inputShape.size() == tileShape.size()); - for (size_t i = 0; i < inputShape.size(); i++) { - if (inputShape[i] == ShapedType::kDynamic) { - return true; - } - if (inputShape[i] % tileShape[i] != 0) { - return true; - } - } - return false; -} - -// Pads |input| on the bottom and on the right to the next multiple of -// |tileShape|. -Value pad(Location loc, PatternRewriter &rewriter, Value input, - ArrayRef tileShape) { - SmallVector lowPadding, highPadding; - SmallVector resultTypeShape; - auto inputType = input.getType().cast(); - ArrayRef inputShape = inputType.getShape(); - if (!needsPadding(inputShape, tileShape)) { - return input; - } - int64_t rank = inputType.getRank(); - for (int64_t i = 0; i < rank; ++i) { - // No 'low' padding i.e. no padding at the top and on the left. - lowPadding.push_back(rewriter.getIndexAttr(0)); - // 'High' padding i.e. padding at the bottom and on the right, and the - // result type shape, will be dynamic in any dimension if and only if the - // input shape is. - if (inputShape[i] == ShapedType::kDynamic) { - resultTypeShape.push_back(ShapedType::kDynamic); - // There only remains to compute the 'high' padding Value. - auto add = [&](Value a, Value b) { - return rewriter.create(loc, a, b); - }; - auto sub = [&](Value a, Value b) { - return rewriter.create(loc, a, b); - }; - auto rem = [&](Value a, Value b) { - return rewriter.create(loc, a, b); - }; - // Compare to the plainer distanceToNextMultipleOf in the static - // dimension case below. - auto distanceToNextMultipleOf = [&](Value a, Value b) { - Value one = rewriter.create(loc, 1); - Value bMinusOne = sub(b, one); - return sub(bMinusOne, rem(add(a, bMinusOne), b)); - }; - Value inputDim = rewriter.create(loc, input, i); - Value tileDim = - rewriter.create(loc, tileShape[i]); - Value padding = distanceToNextMultipleOf(inputDim, tileDim); - highPadding.push_back(padding); - } else { - auto distanceToNextMultipleOf = [=](int64_t a, int64_t b) { - int64_t bMinusOne = b - 1; - return bMinusOne - ((a + bMinusOne) % b); - }; - int64_t inputDim = inputShape[i]; - int64_t tileDim = tileShape[i]; - int64_t padding = distanceToNextMultipleOf(inputDim, tileDim); - resultTypeShape.push_back(inputDim + padding); - highPadding.push_back(rewriter.getIndexAttr(padding)); - } - } - Type elementType = inputType.getElementType(); - RankedTensorType resultType = - RankedTensorType::get(resultTypeShape, elementType); - Value padValue; - if (auto complexTy = elementType.dyn_cast()) { - auto zero = rewriter.getZeroAttr(complexTy.getElementType()); - padValue = rewriter.create( - loc, elementType, rewriter.getArrayAttr({zero, zero})); - } else { - auto zero = rewriter.getZeroAttr(elementType); - padValue = rewriter.create(loc, elementType, zero); - } - return rewriter.create(loc, resultType, input, lowPadding, - highPadding, padValue); -} - -// Pattern to convert linalg.matmul to an equivalent subgraph using -// linalg.mmt4d. Currently, m0, n0 and k0 (packing parameters, aka layout tiling -// parameters) are compile-time constants. -LogicalResult packMatmul(linalg::MatmulOp matmulOp, PatternRewriter &rewriter) { - Location loc = matmulOp.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - - Value lhs = matmulOp.getDpsInputOperand(0)->get(); - Value rhs = matmulOp.getDpsInputOperand(1)->get(); - Value acc = matmulOp.getDpsInitOperand(0)->get(); - - // This transformation supports any mixing of static and dynamic dimensions, - // with one exception: the dynamic-ness of each dimension of the accumulator - // must match the dynamic-ness of the corresponding lhs/rhs dimension. - // This limitation is not inherent to this transformation's code, it's just - // here to avoid a current linalg folding limitation: at the moment, - // removing this gives the following error in e2e matmul tests, - // "error: failed to legalize operation 'tensor.cast' that was explicitly - // marked illegal" - // apparently due to some missing folding of tensor.cast op into reshapes. - if (!haveEqualShapeDim(lhs, acc, 0) || !haveEqualShapeDim(rhs, acc, 1)) { - return failure(); - } - - ShapedType lhsType = lhs.getType().cast(); - ShapedType rhsType = rhs.getType().cast(); - int64_t shapeM = lhsType.getShape()[0]; - int64_t shapeN = rhsType.getShape()[1]; - auto chooseMatMulOrMatVec = - [=](ArrayRef m0k0n0, ArrayRef m0k0n0ForMatVec, - ArrayRef m0k0n0ForWhenRhsHas2Columns, std::string comment) { - assert(m0k0n0ForMatVec[2] == 1 && "not a matrix*vector shape"); - assert(m0k0n0ForWhenRhsHas2Columns[2] == 2 && - "N=2 is expected when RHS has 2 columns"); - - SmallVector params; - if (shapeN == 1 || shapeM == 1) { - params.assign(m0k0n0ForMatVec.begin(), m0k0n0ForMatVec.end()); - } else if (shapeN == 2 || shapeM == 2) { - params.assign(m0k0n0ForWhenRhsHas2Columns.begin(), - m0k0n0ForWhenRhsHas2Columns.end()); - } else { - return Mmt4DTileParams(m0k0n0, comment); - } - - if (shapeN == 1 || shapeN == 2) { - comment += ", matrix * narrow matrix, where the narrow matrix has " + - std::to_string(shapeN) + " column(s)"; - } else { - // The vector*matrix case is intentionally derived from the - // matrix*vector case by swapping M and N dims so that in kernel - // codegen we can reuse matrix*vector kernels by swapping LHS and RHS. - std::swap(params[0], params[2]); - comment += ", narrow matrix * matrix, where the narrow matrix has " + - std::to_string(shapeM) + " column(s)"; - } - return Mmt4DTileParams(params, comment); - }; - - const auto &tileParams = chooseMatMulOrMatVec({8, 1, 8}, {8, 1, 1}, {8, 1, 2}, - "f32*f32->f32, generic"); - - Value paddedLhs = pad(loc, rewriter, lhs, tileParams.lhs()); - Value paddedRhs = pad(loc, rewriter, rhs, tileParams.rhs()); - Value paddedAcc = pad(loc, rewriter, acc, tileParams.acc()); - - Value packed4DLhs = - pack(loc, rewriter, paddedLhs, {0, 1}, tileParams.lhs(), {}); - Value packed4DRhs = - pack(loc, rewriter, paddedRhs, {1, 0}, tileParams.rhsTranspose(), {1, 0}); - Value packed4DAcc = - pack(loc, rewriter, paddedAcc, {0, 1}, tileParams.acc(), {}); - - auto mmt4d = rewriter.create( - loc, packed4DAcc.getType(), ValueRange{packed4DLhs, packed4DRhs}, - ValueRange{packed4DAcc}); - mmt4d->setAttr(StringAttr::get(ctx, "comment"), - StringAttr::get(ctx, tileParams.getComment())); - - Value paddedResult = unpack(loc, rewriter, mmt4d.getResult(0), paddedAcc, - {0, 1}, tileParams.acc(), {}); - - Value result = extractSliceLike(loc, rewriter, paddedResult, acc); - rewriter.replaceOp(matmulOp, ArrayRef{result}); - - return success(); -} - -struct PackMatmulPass : public impl::PackMatmulPassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp func = getOperation(); - - RewritePatternSet patterns(&getContext()); - patterns.add(packMatmul); - - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -createPackMatmulPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/remove_label.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/remove_label.cc deleted file mode 100644 index bc6e41a905645f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/remove_label.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_REMOVELABELPASS -#include "gml_st/transforms/passes.h.inc" - -struct RemoveLabelPass : public impl::RemoveLabelPassBase { - using Base::Base; - - void runOnOperation() override { - getOperation().walk( - [](Operation *op) { removeLabel(op, kTransformedLabel); }); - } -}; -} // namespace - -std::unique_ptr> -createRemoveLabelPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_dot_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_dot_for_cpu.cc deleted file mode 100644 index 316c9e5cc92ec2..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_dot_for_cpu.cc +++ /dev/null @@ -1,681 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/peeling/peeling.h" -#include "gml_st/transforms/tiling/tiling.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/utils/linalg_utils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/IR/Dominance.h" -#include "mlir/Pass/Pass.h" // IWYU pragma: keep -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMDOTFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -constexpr llvm::StringRef kFusionPlanningLabel = "__fusion_planning_label__"; - -struct MatmulSizes { - // [m, k] x [k, n] - int64_t m; - int64_t n; - int64_t k; -}; - -using MatmulTileSizeComputationFn = std::function; - -int64_t roundDownToPowerOfTwo(int64_t n) { - if ((n & (n - 1)) == 0) return n; - n |= n >> 1; - n |= n >> 2; - n |= n >> 4; - n |= n >> 8; - n |= n >> 16; - n |= n >> 32; - return (n + 1) >> 1; -} - -bool isPowerOfTwo(int64_t n) { return (n & (n - 1)) == 0; } - -// Tiling heuristic that was tuned for static power-of-two sized shapes on -// Skylake. -MatmulSizes skylakeTilingHeuristic(MatmulSizes sizes) { - if (sizes.m == 1) { - // Limit the maximum tiling to an arbitrary 32 to limit code growth. This - // needs re-tuning. - return {1, std::min(sizes.n, 32), 1}; - } - - if (sizes.n == 1) { - if (sizes.k <= 8) { - return {1, 1, 1}; - } - return {std::min(8, sizes.m), 1, 4}; - } - - MatmulSizes result; - result.k = sizes.k <= 8 ? 1 : 4; - result.n = std::min(8, sizes.n) << (sizes.m <= 16 ? 1 : 0); - result.m = std::min(32, sizes.m) << (sizes.n <= 4 ? 1 : 0); - return result; -} - -// Tiling heuristic that was tuned for static power-of-two sized shapes on Zen -// v2 ("Rome"). -MatmulSizes znver2TilingHeuristic(MatmulSizes sizes) { - MatmulSizes result; - result.k = sizes.n == 1 ? 8 : 1; - if (sizes.n == 1) { - result.m = sizes.k >= 32 ? 16 : 8; - } else { - result.m = sizes.n <= 8 ? 8 : 4; - } - if (sizes.m == 1) { - result.n = std::min(64, sizes.n) * (sizes.k <= 64 ? 1 : 2); - } else { - result.n = std::min(16, sizes.n); - } - return result; -} - -// Tiling heuristic that was tuned for static sized shapes on generic Haswell. -MatmulSizes haswellTilingHeuristic(MatmulSizes sizes) { - MatmulSizes result; - // Dot - if (sizes.m == 1 && sizes.n == 1) { - // At this point we only have small tensors, dots with bigger tensors are - // already turned into reduce(map). - return {1, std::min(sizes.n, 32), 1}; - } - - // Vecmat - if (sizes.m == 1) { - result.m = 1; - constexpr int64_t kVecmatNThreshold = 64; - constexpr int64_t kVecmatSizeThreshold = 16 * kVecmatNThreshold; - int64_t numElements = sizes.k * sizes.n; - if (sizes.n < kVecmatNThreshold) { - result.n = sizes.n; - if (numElements < kVecmatSizeThreshold) { - result.k = sizes.k; - } else if (isPowerOfTwo(sizes.n)) { - result.k = 2; - } else { - result.k = std::min(result.k / 2, 64); - } - } else { - result.n = kVecmatNThreshold; - if (sizes.k < 16) { - result.k = sizes.k; - } else { - if (sizes.n >= 256) { - result.k = isPowerOfTwo(sizes.k) ? 1 : 8; - } else { - result.k = isPowerOfTwo(sizes.k) ? 8 : 16; - } - } - } - return result; - } - - result.k = sizes.n == 1 ? 8 : 1; - // Matvec - if (sizes.n == 1) { - if (sizes.k <= 8) { - return {1, 1, 1}; - } - return {std::min(8, sizes.m), 1, 4}; - } - // Matmul - result.k = sizes.k <= 8 ? 1 : 4; - result.n = std::min(8, sizes.n) << (sizes.m <= 16 ? 1 : 0); - result.m = std::min(32, sizes.m) << (sizes.n <= 4 ? 1 : 0); - return result; -} - -std::function wrapHeuristic( - const std::function &heuristic, - MatmulSizes dynamicDefault) { - return [=](MatmulSizes sizes) { - if (ShapedType::isDynamic(sizes.n) || ShapedType::isDynamic(sizes.m) || - ShapedType::isDynamic(sizes.k)) { - return dynamicDefault; - } - - sizes.m = roundDownToPowerOfTwo(sizes.m); - sizes.n = roundDownToPowerOfTwo(sizes.n); - sizes.k = roundDownToPowerOfTwo(sizes.k); - - return heuristic(sizes); - }; -} - -MatmulSizes getMatmulSizes(linalg::MatmulOp op) { - // [m, k] x [k, n] - auto lhsTy = op->getOperand(0).getType().cast(); - auto rhsTy = op->getOperand(1).getType().cast(); - MatmulSizes sizes; - sizes.m = lhsTy.getDimSize(0); - sizes.k = rhsTy.getDimSize(0); - sizes.n = rhsTy.getDimSize(1); - return sizes; -} - -MatmulSizes getMatmulSizes(linalg::VecmatOp op) { - // [1, k] x [k, n] - auto ty = op->getOperand(1).getType().cast(); - MatmulSizes sizes; - sizes.m = 1; - sizes.k = ty.getDimSize(0); - sizes.n = ty.getDimSize(1); - return sizes; -} - -MatmulSizes getMatmulSizes(linalg::MatvecOp op) { - // [m, k] x [k, 1] - auto ty = op->getOperand(0).getType().cast(); - MatmulSizes sizes; - sizes.m = ty.getDimSize(0); - sizes.k = ty.getDimSize(1); - sizes.n = 1; - return sizes; -} - -MatmulSizes getMatmulSizes(linalg::DotOp op) { - // [1, k] x [k, 1] - auto ty = op->getOperand(0).getType().cast(); - MatmulSizes sizes; - sizes.m = 1; - sizes.k = ty.getDimSize(0); - sizes.n = 1; - return sizes; -} - -SmallVector dropZeros(ArrayRef tileSizes) { - return to_vector(llvm::make_filter_range( - tileSizes, [](int64_t size) { return size != 0; })); -} - -struct DotAddPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - explicit DotAddPattern(MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(linalg::MapOp mapOp, - PatternRewriter &rewriter) const override { - auto ®ion = mapOp.getMapper(); - if (!region.hasOneBlock()) return failure(); - - auto &body = region.front(); - // The body region should only have one add operation and a linalg.yield. - if (body.getOperations().size() != 2) return failure(); - - auto &mapperOp = body.front(); - if (!isa(mapperOp)) return failure(); - - // Map of add should always be binary. - if (mapOp.getInputs().size() != 2) return failure(); - if (ValueRange{body.getArguments()} != ValueRange{mapperOp.getOperands()}) - return failure(); - - if (!llvm::any_of(mapOp.getInputs(), [](Value operand) { - auto linalgOp = operand.getDefiningOp(); - return linalg::isaContractionOpInterface(linalgOp); - })) - return failure(); - - auto foldAddIntoDotOperand = [&](unsigned opIdx) { - auto dotOp = mapOp.getInputs()[opIdx].getDefiningOp(); - auto otherOp = mapOp.getInputs()[1 - opIdx]; - if (!linalg::isaContractionOpInterface(dotOp)) return false; - if (!dotOp.getDpsInitOperand(0)->get().getDefiningOp()) - return false; - if (!dotOp->hasOneUse()) return false; - // TODO(vuson): handle the case where we need to move dotOp up or otherOp - // down. - mlir::DominanceInfo domInfo(mapOp->getParentOp()); - if (!domInfo.properlyDominates(otherOp, dotOp)) return false; - rewriter.updateRootInPlace( - dotOp, [&]() { dotOp.setDpsInitOperand(0, otherOp); }); - rewriter.replaceOp(mapOp, dotOp->getResults()); - return true; - }; - - return success(foldAddIntoDotOperand(0) || foldAddIntoDotOperand(1)); - } -}; - -LogicalResult tileAndPeelReductionDim(PatternRewriter &rewriter, - Operation *reduceOp, - ArrayRef reductionDimTileSizes) { - FailureOr reductionDimTilingResult = - tileUsingSCFForOpAndFuseGreedily( - rewriter, reduceOp, - getSCFTilingOptions(rewriter.getContext(), reductionDimTileSizes)); - if (failed(reductionDimTilingResult)) return failure(); - - SCFForPeelingResult reductionDimPeelingResult = peelSCFForOp( - rewriter, cast(reductionDimTilingResult->loops.front())); - if (reductionDimPeelingResult.mainLoop) { - setLabel(reductionDimPeelingResult.mainLoop, kPerfectlyTiledLoopLabel); - } - return success(); -} - -SmallVector getTileSizesForDimsOfType(Operation *iop, - ArrayRef tileSizes, - utils::IteratorType iterType) { - TilingInterface op = cast(iop); - SmallVector iteratorTypes = op.getLoopIteratorTypes(); - SmallVector tileSizesOfType(iteratorTypes.size(), 0); - assert(tileSizes.size() == iteratorTypes.size() && - "the number of provided tile sizes should match the iteration domain " - "of the op"); - SmallVector iteratorTypeDimsPositions; - findPositionsOfType(iteratorTypes, iterType, iteratorTypeDimsPositions); - for (unsigned pos : iteratorTypeDimsPositions) - tileSizesOfType[pos] = tileSizes[pos]; - return tileSizesOfType; -} - -/// Helper to tile dot operations (linalg.matvec, linalg.vecmat, linalg.dot) -/// and peel the generated loops. This can be extended to support any op that -/// implements TilingInterface. -template -LogicalResult tileAndPeelMatmulOp(PatternRewriter &rewriter, DotOpTy dotOp, - ArrayRef tileSizes) { - Operation *tilingRoot = dotOp; - if (auto fusionOp = dyn_cast(dotOp->getParentOp())) { - tilingRoot = fusionOp.getTerminator().getValues()[0].getDefiningOp(); - } - - // First level tiling: parallel dimension. - auto parallelDimsTileSizes = getTileSizesForDimsOfType( - dotOp.getOperation(), tileSizes, utils::IteratorType::parallel); - auto reductionDimsTileSizes = getTileSizesForDimsOfType( - dotOp.getOperation(), tileSizes, utils::IteratorType::reduction); - if (!isa(tilingRoot)) - parallelDimsTileSizes = dropZeros(parallelDimsTileSizes); - - auto tilingParallelDimsResult = tileUsingSCFForallOpAndFuseGreedily( - rewriter, tilingRoot, - getSCFTilingOptions(rewriter.getContext(), parallelDimsTileSizes)); - if (failed(tilingParallelDimsResult)) return failure(); - - if (!tilingParallelDimsResult->loop) { - return tileAndPeelReductionDim(rewriter, dotOp, reductionDimsTileSizes); - } - auto peeledParallelLoop = - peelAllLoops(tilingParallelDimsResult->loop, rewriter); - - // Process main parallel loop. - scf::ForallOp mainParallelLoop = peeledParallelLoop.mainLoop; - if (mainParallelLoop) { - auto tiledDotOp = *mainParallelLoop.getBody()->getOps().begin(); - if (failed(tileAndPeelReductionDim(rewriter, tiledDotOp, - reductionDimsTileSizes))) { - return failure(); - } - } - - // Process tail parallel loop. - for (scf::ForallOp tailParallelLoop : peeledParallelLoop.tailLoops) { - for (auto tiledDotOp : llvm::to_vector( - tailParallelLoop.getBody()->template getOps())) { - auto reductionDimTilingResult = tileUsingSCFForOpAndFuseGreedily( - rewriter, tiledDotOp, - getSCFTilingOptions(rewriter.getContext(), reductionDimsTileSizes)); - if (failed(reductionDimTilingResult)) return failure(); - } - } - return success(); -} - -// Tile linalg.conv_2d_nhwc_hwcf to convert it to linalg.matmul.. -struct Conv2DNhwcHwcfOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const override { - if (!isTransformableIntoMatmul(convOp)) return failure(); - FailureOr tilingResult = scf::tileUsingSCFForOp( - rewriter, cast(convOp.getOperation()), - getSCFTilingOptions(rewriter.getContext(), {0, 0, 0, 0, 1, 0, 0})); - if (failed(tilingResult)) return failure(); - rewriter.replaceOp(convOp, tilingResult->replacements); - - auto tiledConv = - cast(tilingResult->tiledOps.front()); - return convertConvToMatmul(tiledConv, rewriter); - } -}; - -// Tile linalg.batch_matmul to 1 in the outermost dimension, then transform a -// unit linalg.batch_matmul into a matmul using reshape ops. -struct BatchMatmulOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp, - PatternRewriter &rewriter) const override { - // Tile and fuse fillOp into the loop nest. - auto tilingResult = tileUsingSCFForallOpAndFuseGreedily( - rewriter, batchMatmulOp.getOperation(), - getSCFTilingOptions(rewriter.getContext(), {1, 0, 0, 0})); - if (failed(tilingResult)) return failure(); - - auto tiledBatchMatmulOp = - cast(tilingResult->tiledOps.front()); - return convertBatchMatmulToMatmul(tiledBatchMatmulOp, rewriter); - } -}; - -struct MatmulPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - MatmulPattern(MLIRContext *context, MatmulTileSizeComputationFn tileSizeFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - tileSizeFn(std::move(tileSizeFn)) {} - - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, - PatternRewriter &rewriter) const override { - if (hasLabel(matmulOp, kTransformedLabel)) - return rewriter.notifyMatchFailure(matmulOp, "already transformed"); - - MatmulSizes tileSizes = tileSizeFn(getMatmulSizes(matmulOp)); - return tileAndPeelMatmulOp(rewriter, matmulOp, - {tileSizes.m, tileSizes.n, tileSizes.k}); - } - - private: - MatmulTileSizeComputationFn tileSizeFn; -}; - -struct MatvecPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - MatvecPattern(MLIRContext *context, MatmulTileSizeComputationFn tileSizeFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - tileSizeFn(std::move(tileSizeFn)) {} - - LogicalResult matchAndRewrite(linalg::MatvecOp matvecOp, - PatternRewriter &rewriter) const override { - if (hasLabel(matvecOp, kTransformedLabel)) - return rewriter.notifyMatchFailure(matvecOp, "already transformed"); - - MatmulSizes matmulSizes = getMatmulSizes(matvecOp); - // For large K it is beneficial to perform reduction in two steps, i.e. - // reduce tensor to tensor and then perform a horizontal - // add to reduce tensoSr to a single element. - constexpr int64_t kReductionDimSizeThreshold = 96; - if (!ShapedType::isDynamic(matmulSizes.k) && - matmulSizes.k > kReductionDimSizeThreshold) { - auto tilingParallelDim = tileUsingSCFForallOpAndFuseGreedily( - rewriter, matvecOp, - getSCFTilingOptions(rewriter.getContext(), {1, 0}), nullptr); - if (failed(tilingParallelDim)) return failure(); - - auto tiledMatvecOp = - cast(tilingParallelDim->tiledOps.front()); - return convertMatvecToDotOp(rewriter, tiledMatvecOp); - } - - MatmulSizes tileSizes = tileSizeFn(matmulSizes); - return tileAndPeelMatmulOp(rewriter, matvecOp, {tileSizes.m, tileSizes.k}); - } - - private: - MatmulTileSizeComputationFn tileSizeFn; -}; - -struct VecmatPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - VecmatPattern(MLIRContext *context, MatmulTileSizeComputationFn tileSizeFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - tileSizeFn(std::move(tileSizeFn)) {} - - LogicalResult matchAndRewrite(linalg::VecmatOp dotOp, - PatternRewriter &rewriter) const override { - if (hasLabel(dotOp, kTransformedLabel)) - return rewriter.notifyMatchFailure(dotOp, "already transformed"); - - MatmulSizes tileSizes = tileSizeFn(getMatmulSizes(dotOp)); - return tileAndPeelMatmulOp(rewriter, dotOp, {tileSizes.n, tileSizes.k}); - } - - private: - MatmulTileSizeComputationFn tileSizeFn; -}; - -struct DotPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - DotPattern(MLIRContext *context, MatmulTileSizeComputationFn tileSizeFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - tileSizeFn(std::move(tileSizeFn)) {} - - LogicalResult matchAndRewrite(linalg::DotOp dotOp, - PatternRewriter &rewriter) const override { - if (hasLabel(dotOp, kTransformedLabel)) - return rewriter.notifyMatchFailure(dotOp, "already transformed"); - - MatmulSizes matmulSizes = getMatmulSizes(dotOp); - constexpr int64_t kReductionDimSizeThreshold = 32; - if (!ShapedType::isDynamic(matmulSizes.k) && - matmulSizes.k > kReductionDimSizeThreshold) { - return convertDotOpToReduce(dotOp, rewriter); - } - MatmulSizes tileSizes = tileSizeFn(matmulSizes); - return tileAndPeelMatmulOp(rewriter, dotOp, {tileSizes.k}); - } - - private: - MatmulTileSizeComputationFn tileSizeFn; -}; - -Value transposeMatrixConstant(ImplicitLocOpBuilder &builder, Value input) { - ElementsAttr inputValues; - matchPattern(input, m_Constant(&inputValues)); - - auto inputType = input.getType().cast(); - ArrayRef inputShape = inputType.getShape(); - assert(inputShape.size() == 2); - - auto outputType = RankedTensorType::get({inputShape[1], inputShape[0]}, - inputType.getElementType()); - - SmallVector outputValues(inputType.getNumElements()); - for (const auto &it : llvm::enumerate(inputValues.getValues())) { - auto row = it.index() / inputShape[1]; - auto col = it.index() % inputShape[1]; - outputValues[col * inputShape[0] + row] = it.value(); - } - return builder.create( - outputType, DenseElementsAttr::get(outputType, outputValues)); -} - -// If we have a matvec with a constant matrix it's profitable to transpose the -// matrix at compile time and use vecmat instead. This has a friendlier memory -// access pattern. -struct MatVecToVecMatPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::MatvecOp matvecOp, - PatternRewriter &rewriter) const override { - auto constantMatrix = - matvecOp.getOperand(0).getDefiningOp(); - if (!constantMatrix) return failure(); - - ImplicitLocOpBuilder builder(constantMatrix.getLoc(), rewriter); - Value transposed = transposeMatrixConstant(builder, constantMatrix); - rewriter.replaceOpWithNewOp( - matvecOp, ValueRange{matvecOp.getOperand(1), transposed}, - matvecOp.getOutputs()); - return success(); - } -}; - -template -LogicalResult fusionClusterPattern(OpTy dotOp, PatternRewriter &rewriter) { - // The op was already processed. - if (dotOp->template getParentOfType()) return failure(); - if (hasLabel(dotOp, kFusionPlanningLabel)) return failure(); - - auto producerFilterFn = [](Operation *op) { - return isa(op); - }; - auto consumerFilterFn = [](Operation *op) { - if (auto mapOp = dyn_cast(op)) - return mapOp.getNumDpsInputs() == 1; - return isa(op); - }; - - auto fusionCluster = - getFusionCluster(dotOp, producerFilterFn, consumerFilterFn); - - for (auto *op : fusionCluster.operations) setLabel(op, kFusionPlanningLabel); - - if (failed(wrapFusionCluster(rewriter, fusionCluster))) return failure(); - - return success(); -} - -struct TransformDotForCpuPass - : public impl::TransformDotForCpuPassBase { - TransformDotForCpuPass() = default; - - explicit TransformDotForCpuPass(MatmulTileSizeComputationFn tileSizeFn) - : tileSizeFn(std::move(tileSizeFn)) {} - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - tensor::registerTilingInterfaceExternalModels(registry); - tensor::registerInferTypeOpInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - // Peephole optimization of dot followed by add. - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - // Cleanup passes to prepare ops for better clustering. - { - RewritePatternSet patterns(ctx); - populateDuplicateInitOpsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - patterns.add(fusionClusterPattern); - patterns.add(fusionClusterPattern); - patterns.add(fusionClusterPattern); - patterns.add(fusionClusterPattern); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - - f.walk([](Operation *op) { removeLabel(op, kFusionPlanningLabel); }); - } - - { - RewritePatternSet patterns(ctx); - patterns.add( - ctx, tileSizeFn); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - } - - MatmulTileSizeComputationFn tileSizeFn; -}; - -} // namespace - -std::unique_ptr> -createTransformDotForCpuPass(ArrayRef tileSizes, StringRef cpuName) { - std::function tilingHeuristic; - if (!tileSizes.empty()) { - assert(tileSizes.size() == 3 && "Expected exactly 3 tile sizes for matmul"); - MatmulSizes fixedSizes{tileSizes[0], tileSizes[1], tileSizes[2]}; - tilingHeuristic = [=](MatmulSizes) { return fixedSizes; }; - } else { - if (cpuName.starts_with("znver")) - tilingHeuristic = wrapHeuristic(znver2TilingHeuristic, {16, 8, 8}); - else if (cpuName.contains("skylake")) - tilingHeuristic = wrapHeuristic(skylakeTilingHeuristic, {16, 16, 4}); - else - // Default to generic Haswell target. - tilingHeuristic = wrapHeuristic(haswellTilingHeuristic, {8, 8, 8}); - } - return std::make_unique( - std::move(tilingHeuristic)); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc deleted file mode 100644 index b9a058d25904db..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_elementwise_for_cpu.cc +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/peeling/peeling.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/utils/tensor_utils.h" -#include "llvm/ADT/TypeSwitch.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMELEMENTWISEFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -constexpr llvm::StringRef kFusionPlanningLabel = "__fusion_planning_label__"; -constexpr llvm::StringRef kElementwiseLabel = "__elementwise_label__"; - -// Indicates the the dimension is not mapped to dimensions of the root op. -constexpr int64_t kNotMappedToRootDims = -1; - -using FusionFilterFn = llvm::function_ref; -using CandidatesMap = llvm::SmallMapVector, 4>; - -// Find the root of the fusion cluster. -Operation *findRootElementwiseOp(Operation *op, FusionFilterFn fusionFilterFn) { - Operation *rootOp = op; - Operation *curOp = nullptr; - do { - curOp = nullptr; - for (OpOperand &use : rootOp->getUses()) { - Operation *owner = use.getOwner(); - if (!fusionFilterFn(owner)) continue; - if (hasLabel(owner, kTransformedLabel)) continue; - if (hasLabel(owner, kFusionPlanningLabel)) continue; - if (auto dpsOp = dyn_cast(owner)) { - SmallVector opOperands = llvm::to_vector(llvm::map_range( - dpsOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); - if (llvm::is_contained(opOperands, &use)) continue; - } - curOp = owner; - rootOp = curOp; - break; - } - } while (curOp != nullptr); - // If the root is a reshape, don't use it, use the defining op for the - // argument instead. - if (isa(rootOp)) - return rootOp->getOperand(0).getDefiningOp(); - return rootOp; -} - -// Depending on the type of the defining op for the `result`, adds its arguments -// with the maps to the root result dimensions. -void addMappedTensorArgs(Value result, const SmallVector &map, - CandidatesMap &args) { - Operation *defOp = result.getDefiningOp(); - if (!defOp) return; - - mlir::TypeSwitch(defOp) - .Case([&](auto op) { - for (OpOperand *operand : - cast(op.getOperation()) - .getDpsInputOperands()) { - Value val = operand->get(); - if (!isa(val.getType())) continue; - args[val] = map; - } - }) - .Case([&](auto op) { - auto transposeOp = cast(op); - SmallVector composed(map.size(), 0); - for (auto [index, id] : llvm::enumerate(transposeOp.getPermutation())) { - composed[index] = map[id]; - } - args[transposeOp.getInput()] = composed; - }) - .Case([&](auto op) { - auto broadcastOp = cast(op); - SmallVector composed; - SmallVector bcastDims = to_vector(broadcastOp.getDimensions()); - - for (auto [index, id] : llvm::enumerate(map)) { - if (llvm::is_contained(bcastDims, index)) continue; - composed.push_back(id); - } - args[broadcastOp.getInput()] = composed; - }) - .Case([&](auto op) { - auto collapseShapeOp = cast(op); - auto srcType = collapseShapeOp.getSrcType(); - - SmallVector preservedDims = getPreservedDimensions( - srcType.getShape(), collapseShapeOp.getReassociationIndices()); - - SmallVector composed(srcType.getRank(), kNotMappedToRootDims); - for (auto [index, mapDim] : llvm::enumerate(map)) - composed[preservedDims[index]] = mapDim; - args[collapseShapeOp.getSrc()] = composed; - }) - .Case([&](auto op) { - auto expandShapeOp = cast(op); - auto dstType = expandShapeOp.getResultType(); - - SmallVector preservedDims = getPreservedDimensions( - dstType.getShape(), expandShapeOp.getReassociationIndices()); - - SmallVector composed(expandShapeOp.getSrcType().getRank()); - for (auto [index, preservedDim] : llvm::enumerate(preservedDims)) - composed[index] = map[preservedDim]; - args[expandShapeOp.getSrc()] = composed; - }) - .Default( - [](Operation *) { llvm_unreachable("The op is not supported"); }); -} - -// Starts a graph traversal from the root trying to fuse all ops that satisfy -// `fusionFilterFn` and also have no users outside of this fusion cluster. -FusionCluster findElementwiseCluster(Operation *rootOp, - FusionFilterFn fusionFilterFn) { - Value rootResult = rootOp->getResult(0); - - SetVector resultOps; - resultOps.insert(rootOp); - CandidatesMap mappedArgs, candidates; - - // Add operands of root. - int64_t rootRank = rootResult.getType().cast().getRank(); - auto identityMap = llvm::to_vector(llvm::seq(0, rootRank)); - addMappedTensorArgs(rootResult, identityMap, candidates); - - while (!candidates.empty()) { - bool fusionHappened = false; - SmallVector argsToErase; - for (auto [arg, map] : llvm::reverse(candidates)) { - // If the arg is already coming outside of the cluster, i.e. it is a - // function argument or a result of some op that is not included by the - // fusionFilterFn, then we remove such arg. - Operation *defOp = arg.getDefiningOp(); - if (mappedArgs.contains(arg) || !defOp || resultOps.contains(defOp) || - !fusionFilterFn(defOp)) { - mappedArgs[arg] = map; - argsToErase.push_back(arg); - continue; - } - - // If there are any users of this op outside of fusion cluster, then skip. - if (llvm::any_of(arg.getUsers(), [&](Operation *user) { - return !resultOps.contains(user); - })) { - continue; - } - - resultOps.insert(defOp); - addMappedTensorArgs(arg, map, candidates); - fusionHappened = true; - break; - } - for (Value argToErase : argsToErase) { - candidates.erase(argToErase); - } - - // If an op to fuse was not found, we add all current candidates to the - // result. - if (!fusionHappened) { - for (auto &candidate : candidates) { - mappedArgs.insert(candidate); - } - break; - } - } - FusionCluster fusionCluster; - fusionCluster.root = rootOp; - fusionCluster.operations = resultOps; - - // Add tensor.empty ops to the cluster. - for (auto *op : resultOps) { - if (auto dpsOp = dyn_cast(op)) { - for (auto operand : dpsOp.getDpsInits()) { - if (auto emptyOp = - dyn_cast_or_null(operand.getDefiningOp())) - fusionCluster.operations.insert(emptyOp); - } - } - } - - llvm::append_range(fusionCluster.argDimsMapping, mappedArgs); - return fusionCluster; -} - -// Searches through the inner-most dimensions of the arguments of the fusion -// cluster to find the most beneficial dimension to tile. Default tile size is 1 -// x ... x 1 x vector_size, which leads to vector.transfer_write to the init -// tensor. -// In case of broadcast, transpose and other maps with the non-identity mapping -// between op input and op result the innermost dimension of the input can be -// different from the one of result. -SmallVector optimizeTileSizes(const FusionCluster &fusionCluster, - int64_t vectorSize) { - auto rootTy = - cast(fusionCluster.root->getResultTypes().front()); - - if (rootTy.getRank() == 0) return {}; - SmallVector tileSizes(rootTy.getRank(), 1); - tileSizes.back() = vectorSize; - - int64_t rootInnermostDim = rootTy.getRank() - 1; - int64_t innermostDimWithMostElements = rootInnermostDim; - int64_t innermostDimMaxElements = std::numeric_limits::min(); - for (auto &[arg, map] : fusionCluster.argDimsMapping) { - auto argInnermostDimIt = llvm::find_if( - llvm::reverse(map), - [](int64_t item) { return item != kNotMappedToRootDims; }); - if (argInnermostDimIt == map.rend()) continue; - int64_t argInnermostDim = *argInnermostDimIt; - if (argInnermostDim == rootInnermostDim) continue; - - int64_t numElements = rootTy.getDimSize(argInnermostDim); - if (innermostDimMaxElements >= numElements && - !ShapedType::isDynamic(numElements)) - continue; - innermostDimMaxElements = numElements; - innermostDimWithMostElements = argInnermostDim; - } - tileSizes[innermostDimWithMostElements] = vectorSize; - return tileSizes; -} - -template -struct FusionClusterPattern : public OpRewritePattern { - FusionClusterPattern(MLIRContext *context, int64_t vectorSize, - bool fuseDegenerateReshapes, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - vectorSize(vectorSize), - fuseDegenerateReshapes(fuseDegenerateReshapes) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - if (hasSingleElementOperandsAndResults(op)) return failure(); - if (hasLabel(op, kFusionPlanningLabel)) return failure(); - if (hasLabel(op, kTransformedLabel)) return failure(); - if (op->template getParentOfType()) return failure(); - - // Find the root from which to start tiling and fusion. - auto fusionFilterFn = [&](Operation *op) { - if (fuseDegenerateReshapes) { - if (auto reshapeOp = dyn_cast(op)) - return isDegenerateReshapeOp(reshapeOp); - if (auto reshapeOp = dyn_cast(op)) - return isDegenerateReshapeOp(reshapeOp); - } - // Add thlo.concatenate here. - return isa(op); - }; - Operation *fusionRoot = findRootElementwiseOp(op, fusionFilterFn); - - // Find the fusion cluster and its arguments. - FusionCluster fusionCluster = - findElementwiseCluster(fusionRoot, fusionFilterFn); - - // Find what dimensions to tile. - SmallVector tileSizes = - optimizeTileSizes(fusionCluster, vectorSize); - - for (auto *clusterOp : fusionCluster.operations) - setLabel(clusterOp, kFusionPlanningLabel); - - auto fusionOp = wrapFusionCluster(rewriter, fusionCluster); - if (failed(fusionOp)) return failure(); - - fusionOp->setParallelTileSizes(tileSizes); - setLabel(*fusionOp, kElementwiseLabel); - - return success(); - } - - private: - int64_t vectorSize; - bool fuseDegenerateReshapes; -}; - -LogicalResult tileAndFuse(FusionOp fusionOp, PatternRewriter &rewriter) { - if (hasLabel(fusionOp, kTransformedLabel)) return failure(); - if (!hasLabel(fusionOp, kElementwiseLabel)) return failure(); - - auto *tilingRootOp = fusionOp.getTerminator().getValues()[0].getDefiningOp(); - auto tileSizes = *fusionOp.getParallelTileSizes(); - - // Tile and fuse. - auto tiledLoop = tileUsingSCFForallOpAndFuseGreedily( - rewriter, tilingRootOp, - getSCFTilingOptions(rewriter.getContext(), tileSizes)); - if (failed(tiledLoop)) return failure(); - - // Peel. - auto peelingResult = peelAllLoops(tiledLoop->loop, rewriter); - setLabel(tiledLoop->loop, kPerfectlyTiledLoopLabel); - - // Tile ops in the peeled loop again, to size 1, so they can be - // scalarized. - if (failed(tilePeeledOpsToScalars(rewriter, peelingResult))) return failure(); - - setLabel(fusionOp, kTransformedLabel); - return success(); -} - -struct TransformElementwiseForCpuPass - : public impl::TransformElementwiseForCpuPassBase< - TransformElementwiseForCpuPass> { - using Base::Base; - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - // Cleanup passes to prepare ops for better clustering. - { - RewritePatternSet patterns(ctx); - populateDuplicateInitOpsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - // clang-format off - patterns.add< - FusionClusterPattern, - FusionClusterPattern, - FusionClusterPattern, - FusionClusterPattern, - FusionClusterPattern - >(ctx, vectorSize, fuseDegenerateReshapes); - // clang-format on - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - patterns.add(tileAndFuse); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createTransformElementwiseForCpuPass(int64_t vectorSize, - bool fuseDegenerateReshapes) { - TransformElementwiseForCpuPassOptions opts; - opts.vectorSize = vectorSize; - opts.fuseDegenerateReshapes = fuseDegenerateReshapes; - return std::make_unique(opts); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_mmt4d_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_mmt4d_for_cpu.cc deleted file mode 100644 index 8616c3fbee1a68..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_mmt4d_for_cpu.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMMMT4DFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -FailureOr tileUsingSCFForAndReplace( - PatternRewriter &rewriter, Operation *op, ArrayRef tilingSizes) { - scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes( - getAsIndexOpFoldResult(rewriter.getContext(), tilingSizes)); - auto tilingResult = scf::tileUsingSCFForOp( - rewriter, cast(op), tilingOptions); - if (failed(tilingResult) || tilingResult->loops.empty()) return failure(); - rewriter.replaceOp(op, tilingResult->replacements); - return tilingResult->tiledOps.front(); -} - -/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the -/// reduction loops. -void splitParallelAndReductionTiles(linalg::LinalgOp op, - SmallVectorImpl ¶llelSizes, - SmallVectorImpl &reductionSizes) { - reductionSizes.assign(parallelSizes.begin(), parallelSizes.end()); - for (auto [index, iteratorType] : - llvm::enumerate(op.getIteratorTypesArray())) { - if (iteratorType == utils::IteratorType::parallel) { - reductionSizes[index] = 0; - } else { - parallelSizes[index] = 0; - } - } -} - -// We tile towards SIMD codegen, so the tile sizes depend on the target -// architecture (vector instruction sizes, etc.). Luckily, this information is -// already captured in linalg.mmt4d during linalg.matmul -> linalg.mmt4d -// lowering phase. It is hardcoded for AVX on x86 for now. -LogicalResult tileMmt4DOp(linalg::Mmt4DOp mmt4dOp, PatternRewriter &rewriter) { - if (hasLabel(mmt4dOp, kTransformedLabel)) { - return rewriter.notifyMatchFailure(mmt4dOp, - "has already been transformed."); - } - - // Compute the tile sizes. Note that at this stage we only do layout tiling. - // Later we might also want to do traversal tiling (only on M and N dims). - auto getL1TileSizes = [&]() -> SmallVector { - auto lhsShape = - mmt4dOp.getInputs()[0].getType().cast().getShape(); - auto rhsShape = - mmt4dOp.getInputs()[1].getType().cast().getShape(); - int64_t m0 = lhsShape[2]; - int64_t n0 = rhsShape[2]; - int64_t k0 = lhsShape[3]; - return {1, 1, 1, m0, n0, k0}; - }; - - SmallVector parallelTileSizes = getL1TileSizes(); - SmallVector reductionTileSizes; - - // Search the number of outer parallel loops to separate them from possible - // inner reduction dimensions. - auto iterTypes = mmt4dOp.getIteratorTypesArray(); - // Make sure to only look at the leading loops for tiling---we will scan - // this array to find the first non-parallel loop later and use that for - // indexing into the tile sizes. - if (iterTypes.size() > parallelTileSizes.size()) { - iterTypes.resize(parallelTileSizes.size()); - } - - splitParallelAndReductionTiles(cast(mmt4dOp.getOperation()), - parallelTileSizes, reductionTileSizes); - - // Tile the parallel loops. - auto tiledOp = tileUsingSCFForAndReplace(rewriter, mmt4dOp.getOperation(), - parallelTileSizes); - if (failed(tiledOp)) return failure(); - mmt4dOp = cast(*tiledOp); - - // Tile the reduction loops. - tiledOp = tileUsingSCFForAndReplace(rewriter, mmt4dOp.getOperation(), - reductionTileSizes); - if (failed(tiledOp)) return failure(); - mmt4dOp = cast(*tiledOp); - - setLabel(mmt4dOp, kTransformedLabel); - return success(); -} - -struct TransformMmt4DForCpuPass - : public impl::TransformMmt4DForCpuPassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp func = getOperation(); - MLIRContext *ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(tileMmt4DOp); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -createTransformMmt4DForCpuPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_pack_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_pack_for_cpu.cc deleted file mode 100644 index 117003b1d9952f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_pack_for_cpu.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMPACKFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -FailureOr tileUsingSCFForAndReplace( - PatternRewriter &rewriter, Operation *op, - const scf::SCFTilingOptions &tilingOptions) { - if (hasLabel(op, kTransformedLabel)) return failure(); - - auto tilingResult = scf::tileUsingSCFForOp( - rewriter, cast(op), tilingOptions); - if (failed(tilingResult) || tilingResult->loops.empty()) return failure(); - - for (Operation *tiledOp : tilingResult->tiledOps) - setLabel(tiledOp, kTransformedLabel); - rewriter.replaceOp(op, tilingResult->replacements); - return tilingResult->tiledOps.front(); -} - -LogicalResult tilePackOp(tensor::PackOp packOp, PatternRewriter &rewriter) { - // Tile tensor.pack ops. - auto packTilingOptions = - scf::SCFTilingOptions().setTileSizeComputationFunction( - [&](OpBuilder b, Operation *op) { - auto numLoops = - cast(op).getLoopIteratorTypes().size(); - SmallVector tiles( - numLoops, getAsIndexOpFoldResult(b.getContext(), 1)); - return tiles; - }); - - return tileUsingSCFForAndReplace(rewriter, packOp, packTilingOptions); -} - -LogicalResult tileUnpackOp(tensor::UnPackOp unpackOp, - PatternRewriter &rewriter) { - // Tile tensor.unpack op. - auto unpackTilingOptions = - scf::SCFTilingOptions().setTileSizeComputationFunction( - [](OpBuilder &builder, Operation *op) { - Location loc = op->getLoc(); - auto unpackOp = cast(op); - auto numLoops = unpackOp.getDestRank(); - auto dimAndTileMapping = unpackOp.getDimAndTileMapping(); - SmallVector tileSizes; - for (size_t i = 0; i < numLoops; ++i) { - if (dimAndTileMapping.count(i)) { - tileSizes.push_back(dimAndTileMapping[i]); - } else { - tileSizes.push_back( - builder.create(loc, unpackOp.getDest(), i) - .getResult()); - } - } - return tileSizes; - }); - - return tileUsingSCFForAndReplace(rewriter, unpackOp, unpackTilingOptions); -} - -struct TransformPackForCpuPass - : public impl::TransformPackForCpuPassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry - .insert(); - tensor::registerTilingInterfaceExternalModels(registry); - tensor::registerInferTypeOpInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp func = getOperation(); - MLIRContext *ctx = &getContext(); - - { - RewritePatternSet patterns(ctx); - patterns.add(tilePackOp); - patterns.add(tileUnpackOp); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } - - // Expanding pack and unpack ops to other primitive tensor/linalg ops and - // canonicalize tiled ops. - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createTransformPackForCpuPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc deleted file mode 100644 index 0c1c878a5f0847..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc +++ /dev/null @@ -1,605 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/peeling/peeling.h" -#include "gml_st/transforms/tiling/tiling.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMREDUCEFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -constexpr llvm::StringRef kReduceCluster = "__reduce_cluster__"; - -struct Reduce1DTileSizes { - int64_t tileSize; - int64_t splitRatio; -}; -using Reduce1DTileSizeComputationFn = std::function; - -SmallVector getParallelDimTileSizes(int64_t reductionDim, - int64_t parallelDimTileSize) { - return reductionDim ? SmallVector{parallelDimTileSize, 0} - : SmallVector{0, parallelDimTileSize}; -} - -SmallVector getReductionDimTileSizes(int64_t reductionDim, - int64_t reductionDimTileSize) { - return reductionDim ? SmallVector{0, reductionDimTileSize} - : SmallVector{reductionDimTileSize, 0}; -} - -LogicalResult validateOp(linalg::ReduceOp reduceOp, PatternRewriter &rewriter, - int64_t expectedRank) { - ArrayRef reduceDimensions = reduceOp.getDimensions(); - if (reduceDimensions.size() != 1) { - return rewriter.notifyMatchFailure( - reduceOp, "expects 1 reduction dimension element. 0 or > 1 received."); - } - SmallVector operands = reduceOp.getDpsInputOperands(); - if (operands.size() != 1) { - return rewriter.notifyMatchFailure(reduceOp, - "expects 1 operand. 0 or > 1 received."); - } - const int64_t operandRank = reduceOp.getRank(operands[0]); - if (operandRank != expectedRank) { - return rewriter.notifyMatchFailure(reduceOp, [&](Diagnostic &diag) { - diag << "expects rank " << expectedRank << ". " << operandRank - << "received."; - }); - } - return success(); -} - -bool reduce1DFusionFilter(Operation *op) { - return isa(op); -} - -bool reduce2DProducerFusionFilter(Operation *op) { - return isa(op); -} - -bool reduce2DConsumerFusionFilter(Operation *op) { - return isa(op); -} - -LogicalResult wrapReduceFusionCluster( - PatternRewriter &rewriter, linalg::ReduceOp reduceOp, - llvm::function_ref producerFilterFn, - llvm::function_ref consumerFilterFn) { - auto fusionCluster = - getFusionCluster(reduceOp, producerFilterFn, consumerFilterFn); - - auto fusionOp = wrapFusionCluster(rewriter, fusionCluster); - if (failed(fusionOp)) return failure(); - - setLabel(reduceOp, kTransformedLabel); - setLabel(*fusionOp, kReduceCluster); - return success(); -} - -LogicalResult fusionClusterPattern(linalg::ReduceOp reduceOp, - PatternRewriter &rewriter) { - if (hasLabel(reduceOp, kTransformedLabel)) return failure(); - - auto fusionOp = reduceOp->getParentOfType(); - if (fusionOp && hasLabel(fusionOp, kReduceCluster)) return failure(); - - const int64_t rank = reduceOp.getRank(reduceOp.getDpsInputOperand(0)); - - if (rank == 1) { - return wrapReduceFusionCluster(rewriter, reduceOp, reduce1DFusionFilter, - [](Operation *) { return false; }); - } - - if (rank == 2) { - return wrapReduceFusionCluster(rewriter, reduceOp, - reduce2DProducerFusionFilter, - reduce2DConsumerFusionFilter); - } - - return failure(); -} - -struct Reduce1DTransformPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - explicit Reduce1DTransformPattern(MLIRContext *context, - Reduce1DTileSizeComputationFn tileSizeFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - tileSizeFn(std::move(tileSizeFn)) {} - - LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, - PatternRewriter &rewriter) const override { - if (hasLabel(reduceOp, kTransformedLabel)) { - return rewriter.notifyMatchFailure(reduceOp, - "has already been transformed."); - } - if (failed(validateOp(reduceOp, rewriter, /*expectedRank=*/1))) - return failure(); - - int64_t inputSize = - reduceOp.getOperand(0).getType().cast().getDimSize(0); - Reduce1DTileSizes tileSizes = tileSizeFn(inputSize); - - // Rewrite as a tree reduction. - FailureOr splitReduce = rewriteReduce1D( - rewriter, reduceOp, tileSizes.tileSize, tileSizes.splitRatio); - if (failed(splitReduce)) { - return rewriter.notifyMatchFailure(reduceOp, - "failed to split reduction dimension"); - } - scf::ForOp mainLoop = splitReduce->mainLoop; - scf::ForOp tailLoop = splitReduce->tailLoop; - - // Fusion. - SmallVector blocks; - if (mainLoop) blocks.push_back(mainLoop.getBody()); - if (tailLoop) blocks.push_back(tailLoop.getBody()); - fuseGreedily(rewriter, blocks, reduce1DFusionFilter); - - // Tiling to 1 and fusion in the tail loop. - if (tailLoop) { - for (auto reduOp : - llvm::to_vector(tailLoop.getBody()->getOps())) { - if (failed(tileUsingSCFForOpAndFuseGreedily( - rewriter, reduOp, getSCFTilingOptions(rewriter.getContext(), 1), - reduce1DFusionFilter))) { - return failure(); - } - } - } - return success(); - } - - private: - struct SplitReduce1DResult { - scf::ForOp mainLoop; - scf::ForOp tailLoop; - linalg::ReduceOp horizontalReduce; - Value result; - }; - // Split reduction tensor -> tensor into - // * scf.for that reduces - // tensor -> tensor - // * horizontal reduce tensor -> tensor - // * scf.for that reduces the remaining M elements. - FailureOr rewriteReduce1D(PatternRewriter &rewriter, - linalg::ReduceOp reduceOp, - int64_t tileSize, - int64_t splitRatio) const { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(reduceOp); - - // 0-d tensor with the neutral elements. - auto fillOp = reduceOp.getInits().front().getDefiningOp(); - if (!fillOp) - return rewriter.notifyMatchFailure(reduceOp, - "init not defined by fill op"); - auto neutralValue = fillOp.value(); - - // Constants. - Location loc = reduceOp.getLoc(); - Value zero = rewriter.create(loc, 0); - Value tileSizeValue = - rewriter.create(loc, tileSize); - - // Input. - Value input = reduceOp.getInputs().front(); - FailureOr inputSizeOfr = - tensor::getMixedSize(rewriter, loc, input, 0); - if (failed(inputSizeOfr)) - return rewriter.notifyMatchFailure(reduceOp, "cannot get input size"); - - // Loop boundaries. - // tileableBound = inputSize - inputSize % tileSize - // remainderSize = inputSize - tileableBound - OpFoldResult tileableBoundOfr = - getTileableBound(rewriter, loc, *inputSizeOfr, tileSize); - Value tileableBoundValue = - getValueOrCreateConstantIndexOp(rewriter, loc, tileableBoundOfr); - - OpFoldResult remainderSize = - getRemainderSize(rewriter, loc, tileableBoundOfr, *inputSizeOfr); - - // Create tensor with neutral elements for tile loop - // init. - Type elementType = neutralValue.getType(); - Value emptyVector = rewriter.create( - loc, llvm::ArrayRef({splitRatio}), elementType); - Value filledVector = - rewriter.create(loc, neutralValue, emptyVector) - .getResult(0); - - // Create a tiled loop - SplitReduce1DResult splitResult; - splitResult.result = fillOp.getResult(0); - - std::optional tileableBoundConstant = - getConstantIntValue(tileableBoundOfr); - if (!tileableBoundConstant || tileableBoundConstant != 0) { - auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv, - ValueRange inits) { - // Tile input as tensor and reshape into - // tensor<(TILE_SIZE/SPLIT_RATIO)xSPLIT_RATIOxELEM_TYPE>. - Value inputSlice = tileAndReshapeInput(b, loc, iv, input, elementType, - tileSize, splitRatio); - - tensor::ExtractSliceOp initSlice = - create1DSlice(b, loc, inits.front(), b.getIndexAttr(0), - b.getIndexAttr(splitRatio)); - - // Create `linalg.reduce` to combine - // `tensor<(TILE_SIZE/SPLIT_RATIO)xSPLIT_RATIOxELEM_TYPE> input with the - // `tensor` accumulator. - auto tiledReduceOp = b.create( - loc, ValueRange{inputSlice}, ValueRange{initSlice}, - /*dimensions=*/SmallVector{0}, - /*bodyBuilder=*/nullptr, linalg::getPrunedAttributeList(reduceOp)); - OpBuilder::InsertionGuard g(rewriter); - Region ®ion = tiledReduceOp.getRegion(); - rewriter.cloneRegionBefore(reduceOp.getRegion(), region, region.end()); - setLabel(tiledReduceOp, kTransformedLabel); - - b.create(loc, tiledReduceOp.getResults()); - }; - - splitResult.mainLoop = rewriter.create( - loc, zero, tileableBoundValue, tileSizeValue, filledVector, - tiledLoopBodyBuilder); - setLabel(splitResult.mainLoop, kPerfectlyTiledLoopLabel); - - // Create `linalg.reduce` from tensor to - // tensor. - splitResult.horizontalReduce = - cloneReduceOp(rewriter, reduceOp, splitResult.mainLoop.getResult(0), - reduceOp.getInits().front()); - splitResult.result = splitResult.horizontalReduce.getResult(0); - } - - // Combine `horizontal reduce` with the tail of the input. The tail is - // always smaller than TILE_SIZE. - std::optional tripCount = constantTripCount( - tileableBoundOfr, *inputSizeOfr, rewriter.getIndexAttr(tileSize)); - scf::ForOp remainderLoop; - if (!tripCount || *tripCount > 0) { - auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv, - ValueRange inits) { - Value inputSlice = create1DSlice(b, loc, input, iv, remainderSize); - - Value initSlice = b.create( - loc, inits.front(), /*offsets=*/SmallVector{}, - /*sizes=*/SmallVector{}, - /*strides=*/SmallVector{}); - - linalg::ReduceOp newReduce = - cloneReduceOp(b, reduceOp, inputSlice, initSlice); - b.create(loc, newReduce->getResults()); - }; - splitResult.tailLoop = rewriter.create( - loc, tileableBoundValue, - getValueOrCreateConstantIndexOp(rewriter, loc, *inputSizeOfr), - tileSizeValue, splitResult.result, remainderLoopBodyBuilder); - splitResult.result = splitResult.tailLoop.getResult(0); - } - rewriter.replaceOp(reduceOp, splitResult.result); - return splitResult; - } - - OpFoldResult getTileableBound(OpBuilder &b, Location loc, - OpFoldResult inputSizeOfr, - int64_t tileSize) const { - if (tileSize == 1) return inputSizeOfr; - - auto inputSizeInt = getConstantIntValue(inputSizeOfr); - if (inputSizeInt && *inputSizeInt < tileSize) return b.getIndexAttr(0); - - AffineExpr sym0; - bindSymbols(b.getContext(), sym0); - - auto modMap = AffineMap::get(0, 1, {sym0 - sym0 % tileSize}); - return affine::makeComposedFoldedAffineApply(b, loc, modMap, inputSizeOfr); - } - - OpFoldResult getRemainderSize(OpBuilder &b, Location loc, - OpFoldResult tileableBoundOfr, - OpFoldResult inputSize) const { - AffineExpr sym0, sym1; - bindSymbols(b.getContext(), sym0, sym1); - auto diffMap = AffineMap::get(0, 2, {sym1 - sym0}); - return affine::makeComposedFoldedAffineApply(b, loc, diffMap, - {tileableBoundOfr, inputSize}); - } - - tensor::ExtractSliceOp create1DSlice(OpBuilder &b, Location loc, Value source, - OpFoldResult offset, - OpFoldResult size) const { - SmallVector offsets{offset}; - SmallVector sizes{size}; - SmallVector strides{b.getIndexAttr(1)}; - - return b.create(loc, source, offsets, sizes, - strides); - } - - linalg::ReduceOp cloneReduceOp(OpBuilder &b, linalg::ReduceOp reduceOp, - ValueRange newInputs, Value newInit) const { - IRMapping bvm; - bvm.map(reduceOp.getInputs(), newInputs); - bvm.map(reduceOp.getInits(), ValueRange{newInit}); - - auto *newReduceOp = b.clone(*reduceOp.getOperation(), bvm); - setLabel(newReduceOp, kTransformedLabel); - return cast(newReduceOp); - } - - Value tileAndReshapeInput(OpBuilder &b, Location loc, Value iv, Value input, - Type elementType, int64_t tileSize, - int64_t splitRatio) const { - Value inputSlice = - create1DSlice(b, loc, input, iv, b.getIndexAttr(tileSize)); - - auto reshapeType = - RankedTensorType::get({tileSize / splitRatio, splitRatio}, elementType); - SmallVector ri = {{0, 1}}; - return b.create(loc, reshapeType, inputSlice, ri); - } - - Reduce1DTileSizeComputationFn tileSizeFn; -}; - -/// Pattern to tile `linalg.reduce` and fuse `linalg.fill` into generated -/// `scf.forall`. -struct Reduce2DTransformPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - explicit Reduce2DTransformPattern(MLIRContext *context, - int64_t parallelDimTileSize = 4, - int64_t reductionDimTileSize = 2, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - parallelDimTileSize(parallelDimTileSize), - reductionDimTileSize(reductionDimTileSize) {} - - LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, - PatternRewriter &rewriter) const override { - if (reduceOp.getDimensions().size() != 1) return failure(); - int64_t reductionDim = reduceOp.getDimensions()[0]; - - if (hasLabel(reduceOp, kTransformedLabel)) { - return rewriter.notifyMatchFailure(reduceOp, - "has already been transformed."); - } - if (failed(validateOp(reduceOp, rewriter, /*expectedRank=*/2))) - return failure(); - - auto fusionOp = reduceOp->getParentOfType(); - auto *tilingRoot = fusionOp.getTerminator().getValues()[0].getDefiningOp(); - - // First level tiling: parallel dimension. - auto parallelDimsTileSizes = - isa(tilingRoot) - ? getParallelDimTileSizes(reduceOp.getDimensions()[0], - parallelDimTileSize) - : SmallVector{parallelDimTileSize}; - auto tilingParallelDimsResult = tileUsingSCFForallOpAndFuseGreedily( - rewriter, tilingRoot, - getSCFTilingOptions(rewriter.getContext(), parallelDimsTileSizes)); - if (failed(tilingParallelDimsResult)) return failure(); - - auto peeledParallelLoop = - peelAllLoops(tilingParallelDimsResult->loop, rewriter); - - auto filterFn = [&](Operation *op) { - return reduce2DProducerFusionFilter(op) || isa(op); - }; - - // Process main parallel loop. - scf::ForallOp mainParallelLoop = peeledParallelLoop.mainLoop; - if (mainParallelLoop) { - auto tiledReduceOp = - *mainParallelLoop.getBody()->getOps().begin(); - if (failed(tileAndPeelReductionDim(rewriter, tiledReduceOp, reductionDim, - filterFn))) { - return failure(); - } - } - - // Process tail parallel loop. - scf::ForallOp tailParallelLoop = peeledParallelLoop.tailLoops.size() == 1 - ? peeledParallelLoop.tailLoops.front() - : nullptr; - if (tailParallelLoop) { - Value yieldedTensor = - getYieldedValues(tailParallelLoop.getTerminator()).front(); - auto *definingOp = yieldedTensor.getDefiningOp(); - if (!definingOp) return failure(); - - auto opts = - getSCFTilingOptions(rewriter.getContext(), - SmallVector(definingOp->getResult(0) - .getType() - .cast() - .getRank(), - 1)); - auto parallelDimTilingOpts = - isa(definingOp) - ? getSCFTilingOptions(rewriter.getContext(), - getParallelDimTileSizes(reductionDim, 1)) - : getSCFTilingOptions(rewriter.getContext(), 1); - auto parallelDimTilingResult = tileUsingSCFForallOpAndFuseGreedily( - rewriter, definingOp, parallelDimTilingOpts, filterFn); - if (failed(parallelDimTilingResult)) return failure(); - - for (auto tiledReduceOp : - llvm::to_vector(parallelDimTilingResult->loop.getBody() - ->getOps())) { - auto reductionDimTilingResult = tileUsingSCFForOpAndFuseGreedily( - rewriter, tiledReduceOp, - getSCFTilingOptions(rewriter.getContext(), - getReductionDimTileSizes(reductionDim, 1)), - reduce2DProducerFusionFilter); - if (failed(reductionDimTilingResult)) return failure(); - } - } - - return success(); - } - - private: - LogicalResult tileAndPeelReductionDim( - PatternRewriter &rewriter, linalg::ReduceOp reduceOp, - int64_t reductionDim, - llvm::function_ref producerFilterFn) const { - FailureOr reductionDimTilingResult = - tileUsingSCFForOpAndFuseGreedily( - rewriter, reduceOp, - getSCFTilingOptions( - rewriter.getContext(), - getReductionDimTileSizes(reductionDim, reductionDimTileSize)), - producerFilterFn); - if (failed(reductionDimTilingResult)) return failure(); - - SCFForPeelingResult reductionDimPeelingResult = peelSCFForOp( - rewriter, cast(reductionDimTilingResult->loops.front())); - if (reductionDimPeelingResult.mainLoop) { - setLabel(reductionDimPeelingResult.mainLoop, kPerfectlyTiledLoopLabel); - } - if (reductionDimPeelingResult.tailLoop) { - for (auto reduOp : - llvm::to_vector(reductionDimPeelingResult.tailLoop.getBody() - ->getOps())) { - // Column reductions have to be tiled even further, otherwise we - // would get vector.multi_reduction 4x1 -> 1, which is expensive. - // Potentially, we could lower it to a horizontal add. - if (reductionDim == 0) { - auto parallelDimSizeOneTilingResult = - tileUsingSCFForOpAndFuseGreedily( - rewriter, reduOp, - getSCFTilingOptions(rewriter.getContext(), - getParallelDimTileSizes(reductionDim, 1)), - producerFilterFn); - if (failed(parallelDimSizeOneTilingResult)) return failure(); - - reduOp = cast( - parallelDimSizeOneTilingResult->tiledOps.front()); - } - if (failed(tileUsingSCFForOpAndFuseGreedily( - rewriter, reduOp, - getSCFTilingOptions(rewriter.getContext(), - getReductionDimTileSizes(reductionDim, 1)), - producerFilterFn))) { - return failure(); - } - } - } - return success(); - } - - int64_t parallelDimTileSize; - int64_t reductionDimTileSize; -}; - -struct TransformReduceForCpuPass - : public impl::TransformReduceForCpuPassBase { - using Base::Base; - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - // Cleanup passes to prepare ops for better clustering. - { - RewritePatternSet patterns(ctx); - populateDuplicateInitOpsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - patterns.add(fusionClusterPattern); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns(ctx); - Reduce1DTileSizeComputationFn tilingHeuristic; - if (enableHeuristic) { - tilingHeuristic = [](int64_t size) { - if (!ShapedType::isDynamic(size) && size > 96) - return Reduce1DTileSizes{32, 8}; - return Reduce1DTileSizes{8, 8}; - }; - } else { - tilingHeuristic = [=](int64_t) { - return Reduce1DTileSizes{tileSize1D, splitRatio1D}; - }; - } - patterns.add(ctx, std::move(tilingHeuristic)); - patterns.add(ctx, parallelDimTileSize2D, - reductionDimTileSize2D); - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr createTransformReduceForCpuPass( - const TransformReduceForCpuPassOptions &opts) { - return std::make_unique(opts); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc deleted file mode 100644 index b69ccee201a7e3..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/scalarization/scalarization.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMSCATTERFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -struct TileScatterPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(thlo::ScatterOp scatterOp, - PatternRewriter &rewriter) const override { - if (hasLabel(scatterOp, kTransformedLabel)) return failure(); - - // Tile everything to points and fuse. - scf::SCFTilingOptions opts; - opts.setTileSizes(SmallVector( - scatterOp.getLoopIteratorTypes().size(), - getAsIndexOpFoldResult(rewriter.getContext(), 1))); - - auto fuseFilterFn = [](Operation *op) { - return isa(op); - }; - auto tilingResult = tileUsingSCFForOpAndFuseGreedily(rewriter, scatterOp, - opts, fuseFilterFn); - - if (failed(tilingResult)) return failure(); - - assert(tilingResult->tiledOps.size() == 1 && - "Tiling of thlo.scatter should generate a single op"); - - // Scalarize scatter op. - scatterOp = cast(tilingResult->tiledOps.front()); - FailureOr ifOpOr = rewriteScatterOpAsIfOp(scatterOp, rewriter); - if (failed(ifOpOr)) return failure(); - - // Fuse into `then` block. - fuseGreedily(rewriter, &ifOpOr->getThenRegion().front(), fuseFilterFn); - - // Remove tiling label to continue generating code inside the region. - ifOpOr->walk([](Operation *op) { removeLabel(op, kTransformedLabel); }); - return success(); - } -}; - -struct TransformScatterForCpuPass - : public impl::TransformScatterForCpuPassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -createTransformScatterForCpuPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc deleted file mode 100644 index 04cf7f87181bac..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc +++ /dev/null @@ -1,796 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/fusion/fusion.h" - -#include -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/tiling/tiling.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/utils/tensor_utils.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetOperations.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/InliningUtils.h" -#include "mlir/Transforms/RegionUtils.h" -#include "mlir/Transforms/TopologicalSortUtils.h" - -namespace mlir::gml_st { -namespace { - -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const Operation* opC) { - return OperationEquivalence::computeHash( - const_cast(opC), - /*hashOperands=*/OperationEquivalence::directHashValue, - /*hashResults=*/OperationEquivalence::ignoreHashValue, - OperationEquivalence::IgnoreLocations); - } - static bool isEqual(const Operation* lhsC, const Operation* rhsC) { - auto* lhs = const_cast(lhsC); - auto* rhs = const_cast(rhsC); - if (lhs == rhs) return true; - if (lhs == getTombstoneKey() || lhs == getEmptyKey() || - rhs == getTombstoneKey() || rhs == getEmptyKey()) - return false; - return OperationEquivalence::isEquivalentTo( - const_cast(lhsC), const_cast(rhsC), - OperationEquivalence::IgnoreLocations); - } -}; - -template -void eliminateEqualOps(PatternRewriter& rewriter, Block& block) { - llvm::DenseMap uniqueOps; - - for (auto op : llvm::make_early_inc_range(block.getOps())) { - if (auto* equivalentOp = uniqueOps.lookup(op)) { - rewriter.replaceOp(op, equivalentOp->getResults()); - } else { - uniqueOps.insert(std::make_pair(op, op)); - } - } -} - -void eliminateTriviallyDeadUsers(PatternRewriter& rewriter, Operation* op) { - for (auto* user : - DenseSet(op->getUsers().begin(), op->getUsers().end())) { - if (isOpTriviallyDead(user)) rewriter.eraseOp(user); - } -} - -void reifyDimOp(PatternRewriter& rewriter, tensor::DimOp dimOp) { - auto dimValue = dimOp.getSource().template dyn_cast(); - if (!dimValue) return; - - std::optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) return; - - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), - reifiedResultShapes))) { - return; - } - - if (reifiedResultShapes.size() != dimValue.getOwner()->getNumResults()) - return; - - unsigned resultNumber = dimValue.getResultNumber(); - auto sourceType = dimValue.getType().dyn_cast(); - if (reifiedResultShapes[resultNumber].size() != - static_cast(sourceType.getRank())) - return; - - rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), - reifiedResultShapes[resultNumber][*dimIndex])); -} - -void reifyDimOpsUsers(PatternRewriter& rewriter, Operation* op) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(op); - - for (auto* user : llvm::make_early_inc_range(op->getUsers())) { - auto dimOp = dyn_cast(user); - if (dimOp) reifyDimOp(rewriter, dimOp); - } -} - -LogicalResult fuseTensorCast(PatternRewriter& rewriter, tensor::CastOp castOp, - tensor::ExtractSliceOp sliceOp) { - if (!tensor::canFoldIntoConsumerOp(castOp)) return failure(); - - /// Deduce the type of the result to use for the canonicalized operation. - RankedTensorType resultType = - tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - sliceOp.getType().getRank(), sliceOp.getSourceType(), - sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), - sliceOp.getMixedStrides()); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(sliceOp); - Value newSlice = rewriter.create( - sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(), - sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), - sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); - rewriter.replaceOpWithNewOp(sliceOp, sliceOp.getType(), - newSlice); - return success(); -} - -// TODO(vuson): maybe overload this function instead of templating it. -// Fuse a reshape op being used by an extract_slice op (inside a loop) into the -// loop by reversing the order of these two ops (and fixing their -// operands/result accordingly). For example, fusing a tensor.collapse_shape: -// -// %1 = tensor.collapse_shape %0 [[0], [1, 2]] : -// tensor<10x10x1xf32> into tensor<10x10xf32> -// some_scf_loop (%a1, %a2) = (0, 0) to (10, 10) step (1, 8) ... -// %3 = tensor.extract_slice %1[%a1, %a2] [1, 8] [1, 1] : -// tensor<10x10xf32> to tensor<1x?xf32> -// -// into -// -// some_scf_loop (%a1, %a2) = (0, 0) to (10, 10) step (1, 8) ... -// %3 = tensor.extract_slice %0[%a1, %a2, 0] [1, 8, 1] [1, 1, 1] : -// tensor<10x10x1xf32> to tensor<1x?x1xf32> -// %1 = tensor.collapse_shape %3 [[0], [1, 2]] : -// tensor<1x?x1xf32> into tensor<1x?xf32> -// -// This also works for tensor.expand_shape instead of tensor.collapse_shape: -// -// %1 = tensor.expand_shape %0 [[0], [1, 2]] : -// tensor<10x10xf32> into tensor<10x10x1xf32> -// some_scf_loop (%a1, %a2) = (0, 0) to (10, 10) step (1, 8) ... -// %3 = tensor.extract_slice %1[%a1, %a2, 0] [1, 8, 1] [1, 1, 1] : -// tensor<10x10x1xf32> to tensor<1x?x1xf32> -// -// into -// -// some_scf_loop (%a1, %a2) = (0, 0) to (10, 10) step (1, 8) ... -// %3 = tensor.extract_slice %0[%a1, %a2] [1, 8] [1, 1] : -// tensor<10x10xf32> to tensor<1x?xf32> -// %1 = tensor.expand_shape %2 [[0], [1, 2]] : -// tensor<1x?xf32> into tensor<1x?x1xf32> -template -LogicalResult fuseTensorReshape(PatternRewriter& rewriter, - TensorReshapeOp reshapeOp, - tensor::ExtractSliceOp sliceOp) { - if (!isDegenerateReshapeOp(reshapeOp)) return failure(); - - auto newSliceSrcType = reshapeOp.getSrcType(); - llvm::ArrayRef newSliceSrcShape = newSliceSrcType.getShape(); - auto newSliceRank = newSliceSrcType.getRank(); - // If the source type of reshape op is a rank-0 tensor, there will be no - // extract_slice possible from that source value, let's bail out. - if (newSliceRank == 0) return failure(); - - auto one = rewriter.getIndexAttr(1); - auto zero = rewriter.getIndexAttr(0); - SmallVector newOffsets(newSliceRank, zero); - SmallVector newSizes(newSliceRank, one); - SmallVector newStrides(newSliceRank, one); - - llvm::ArrayRef sliceSrcShape = sliceOp.getSourceType().getShape(); - auto reassociation = reshapeOp.getReassociationIndices(); - constexpr bool isExpanding = - std::is_same::value; - llvm::ArrayRef shape = - isExpanding ? sliceSrcShape : newSliceSrcShape; - // For each reassociation indices, a degenerate reshape op only has at most - // 1 non-unit-dimension. If there is none, it means the source shape already - // has some unit-dimensions (e.g. tensor<1x1x1xf32> collapsed into - // tensor<1xf32>) - assert( - static_cast( - llvm::count_if(reassociation, - [&](auto indices) { - return llvm::count_if(indices, [&](int64_t idx) { - return shape[idx] != 1; - }) <= 1; - })) == reassociation.size() && - "Degenerate reshape op should only have at most 1 non-unit dimension for " - "each reassociation indices"); - for (const auto& [enumIdx, indices] : llvm::enumerate(reassociation)) { - auto findIt = - llvm::find_if(indices, [&](int64_t idx) { return shape[idx] != 1; }); - // No non-unit dimension, which means the source shape already has some - // unit-dimensions. The default values for offset/size/stride (0/1/1) should - // be usable. Skip updating offset/size/stride for this dimension. - if (findIt == indices.end()) continue; - auto newIdx = isExpanding ? enumIdx : *findIt; - auto idx = isExpanding ? *findIt : enumIdx; - newOffsets[newIdx] = sliceOp.getMixedOffsets()[idx]; - newSizes[newIdx] = sliceOp.getMixedSizes()[idx]; - newStrides[newIdx] = sliceOp.getMixedStrides()[idx]; - } - - RankedTensorType newSliceResultType = - tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - newSliceRank, newSliceSrcType, newOffsets, newSizes, newStrides); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(sliceOp); - auto newSlice = rewriter.create( - sliceOp.getLoc(), newSliceResultType, reshapeOp.getSrc(), newOffsets, - newSizes, newStrides); - - rewriter.replaceOpWithNewOp(sliceOp, sliceOp.getResultType(), - newSlice, reassociation); - return success(); -} - -// Checks that there is at most one user in each given block. -bool atMostOneUserPerBlock(ArrayRef users, - ArrayRef blocks) { - if (users.size() == 1) return true; - if (users.size() > blocks.size()) return false; - - llvm::SmallSetVector blocksWithUsers; - - Block* funcBlock = &users.front() - ->getParentWithTrait() - ->getRegion(0) - .front(); - // Return false if there two users in a block. - for (Operation* user : users) { - Block* block = user->getBlock(); - while (block != funcBlock) { - if (llvm::is_contained(blocks, block)) { - if (!blocksWithUsers.insert(block)) return false; - } - block = block->getParentOp()->getBlock(); - } - } - return true; -} - -DenseSet getFusionCandidates( - Region& region, llvm::function_ref filterFn) { - DenseSet fusionCandidates; - visitUsedValuesDefinedAbove(region, [&](OpOperand* operand) { - auto* fusionCandidate = operand->get().getDefiningOp(); - // Do not fuse if there is no defining op. Of example, if it's an - // extract_slice from a function argument. - if (!fusionCandidate) return; - - // Filter candidates that we don't want to fuse. - if (filterFn && !filterFn(fusionCandidate)) return; - - // Check that the candidate doesn't have users that will block fusion. - if (!llvm::all_of(fusionCandidate->getUsers(), [](Operation* op) { - // Fusion candidates can only be fused into tensor.extract_slice or - // tensor.extract. - return isa(op) || - // tensor.dim is pushed 'above' the fusion candidate. - isa(op) || - // Trivially dead ops will be removed. - isOpTriviallyDead(op); - })) { - return; - } - fusionCandidates.insert(fusionCandidate); - }); - return fusionCandidates; -} - -LogicalResult fuseIntoUser(PatternRewriter& rewriter, - Operation* fusionCandidate, - Operation* candidateUser) { - // If the user of the fusion candidate is `tensor.extract_slice`, we use - // TilingInterface to rewrite `tensor.extract_slice(fusionOp)` into - // `tiledFusionOp(tensor.extract_slice)`. - if (auto extractSliceOp = dyn_cast(candidateUser)) { - if (auto castOp = dyn_cast(fusionCandidate)) { - return fuseTensorCast(rewriter, castOp, extractSliceOp); - } - if (auto collapseShapeOp = - dyn_cast(fusionCandidate)) { - return fuseTensorReshape(rewriter, collapseShapeOp, extractSliceOp); - } - if (auto expandShapeOp = dyn_cast(fusionCandidate)) { - return fuseTensorReshape(rewriter, expandShapeOp, extractSliceOp); - } - auto fusedOp = fuse(rewriter, extractSliceOp); - if (succeeded(fusedOp)) { - setLabel(*fusedOp, kTransformedLabel); - return success(); - } - return failure(); - } - - // TODO(shyshkov): Implement fusion into `tensor.extract` using - // TilingInterface. - if (auto extractOp = dyn_cast(candidateUser)) { - return failure(); - } - - // Otherwise, the fusion candidate op is moved inside of the region. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(candidateUser); - Operation* clonedCandidate = rewriter.clone(*fusionCandidate); - rewriter.replaceOp(fusionCandidate, clonedCandidate->getResults()); - return success(); -} - -LogicalResult fuseIntoUsers(PatternRewriter& rewriter, - Operation* fusionCandidate, - ArrayRef fusionCandidateUsers) { - for (Operation* candidateUser : fusionCandidateUsers) { - if (failed(fuseIntoUser(rewriter, fusionCandidate, candidateUser))) - return failure(); - } - return success(); -} - -// Iterates over tensor::ExtractSliceOp inside the block, finds a suitable -// candidate for fusion and fuses it. The fusion candidate should satisfy the -// filter function and not have uses outside of the block. Fails if nothing -// can be fused. -LogicalResult fuseGreedilyOneOpIntoBlock( - PatternRewriter& rewriter, ArrayRef blocks, - llvm::function_ref filterFn) { - if (blocks.empty()) return failure(); - - // Ad-hoc CSE to eliminate duplicate MatrializeOp that could have been added - // after previous fusions. Running the whole CSE pass would be to expensive - // here and unnecessary. Without removing those duplicate, some ops will be - // fused multiple times resulting in exponential code growth. - DenseSet fusionCandidates; - for (auto [index, block] : llvm::enumerate(blocks)) { - eliminateEqualOps(rewriter, *block); - - if (index == 0) { - fusionCandidates = getFusionCandidates(*block->getParent(), filterFn); - continue; - } - llvm::set_intersect(fusionCandidates, - getFusionCandidates(*block->getParent(), filterFn)); - } - - for (Operation* fusionCandidate : fusionCandidates) { - // Ad-hoc DCE to trim the fusion candidate from dead users that could have - // been added in the previous fusion cycles. Normally those ops would be - // garbage collected after the pattern rewriter driver finished working, - // but here it requires manual handling. - eliminateTriviallyDeadUsers(rewriter, fusionCandidate); - - // Push tensor.dim ops 'above' the fusion candidate. This is normally done - // by canonicalization passes, but running the whole canonicalization - // pipeline here is too expensive. - reifyDimOpsUsers(rewriter, fusionCandidate); - - // After the previous steps, there should be at most one user of the - // fusion candidate per block. Otherwise this candidate should not be fused. - // We always want to fuse linalg.fill. - auto fusionCandidateUsers = llvm::to_vector(fusionCandidate->getUsers()); - if (!atMostOneUserPerBlock(fusionCandidateUsers, blocks)) { - continue; - } - if (succeeded( - fuseIntoUsers(rewriter, fusionCandidate, fusionCandidateUsers))) { - return success(); - } - } - return failure(); -} - -FailureOr createFusedOp(PatternRewriter& rewriter, - tensor::ExtractSliceOp extractSliceOp) { - Value src = extractSliceOp.getSource(); - if (!src) return failure(); - auto tileableOp = src.getDefiningOp(); - if (!tileableOp) { - return rewriter.notifyMatchFailure( - extractSliceOp, - "expected source to be defined by tiling interface op "); - } - - SmallVector offsets = extractSliceOp.getMixedOffsets(); - SmallVector sizes = extractSliceOp.getMixedSizes(); - - // Tile the producer. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(extractSliceOp); - FailureOr tiledProducer = tileableOp.generateResultTileValue( - rewriter, /*resultNumber=*/0, offsets, sizes); - if (failed(tiledProducer)) { - return rewriter.notifyMatchFailure(tileableOp, - "failed to tile the producer"); - } - - return tiledProducer; -} - -void fuseFillOpsIntoForallOp(PatternRewriter& rewriter, - scf::ForallOp parallelOp) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointToStart(parallelOp.getBody()); - for (OpOperand& output : - parallelOp->getOpOperands().take_back(parallelOp.getNumResults())) { - auto fillOp = output.get().getDefiningOp(); - if (!fillOp) continue; - - // Clone `linalg.fill` op inside the loop, update the uses of bbArg. - BlockArgument regionOutputArg = parallelOp.getTiedBlockArgument(&output); - auto clonedFill = cast( - mlir::clone(rewriter, fillOp, fillOp.getResultTypes(), - {fillOp.value(), regionOutputArg})); - - output.set(fillOp.output()); - setLabel(clonedFill, kTransformedLabel); - - SmallVector sliceOps; - regionOutputArg.replaceUsesWithIf( - clonedFill.getResult(0), [&](OpOperand& operand) { - Operation* owner = operand.getOwner(); - if (auto sliceOp = dyn_cast_or_null(owner)) - sliceOps.push_back(sliceOp); - return owner != clonedFill && - !isa(owner) && - owner->getParentOfType() == parallelOp; - }); - - // Use standard fusion logic to swap extract_slice(fill) -> - // fill(extract_slice). - for (tensor::ExtractSliceOp sliceOp : sliceOps) - (void)fuse(rewriter, sliceOp); - } -} - -// Finds the source of the operand. It could be a tensor.empty, a region arg or -// an op outside of the cluster. -Value getTiedSourceOp(PatternRewriter& rewriter, OpOperand* operand, - const FusionCluster& fusionCluster) { - auto* definingOp = operand->get().getDefiningOp(); - if (!definingOp) return operand->get(); - - // A tensor.empty used tied to fusion cluster result should not be fused, so - // bufferization can properly handle allocations. If the same tensor.empty is - // used in other ops for temporary result, it should be fused. Copied op is - // not in the cluster, so it will not be fused. - if (auto emptyOp = dyn_cast(definingOp)) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(emptyOp); - - auto newEmptyOp = cast(rewriter.clone(*emptyOp)); - operand->set(newEmptyOp); - return newEmptyOp; - } - - // Source of the operand is outside of the cluster, so pass it as an argument. - if (!llvm::is_contained(fusionCluster.operations, definingOp)) { - return operand->get(); - } - - // Source of the operand is another DPS op from the cluster. Look higher in - // the chain. - if (auto dstStyleOp = dyn_cast(definingOp)) { - OpOperand* tiedOperand = - dstStyleOp.getTiedOpOperand(operand->get().dyn_cast()); - return getTiedSourceOp(rewriter, tiedOperand, fusionCluster); - } - - return operand->get(); -} - -SmallVector getRootOpInitOperands(PatternRewriter& rewriter, - const FusionCluster& fusionCluster) { - auto dstStyleOp = dyn_cast(fusionCluster.root); - if (!dstStyleOp) return {}; - - SmallVector initOperands; - - for (OpOperand& operand : dstStyleOp.getDpsInitsMutable()) { - initOperands.push_back(getTiedSourceOp(rewriter, &operand, fusionCluster)); - } - - return initOperands; -} - -} // namespace - -FailureOr fuse(PatternRewriter& rewriter, - tensor::ExtractSliceOp extractSliceOp) { - Location loc = extractSliceOp.getLoc(); - FailureOr fusedOr = createFusedOp(rewriter, extractSliceOp); - if (failed(fusedOr)) return failure(); // Match failure already notified. - - // Insert cast if needed. - Value fused = fusedOr->tiledOps.front()->getResult(0); - if (fused.getType() != extractSliceOp.getType()) { - // The result should be a tensor, cast it to the correct shape - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(fused.getDefiningOp()); - fused = - rewriter.create(loc, extractSliceOp.getType(), fused); - } - - rewriter.replaceOp(extractSliceOp, fused); - return fused.getDefiningOp(); -} - -void fuseGreedily(PatternRewriter& rewriter, ArrayRef blocks, - llvm::function_ref filterFn) { - while (succeeded(fuseGreedilyOneOpIntoBlock(rewriter, blocks, filterFn))) - ; -} - -// Cluster producers and consumers around the root op. -FusionCluster getFusionCluster( - Operation* op, llvm::function_ref producerFilterFn, - llvm::function_ref consumerFilterFn) { - // Find a chain of users and use the last one as a root of cluster. - SetVector resultOps; - - Operation* rootOp = op; - while (true) { - auto users = llvm::to_vector(rootOp->getUsers()); - - if (users.size() != 1) break; - - if (!consumerFilterFn(users[0])) break; - resultOps.insert(rootOp); - - rootOp = users[0]; - } - resultOps.insert(rootOp); - - // Run DFS to find all ops that satisfy producerFilterFn. - SmallVector remainingProducers; - for (Value operand : op->getOperands()) - remainingProducers.push_back(operand.getDefiningOp()); - - while (!remainingProducers.empty()) { - Operation* curOp = remainingProducers.pop_back_val(); - if (!curOp || resultOps.contains(curOp)) continue; - if (!llvm::all_of(curOp->getUsers(), - [&](Operation* op) { return resultOps.contains(op); })) { - continue; - } - - if (curOp == op || producerFilterFn(curOp)) { - resultOps.insert(curOp); - for (Value operand : curOp->getOperands()) - remainingProducers.push_back(operand.getDefiningOp()); - } - } - return {resultOps, rootOp, {}}; -} - -FailureOr tileUsingSCFForallOpAndFuseGreedily( - PatternRewriter& rewriter, Operation* op, const scf::SCFTilingOptions& opts, - llvm::function_ref fuseFilterFn) { - auto tilingResult = - tileUsingSCFForallOp(rewriter, cast(op), opts); - if (failed(tilingResult)) return failure(); - - for (Operation* tiledOp : tilingResult->tiledOps) - setLabel(tiledOp, kTransformedLabel); - - // If tiling created an `scf.forall` loop, we fuse. - if (tilingResult->loop != nullptr) { - rewriter.replaceOp(op, tilingResult->loop->getResults()); - // Fuse ops into the loop. - fuseGreedily(rewriter, tilingResult->tiledOps.front()->getBlock(), - fuseFilterFn); - fuseFillOpsIntoForallOp(rewriter, tilingResult->loop); - } - return tilingResult; -} - -FailureOr tileUsingSCFForOpAndFuseGreedily( - PatternRewriter& rewriter, Operation* op, const scf::SCFTilingOptions& opts, - llvm::function_ref fuseFilterFn) { - auto tilingResult = - scf::tileUsingSCFForOp(rewriter, cast(op), opts); - if (failed(tilingResult)) return failure(); - rewriter.replaceOp(op, tilingResult->replacements); - - for (Operation* tiledOp : tilingResult->tiledOps) - setLabel(tiledOp, kTransformedLabel); - - // If tiling created an `scf.for` loop nest, we fuse. - if (!tilingResult->loops.empty()) { - scf::ForOp innerLoop = cast(tilingResult->loops.back()); - fuseGreedily(rewriter, innerLoop.getBody(), fuseFilterFn); - } - return tilingResult; -} - -LogicalResult tilePeeledOpsToScalars( - PatternRewriter& rewriter, const GmlStPeelingResult& peelingResult, - llvm::function_ref fuseFilterFn) { - for (scf::ForallOp peeledLoop : peelingResult.tailLoops) { - SmallVector yieldedTensors = - getYieldedValues(peeledLoop.getTerminator()); - - assert(yieldedTensors.size() == 1 && - "expected to have a single result in scf.forall loop"); - auto definingOp = yieldedTensors.front().getDefiningOp(); - if (!definingOp) return failure(); - - auto opts = getSCFTilingOptions( - rewriter.getContext(), - SmallVector(definingOp.getLoopIteratorTypes().size(), 1)); - if (failed(tileUsingSCFForallOpAndFuseGreedily(rewriter, definingOp, opts, - fuseFilterFn))) { - return failure(); - } - } - return success(); -} - -FailureOr wrapFusionCluster( - PatternRewriter& rewriter, const FusionCluster& fusionCluster) { - auto loc = fusionCluster.root->getLoc(); - - SetVector inputOperands; - SmallVector initOperands = - getRootOpInitOperands(rewriter, fusionCluster); - - // 1. Find operands and results of the cluster op. - SmallVector clusterResults; - SmallVector constantOps; - auto visitOpOperand = [&](OpOperand* operand) { - Value operandValue = operand->get(); - auto* definingOp = operandValue.getDefiningOp(); - - if (definingOp && definingOp->hasTrait()) { - constantOps.push_back(operandValue); - return; - } - - if (fusionCluster.operations.contains(definingOp)) return; - if (llvm::is_contained(initOperands, operandValue)) return; - - inputOperands.insert(operandValue); - }; - - for (Operation* op : fusionCluster.operations) { - for (OpOperand& operand : op->getOpOperands()) visitOpOperand(&operand); - - visitUsedValuesDefinedAbove(op->getRegions(), visitOpOperand); - - if (llvm::any_of(op->getUsers(), [&](Operation* user) { - return !fusionCluster.operations.contains(user); - })) { - llvm::append_range(clusterResults, op->getResults()); - } - } - - SetVector clusterOperands = inputOperands; - clusterOperands.insert(initOperands.begin(), initOperands.end()); - - // 2. Create an empty op. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(fusionCluster.root); - auto fusionClusterOp = rewriter.create( - loc, TypeRange(ValueRange(clusterResults)), - ValueRange(inputOperands.getArrayRef()), ValueRange(initOperands), - nullptr, nullptr); - - // 3. Create block with mapping between operands and block arguments. - SmallVector blockArgTypes = - llvm::to_vector(TypeRange(ValueRange(clusterOperands.getArrayRef()))); - SmallVector blockArgLocs(blockArgTypes.size(), loc); - - Region& region = fusionClusterOp.getRegion(); - Block* block = - rewriter.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - - IRMapping mapper; - mapper.map(clusterOperands, block->getArguments()); - - // 4. Copy ops into the cluster region. - // 4.1. Copy constant ops. - for (auto v : constantOps) { - auto newOp = rewriter.clone(*v.getDefiningOp())->getResult(0); - mapper.map(v, newOp); - } - - // 4.2. Copy ops into the cluster region in topoligical order to avoid - // swapping depending ops. - SmallVector clusterOps(fusionCluster.operations.begin(), - fusionCluster.operations.end()); - - mlir::computeTopologicalSorting(clusterOps); - for (Operation* op : clusterOps) { - rewriter.clone(*op, mapper); - } - - // 4.3 Create terminator gml_st.yield. - SmallVector yieldOpOperands = llvm::to_vector(llvm::map_range( - clusterResults, [&](Value v) { return mapper.lookupOrDefault(v); })); - auto yieldOp = rewriter.create(loc, yieldOpOperands); - - // 5. Replace all uses of ops in the cluster with results of the new fusion - // cluster op. - for (auto [fromV, toV] : - llvm::zip(clusterResults, fusionClusterOp.getResults())) { - rewriter.replaceAllUsesExcept(fromV, toV, yieldOp); - } - - return fusionClusterOp; -} - -LogicalResult inlineFusionCluster(FusionOp fusionOp, - PatternRewriter& rewriter) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(fusionOp); - - IRMapping mapper; - mapper.map(fusionOp.getRegion().getArguments(), fusionOp.getOperands()); - - for (auto& op : fusionOp.getBody()->without_terminator()) { - rewriter.clone(op, mapper); - } - - if (fusionOp.hasTensorSemantics()) { - SmallVector yieldOpOperands = llvm::to_vector( - llvm::map_range(fusionOp.getTerminator().getOperands(), - [&](Value v) { return mapper.lookupOrDefault(v); })); - - rewriter.replaceOp(fusionOp, yieldOpOperands); - } else { - rewriter.eraseOp(fusionOp); - } - - return success(); -} - -// Duplicates the op so each copy has only one use as init parameter. -template -LogicalResult duplicateInitOps(OpTy op, PatternRewriter& rewriter) { - // Nothing to do, because the op has 0 or 1 users. - if (std::distance(op->user_begin(), op->user_end()) <= 1) return failure(); - - bool modified = false; - for (auto& use : llvm::make_early_inc_range(op->getUses())) { - Operation* ownerOp = use.getOwner(); - - auto dstStyleOp = dyn_cast(ownerOp); - if (!dstStyleOp || !dstStyleOp.isDpsInit(&use)) continue; - - auto newOp = cast(rewriter.clone(*op)); - use.set(newOp->getResult(0)); - modified = true; - } - return success(modified); -} - -void populateDuplicateInitOpsPatterns(RewritePatternSet& patterns) { - patterns.add(duplicateInitOps); - patterns.add(duplicateInitOps); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h deleted file mode 100644 index b1b9181108e0e8..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H -#define MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H - -#include - -#include "gml_st/transforms/peeling/peeling.h" -#include "gml_st/transforms/tiling/tiling.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir::gml_st { - -struct FusionCluster { - SetVector operations; - Operation *root; - // Map from Value of the fusion cluster argument to the root dimensions. - llvm::SmallVector>> argDimsMapping; -}; -// Cluster producers and consumers around the root op. -FusionCluster getFusionCluster( - Operation *op, llvm::function_ref producerFilterFn, - llvm::function_ref consumerFilterFn); - -// Creates gml_st.fusion op with a region with ops from the fusion cluster. -// Operands of the ops in the region are replaced with region arguments to -// isolate the fusion cluster form above. Usages of the ops are replaces with -// the fusion op results. -FailureOr wrapFusionCluster( - PatternRewriter &rewriter, const FusionCluster &fusionCluster); - -// Replaces gml_st.fusion op with ops from the region. -LogicalResult inlineFusionCluster(FusionOp fusionOp, PatternRewriter &rewriter); - -// Adds patterns to duplicate linalg.fill and tensor.empty that used as init -// parameters. -void populateDuplicateInitOpsPatterns(RewritePatternSet &patterns); - -// Fuses an op into `tensor.extract_slice` and performs the necessary updates to -// the surrounding loop if any. -FailureOr fuse(PatternRewriter &rewriter, - tensor::ExtractSliceOp materializeOp); - -// Finds `tensor.extract_slice` ops in the block and fuses ops into them. -// Verifies that fusion candidate doesn't have any uses except the one -// `tensor.extract_slice` in the block to avoid exponential code growth. -void fuseGreedily(PatternRewriter &rewriter, ArrayRef blocks, - llvm::function_ref filterFn = nullptr); - -// Tiles the op to gml_st.parallel and fuses greedily according to the filter. -FailureOr tileUsingSCFForallOpAndFuseGreedily( - PatternRewriter &rewriter, Operation *op, const scf::SCFTilingOptions &opts, - llvm::function_ref fuseFilterFn = nullptr); - -// Tiles the op to scf.for and fuses greedily according to the filter. -FailureOr tileUsingSCFForOpAndFuseGreedily( - PatternRewriter &rewriter, Operation *op, const scf::SCFTilingOptions &opts, - llvm::function_ref fuseFilterFn = nullptr); - -// Tiles the op to 1 for all dimensions and fuses greedily according to the -// filter function. -LogicalResult tilePeeledOpsToScalars( - PatternRewriter &rewriter, const GmlStPeelingResult &peelingResult, - llvm::function_ref fuseFilterFn = nullptr); - -} // namespace mlir::gml_st - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h deleted file mode 100644 index b9562e722177d3..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.h +++ /dev/null @@ -1,244 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H -#define MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H - -#include -#include -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::gml_st { - -#define GEN_PASS_DECL -#include "gml_st/transforms/passes.h.inc" - -/// Pass to fuse producers into a tiled consumer. -std::unique_ptr> createFusionPass( - StringRef producer = "", StringRef consumer = ""); - -/// Pass to match, tile, and fuse softmax implementations. -std::unique_ptr> createTilingSoftmaxPass( - ArrayRef tileSizes); -std::unique_ptr> createTilingSoftmaxPass(); - -/// Pass to tile the root operation and to greedily fuse producers into it. -std::unique_ptr> createGreedyFusionPass( - ArrayRef tileSizes); -std::unique_ptr> createGreedyFusionPass(); - -// Pass to collapse dimensions of bcasts, reductions, and cwise ops. -std::unique_ptr> createCollapseShapePass(); -std::unique_ptr> createCollapseShapePass( - const CollapseShapePassOptions &options); - -// Pass to tile all tileable ops to size 1. -std::unique_ptr> createTileByOnePass(); - -/// Pass to compose tensor.extract_slice/insert_slice ops. -std::unique_ptr> -createComposeExtractInsertSlicePass(); - -/// Pass to vectorize compute ops and scf.for loops that are tiled perfectly. -std::unique_ptr> createVectorizeForCPUPass( - int64_t numElementsThreshold = 1024); - -/// Pass to gradually lower vector ops to SCF. -std::unique_ptr> createLowerVectorsPass( - bool enableAVX2 = true, bool flatten = false); - -/// Pass to pack linalg.matmul as linalg.mmt4d. -std::unique_ptr> createPackMatmulPass(); - -/// Pass to transform a thlo.scatter op for CPU backend. -std::unique_ptr> createTransformScatterForCpuPass(); - -/// Pass to transform a dot operation for CPU backend. -std::unique_ptr> createTransformDotForCpuPass( - ArrayRef tileSizes = {}, StringRef cpuName = ""); - -/// Pass to transform tensor.pack/unpack ops for CPU backend. -std::unique_ptr> createTransformPackForCpuPass(); - -/// Pass to transform a linalg.mmt4d op for CPU backend. -std::unique_ptr> createTransformMmt4DForCpuPass(); - -/// Pass to fuse linalg on tensor operations. -std::unique_ptr> createFusionOfTensorOpsPass(); - -/// Pass to convert ops on tensors with 1 element to scalar ops. -std::unique_ptr> createScalarizationPass( - bool scalarizeAllThlo = true); - -/// Pass to transform elementwise ops for CPU backend. -std::unique_ptr> -createTransformElementwiseForCpuPass(int64_t vectorSize = 8, - bool fuseDegenerateReshapes = false); - -/// Pass to transform a linalg.reduce op for CPU backend. -std::unique_ptr createTransformReduceForCpuPass( - const TransformReduceForCpuPassOptions &option = {}); - -/// Pass to create fusion clusters. -std::unique_ptr> -createFusionPlanningForCpuPass(int64_t vectorSize = 8); - -/// Pass to outline fusion regions into functions. -std::unique_ptr> createFusionOutliningPass(); - -/// Pass to inline fusion clusters. -std::unique_ptr> -createInlineFusionClustersPass(); - -/// Pass with canonicalization patterns for linalg ops. -std::unique_ptr> -createOptimizeLinalgOpsPass(); - -/// Pass to rewrite tensor.from_elements into tensor.insert. -std::unique_ptr> -createRewriteFromElementsOpPass(); - -/// Pass to rewrite scf.forall to scf.for. -std::unique_ptr> -createRewriteForallOpPass(); - -/// Pass to add debug info to be propagated into LLVM backend. -std::unique_ptr> createAddDebugInfoPass(); - -/// Pass to print stats about tileable ops. -std::unique_ptr> createCollectStatsPass( - int64_t level = 0); - -/// Pass to remove all transformed labels from tiled ops. -std::unique_ptr> -createRemoveLabelPass(); - -/// Populate pattern to remove single/zero iteration scf.forall dimensions. -void populateCollapseForallOpDimensionsPattern(RewritePatternSet &patterns); - -struct GmlStCPUTilingOptions - : public mlir::PassPipelineOptions { - GmlStCPUTilingOptions() = default; - GmlStCPUTilingOptions(const GmlStCPUTilingOptions &opts) { - this->lowerToMmt4d = opts.lowerToMmt4d; - this->matmulTileSizes = opts.matmulTileSizes; - this->reduction1DTileSize = opts.reduction1DTileSize; - this->reduction1DSplitRatio = opts.reduction1DSplitRatio; - this->reduction2DParallelDimTileSize = opts.reduction2DParallelDimTileSize; - this->reduction2DReductionDimTileSize = - opts.reduction2DReductionDimTileSize; - this->vectorSize = opts.vectorSize; - this->statsDetailLevel = opts.statsDetailLevel; - this->cpuName = opts.cpuName; - this->inlineFusionClusters = opts.inlineFusionClusters; - } - - Option vectorSize{*this, "vector-size", - llvm::cl::desc("Vector size for a 1D reduction."), - llvm::cl::init(8)}; - - Option reductionEnableHeuristic{ - *this, "reduction-enable-heuristic", - llvm::cl::desc("Enable tiling parameters heuristic for reductions."), - llvm::cl::init(false)}; - - Option reduction1DTileSize{ - *this, "reduction-1d-tile-size", - llvm::cl::desc("Tile size for a 1D reduction."), llvm::cl::init(32)}; - - Option reduction1DSplitRatio{ - *this, "reduction-1d-split-ratio", - llvm::cl::desc("Ratio used to split the reduction dimension"), - llvm::cl::init(8)}; - - Option reduction2DParallelDimTileSize{ - *this, "reduction-2d-parallel-dim-tile-size", - llvm::cl::desc("Tile size for the parallel dimension of a 2D reduction."), - llvm::cl::init(4)}; - - Option reduction2DReductionDimTileSize{ - *this, "reduction-2d-reduction-dim-tile-size", - llvm::cl::desc( - "Tile size for the reduction dimension of a 2D reduction."), - llvm::cl::init(4)}; - - ListOption matmulTileSizes{ - *this, "matmul-tile-sizes", - llvm::cl::desc("Tile sizes for `linalg.matmul`. Leave empty to determine " - "sizes automatically."), - llvm::cl::list_init({}), llvm::cl::ZeroOrMore}; - - Option vectorizationSizeThreshold{ - *this, "vectorization-size-threshold", - llvm::cl::desc("Threshold size for vectorization."), llvm::cl::init(128)}; - - Option vectorizationTiledSizeThreshold{ - *this, "vectorization-tiled-size-threshold", - llvm::cl::desc("Threshold size for vectorization after tiling."), - llvm::cl::init(1024)}; - - Option lowerToMmt4d{ - *this, "lower-to-mmt4d", - llvm::cl::desc("Enable the specific code generation (packing) for matmul " - "operations."), - llvm::cl::init(false)}; - - Option cpuName{ - *this, "cpu", - llvm::cl::desc("CPU name, similar to llc's -mcpu flag. e.g. 'znver2', " - "'skylake-avx512'."), - llvm::cl::init("")}; - - Option statsDetailLevel{ - *this, "stats-detail-level", - llvm::cl::desc("Detail level for collecting IR statistics."), - llvm::cl::init(0)}; - - Option fuseDegenerateReshapes{ - *this, "fuse-degenerate-reshapes", - llvm::cl::desc("Fuse through tensor.expand/collapse_shape"), - llvm::cl::init(false)}; - - Option inlineFusionClusters{ - *this, "inline-fusion-clusters", - llvm::cl::desc("Inline fusion clusters at the end of the pipeline."), - llvm::cl::init(true)}; -}; - -// Returns default "optimized" tiling parameters. -GmlStCPUTilingOptions getDefaultCPUPipelineOptions( - StringRef cpuName, int64_t statsDetailLevel = 0); - -// Adds tiling-fusion-vectorization passes for tHLO/Linalg ops mix. -void addCPUTilingPipeline(OpPassManager &pm, - const GmlStCPUTilingOptions &options); - -// Adds tiling-fusion-vectorization passes for tHLO/Linalg ops mix with the -// "optimized" tiling parameters. -void addDefaultCPUTilingPipeline(OpPassManager &pm, StringRef cpuName, - int64_t statsDetailLevel = 0); - -#define GEN_PASS_REGISTRATION -#include "gml_st/transforms/passes.h.inc" - -} // namespace mlir::gml_st - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td b/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td deleted file mode 100644 index cb41946b07eb80..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/passes.td +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/Pass/PassBase.td" - -def TilingSoftmaxPass : Pass<"gml-tiling-softmax", "mlir::func::FuncOp"> { - let summary = "Match, tile, and fuse softmax implementations"; - let constructor = "::mlir::gml_st::createTilingSoftmaxPass()"; - let options = [ - ListOption<"tileSizes", "tile-sizes", "int64_t", - "Right-aligned tile sizes. Do not tile possible remaining " - "dimensions", "llvm::cl::ZeroOrMore">, - ]; -} - -def TileByOnePass : Pass<"gml-tile-by-one", "mlir::func::FuncOp"> { - let summary = "Tile all tileable ops by size 1"; - let description = [{ - Tile all tileable ops to size 1. This is meant as a fallback for those ops - that were not previously tiled and vectorized. - }]; - let constructor = "::mlir::gml_st::createTileByOnePass()"; -} - -def CollapseShapePass : Pass<"gml-collapse-shape", "mlir::func::FuncOp"> { - let summary = "Collapse dimensions of bcasts, reductions, and cwise ops"; - let description = [{ - Pass to collapse dimensions of bcasts, reductions, and cwise ops. A given - number of trailing dimensions remains untouched while the remaining leading - dimensions will be collapsed where possible. - }]; - let constructor = "::mlir::gml_st::createCollapseShapePass()"; - let options = [ - Option<"retainTrailingDims", "retain-trailing-dims", "int64_t", - /*default=*/"0", - "Number of trailing dimensions that will not be collapsed.">, - ]; - let dependentDialects = ["::mlir::tensor::TensorDialect"]; -} - -def ComposeExtractInsertSlicePass : Pass<"gml-compose-extract-insert-slice", - "mlir::func::FuncOp"> { - let summary = "Compose tensor.extract_slice/insert_slice ops."; - let constructor = "::mlir::gml_st::createComposeExtractInsertSlicePass()"; -} - -def VectorizeForCPUPass : Pass<"vectorize-for-cpu", "mlir::func::FuncOp"> { - let summary = "Pass to vectorize gml_st.for loops that are tiled perfectly."; - let constructor = "::mlir::gml_st::createVectorizeForCPUPass()"; - let dependentDialects = [ - "::mlir::vector::VectorDialect", - "::mlir::tensor::TensorDialect" - ]; - let options = [ - Option<"numElementsThreshold", "num-elements-threshold", "int64_t", - /*default=*/"128", - "Number of elements max of the tensor operands in order for the op to be vectorized.">, - ]; -} - -def LowerVectorsPass : Pass<"lower-vectors", "mlir::func::FuncOp"> { - let summary = "Pass to lower vector operations progressively."; - let constructor = "::mlir::gml_st::createLowerVectorsPass()"; - let dependentDialects = [ - "::mlir::LLVM::LLVMDialect", - "::mlir::vector::VectorDialect", - "::mlir::affine::AffineDialect", - ]; - let options = [ - Option<"enableAVX2", "enable-avx2", "bool", /*default=*/"true", - "Enable specialized lowerings for AVX2.">, - Option<"flatten", "flatten", "bool", /*default=*/"false", - "Flatten multiple small n-D vector transfers into a large 1-D transfer.">, - ]; -} - -def ScalarizationPass : Pass<"scalarize", "mlir::func::FuncOp"> { - let summary = "Converts ops on tensors with 1 element to scalar ops."; - let dependentDialects = [ - "arith::ArithDialect", - "scf::SCFDialect", - "tensor::TensorDialect" - ]; - let constructor = "createScalarizationPass()"; - let options = [ - Option<"scalarizeAllThlo", "scalarize-all-thlo", "bool", /*default=*/"true", - "Enable scalarization of thlo.concatenate/gather/scatter.">, - ]; -} - -def PackMatmulPass : Pass<"xla-cpu-pack-matmul", "mlir::func::FuncOp"> { - let summary = "Pack linalg.matmul as linalg.mmt4d"; - let constructor = "createPackMatmulPass()"; -} - -def TransformScatterForCpuPass : - Pass<"xla-cpu-transform-scatter", "mlir::func::FuncOp"> { - let summary = "Transform scatter ops for running on CPU"; - - let constructor = "createTransformScatterForCpuPass()"; -} - -def TransformDotForCpuPass : - Pass<"xla-cpu-transform-dot", "mlir::func::FuncOp"> { - let summary = "Transform dot ops for running on CPU"; - let constructor = "createTransformDotForCpuPass()"; -} - -def TransformPackForCpuPass : - Pass<"xla-cpu-transform-pack", "mlir::func::FuncOp"> { - let summary = "Transform tensor.pack/unpack ops for running on CPU"; - let constructor = "createTransformPackForCpuPass()"; -} - -def TransformMmt4DForCpuPass : - Pass<"xla-cpu-transform-mmt4d", "mlir::func::FuncOp"> { - let summary = "Transform linalg.mmt4d ops for running on CPU"; - let constructor = "createTransformMmt4DForCpuPass()"; -} - -def TransformElementwiseForCpuPass : - Pass <"xla-cpu-transform-elementwise", "mlir::func::FuncOp"> { - let summary = "Transform elementwise ops for running on CPU"; - let description = [{ - Transforms elementwise ops, i.e. map, transpose, broadcast, concat, reverse. - }]; - let constructor = "::mlir::gml_st::createTransformElementwiseForCpuPass()"; - - let options = [ - Option<"vectorSize", "vector-size", "int64_t", "8", "Vector size.">, - Option<"fuseDegenerateReshapes", "fuse-degenerate-reshapes", "bool", - /*default=*/"false", - "Fuse through degenerate tensor.expand/collapse_shape">, - ]; -} - -def TransformReduceForCpuPass : - Pass<"xla-cpu-transform-reduce", "mlir::func::FuncOp"> { - let summary = "Transform reduce ops for running on CPU"; - let options = [ - Option<"enableHeuristic", "enable_heuristic", "bool", "false", - "Enable heuristic for tiling sizes. Currently only for 1D.">, - Option<"tileSize1D", "reduction-1d-tile-size", "int64_t", "32", - "Tile size for a 1D reduction.">, - Option<"splitRatio1D", "reduction-1d-split-ratio", "int64_t", "8", - "Ratio used to split the reduction dimension, i.e. tiled reduce op " - "`reduce(tensor)` will be split into a composition of a " - " column reduction `reduce(tensor)` " - "and a row 1D reductionreduce(tensor)`." >, - Option<"parallelDimTileSize2D", - "reduction-2d-parallel-dim-tile-size", "int64_t", "4", - "Tile size for the parallel dimension of a 2D reduction.">, - Option<"reductionDimTileSize2D", - "reduction-2d-reduction-dim-tile-size", "int64_t", "4", - "Tile size for the reduction dimension of a 2D reduction.">, - ]; - let constructor = "::mlir::gml_st::createTransformReduceForCpuPass()"; -} - -def AddDebugInfoPass : - Pass<"add-debug-info", "mlir::ModuleOp"> { - let summary = "Add debug info for the whole module"; - let constructor = "::mlir::gml_st::createAddDebugInfoPass()"; - let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; -} - -def CollectStatsPass : - Pass<"collect-stats", "mlir::func::FuncOp"> { - let summary = "Print stats about tileable ops"; - let constructor = "::mlir::gml_st::createCollectStatsPass()"; -} - -def RemoveLabelPass : Pass<"remove-label", "mlir::func::FuncOp"> { - let summary = "Remove transformed labels from tiled ops"; - let constructor = "::mlir::gml_st::createRemoveLabelPass()"; -} - -def FusionPlanningForCpuPass : - Pass<"gml-st-cpu-fusion-planning", "mlir::func::FuncOp"> { - let summary = "Create fusion clusters."; - let constructor = "createFusionPlanningForCpuPass()"; - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::gml_st::GmlStDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::tensor::TensorDialect" - ]; - - let options = [ - Option<"vectorSize", "vector-size", "int64_t", "8", - "Tile size for the innermost dimension of `linalg.map`">, - ]; -} - -def RewriteFromElementsOpPass - : Pass<"gml-st-rewrite-from-elements-ops", "mlir::func::FuncOp"> { - let summary = "Pass to rewrite tensor.from_elements into tensor.insert."; - let constructor = "createRewriteFromElementsOpPass()"; -} - -def RewriteForallOpPass - : Pass<"gml-st-rewrite-forall-ops", "mlir::func::FuncOp"> { - let summary = "Pass to rewrite scf.forall to scf.for."; - let constructor = "createRewriteForallOpPass()"; -} - -def FusionOutliningPass : Pass<"gml-fusion-outlining", "mlir::ModuleOp"> { - let summary = "Pass to outline fusion regions into functions."; - let constructor = "createFusionOutliningPass()"; -} - -def InlineFusionClustersPass : - Pass<"gml-st-inline-fusion-clusters", "mlir::func::FuncOp"> { - let summary = "Replaces all gml_st.fusion op with ops from the region."; - let constructor = "createInlineFusionClustersPass()"; - let dependentDialects = [ - "::mlir::gml_st::GmlStDialect", - ]; -} - -def OptimizeLinalgOpsPass - : Pass<"gml-st-optimize-linalg-ops-pass", "mlir::func::FuncOp"> { - let summary = "Canonicalization patterns for linalg ops."; - let constructor = "createOptimizeLinalgOpsPass()"; - - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::tensor::TensorDialect" - ]; -} diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc deleted file mode 100644 index 587f43561c1c32..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/peeling/peeling.h" - -#include "gml_st/IR/gml_st_ops.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/IRMapping.h" - -namespace mlir { -namespace gml_st { -namespace { - -bool isATensor(Type t) { return t.isa(); } - -/// Return true if the given op has only tensor-typed results or operands. -bool hasTensorSemantics(Operation *op) { - return llvm::all_of(op->getResultTypes(), isATensor) || - llvm::all_of(op->getOperandTypes(), isATensor); -} - -LogicalResult peelLoop(RewriterBase &b, scf::ForallOp loopOp, int64_t idx, - scf::ForallOp &result, Value &splitBound) { - if (!hasTensorSemantics(loopOp)) return failure(); - - Location loc = loopOp.getLoc(); - Value lb = - getValueOrCreateConstantIndexOp(b, loc, loopOp.getMixedLowerBound()[idx]); - Value ub = - getValueOrCreateConstantIndexOp(b, loc, loopOp.getMixedUpperBound()[idx]); - Value step = - getValueOrCreateConstantIndexOp(b, loc, loopOp.getMixedStep()[idx]); - auto ubInt = getConstantIntValue(ub); - - AffineExpr exprLb, exprUb, exprStep; - bindSymbols(b.getContext(), exprLb, exprUb, exprStep); - // New upper bound: %ub - (%ub - %lb) mod %step - auto modMap = AffineMap::get(0, 3, exprUb - ((exprUb - exprLb) % exprStep)); - SmallVector operands{lb, ub, step}; - affine::canonicalizeMapAndOperands(&modMap, &operands); - modMap = simplifyAffineMap(modMap); - RewriterBase::InsertionGuard guard(b); - b.setInsertionPoint(loopOp); - splitBound = b.createOrFold(loc, modMap, operands); - - // No specialization necessary if step already divides upper bound evenly. - if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) - return failure(); - - // Create remainder loop. - IRMapping bvm; - for (const auto &[res, termDst] : - llvm::zip(loopOp.getResults(), loopOp.getOutputs())) { - bvm.map(termDst, res); - } - b.setInsertionPointAfter(loopOp); - auto remainderLoop = - cast(b.clone(*loopOp.getOperation(), bvm)); - - Operation *remainderLoopOp = remainderLoop.getOperation(); - - for (auto [oldRes, newRes] : - llvm::zip(loopOp.getResults(), remainderLoop.getResults())) { - SmallPtrSet exceptions({remainderLoopOp}); - for (OpOperand &use : oldRes.getUses()) { - Operation *user = use.getOwner(); - if (user->getParentOp() == remainderLoopOp) exceptions.insert(user); - } - oldRes.replaceAllUsesExcept(newRes, exceptions); - } - - // Set new loop bounds. - SmallVector ubs = loopOp.getMixedUpperBound(); - ubs[idx] = splitBound; - SmallVector dynamicUbs; - SmallVector staticUbs; - dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs); - b.updateRootInPlace(loopOp, [&]() { - loopOp.getDynamicUpperBoundMutable().assign(dynamicUbs); - loopOp.setStaticUpperBound(staticUbs); - }); - - SmallVector lbs = remainderLoop.getMixedLowerBound(); - lbs[idx] = splitBound; - SmallVector dynamicLbs; - SmallVector staticLbs; - dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs); - b.updateRootInPlace(remainderLoop, [&]() { - remainderLoop.getDynamicLowerBoundMutable().assign(dynamicLbs); - remainderLoop.setStaticLowerBound(staticLbs); - }); - - result = remainderLoop; - return success(); -} - -template -void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, Operation *mainLoop, - Operation *remainderLoop, Value mainIv, - Value remainderIv, Value ub, Value step) { - mainLoop->walk([&](OpTy affineOp) { - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, ub, step, - /*insideLoop=*/true); - }); - remainderLoop->walk([&](OpTy affineOp) { - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, remainderIv, ub, step, - /*insideLoop=*/false); - }); -} - -} // namespace - -GmlStPeelingResult peelAllLoops(scf::ForallOp loop, - mlir::PatternRewriter &rewriter) { - GmlStPeelingResult peelingResult; - - bool hasMainLoop = true; - for (unsigned peeledIdx = 0; peeledIdx < loop.getRank(); ++peeledIdx) { - int64_t numLoops = loop.getRank(); - if (peeledIdx < 0 || numLoops <= peeledIdx) continue; - - OpFoldResult ubOfr = loop.getMixedUpperBound()[peeledIdx]; - OpFoldResult stepOfr = loop.getMixedStep()[peeledIdx]; - auto ubInt = getConstantIntValue(ubOfr); - auto stepInt = getConstantIntValue(stepOfr); - - // If the loop is smaller than the step, then append loop as tail. Needs to - // be done only once. - if (ubInt && stepInt && ubInt < stepInt) { - if (hasMainLoop) { - peelingResult.tailLoops.push_back(loop); - hasMainLoop = false; - } - continue; - } - - Location loc = loop.getLoc(); - Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, ubOfr); - Value step = getValueOrCreateConstantIndexOp(rewriter, loc, stepOfr); - scf::ForallOp remainderLoop; - Value splitBound; - if (failed(peelLoop(rewriter, loop, peeledIdx, remainderLoop, splitBound))) - continue; - - // Rewrite affine.min and affine.max ops. - Value mainIv = loop.getInductionVars()[peeledIdx], - remainderIv = remainderLoop.getInductionVars()[peeledIdx]; - - rewriteAffineOpAfterPeeling( - rewriter, loop, remainderLoop, mainIv, remainderIv, ub, step); - rewriteAffineOpAfterPeeling( - rewriter, loop, remainderLoop, mainIv, remainderIv, ub, step); - - // Mark the new loop if one was created. - peelingResult.tailLoops.push_back(remainderLoop); - } - - // Update main loop if applicable. - if (hasMainLoop) peelingResult.mainLoop = loop; - - return peelingResult; -} - -SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp loop) { - // Peeling fails, if the step divides the upper bound. In that case, - // we still want to return {loop, nullptr}. - scf::ForOp tailLoop; - return succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, loop, tailLoop)) - ? SCFForPeelingResult{loop, tailLoop} - : SCFForPeelingResult{loop, nullptr}; -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h deleted file mode 100644 index da3412b8b9901f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H -#define MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H - -#include "gml_st/IR/gml_st_ops.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace gml_st { - -struct GmlStPeelingResult { - scf::ForallOp mainLoop = nullptr; - SmallVector tailLoops = {}; -}; - -/// Rewrite a scf::ForallOp with bounds/step that potentially do not divide -/// evenly into a scf::ForallOp where the step divides the iteration space -/// evenly, followed by another scf::ForallOp for the last (partial) -/// iteration (if any). This transformation is called "loop peeling". -/// -/// These functions peel all loops in the loop nest by calling -/// peelAndCanonicalizeGmlStLoop. Additionally, they mark all loops (main and -/// remainder loops) as peeled, so the same loop is not rewritten a second time. -GmlStPeelingResult peelAllLoops(scf::ForallOp loop, - mlir::PatternRewriter &rewriter); - -struct SCFForPeelingResult { - scf::ForOp mainLoop = nullptr; - scf::ForOp tailLoop = nullptr; -}; -SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_from_elements_op/rewrite_from_elements_op.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_from_elements_op/rewrite_from_elements_op.cc deleted file mode 100644 index 56ee84469abefe..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_from_elements_op/rewrite_from_elements_op.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "gml_st/transforms/passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_REWRITEFROMELEMENTSOPPASS -#include "gml_st/transforms/passes.h.inc" - -// Rewrite `tensor.from_elements(x)` into `tensor.insert(x, tensor.empty)`. -// In combination with `empty-tensor-elimination` it removes the alloc that can -// result from `tensor.from_elements`. -struct RewriteFromElementsOpInDestinationPassingStyle - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::FromElementsOp op, - PatternRewriter &rewriter) const override { - return linalg::rewriteInDestinationPassingStyle(rewriter, op); - } -}; - -class RewriteFromElementsOpPass - : public impl::RewriteFromElementsOpPassBase { - void runOnOperation() override { - auto func = getOperation(); - auto *context = &getContext(); - - RewritePatternSet patterns(context); - patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> createRewriteFromElementsOpPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_scf_forall/rewrite_scf_forall.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_scf_forall/rewrite_scf_forall.cc deleted file mode 100644 index 0fe18bd3af6363..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/rewrite_scf_forall/rewrite_scf_forall.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -// Rewrites `scf.forall` to an `scf.for` loop nest. -LogicalResult rewriteScfForallToScfFor(scf::ForallOp forallOp, - PatternRewriter &rewriter) { - if (forallOp.getRank() == 0) return failure(); - // Do not convert to scf.for if scf.forall is mapped to threads. - if (forallOp.getMapping().has_value()) return failure(); - - Location loc = forallOp.getLoc(); - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, loc, forallOp.getLowerBound(rewriter), - forallOp.getUpperBound(rewriter), forallOp.getStep(rewriter), - forallOp.getOutputs(), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - IRMapping map; - map.map(forallOp.getInductionVars(), ivs); - map.map(forallOp.getOutputBlockArguments(), iterArgs); - - for (auto &op : forallOp.getBody()->without_terminator()) - nestedBuilder.clone(op, map); - - auto inParallelOp = forallOp.getTerminator(); - scf::ValueVector results; - for (auto &op : inParallelOp.getYieldingOps()) { - auto mappedOperands = - llvm::to_vector(llvm::map_range(op.getOperands(), [&](Value val) { - return map.lookupOrDefault(val); - })); - results.push_back(rewriter.create( - nestedLoc, mappedOperands, op.getAttrs())); - } - rewriter.eraseOp(forallOp.getTerminator()); - return results; - }); - - // Copy attributes from `scf.forall` to the output - SmallVector elidedAttrs{forallOp.getOperandSegmentSizesAttrName(), - forallOp.getStaticLowerBoundAttrName(), - forallOp.getStaticUpperBoundAttrName(), - forallOp.getStaticStepAttrName()}; - SmallVector attrs = llvm::to_vector(llvm::make_filter_range( - forallOp->getAttrs(), [&](const NamedAttribute &attr) { - return !llvm::is_contained(elidedAttrs, attr.getName()); - })); - - for (scf::ForOp loop : loopNest.loops) { - rewriter.updateRootInPlace(loop, [&]() { - for (const auto &attr : attrs) - loop->setAttr(attr.getName(), attr.getValue()); - }); - } - rewriter.replaceOp(forallOp, loopNest.results); - return success(); -} - -#define GEN_PASS_DEF_REWRITEFORALLOPPASS -#include "gml_st/transforms/passes.h.inc" - -class RewriteForallOpPass - : public impl::RewriteForallOpPassBase { - void runOnOperation() override { - auto func = getOperation(); - auto *context = &getContext(); - - RewritePatternSet patterns(context); - patterns.add(rewriteScfForallToScfFor); - scf::ForOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> createRewriteForallOpPass() { - return std::make_unique(); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc deleted file mode 100644 index df12ce7714b3ba..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc +++ /dev/null @@ -1,670 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/scalarization/scalarization.h" - -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_SCALARIZATIONPASS -#include "gml_st/transforms/passes.h.inc" - -using linalg::LinalgOp; -using tensor::ExtractOp; -using tensor::ExtractSliceOp; -using tensor::FromElementsOp; -using tensor::InsertOp; - -// Fold `tensor.insert_slice(tensor.from_elements(x), dst)` into -// `tensor.insert(x, dst)` for single-element tensors. -struct FoldTensorFromElementsIntoInsertSlice - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, - PatternRewriter &rewriter) const override { - auto fromElementsOp = - insertSliceOp.getSource().getDefiningOp(); - if (!fromElementsOp || !hasSingleElement(fromElementsOp.getType())) { - return failure(); - } - SmallVector indices = getValueOrCreateConstantIndexOp( - rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets()); - rewriter.replaceOpWithNewOp( - insertSliceOp, fromElementsOp.getElements().front(), - insertSliceOp.getDest(), indices); - return success(); - } -}; - -LogicalResult inlinePayload(PatternRewriter &rewriter, Location loc, - LinalgOp linalgOp, ValueRange argValues) { - // Clone everything but terminator. - Block *body = linalgOp.getBlock(); - IRMapping map; - map.map(body->getArguments(), argValues); - for (auto &op : body->without_terminator()) { - if (auto indexOp = dyn_cast(&op)) { - Value zero = rewriter.create(loc, 0); - map.map(indexOp.getResult(), zero); - continue; - } - rewriter.clone(op, map); - } - - // Wrap every scalar result into a tensor using `tensor.from_elements`. - SmallVector newResults; - for (auto [resultType, yieldOperand] : llvm::zip( - linalgOp->getResultTypes(), body->getTerminator()->getOperands())) { - auto scalarValue = map.lookupOrDefault(yieldOperand); - newResults.push_back( - rewriter.create(loc, resultType, scalarValue)); - } - rewriter.replaceOp(linalgOp, newResults); - return success(); -} - -// `scalarizeLinalgOp` has to be wrapped in OpInterfaceRewritePattern, because -// `patterns.add` does not support adding interface rewriter patterns yet. -struct ScalarizeLinalgOp : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp linalgOp, - PatternRewriter &rewriter) const override { - return scalarizeLinalgOp(linalgOp, rewriter); - } -}; - -// Get reassociation indices to collapse first dimension. -SmallVector getCollapseFirstDimReassociation( - unsigned rank) { - SmallVector result{{0, 1}}; - for (unsigned i = 2; i < rank; ++i) result.push_back({i}); - return result; -} - -// Returns `startIndices`[0, :] for `startIndices` of shape 1xn. Returns None if -// startIndices has a different shape. -std::optional> extractStartIndices( - ImplicitLocOpBuilder &b, TypedValue startIndices) { - if (startIndices.getType().getRank() != 2 || - startIndices.getType().getDimSize(0) != 1) { - return std::nullopt; - } - - int64_t indexVectorSize = startIndices.getType().getDimSize(1); - SmallVector result; - result.reserve(indexVectorSize); - Value zero = b.create(0); - for (int64_t i = 0; i < indexVectorSize; ++i) { - result.push_back(b.create( - startIndices, ValueRange{zero, b.create(i)})); - } - return result; -} - -// Return i1 value after checking that 0 <= indices < dims(tensor). -Value isValidIndex(OpBuilder &b, Location loc, ArrayRef indices, - ArrayRef tensorDims, Value &zero) { - auto i1Type = b.getI1Type(); - Value isValid = b.create( - loc, i1Type, IntegerAttr::get(i1Type, APInt(1, 1))); - - for (auto [dim, index] : llvm::zip(tensorDims, indices)) { - Value geZero = - b.create(loc, arith::CmpIPredicate::sge, index, zero); - Value ltDim = - b.create(loc, arith::CmpIPredicate::slt, index, dim); - Value dimInBounds = b.create(loc, geZero, ltDim); - isValid = b.create(loc, isValid, dimInBounds); - } - return isValid; -} - -Value isIndexInBounds(ImplicitLocOpBuilder &b, Location loc, - ArrayRef updatesDimValues, - ArrayRef scatterIndices, - ArrayRef initDimValues, Value &zero, Value &one) { - SmallVector limitIndex{updatesDimValues.drop_front()}; - for (const auto &en : llvm::enumerate(scatterIndices)) { - limitIndex[en.index()] = - b.create(loc, limitIndex[en.index()], en.value()); - } - for (auto &value : limitIndex) { - value = b.create(loc, value, one); - } - - Value inBounds = isValidIndex(b, loc, limitIndex, initDimValues, zero); - return b.create( - loc, inBounds, isValidIndex(b, loc, scatterIndices, initDimValues, zero)); -} - -Value tensorHasElement(OpBuilder &b, Location loc, Value input, - int64_t concatDim) { - Value zero = b.create(loc, 0); - Value concatDimSize = b.create(loc, input, concatDim); - return b.create(loc, arith::CmpIPredicate::ne, concatDimSize, - zero); -} - -Value extractElementFromInputs( - OpBuilder &b, Location loc, ValueRange inputs, Type resultType, - int64_t concatDim, - llvm::function_ref - materializeAndInsert) { - if (inputs.size() == 1) { - return materializeAndInsert(b, loc, inputs.front()); - } - - return b - .create( - loc, tensorHasElement(b, loc, inputs.front(), concatDim), - [&](OpBuilder &thenBuilder, Location thenLoc) { - thenBuilder.create( - thenLoc, - materializeAndInsert(thenBuilder, thenLoc, inputs.front())); - }, - [&](OpBuilder &elseBuilder, Location elseLoc) { - elseBuilder.create( - elseLoc, extractElementFromInputs( - elseBuilder, elseLoc, inputs.drop_front(), - resultType, concatDim, materializeAndInsert)); - }) - .getResult(0); -} - -LogicalResult scalarizeOp(Operation *op, PatternRewriter &rewriter, - TypedValue &input, - TypedValue &output) { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - - auto outputType = output.getType().dyn_cast(); - if (!outputType) { - return rewriter.notifyMatchFailure( - op, "failed to cast output to RankedTensorType"); - } - if (!hasSingleElement(outputType)) { - return rewriter.notifyMatchFailure( - op, "has output with number of elements not equal to 1"); - } - - auto inputType = input.getType().dyn_cast(); - if (!inputType) { - return rewriter.notifyMatchFailure( - op, "failed to cast input to RankedTensorType"); - } - - Value zero = b.create(0); - llvm::SmallVector indicesInput(inputType.getRank(), zero); - llvm::SmallVector indicesOutput(outputType.getRank(), zero); - - Value extractedValue = b.create(input, indicesInput); - Value result = b.create(outputType, extractedValue); - - rewriter.replaceOp(op, result); - return success(); -} - -LogicalResult hoistTensorExtractFromForOp(scf::ForOp forOp, - PatternRewriter &rewriter) { - if (forOp.getInitArgs().size() != 1) return failure(); - OpOperand &iterOperand = forOp.getInitArgsMutable()[0]; - auto iterArgTensorTy = - dyn_cast(iterOperand.get().getType()); - if (!iterArgTensorTy || !hasSingleElement(iterArgTensorTy)) return failure(); - - Value bbArg = forOp.getTiedLoopRegionIterArg(&iterOperand); - - if (!bbArg.hasOneUse()) return failure(); - - Operation *user = *bbArg.getUsers().begin(); - auto extractOp = dyn_cast(user); - if (!extractOp) return failure(); - - Operation *terminator = forOp.getBody()->getTerminator(); - auto fromTensorOp = - terminator->getOperand(0).getDefiningOp(); - if (!fromTensorOp) return failure(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(forOp); - Location loc = forOp.getLoc(); - Value extractedElement = rewriter.create(loc, iterOperand.get(), - extractOp.getIndices()); - auto newForOp = rewriter.create( - loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), - ValueRange{extractedElement}); - newForOp->setAttrs(forOp->getAttrs()); - Block *newLoopBody = newForOp.getBody(); - - // Move old body into new for loop. - rewriter.setInsertionPointToStart(newLoopBody); - SmallVector blockArgs{ - newForOp.getInductionVar(), - rewriter.create(loc, iterArgTensorTy, - newForOp.getRegionIterArg(0))}; - rewriter.mergeBlocks(forOp.getBody(), newLoopBody, blockArgs); - - // Replace terminator that yields a tensor with the one that yields the - // element. - Operation *newTerminator = newForOp.getBody()->getTerminator(); - rewriter.setInsertionPointAfter(newTerminator); - Value elemOfYieldedTensor = rewriter.create( - loc, terminator->getOperand(0), extractOp.getIndices()); - rewriter.replaceOpWithNewOp(newTerminator, elemOfYieldedTensor); - - // Replace the old loop with the new loop result wrapped in a tensor. - rewriter.setInsertionPointAfter(newForOp); - rewriter.replaceOpWithNewOp( - forOp, forOp.getResultTypes().front(), newForOp.getResult(0)); - - return success(); -} - -LogicalResult hoistTensorExtractFromIfOp(scf::IfOp ifOp, - PatternRewriter &rewriter) { - // Analyse result types and determine what we can scalarize. - int64_t numResults = ifOp.getNumResults(); - SmallVector isScalarizableResult(numResults, false); - SmallVector unscalarizedResultType = - llvm::to_vector(ifOp.getResultTypes()); - SmallVector scalarizedResultType = - llvm::to_vector(ifOp.getResultTypes()); - bool isAnyResultScalarizable = false; - for (int64_t i = 0; i < numResults; ++i) { - auto rankedTy = scalarizedResultType[i].dyn_cast(); - if (!rankedTy || !hasSingleElement(rankedTy)) continue; - isScalarizableResult[i] = true; - scalarizedResultType[i] = rankedTy.getElementType(); - isAnyResultScalarizable = true; - } - - if (!isAnyResultScalarizable) { - return rewriter.notifyMatchFailure(ifOp, "cannot scalarize any result"); - } - - // Create new if ifOp. - Location loc = ifOp.getLoc(); - Value zero = rewriter.create(loc, 0); - auto scalarizedOp = rewriter.create(loc, scalarizedResultType, - ifOp.getCondition()); - scalarizedOp.getThenRegion().takeBody(ifOp.getThenRegion()); - scalarizedOp.getElseRegion().takeBody(ifOp.getElseRegion()); - for (int64_t i = 0; i < numResults; ++i) { - if (!isScalarizableResult[i]) continue; - - // Insert `extract` ops to yield value as a scalar. - llvm::SmallVector zeroIndices( - unscalarizedResultType[i].cast().getRank(), zero); - rewriter.setInsertionPoint(scalarizedOp.thenYield()); - Value thenScalar = rewriter.createOrFold( - loc, scalarizedOp.thenYield().getOperand(i), zeroIndices); - scalarizedOp.thenYield().setOperand(i, thenScalar); - rewriter.setInsertionPoint(scalarizedOp.elseYield()); - Value elseScalar = rewriter.createOrFold( - loc, scalarizedOp.elseYield().getOperand(i), zeroIndices); - scalarizedOp.elseYield().setOperand(i, elseScalar); - } - - // Insert `from_elements` ifOp to be type compatible. - rewriter.setInsertionPointAfter(scalarizedOp); - SmallVector results(scalarizedOp.getResults()); - for (int64_t i = 0; i < numResults; ++i) { - if (!isScalarizableResult[i]) continue; - - // Wrap scalar. - results[i] = rewriter.create( - loc, unscalarizedResultType[i], results[i]); - } - - rewriter.replaceOp(ifOp, results); - return success(); -} - -struct ScalarizationPass - : public impl::ScalarizationPassBase { - using Base::Base; - - void runOnOperation() override { - auto func = getOperation(); - auto *ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx); - patterns.add(hoistTensorExtractFromForOp); - patterns.add(hoistTensorExtractFromIfOp); - patterns.add(scalarizeDynamicBroadcastInDimOp); - patterns.add(scalarizeReverseOp); - - if (scalarizeAllThlo) { - patterns.add(scalarizeConcatenateOp); - patterns.add(scalarizeGatherOp); - patterns.add(scalarizeScatterOp); - } - - FromElementsOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - signalPassFailure(); - } -}; - -} // namespace - -LogicalResult scalarizeConcatenateOp(thlo::ConcatenateOp concatenateOp, - PatternRewriter &rewriter) { - Location loc = concatenateOp.getLoc(); - int64_t concatDim = concatenateOp.getDimension().getSExtValue(); - - auto initTensor = concatenateOp.getInit(); - auto initType = initTensor.getType(); - int64_t rank = initTensor.getType().getRank(); - - // Only scalarize when it's statically known that output concatenation dim - // size is one. - if (initType.getShape()[concatDim] != 1) { - return failure(); - } - - IntegerAttr oneAttr = rewriter.getIndexAttr(1); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector strides(rank, oneAttr); - - SmallVector sizes; - for (int i = 0; i < rank; ++i) { - if (i == concatDim) { - sizes.push_back(oneAttr); - } else { - sizes.emplace_back(rewriter.create(loc, initTensor, i)); - } - } - - auto materializeAndInsert = [&](OpBuilder &b, Location l, Value input) { - Value slice = b.create(l, input, offsets, sizes, strides); - return b.create(l, slice, initTensor, offsets, sizes, - strides); - }; - - Value res = - extractElementFromInputs(rewriter, loc, concatenateOp.getInputs(), - initType, concatDim, materializeAndInsert); - - rewriter.replaceOp(concatenateOp, res); - - return success(); -} - -LogicalResult scalarizeDynamicBroadcastInDimOp( - thlo::DynamicBroadcastInDimOp broadcastOp, PatternRewriter &rewriter) { - auto input = broadcastOp.getOperand(); - auto output = broadcastOp.getInit(); - return scalarizeOp(broadcastOp, rewriter, input, output); -} - -LogicalResult scalarizeGatherOp(thlo::GatherOp gatherOp, - PatternRewriter &rewriter) { - Location loc = gatherOp.getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - auto startIndices = extractStartIndices(b, gatherOp.getStartIndices()); - if (!startIndices) return failure(); - - TypedValue init = gatherOp.getInit(); - ShapedType initTy = init.getType(); - int64_t initRank = initTy.getRank(); - SmallVector initDimSizes = tensor::getMixedSizes(b, loc, init); - SmallVector initDimSizeValues = - getValueOrCreateConstantIndexOp(b, loc, initDimSizes); - - IntegerAttr oneAttr = b.getI64IntegerAttr(1); - - TypedValue operand = gatherOp.getOperand(); - auto operandSizes = getValueOrCreateConstantIndexOp( - b, loc, tensor::getMixedSizes(b, loc, operand)); - Value zero = b.create(0); - Value one = b.create(1); - - SmallVector sliceSizes{initDimSizeValues.begin() + 1, - initDimSizeValues.end()}; - while (sliceSizes.size() < startIndices->size()) { - sliceSizes.push_back(one); - } - - // Clamp the indices. - for (auto &&[startIndex, max, sliceSize] : - llvm::zip(*startIndices, operandSizes, sliceSizes)) { - auto maxMinusSize = b.createOrFold(loc, max, sliceSize); - startIndex = b.create(loc, startIndex, maxMinusSize); - startIndex = b.create(loc, startIndex, zero); - } - - SmallVector lbs(initRank, zero); - SmallVector steps(initRank, one); - - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, loc, lbs, initDimSizeValues, steps, ValueRange{init}, - [&](OpBuilder &nestedBuilder, Location bodyLoc, ValueRange ivs, - ValueRange loopInits) { - // Compute the index in the operand. - SmallVector readIndices(operand.getType().getRank(), zero); - llvm::copy(ivs.drop_front(1), readIndices.begin()); - for (auto &&[readIndex, startIndex] : - llvm::zip(readIndices, *startIndices)) { - readIndex = nestedBuilder.create(bodyLoc, readIndex, - startIndex); - } - - // Materialize the value and yield it. - SmallVector ones(initRank, oneAttr); - Value val = - nestedBuilder.create(bodyLoc, operand, readIndices); - Value updatedInit = nestedBuilder.create( - bodyLoc, val, loopInits.front(), ivs); - - return scf::ValueVector({updatedInit}); - }); - - rewriter.replaceOp(gatherOp, loopNest.results); - return success(); -} - -LogicalResult scalarizeLinalgOp(LinalgOp linalgOp, PatternRewriter &rewriter) { - // Fail if not every argument is a scalar or a single-element tensor. - if (!hasSingleElementOperandsAndResults(linalgOp)) return failure(); - - // Do not scalarize linalg::FillOp that is only used by DPS ops as init - // operands. - if (isa(linalgOp)) { - if (llvm::all_of(linalgOp->getUses(), [&](OpOperand &use) { - Operation *user = use.getOwner(); - if (auto dpsOp = dyn_cast(user)) { - SmallVector opOperands = llvm::to_vector( - llvm::map_range(dpsOp.getDpsInitsMutable(), - [](OpOperand &o) { return &o; })); - return llvm::is_contained(opOperands, &use); - } - return false; - })) - return failure(); - } - - // Load the data corresponding to the block arguments that - // represent input operands. - SmallVector indexedValues; - indexedValues.reserve(linalgOp->getNumOperands()); - Location loc = linalgOp->getLoc(); - auto zero = rewriter.create(loc, 0); - for (OpOperand &operand : linalgOp->getOpOperands()) { - if (!linalgOp.payloadUsesValueFromOperand(&operand)) { - indexedValues.push_back(nullptr); - continue; - } - if (linalgOp.isScalar(&operand)) { - indexedValues.push_back(operand.get()); - continue; - } - Value operandValue = operand.get(); - Type operandType = operandValue.getType(); - SmallVector indices(operandType.cast().getRank(), - zero); - Value load = rewriter.create(loc, operandValue, indices); - indexedValues.push_back(load); - } - - // Inline the op payload and rewrite the operation. - return inlinePayload(rewriter, loc, linalgOp, indexedValues); -} - -LogicalResult scalarizeReverseOp(thlo::ReverseOp reverseOp, - PatternRewriter &rewriter) { - auto input = reverseOp.getInput(); - auto output = reverseOp.getInit(); - return scalarizeOp(reverseOp, rewriter, input, output); -} - -FailureOr rewriteScatterOpAsIfOp(thlo::ScatterOp scatterOp, - PatternRewriter &rewriter) { - Location loc = scatterOp.getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - b.setInsertionPoint(scatterOp); - - auto scatterIndices = extractStartIndices(b, scatterOp.getIndices()); - if (!scatterIndices) return failure(); - Value updates = scatterOp.getUpdates(); - auto updatesType = updates.getType().dyn_cast(); - if (!updatesType) return failure(); - unsigned updatesRank = updatesType.getRank(); - - SmallVector updatesDimSizes = - tensor::getMixedSizes(b, loc, updates); - SmallVector updatesDimValues = - getValueOrCreateConstantIndexOp(b, loc, updatesDimSizes); - - Value init = scatterOp.getInit(); - auto initType = init.getType().dyn_cast(); - if (!initType) return failure(); - SmallVector initDimValues = getValueOrCreateConstantIndexOp( - b, loc, tensor::getMixedSizes(b, loc, init)); - - Value zero = b.create(0); - Value one = b.create(1); - - Value indexIsInBounds = - isIndexInBounds(b, loc, updatesDimValues, scatterIndices.value(), - initDimValues, zero, one); - auto ifOp = b.create( - loc, indexIsInBounds, - [&](OpBuilder &thenBuilder, Location thenLoc) { - SmallVector collapsedOffsets; - for (size_t i = 0; i < updatesRank - 1; ++i) { - collapsedOffsets.push_back( - i < (scatterIndices->size()) ? (*scatterIndices)[i] : zero); - } - SmallVector collapsedSizes; - for (size_t i = 1; i < updatesRank; ++i) { - collapsedSizes.push_back(updatesDimSizes[i]); - } - - auto collapsedStrides = SmallVector(updatesRank - 1, one); - - // If body consists only from terminator, then insert the update - // slice into `init`, otherwise reduce the update slice with the same - // body. - if (scatterOp.getBody()->getOperations().size() == 1) { - SmallVector offsets(updatesRank, zero); - SmallVector strides(updatesRank, one); - - // Create rank-reducing `tensor.extract_slice` to avoid insertion of - // `tensor.collapse_shape` to get rid of the outer size-1 dimension. - Value extracted = thenBuilder.create( - thenLoc, updates, offsets, updatesDimSizes, strides); - Value collapsed = thenBuilder.create( - thenLoc, extracted, - getCollapseFirstDimReassociation(updatesRank)); - - // Insert resized `updates` into `init`. - Value inserted = thenBuilder.create( - thenLoc, collapsed, init, collapsedOffsets, collapsedSizes, - collapsedStrides); - thenBuilder.create(thenLoc, inserted); - return; - } - - // Extract a slice for `init`. - Value extracted = thenBuilder.create( - thenLoc, init, collapsedOffsets, collapsedSizes, collapsedStrides); - - // Insert indentity slice for `updates`. - Value updatesSlice = thenBuilder.create( - thenLoc, updates, - SmallVector(updatesRank, b.getIndexAttr(0)), - updatesDimSizes, - SmallVector(updatesRank, b.getIndexAttr(1))); - - // Reduce `updates` into that slice. - auto reduced = thenBuilder.create( - thenLoc, extracted.getType().cast(), updatesSlice, - extracted, ArrayRef({0})); - reduced.getRegion().takeBody(scatterOp.getBodyRegion()); - - Operation *yield = reduced.getBlock()->getTerminator(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp(yield, - yield->getOperands()); - // Put that slice back. - auto inserted = thenBuilder.create( - thenLoc, reduced.getResults().front(), init, collapsedOffsets, - collapsedSizes, collapsedStrides); - thenBuilder.create(thenLoc, inserted.getResult()); - }, - [&](OpBuilder &elseBuilder, Location elseLoc) { - elseBuilder.create(elseLoc, init); - }); - rewriter.replaceOp(scatterOp, ifOp.getResults()); - return ifOp; -} - -LogicalResult scalarizeScatterOp(thlo::ScatterOp scatterOp, - PatternRewriter &rewriter) { - return rewriteScatterOpAsIfOp(scatterOp, rewriter); -} - -std::unique_ptr> createScalarizationPass( - bool scalarizeAllThlo) { - ScalarizationPassOptions opts; - opts.scalarizeAllThlo = scalarizeAllThlo; - return std::make_unique(opts); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.h deleted file mode 100644 index 9455aa1c95eb18..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_SCALARIZATION_SCALARIZATION_H -#define MLIR_HLO_GML_ST_TRANSFORMS_SCALARIZATION_SCALARIZATION_H - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/PatternMatch.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace gml_st { - -/// Rewrites thlo.concatenate, returns `failure` if IR was not changed. -LogicalResult scalarizeConcatenateOp(thlo::ConcatenateOp concatenateOp, - PatternRewriter &rewriter); - -/// Rewrites thlo.dynamic_broadcast_in_dim, returns `failure` if IR was not -/// changed. -LogicalResult scalarizeDynamicBroadcastInDimOp( - thlo::DynamicBroadcastInDimOp broadcastOp, PatternRewriter &rewriter); - -/// Rewrites thlo.gather, returns `failure` if IR was not changed. -LogicalResult scalarizeGatherOp(thlo::GatherOp gatherOp, - PatternRewriter &rewriter); - -/// Rewrites LinalgOp interface ops, returns `failure` if IR was not changed. -LogicalResult scalarizeLinalgOp(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter); - -/// Rewrites thlo.reverse, returns `failure` if IR was not changed. -LogicalResult scalarizeReverseOp(thlo::ReverseOp reverseOp, - PatternRewriter &rewriter); - -/// Rewrites thlo.scatter, returns `failure` if IR was not changed. -LogicalResult scalarizeScatterOp(thlo::ScatterOp scatterOp, - PatternRewriter &rewriter); - -FailureOr rewriteScatterOpAsIfOp(thlo::ScatterOp scatterOp, - PatternRewriter &rewriter); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_SCALARIZATION_SCALARIZATION_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.cc deleted file mode 100644 index a5aa31770e11b4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/test_passes.h" - -#include -#include -#include - -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/peeling/peeling.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_TESTGMLSTGREEDYFUSION -#include "gml_st/transforms/test_passes.h.inc" - -static constexpr llvm::StringRef kTestFusionAppliedLabel = - "__test_fusion_applied_label__"; - -struct GreedyFusionPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ForallOp op, - PatternRewriter &rewriter) const override { - if (hasLabel(op, kTestFusionAppliedLabel)) return failure(); - - rewriter.updateRootInPlace(op, [&]() { - fuseGreedily(rewriter, &op.getRegion().front(), [](Operation *op) { - return isa(op); - }); - }); - - setLabel(op, kTestFusionAppliedLabel); - return success(); - } -}; - -struct TestGmlStGreedyFusionPass - : public impl::TestGmlStGreedyFusionBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - - patterns.add(ctx); - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - return signalPassFailure(); - - funcOp.walk( - [](scf::ForallOp op) { removeLabel(op, kTestFusionAppliedLabel); }); - } -}; - -} // namespace - -std::unique_ptr> createTestGmlStGreedyFusionPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.h deleted file mode 100644 index 7b4399b9c2fabc..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H -#define MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H - -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace gml_st { - -#define GEN_PASS_DECL -#include "gml_st/transforms/test_passes.h.inc" - -std::unique_ptr> createTestGmlStGreedyFusionPass(); - -#define GEN_PASS_REGISTRATION -#include "gml_st/transforms/test_passes.h.inc" - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.td b/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.td deleted file mode 100644 index 2be3577fc79b99..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/test_passes.td +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/Pass/PassBase.td" - -def TestGmlStGreedyFusion : Pass<"test-gml-st-greedy-fusion", "mlir::func::FuncOp"> { - let summary = "Fuse ops greedily into gml-st loops."; - let constructor = "::mlir::gml_st::createTestGmlStGreedyFusionPass()"; -} diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tile_by_one.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tile_by_one.cc deleted file mode 100644 index 6e0d28273e543f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tile_by_one.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_TILEBYONEPASS -#include "gml_st/transforms/passes.h.inc" - -static constexpr llvm::StringRef kTileByOneLabel = "__tile_by_one_label__"; - -template -struct TileByOnePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Skip ops that are already tiled. - if (hasLabel(op, kTileByOneLabel)) return failure(); - - // Skip if iteration domain is statically known to be of size 1. - auto iface = llvm::cast(op.getOperation()); - // TODO(frgossen): Avoid creating the IR for these ranges. Instead, the - // tiling interface should allow to access statically known iteration - // domains. - SmallVector iterationDomain = iface.getIterationDomain(rewriter); - auto isRangeSizeOne = [](Range range) { - if (!range.size.is()) return false; - auto intAttr = range.size.get().dyn_cast(); - if (!intAttr) return false; - return intAttr.getInt() == 1; - }; - if (llvm::all_of(iterationDomain, isRangeSizeOne)) return failure(); - - // Tile. - scf::SCFTilingOptions opts; - opts.setTileSizes(SmallVector( - iface.getLoopIteratorTypes().size(), - getAsIndexOpFoldResult(rewriter.getContext(), 1))); - FailureOr tilingResult = - tileUsingSCFForOp(rewriter, iface, opts); - if (failed(tilingResult)) - return rewriter.notifyMatchFailure(op, "tiling to scf.for failed"); - - // Mark resulting tiled ops. - for (Operation *tiled : tilingResult->tiledOps) { - setLabel(tiled, kTileByOneLabel); - } - - rewriter.replaceOp(op, tilingResult->replacements); - return success(); - } -}; - -struct TileByOnePass : public impl::TileByOnePassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - // Populate patterns. - RewritePatternSet patterns(ctx); - // clang-format off - patterns.add< - TileByOnePattern, - TileByOnePattern, - TileByOnePattern, - TileByOnePattern, - TileByOnePattern, - TileByOnePattern>(ctx); - // clang-format on - - // Apply patterns. - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { - return signalPassFailure(); - } - - // Clean up by removing temporary attributes. - f->walk([](Operation *op) { removeLabel(op, kTileByOneLabel); }); - } -}; - -} // namespace - -std::unique_ptr> createTileByOnePass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc deleted file mode 100644 index 85c0e72babf384..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/tiling/tiling.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir::gml_st { -namespace { - -// Compute tile size for the tile that starts at `offset`, has size `tileSize` -// for the tensor with the dimension size `dimSize`. -// The tile size is static when `tileSize` divides `dimSize` or when the -// `tileSize` is 1. -// Otherwise, it is minimum of `tileSize` and `dimSize - offset` to avoid out of -// bounds access. -OpFoldResult computeTileSizeInDim(OpBuilder &builder, Location loc, - OpFoldResult tileSize, OpFoldResult dimSize, - OpFoldResult offset) { - std::optional tileCst = getConstantIntValue(tileSize); - std::optional dimCst = getConstantIntValue(dimSize); - - bool hasTileSizeOne = tileCst && *tileCst == 1; - bool dividesEvenly = tileCst && dimCst && ((*dimCst % *tileCst) == 0); - if (hasTileSizeOne || dividesEvenly) return builder.getIndexAttr(*tileCst); - - AffineExpr d0, s0; - bindDims(builder.getContext(), d0); - bindSymbols(builder.getContext(), s0); - OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( - builder, loc, s0 - d0, {offset, dimSize}); - - return affine::makeComposedFoldedAffineMin( - builder, loc, AffineMap::getMultiDimIdentityMap(2, loc.getContext()), - {residualTileSize, tileSize}); -} - -// Updates offsets, sizes as functions of ivs and insert parallel_insert_slices -// into `in_parallel` terminator. -void calculateTileOffsetsAndSizes(OpBuilder &b, Location loc, - scf::ForallOp forallOp, - ArrayRef steps, - ArrayRef ubs, - ArrayRef nonemptyRangeIndices, - SmallVector &offsets, - SmallVector &sizes) { - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(forallOp.getBody(0)); - for (const auto &[index, iv] : llvm::enumerate(forallOp.getInductionVars())) { - offsets[nonemptyRangeIndices[index]] = iv; - sizes[nonemptyRangeIndices[index]] = - computeTileSizeInDim(b, loc, steps[index], ubs[index], iv); - } -} - -/// Generate an empty loop nest that represents the tiled loop nest shell. -/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. -/// - In `offsets` and `sizes` return the multi-dimensional offset and size of -/// the tile processed within the inner most loop. -scf::ForallOp generateTileLoopNest(OpBuilder &builder, Location loc, - ArrayRef loopRanges, - ArrayRef tileSizeVals, - ArrayRef dstOperands, - SmallVector &offsets, - SmallVector &sizes) { - assert(!loopRanges.empty() && "expected at least one loop range"); - assert(loopRanges.size() == tileSizeVals.size() && - "expected as many tile sizes as loop ranges"); - OpBuilder::InsertionGuard guard(builder); - - SmallVector lbs, ubs, steps; - SmallVector nonemptyRangeIndices; - for (const auto &loopRange : llvm::enumerate(loopRanges)) { - OpFoldResult offset = loopRange.value().offset; - OpFoldResult size = loopRange.value().size; - // No loops if tile size is zero. Set offset and size to the loop offset and - // size. - offsets.push_back(offset); - sizes.push_back(size); - if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) continue; - lbs.push_back(offset); - ubs.push_back(size); - steps.push_back(tileSizeVals[loopRange.index()]); - nonemptyRangeIndices.push_back(loopRange.index()); - } - auto loop = builder.create(loc, lbs, ubs, steps, dstOperands, - std::nullopt); - - calculateTileOffsetsAndSizes(builder, loc, loop, steps, ubs, - nonemptyRangeIndices, offsets, sizes); - return loop; -} - -void updateOutputs(const GMLSTTilingResult &tilingResult, - ValueRange dstOperands) { - scf::ForallOp parallelLoop = tilingResult.loop; - - if (auto dstOp = dyn_cast( - tilingResult.tiledOps.front())) { - for (auto [dst, regionArg] : - llvm::zip(dstOperands, parallelLoop.getOutputBlockArguments())) { - dst.replaceUsesWithIf(regionArg, [&](OpOperand &operand) { - Operation *owner = operand.getOwner(); - return isa(owner) && - owner->getParentOfType() == - parallelLoop.getOperation(); - }); - } - } -} - -} // namespace - -scf::SCFTilingOptions getSCFTilingOptions(MLIRContext *context, - ArrayRef tileSizes) { - scf::SCFTilingOptions opts; - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(context, tileSizes); - opts.setTileSizes(tileSizesOfr); - return opts; -} - -FailureOr tileUsingSCFForallOp( - PatternRewriter &rewriter, TilingInterface op, - const scf::SCFTilingOptions &options) { - rewriter.setInsertionPoint(op); - if (!options.tileSizeComputationFunction) { - return rewriter.notifyMatchFailure( - op, "missing tile size computation function"); - } - Location loc = op.getLoc(); - - // 1. Get the range of the loops that are represented by the operation. - SmallVector iterationDomain = op.getIterationDomain(rewriter); - size_t numLoops = iterationDomain.size(); - if (numLoops == 0) return failure(); - - // 2. Materialize the tile sizes. Enforce the convention that "tiling by - // zero" skips tiling a particular dimension. This convention is - // significantly simpler to handle instead of adjusting affine maps to - // account for missing dimensions. - SmallVector tileSizeVector; - { - OpBuilder::InsertionGuard guard(rewriter); - tileSizeVector = llvm::to_vector( - llvm::map_range(options.tileSizeComputationFunction(rewriter, op), - [&](const OpFoldResult &ofr) -> Value { - if (Value value = mlir::dyn_cast(ofr)) - return value; - if (Attribute attr = mlir::dyn_cast(ofr)) - return rewriter.create( - loc, attr.cast()); - return Value(); - })); - } - - if (tileSizeVector.size() < iterationDomain.size()) { - auto zero = rewriter.create(loc, 0); - tileSizeVector.append(numLoops - tileSizeVector.size(), zero); - } - - if (llvm::all_of(tileSizeVector, - [](Value v) { return matchPattern(v, m_Zero()); })) { - return GMLSTTilingResult{{op}, nullptr}; - } - - // 3. Materialize an empty loop nest that iterates over the tiles. - SmallVector dstOperands; - if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dstOperands))) - return rewriter.notifyMatchFailure(op, "failed to get destinations"); - SmallVector offsets, sizes; - GMLSTTilingResult tilingResult; - tilingResult.loop = - generateTileLoopNest(rewriter, loc, iterationDomain, tileSizeVector, - dstOperands, offsets, sizes); - - Block *loopBody = &tilingResult.loop->getRegion(0).front(); - auto terminator = cast(loopBody->getTerminator()); - rewriter.setInsertionPoint(terminator); - - // 4. Insert the tiled implementation within the loop. - FailureOr tiledImplementation = - op.getTiledImplementation(rewriter, offsets, sizes); - if (failed(tiledImplementation)) - return rewriter.notifyMatchFailure(op, - "failed to get tiled implementation"); - tilingResult.tiledOps = tiledImplementation->tiledOps; - - // 5. Compute tiles for the insertion. - int64_t numResults = op->getNumResults(); - SmallVector outputTiles; - auto oneAttr = rewriter.getI64IntegerAttr(1); - for (const auto &result : llvm::enumerate(op->getResults())) { - rewriter.setInsertionPoint(terminator); - SmallVector resultOffsetsList(numResults), - resultSizesList(numResults); - if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, - sizes, resultOffsetsList, - resultSizesList))) { - return rewriter.notifyMatchFailure( - op, "failed to get slice of result produced"); - } - rewriter.setInsertionPointToEnd(terminator.getBody()); - rewriter.create( - loc, tilingResult.tiledOps.front()->getResult(result.index()), - tilingResult.loop.getOutputBlockArguments()[result.index()], - resultOffsetsList, resultSizesList, - SmallVector(resultSizesList.size(), oneAttr)); - } - rewriter.setInsertionPoint(tilingResult.loop); - - // 6. Update the uses of `outputs` with the output bbArgs. - updateOutputs(tilingResult, dstOperands); - return tilingResult; -} - -SmallVector getYieldedValues(scf::InParallelOp inParallelOp) { - return llvm::to_vector(llvm::map_range( - inParallelOp.getYieldingOps(), [](Operation &op) -> Value { - auto insertSliceOp = cast(&op); - return insertSliceOp.getSource(); - })); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h deleted file mode 100644 index 47d852dcdf1aac..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H -#define MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H - -#include - -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/TilingInterface.h" - -namespace mlir::gml_st { - -// Creates SCFTilingOptions from the list of tile sizes. -scf::SCFTilingOptions getSCFTilingOptions(MLIRContext *context, - ArrayRef tileSizes); - -/// Returns `failure`, when there occurs a problem during tiling. If the tile -/// sizes are smaller then the iteration domain of the op, it will still create -/// an `scf.forall` op. This is matches the behavior of tiling to `scf.for` -/// upstream. -struct GMLSTTilingResult { - SmallVector tiledOps; - scf::ForallOp loop = nullptr; -}; -FailureOr tileUsingSCFForallOp( - PatternRewriter &rewriter, TilingInterface op, - const scf::SCFTilingOptions &options); - -/// Extracts all yielded values from scf.in_parallel terminator. It should be -/// upstreamed. -SmallVector getYieldedValues(scf::InParallelOp inParallelOp); - -} // namespace mlir::gml_st - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc deleted file mode 100644 index c1ea034f0a64a9..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc +++ /dev/null @@ -1,292 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "gml_st/transforms/fusion/fusion.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/tiling/tiling.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/utils/linalg_utils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TILINGSOFTMAXPASS -#include "gml_st/transforms/passes.h.inc" - -constexpr llvm::StringRef kTileSoftmaxAppliedLabel = - "__tile_softmax_applied_label__"; - -Operation *fuseIthOperandInPlace(PatternRewriter &rewriter, Operation *op, - int64_t i) { - auto matOp = - llvm::cast(op->getOperand(i).getDefiningOp()); - FailureOr fused = fuse(rewriter, matOp); - assert(succeeded(fused) && "expect success after matching"); - return *fused; -} - -LogicalResult tilePartialSoftmax( - TilingInterface op, PatternRewriter &rewriter, - llvm::function_ref(Operation *, int64_t)> - tileOperationFn) { - // Match cwise root op. - // Match all operands to be derived from the same source value in one of two - // ways: - // i) by a reduction and subsequent bcast in one dimension, or - // ii) by using the source value as is. - Value commonSource; - std::optional commonReductionDim; - SmallVector> simpleBcastReductions; - auto mapOp = llvm::dyn_cast_or_null(op.getOperation()); - if (!mapOp || mapOp.getNumDpsInits() != 1) - return rewriter.notifyMatchFailure(op, "no mapOp"); - for (Value operand : mapOp.getInputs()) { - // Case i. - SimpleBcastReduction bcastReduction; - int64_t reductionDim; - if (isSimpleBcastReduction(operand.getDefiningOp(), &reductionDim, - &bcastReduction)) { - if (commonSource && commonSource != bcastReduction.operand) { - return rewriter.notifyMatchFailure(bcastReduction.bcast, - "no common reduction source"); - } - commonSource = bcastReduction.operand; - if (commonReductionDim && *commonReductionDim != reductionDim) { - return rewriter.notifyMatchFailure(bcastReduction.reduction, - "no common reduction dim"); - } - commonReductionDim = reductionDim; - simpleBcastReductions.push_back(bcastReduction); - continue; - } - - // Case ii. - if (commonSource && commonSource != operand) - return rewriter.notifyMatchFailure(op, "common source != operand"); - commonSource = operand; - simpleBcastReductions.push_back(std::nullopt); - } - - if (!commonReductionDim || !commonSource) - return rewriter.notifyMatchFailure(op, "no common dim/src"); - - // Tile or fuse cwise root op. - FailureOr tilingResult = - tileOperationFn(op, *commonReductionDim); - if (failed(tilingResult)) - return rewriter.notifyMatchFailure(op, "call to tileOperationFn failed"); - Operation *tiledOp = tilingResult->tiledOps[0]; - setLabel(tiledOp, kTileSoftmaxAppliedLabel); - - // Fuse through the bcast reduction chains. - Value commonTiledSource; - for (int64_t i = 0; i < static_cast(simpleBcastReductions.size()); - i++) { - if (!simpleBcastReductions[i]) continue; - - // Fuse. - Operation *tiledBcast = fuseIthOperandInPlace(rewriter, tiledOp, i); - Operation *tiledReduction = - fuseIthOperandInPlace(rewriter, tiledBcast, /*i=*/0); - - // Use common tiled source value. - if (commonTiledSource) { - tiledReduction->setOperand(0, commonTiledSource); - } else { - commonTiledSource = tiledReduction->getOperands().front(); - } - } - - // Also use the common tiled source value for the remaining operands. - for (size_t i = 0; i < simpleBcastReductions.size(); i++) { - if (simpleBcastReductions[i]) continue; - tiledOp->setOperand(i, commonTiledSource); - } - - return success(); -} - -struct TilePartialSoftmaxPattern - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - TilePartialSoftmaxPattern(MLIRContext *ctx, SmallVector tileSizes, - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(ctx, benefit), - tileSizes(std::move(tileSizes)) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (hasLabel(op, kTileSoftmaxAppliedLabel)) - return rewriter.notifyMatchFailure(op, "has tranformation attr"); - - // Only apply to non-fusable occurrences. - bool hasFusableOccurrences = llvm::any_of( - op->getUsers(), - [](Operation *op) { return llvm::isa(op); }); - if (hasFusableOccurrences) - return rewriter.notifyMatchFailure(op, "has fusable occurrences"); - - return tilePartialSoftmax( - op, rewriter, - [&](Operation *op, - int64_t commonReductionDim) -> FailureOr { - // Populate tiling options. - scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizeComputationFunction( - [&](OpBuilder &b, Operation *op) -> SmallVector { - SmallVector tileSizeValues; - for (int64_t i = 0; i < static_cast(tileSizes.size()); - i++) { - // Skip tiling the reduction dimension. By convention, this is - // a tile size of 0. - int64_t tileSizeInDim = - i == commonReductionDim ? 0 : tileSizes[i]; - tileSizeValues.push_back( - getAsIndexOpFoldResult(b.getContext(), tileSizeInDim)); - } - return tileSizeValues; - }); - // Tile. - FailureOr tilingResult = tileUsingSCFForallOp( - rewriter, cast(op), tilingOptions); - if (failed(tilingResult)) return failure(); - - rewriter.replaceOp(op, tilingResult->loop->getResults()); - setLabel(tilingResult->tiledOps.front(), kTileSoftmaxAppliedLabel); - Operation *tiledOp = tilingResult->tiledOps.front(); - return TilingResult{{tiledOp}, - SmallVector(tiledOp->result_begin(), - tiledOp->result_end())}; - }); - } - - private: - SmallVector tileSizes; -}; - -struct FusePartialSoftmaxPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, - PatternRewriter &rewriter) const override { - Value source = op.getSource(); - Operation *def = source.getDefiningOp(); - if (!def) return failure(); - - if (!llvm::isa(def)) return failure(); - - return tilePartialSoftmax( - cast(def), rewriter, - [&](Operation *cwiseOp, - int64_t /*commonReductionDim*/) -> FailureOr { - auto iface = llvm::dyn_cast_or_null(cwiseOp); - if (!iface) { - return rewriter.notifyMatchFailure( - cwiseOp, "doesn't implement tiling iface"); - } - - // By construction, we assume that the tile spans the operand in the - // common reduction dimension (`commonReductionDim`). - // TODO(frgossen): Assert this assumption when we have moved to - // unnested tiles. - - // Fuse. - SmallVector offsets = op.getMixedOffsets(); - SmallVector sizes = op.getMixedSizes(); - FailureOr tilingResult = - iface.generateResultTileValue(rewriter, 0, offsets, sizes); - if (failed(tilingResult)) { - return rewriter.notifyMatchFailure( - cwiseOp, "failed to generate result tile"); - } - - rewriter.replaceOp(op, tilingResult->tiledValues[0]); - return tilingResult; - }); - } -}; - -struct FuseUnaryCwisePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, - PatternRewriter &rewriter) const override { - // Match unary cwise ops. - Operation *source = op.getSource().getDefiningOp(); - auto mapOp = dyn_cast_or_null(source); - if (!mapOp || mapOp.getNumDpsInputs() != 1) return failure(); - // Fuse. - return fuse(rewriter, op); - } -}; - -struct TilingSoftmaxPass - : public impl::TilingSoftmaxPassBase { - TilingSoftmaxPass() = default; - explicit TilingSoftmaxPass(ArrayRef ts) { this->tileSizes = ts; } - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - // Populate tiling and fusion patterns for partial softmax and unary cwise - // ops. - RewritePatternSet patterns(ctx); - SmallVector tileSizes(this->tileSizes.begin(), - this->tileSizes.end()); - patterns.insert(ctx, tileSizes); - patterns.insert(ctx); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { - return signalPassFailure(); - } - - // Clean up by removing temporary attributes. - f.walk([](Operation *op) { removeLabel(op, kTileSoftmaxAppliedLabel); }); - } -}; - -} // namespace - -std::unique_ptr> createTilingSoftmaxPass() { - return std::make_unique(); -} - -std::unique_ptr> createTilingSoftmaxPass( - ArrayRef tileSizes) { - return std::make_unique(tileSizes); -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.cc deleted file mode 100644 index 315a0f07838902..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/transforms.h" - -#include "mlir/IR/BuiltinTypes.h" - -namespace mlir { -namespace gml_st { - -bool hasSingleElementOperandsAndResults(Operation *op) { - auto isScalar = [](Type type) { - return !type.isa() || - (type.isa() && - hasSingleElement(type.cast())); - }; - return llvm::all_of(op->getOperandTypes(), isScalar) && - llvm::all_of(op->getResultTypes(), isScalar); -} - -void setLabel(Operation *op, StringRef name) { - op->setAttr(name, UnitAttr::get(op->getContext())); -} - -void removeLabel(Operation *op, StringRef name) { op->removeAttr(name); } - -bool hasLabel(Operation *op, StringRef name) { return op->hasAttr(name); } - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.h deleted file mode 100644 index 97d0ef0cf750a3..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/transforms.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H -#define MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H - -#include "mlir/IR/Operation.h" - -namespace mlir { -namespace gml_st { - -constexpr llvm::StringRef kPerfectlyTiledLoopLabel = - "__perfectly_tiled_loop_label__"; - -static constexpr llvm::StringRef kTransformedLabel = "__transformed_label__"; - -template -bool hasSingleElement(ShapedTy type) { - return type.hasStaticShape() && type.getNumElements() == 1; -} -bool hasSingleElementOperandsAndResults(Operation *op); - -// Sets the attribute to the `op` that indicates that the op was transformed. -void setLabel(Operation *op, StringRef name); - -// Removes the attribute that indicates that it was transformed. -void removeLabel(Operation *op, StringRef name); - -// Checks if `op` has the attribute that indicates that it was transformed. -bool hasLabel(Operation *op, StringRef name); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/lower_vectors.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/lower_vectors.cc deleted file mode 100644 index 5622027e00d205..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/lower_vectors.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "gml_st/transforms/passes.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/Dialect/X86Vector/Transforms.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_LOWERVECTORSPASS -#include "gml_st/transforms/passes.h.inc" - -using func::FuncOp; - -LogicalResult rewriteVectorContract(MLIRContext* ctx, FuncOp funcOp) { - // Reduce vector.contract dimensions to fit one of the lowering patterns to - // vector.outerproduct. - { - RewritePatternSet castAwayUnitDimPatterns(ctx); - vector::populateCastAwayVectorLeadingOneDimPatterns( - castAwayUnitDimPatterns); - if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(castAwayUnitDimPatterns)))) { - return failure(); - } - - RewritePatternSet reductionToContractPatterns(ctx); - vector::populateVectorReductionToContractPatterns( - reductionToContractPatterns); - vector::ExtractOp::getCanonicalizationPatterns(reductionToContractPatterns, - ctx); - if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(reductionToContractPatterns)))) { - return failure(); - } - } - - RewritePatternSet patterns(ctx); - vector::populateVectorToVectorCanonicalizationPatterns(patterns); - - // Currently we always lower vector.contract into vector.outerproduct. - vector::populateVectorContractLoweringPatterns( - patterns, - vector::VectorTransformsOptions().setVectorTransformsOptions( - vector::VectorContractLowering::OuterProduct), - /*benefit=*/2, - /*disableOuterProductLowering*/ true); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - - return applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - -// Rewrite `vector.transpose` into vector.shuffle ops. -LogicalResult rewriteVectorTranspose(MLIRContext* ctx, Operation* funcOp, - bool enableAVX2) { - RewritePatternSet patterns(ctx); - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( - vector::VectorTransposeLowering::EltWise); - vector::populateVectorTransposeLoweringPatterns(patterns, - vectorTransformOptions); - - if (enableAVX2) { - // Options for controlling specialized AVX2 lowerings. These lowerings may - // either use intrin or inline_asm depending on needs. So they won't work - // for SSE. - auto avxLoweringOptions = - x86vector::avx2::LoweringOptions().setTransposeOptions( - x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32() - .lower8x8xf32()); - - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avxLoweringOptions, /*benefit=*/10); - } - - return applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - -// Rewrite N-D reductions as the sequence of vector operations without -// horizontal reduction, i.e. `vector.reduction`. -LogicalResult rewriteVectorReductionsND(MLIRContext* ctx, FuncOp funcOp) { - ConversionTarget target(*ctx); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&](vector::MultiDimReductionOp op) { - return op.getSourceVectorType().getRank() == 1; - }); - - RewritePatternSet patterns(ctx); - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vector::VectorMultiReductionLowering::InnerParallel); - return applyPartialConversion(funcOp, target, std::move(patterns)); -} - -// Rewrite 1D reductions as a `vector.reduction`. -LogicalResult rewriteVectorReductions1D(MLIRContext* ctx, Operation* op) { - RewritePatternSet patterns(ctx); - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vector::VectorMultiReductionLowering::InnerReduction); - return applyPatternsAndFoldGreedily(op, std::move(patterns)); -} - -// Return the uses of op if they all are either StoreOp, TransferWriteOp, or -// SubviewOp with only StoreOp/TransferWriteOp users. -std::optional> getUsesIfAllStores(Operation* op) { - llvm::SmallVector opUses; - for (OpOperand& use : op->getUses()) { - Operation* useOp = use.getOwner(); - if (isa(useOp)) { - opUses.push_back(useOp); - continue; - } - if (isa(useOp)) { - if (auto subviewUses = getUsesIfAllStores(useOp)) { - opUses.insert(opUses.end(), subviewUses->begin(), subviewUses->end()); - opUses.push_back(useOp); - continue; - } - } - return std::nullopt; - } - return opUses; -} - -// Track temporary allocations that are never read from. If this is the case -// it means both the allocations and associated stores can be removed. -void eraseDeadAllocAndStores(func::FuncOp func) { - SmallVector opToErase; - func.walk([&](memref::AllocOp op) { - if (auto uses = getUsesIfAllStores(op)) { - // Insert the uses first, - opToErase.insert(opToErase.end(), uses->begin(), uses->end()); - // then the op itself, since we will be erasing from opToErase's start. - opToErase.push_back(op.getOperation()); - } - }); - for (Operation* op : opToErase) { - op->erase(); - } -} - -// Pattern to canonialize tranpose where only one dimension is not unit -// dimension. In this case the transpose is a no-op and should be simplified -// before getting to the conversion to llvm/spirv. -class TransposeUnitDimToShapeCast - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter& rewriter) const override { - unsigned numNonUnitSrcDim = - llvm::count_if(op.getSourceVectorType().getShape(), - [](int64_t dim) { return dim != 1; }); - if (numNonUnitSrcDim != 1) return failure(); - rewriter.replaceOpWithNewOp( - op, op.getResultVectorType(), op.getVector()); - return success(); - } -}; - -// Run optimization transformations on vector transfer operations. -LogicalResult optimizeVectorTransfers(MLIRContext* ctx, FuncOp funcOp, - bool flatten) { - // Generate vector.shape_cast for dropping leading one dimensions in vector - // ops. This increases the chance that we can forward more transfer writes - // to transfer reads. - { - RewritePatternSet patterns(ctx); - mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - vector::ExtractOp::getCanonicalizationPatterns(patterns, ctx); - patterns.add(ctx); - mlir::vector::populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - } - - // Move bitcast inwards from loop region boundaries to increase chances to - // cancel them. - { - RewritePatternSet patterns(ctx); - vector::populateBubbleVectorBitCastOpPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - } - - // Third stage of patterns to flatten transfer ops. - if (flatten) { - RewritePatternSet patterns(ctx); - mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns); - mlir::vector::populateFlattenVectorTransferPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return failure(); - } - } - // Delete potential dead alloc and associated ops after store to load - // forwarding. - eraseDeadAllocAndStores(funcOp); - return success(); -} - -LogicalResult lowerVectorOpsToSCF(MLIRContext* ctx, FuncOp funcOp) { - RewritePatternSet patterns(ctx); - auto vectorTransferToSCFOptions = - VectorTransferToSCFOptions().enableFullUnroll(true).setTargetRank(1); - - populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); - return applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - -struct LowerVectorsPass : public impl::LowerVectorsPassBase { - using Base::Base; - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext* ctx = &getContext(); - - if (failed(rewriteVectorContract(ctx, funcOp))) signalPassFailure(); - if (failed(rewriteVectorTranspose(ctx, funcOp, enableAVX2))) - signalPassFailure(); - if (failed(rewriteVectorReductionsND(ctx, funcOp))) signalPassFailure(); - if (failed(rewriteVectorReductions1D(ctx, funcOp))) signalPassFailure(); - if (failed(optimizeVectorTransfers(ctx, funcOp, flatten))) - signalPassFailure(); - if (failed(lowerVectorOpsToSCF(ctx, funcOp))) signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr> createLowerVectorsPass( - bool enableAVX2, bool flatten) { - LowerVectorsPassOptions opts; - opts.enableAVX2 = enableAVX2; - opts.flatten = flatten; - return std::make_unique(opts); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc deleted file mode 100644 index 2ae94e276921a5..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/transforms/vectorization/vectorization.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace gml_st { - -using mlir::vector::TransferWriteOp; - -RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx) { - RewritePatternSet patterns(ctx); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); - patterns.add(ctx, /*benefit=*/2); - TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - return patterns; -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h deleted file mode 100644 index ede8e9320ff11d..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H -#define MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H - -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace gml_st { - -// TODO(manany): This should be parameterized later on depending on hardware. -static constexpr int64_t kNumElementsVectorization = 8; - -template -struct VectorizationPattern : public mlir::OpRewritePattern { - VectorizationPattern(MLIRContext *context, - llvm::function_ref matchFn, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), filterFn(matchFn) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - if (!filterFn(op)) - return rewriter.notifyMatchFailure(op, "did not match filter"); - return mlir::linalg::vectorize(rewriter, op); - } - - private: - llvm::function_ref filterFn; -}; - -RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc deleted file mode 100644 index 2089351d6b01fd..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc +++ /dev/null @@ -1,421 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/transforms.h" -#include "gml_st/transforms/vectorization/vectorization.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_VECTORIZEFORCPUPASS -#include "gml_st/transforms/passes.h.inc" - -using mlir::linalg::BroadcastOp; -using mlir::linalg::DotOp; -using mlir::linalg::FillOp; -using mlir::linalg::GenericOp; -using mlir::linalg::MapOp; -using mlir::linalg::MatmulOp; -using mlir::linalg::MatvecOp; -using mlir::linalg::Mmt4DOp; -using mlir::linalg::ReduceOp; -using mlir::linalg::TransposeOp; -using mlir::linalg::VecmatOp; -using mlir::tensor::ExpandShapeOp; -using mlir::thlo::ReverseOp; -using mlir::vector::TransferReadOp; -using mlir::vector::TransferWriteOp; - -struct PassVectorizedValuesThroughIfOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::IfOp op, - PatternRewriter &rewriter) const override { - int64_t numResults = op.getNumResults(); - if (numResults == 0) { - return rewriter.notifyMatchFailure(op, - "cannot vectorize if op w/o results"); - } - - // Derive vectorized types. - SmallVector vectorizedTypes(op.getResultTypes()); - int64_t numActuallyVectorizedTypes = 0; - scf::YieldOp thenYieldOp = op.thenYield(); - scf::YieldOp elseYieldOp = op.elseYield(); - for (int64_t i = 0; i < numResults; ++i) { - Value result = op.getResult(i); - - // Can only vectorized statically shaped results. - auto rankedTy = result.getType().dyn_cast(); - if (!rankedTy || !rankedTy.hasStaticShape()) continue; - - // Vectorize only results that are either always used as a vector or - // always produced as a vector. - bool allVectorConsumers = - llvm::all_of(result.getUsers(), [](Operation *user) { - return llvm::isa_and_nonnull(user); - }); - bool allVectorProducers = - llvm::isa_and_nonnull( - thenYieldOp.getOperand(i).getDefiningOp()) && - llvm::isa_and_nonnull( - elseYieldOp.getOperand(i).getDefiningOp()); - if (!allVectorProducers && !allVectorConsumers) continue; - - // Derive vectorized type. - vectorizedTypes[i] = - VectorType::get(rankedTy.getShape(), rankedTy.getElementType()); - numActuallyVectorizedTypes++; - } - - // Fail if there isn't anything to vectorize. - if (numActuallyVectorizedTypes == 0) { - return rewriter.notifyMatchFailure(op, "nothing to vectorize"); - } - - // Create vectorized if op and steal bodies. - Location loc = op.getLoc(); - auto vectorizedIfOp = - rewriter.create(loc, vectorizedTypes, op.getCondition()); - vectorizedIfOp.getThenRegion().takeBody(op.getThenRegion()); - vectorizedIfOp.getElseRegion().takeBody(op.getElseRegion()); - - // Insert `transfer_read/write` ops for type compatibility. - auto zero = rewriter.create(loc, 0); - SmallVector replacements(vectorizedIfOp.getResults()); - for (int64_t i = 0; i < numResults; ++i) { - // Skip non-vectorizable values. - auto vectorTy = vectorizedTypes[i].dyn_cast(); - if (!vectorTy) continue; - - // Yield vectorized value in then-case. - rewriter.setInsertionPoint(vectorizedIfOp.thenYield()); - SmallVector indices(vectorTy.getRank(), zero); - Value unvectorizedThen = vectorizedIfOp.thenYield().getOperand(i); - Value vectorizedThen = rewriter.create( - loc, vectorTy, unvectorizedThen, indices); - vectorizedIfOp.thenYield().setOperand(i, vectorizedThen); - - // Yield vectorized value in else-case. - rewriter.setInsertionPoint(vectorizedIfOp.elseYield()); - Value unvectorizedElse = vectorizedIfOp.elseYield().getOperand(i); - Value vectorizedElse = rewriter.create( - loc, vectorTy, unvectorizedElse, indices); - vectorizedIfOp.elseYield().setOperand(i, vectorizedElse); - - // Insert `transfer_write` op after the vectorized if op for type - // compatibility. - rewriter.setInsertionPointAfter(vectorizedIfOp); - Value init = rewriter.create( - loc, vectorTy.getShape(), vectorTy.getElementType(), ValueRange{}); - replacements[i] = rewriter - .create( - loc, vectorizedIfOp.getResult(i), init, indices) - .getResult(); - } - - // Replace op. - rewriter.replaceOp(op, replacements); - return success(); - } -}; - -// TODO(b/269643522): Upstream this as a canonicalization for `scf.if`. -struct InlineCastInIfOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::CastOp op, - PatternRewriter &rewriter) const override { - auto srcTy = op.getSource().getType().cast(); - auto dstTy = op.getType().cast(); - if (srcTy.hasStaticShape() || !dstTy.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "not cast from dynamic to static shape"); - } - - if (!op.getSource().hasOneUse()) - return rewriter.notifyMatchFailure(op, "source has more than one use"); - - auto ifOp = op.getSource().getDefiningOp(); - if (!ifOp || ifOp.getNumResults() != 1) { - return rewriter.notifyMatchFailure( - op, "source is not an if op with a unique result"); - } - - // Determine result types for the new if op. - SmallVector newResultTypes(ifOp.getResultTypes()); - auto ifOpResult = llvm::cast(op.getSource()); - int64_t resultIdx = ifOpResult.getResultNumber(); - newResultTypes[resultIdx] = dstTy; - - // Create new if op and steal bodies. - rewriter.setInsertionPoint(ifOp); - Location loc = ifOp.getLoc(); - auto newIfOp = - rewriter.create(loc, newResultTypes, ifOp.getCondition()); - newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); - newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); - - // Insert inner casts. - rewriter.setInsertionPoint(newIfOp.thenYield()); - newIfOp.thenYield().setOperand( - resultIdx, rewriter.create( - loc, dstTy, newIfOp.thenYield().getOperand(resultIdx))); - rewriter.setInsertionPoint(newIfOp.elseYield()); - newIfOp.elseYield().setOperand( - resultIdx, rewriter.create( - loc, dstTy, newIfOp.elseYield().getOperand(resultIdx))); - - // Replace op. - rewriter.replaceOp(op, newIfOp.getResults()); - rewriter.eraseOp(ifOp); - return success(); - } -}; - -// This currently matches for all thlo.reverse of the form 1x1x..x1xVectorSize. -// DimSize < kNumElementsVectorization will be handled by Scalarization. -bool isPerfectlyTiledReverse(thlo::ReverseOp reverseOp) { - auto inputType = reverseOp.getInput().getType(); - for (unsigned i = 0; i < inputType.getRank(); ++i) { - if (inputType.isDynamicDim(i)) { - return false; - } - if (i == inputType.getRank() - 1) { - return inputType.getDimSize(i) == kNumElementsVectorization && - llvm::is_contained(reverseOp.getReverseDimensions(), i); - } - if (inputType.getDimSize(i) != 1) { - return false; - } - } - return false; -} - -// Rewrite thlo.reverse of pattern 1x1x..x1xVectorSize as vector.transfer_read -// followed by vector.shuffle followed by vector.transfer_write. -struct ThloReverseVectorizationPattern - : public mlir::OpRewritePattern { - explicit ThloReverseVectorizationPattern(MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(thlo::ReverseOp op, - PatternRewriter &rewriter) const override { - if (!isPerfectlyTiledReverse(op)) - return rewriter.notifyMatchFailure(op, "did not match filter"); - - auto inputType = op.getInput().getType(); - if (!VectorType::isValidElementType(inputType.getElementType())) { - return rewriter.notifyMatchFailure(op, "cannot be vectorized"); - } - auto vecTargetType = - VectorType::get(inputType.getShape()[inputType.getRank() - 1], - inputType.getElementType()); - Value zero = rewriter.create(op.getLoc(), 0); - SmallVector indices(op.getInit().getType().getRank(), zero); - - auto readInput = rewriter.create( - op.getLoc(), vecTargetType, op.getInput(), indices); - - SmallVector mask; - int64_t maskSize = inputType.getShape()[inputType.getRank() - 1]; - mask.reserve(maskSize); - for (int64_t i = maskSize - 1; i >= 0; --i) { - mask.push_back(i); - } - auto shuffle = rewriter.create(op.getLoc(), readInput, - readInput, mask); - - rewriter.replaceOpWithNewOp( - op, shuffle.getResult(), op.getInit(), indices); - return success(); - } -}; - -struct IdentityTransposeOpFoldingPattern - : public OpRewritePattern { - explicit IdentityTransposeOpFoldingPattern(MLIRContext *context, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter & /*rewriter*/) const override { - auto perm = op.getPermutation(); - for (int64_t i = 0; static_cast(i) < perm.size(); ++i) { - if (perm[i] != i) return failure(); - } - - if (!hasSingleElementOperandsAndResults(op)) return failure(); - - op.replaceAllUsesWith(SmallVector(1, op.getInput())); - return success(); - } -}; - -// Rewrite `vector.transfer_read(linalg.expand_shape)` as -// `vector.shape_cast(vector.transfer_read)`. -struct TransferReadOfOneDimExpandShape - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - vector::TransferReadOp vectorRead, - mlir::PatternRewriter &rewriter) const override { - auto expand = vectorRead.getSource().getDefiningOp(); - if (!expand) return failure(); - - auto expandSrc = expand.getSrc(); - auto expandSrcType = expand.getSrcType(); - auto expandDstType = expand.getResultType(); - if (expandSrcType.getRank() != 1 || expandDstType.getRank() != 2) - return failure(); - - auto resultType = vectorRead.getType().dyn_cast(); - if (!resultType || resultType.getShape() != expandDstType.getShape()) - return failure(); - - auto zero = rewriter.create(vectorRead.getLoc(), 0); - auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)}, - vectorRead.getContext()); - // TODO(pifon): Also support canonicalization in case the map is not an - // identity. - if (!map.isIdentity()) return failure(); - - auto newRead = rewriter.create( - vectorRead.getLoc(), - mlir::VectorType::get(expandSrcType.getShape(), - expandSrcType.getElementType()), - expandSrc, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map), - vectorRead.getPadding(), - /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true})); - rewriter.replaceOpWithNewOp( - vectorRead, vectorRead.getType(), newRead); - return success(); - } -}; - -struct VectorizeForCPUPass - : public impl::VectorizeForCPUPassBase { - using Base::Base; - - void runOnOperation() override { - auto func = getOperation(); - auto *ctx = func.getContext(); - - auto isNonComplexSmallTensorOrScalar = [&](Type ty) { - if (getElementTypeOrSelf(ty).isa()) return false; - if (auto rankedTy = ty.dyn_cast()) { - return rankedTy.hasStaticShape() && - rankedTy.getNumElements() < numElementsThreshold; - } - - return !isa(ty); - }; - - auto isOpOnNonComplexSmallTensorOrScalar = [&](Operation *op) { - return llvm::all_of(op->getOperandTypes(), - isNonComplexSmallTensorOrScalar) && - llvm::all_of(op->getResultTypes(), - isNonComplexSmallTensorOrScalar); - }; - auto isInsidePerfectlyTiledLoop = [&](Operation *op) { - Operation *parent = op->getParentOp(); - return (isa(parent)) && - hasLabel(parent, kPerfectlyTiledLoopLabel); - }; - auto isInsidePerfectlyTiledLoopOrSmall = [&](Operation *op) { - return !hasSingleElementOperandsAndResults(op) && - (isInsidePerfectlyTiledLoop(op) || - isOpOnNonComplexSmallTensorOrScalar(op)); - }; - { - RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); - TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - // clang-format off - patterns.add< - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern, - VectorizationPattern - >(ctx, isInsidePerfectlyTiledLoopOrSmall); - // clang-format on - patterns - .add(ctx); - tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); - tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } - - { - RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); - TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } - - // Hoisting transfer_read/transfer_write. - IRRewriter rewriter(func->getContext()); - func.walk( - [&](scf::ForOp forOp) { hoistLoopInvariantSubsets(rewriter, forOp); }); - } -}; - -} // namespace - -std::unique_ptr> createVectorizeForCPUPass( - int64_t numElementsThreshold) { - VectorizeForCPUPassOptions opts; - opts.numElementsThreshold = numElementsThreshold; - return std::make_unique(opts); -} - -} // namespace gml_st -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/gml_st/utils/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/gml_st/utils/CMakeLists.txt deleted file mode 100644 index 809a3e3a892f77..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/utils/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -add_mlir_dialect_library(MLIRGmlStUtils - linalg_utils.cc - tensor_utils.cc - - LINK_LIBS PUBLIC - MLIRLinalgDialect - MLIRTensorDialect -) diff --git a/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.cc b/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.cc deleted file mode 100644 index f1f1f7c300f559..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.cc +++ /dev/null @@ -1,238 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/utils/linalg_utils.h" - -#include - -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" - -namespace mlir::gml_st { -namespace { - -using tensor::CollapseShapeOp; -using tensor::ExpandShapeOp; - -Value collapseDpsInit(OpBuilder &b, Location loc, Value init, - ArrayRef reassociation) { - auto fillOp = init.getDefiningOp(); - if (!fillOp) return b.create(loc, init, reassociation); - - Value collapsedInit = b.create( - loc, fillOp.getOutputs().front(), reassociation); - auto newFill = b.create(loc, fillOp.getInputs(), - ValueRange{collapsedInit}); - return newFill.getResult(0); -} - -} // namespace - -bool isCwiseGenericOp(Operation *op, int64_t *arity) { - auto genericOp = llvm::dyn_cast_or_null(op); - if (!genericOp || genericOp.getNumDpsInits() != 1) return false; - - // Check all-parallel iterator types. - if (!llvm::all_of(genericOp.getIteratorTypesArray(), - linalg::isParallelIterator)) - return false; - - // Check all-identity maps. - if (!llvm::all_of(genericOp.getIndexingMapsArray(), - [](AffineMap map) { return map.isIdentity(); })) { - return false; - } - - // Allow for pattern matching the arity. - if (arity != nullptr) *arity = genericOp.getNumDpsInputs(); - return true; -} - -bool isSimpleBcastReduction(Operation *op, int64_t *dimension, - SimpleBcastReduction *chain) { - // Match bcast. - auto broadcastOp = llvm::dyn_cast_or_null(op); - if (!broadcastOp) return false; - - // Match reduction. - auto reduceOp = llvm::dyn_cast_or_null( - broadcastOp.getOperands().front().getDefiningOp()); - if (!reduceOp || reduceOp.getNumDpsInits() != 1) return false; - - // Check that bcast and reduction dimensions match. - auto bcstDimensions = broadcastOp.getDimensions(); - if (!bcstDimensions.empty() && bcstDimensions != reduceOp.getDimensions()) - return false; - - // Allow for pattern matching the reduction dimension and operation chain. - if (dimension != nullptr) *dimension = bcstDimensions.front(); - if (chain != nullptr) { - chain->bcast = op; - chain->reduction = reduceOp; - chain->operand = reduceOp.getInputs().front(); - } - return true; -} - -bool isTransformableIntoMatmul(linalg::Conv2DNhwcHwcfOp convOp) { - if (!convOp.hasTensorSemantics()) return false; - - Value input = convOp.getInputs()[0]; - auto inputType = input.getType().cast(); - - Value kernel = convOp.getInputs()[1]; - auto kernelType = kernel.getType().cast(); - - Value init = convOp.getOutputs()[0]; - auto initType = init.getType().cast(); - - if (!inputType.hasStaticShape() || !kernelType.hasStaticShape() || - !initType.hasStaticShape()) { - return false; - } - - auto allOnes = [](DenseIntElementsAttr attr) { - return attr.isSplat() && attr.getValues()[0] == 1; - }; - if (!allOnes(convOp.getDilations()) || !allOnes(convOp.getStrides())) - return false; - - if (inputType.getDimSize(0) != 1 || inputType.getDimSize(3) != 1 || - kernelType.getDimSize(2) != 1 || initType.getDimSize(0) != 1 || - initType.getDimSize(2) != 1) - return false; - return true; -} - -FailureOr convertConvToMatmul(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(convOp); - Value input = convOp.getInputs()[0]; - Value kernel = convOp.getInputs()[1]; - Value init = convOp.getOutputs()[0]; - - auto kernelType = kernel.getType().cast(); - if (!isTransformableIntoMatmul(convOp) || kernelType.getDimSize(0) != 1) - return failure(); - - Location loc = convOp.getLoc(); - SmallVector map{{0, 1}, {2, 3}}; - Value newInput = rewriter.create(loc, input, map); - Value newKernel = rewriter.create(loc, kernel, map); - Value newInit = rewriter.create(loc, init, map); - - auto matmul = rewriter.create( - loc, newInit.getType(), ValueRange{newInput, newKernel}, - ValueRange{newInit}); - - rewriter.replaceOpWithNewOp(convOp, convOp.getType(0), - matmul.getResult(0), map); - return matmul; -} - -FailureOr convertBatchMatmulToMatmul( - linalg::BatchMatmulOp batchMatmulOp, PatternRewriter &rewriter) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(batchMatmulOp); - Value lhs = batchMatmulOp.getInputs()[0]; - Value rhs = batchMatmulOp.getInputs()[1]; - Value init = batchMatmulOp.getOutputs()[0]; - - Location loc = batchMatmulOp.getLoc(); - SmallVector map{{0, 1}, {2}}; - Value newLhs = rewriter.create(loc, lhs, map); - Value newRhs = rewriter.create(loc, rhs, map); - Value newInit = collapseDpsInit(rewriter, loc, init, map); - auto matmul = rewriter.create( - loc, newInit.getType(), ValueRange{newLhs, newRhs}, ValueRange{newInit}); - - rewriter.replaceOpWithNewOp( - batchMatmulOp, batchMatmulOp.getType(0), matmul.getResult(0), map); - return matmul; -} - -FailureOr convertMatvecToDotOp(PatternRewriter &rewriter, - linalg::MatvecOp matvecOp) { - auto resultType = matvecOp.getType(0).cast(); - if (resultType.getDimSize(0) != 1) return failure(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(matvecOp); - - Location loc = matvecOp.getLoc(); - Value lhs = matvecOp.getInputs().front(); - Value rhs = matvecOp.getInputs().back(); - Value init = matvecOp.getOutputs().front(); - - Value collapsedLhs = - rewriter.create(loc, lhs, ReassociationIndices{{0, 1}}); - Value collapsedInit = collapseDpsInit(rewriter, loc, init, {}); - auto dotOp = rewriter.create(loc, collapsedInit.getType(), - ValueRange{collapsedLhs, rhs}, - ValueRange{collapsedInit}); - Value expandResult = - rewriter.create(loc, init.getType(), dotOp.getResult(0), - ArrayRef{}); - - rewriter.replaceOp(matvecOp, expandResult); - return dotOp; -} - -FailureOr convertDotOpToReduce(linalg::DotOp dotOp, - PatternRewriter &rewriter) { - Location loc = dotOp.getLoc(); - - // Create empty tensor for linalg.map. - Value lhs = dotOp.getInputs().front(); - FailureOr inputSizeOfr = - tensor::getMixedSize(rewriter, loc, lhs, 0); - - if (failed(inputSizeOfr)) { - return rewriter.notifyMatchFailure( - dotOp, "cannot get the size of the input tensor"); - } - - Type elementType = getElementTypeOrSelf(lhs.getType()); - Value emptyTensor = - rewriter.create(loc, *inputSizeOfr, elementType); - - // Create linalg.map. - Operation *arithMul = &dotOp.getBody()->front(); - auto mul = rewriter.create( - loc, dotOp.getOperands().take_front(2), emptyTensor, - [&](OpBuilder &b, Location loc, ValueRange args) { - auto *n = mlir::clone(b, arithMul, arithMul->getResultTypes(), - args.take_front(2)); - b.create(loc, n->getResults()); - }); - - // Create linalg.reduce. - Operation *arithAdd = &(*std::next(dotOp.getBody()->begin())); - auto add = rewriter.create( - loc, ValueRange{mul.getResult()}, ValueRange{dotOp.getOperand(2)}, - SmallVector{0}, - [&](OpBuilder &b, Location loc, ValueRange args) { - auto *n = mlir::clone(b, arithAdd, arithAdd->getResultTypes(), - {args[1], args[0]}); - b.create(loc, n->getResults()); - }); - - rewriter.replaceOp(dotOp, add->getResults()); - return add; -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.h b/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.h deleted file mode 100644 index aef0a7603ad3f7..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/utils/linalg_utils.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_UTILS_LINALG_UTILS_H -#define MLIR_HLO_GML_ST_UTILS_LINALG_UTILS_H - -#include "mlir/Dialect/Linalg/IR/Linalg.h" - -namespace mlir::gml_st { - -// Helper functions to match Linalg ops that implement simple reductions, -// bcasts, and cwise ops. - -struct SimpleBcastReduction { - Operation *bcast; - Operation *reduction; - Value operand; -}; - -bool isSimpleBcastReduction(Operation *op, int64_t *dimension = nullptr, - SimpleBcastReduction *chain = nullptr); - -// The Conv2D is transformable into a matmul, if it has the following shape -// -// linalg.conv_2d_nhwc_hwcf -// {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} -// ins(%input, %kernel : tensor<1x(N+L-1)xKx1xf32>, tensor) -// outs(%fill : tensor<1xNx1xM>) -> tensor<1xNx1xMxf32> -// -// in that case we can tile w.r.t. L to bring it to the following form -// -// linalg.conv_2d_nhwc_hwcf -// {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} -// ins(%input, %kernel : tensor<1xNxKx1xf32>, tensor<1xKx1xMxf32>) -// outs(%fill : tensor<1xNx1xM>) -> tensor<1xNx1xMxf32> -bool isTransformableIntoMatmul(linalg::Conv2DNhwcHwcfOp convOp); - -// linalg.conv_2d_nhwc_hwcf -// {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} -// ins(%input, %kernel : tensor<1xNxKx1xf32>, tensor<1xKx1xMxf32>) -// outs(%fill : tensor<1xNx1xM>) -> tensor<1xNx1xMxf32> -// -// into -// -// linalg.matmul -// ins(%lhs, %rhs : tensor, tensor) -// outs(%fill : tensor) -> tensor<1xNx1xMxf32> -FailureOr convertConvToMatmul(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter); - -// Converts linalg.batch_matmul into linalg.matmul. -FailureOr convertBatchMatmulToMatmul( - linalg::BatchMatmulOp batchMatmulOp, PatternRewriter &rewriter); - -// Converts linalg.matvec into linalg.dot. -FailureOr convertMatvecToDotOp(PatternRewriter &rewriter, - linalg::MatvecOp matvecOp); - -// Converts linalg.dot into linalg.reduce(linalg.map). -FailureOr convertDotOpToReduce(linalg::DotOp dotOp, - PatternRewriter &rewriter); - -} // namespace mlir::gml_st - -#endif // MLIR_HLO_GML_ST_UTILS_LINALG_UTILS_H diff --git a/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.cc b/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.cc deleted file mode 100644 index 50e6cd31c495de..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "gml_st/utils/tensor_utils.h" - -namespace mlir::gml_st { - -// Returns ids of size-1 dims that were expanded or collapsed by -// tensor.expand_shape/tensor.collapse_shape. -SmallVector getPreservedDimensions( - ArrayRef shape, - ArrayRef reassociationIndices) { - SmallVector result; - for (ReassociationIndicesRef indices : reassociationIndices) { - const auto* findIt = - llvm::find_if(indices, [&](int64_t idx) { return shape[idx] != 1; }); - result.push_back(findIt == indices.end() ? 0 : *findIt); - } - return result; -} - -} // namespace mlir::gml_st diff --git a/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.h b/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.h deleted file mode 100644 index 0d0de996984ef0..00000000000000 --- a/third_party/xla/xla/mlir_hlo/gml_st/utils/tensor_utils.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_GML_ST_UTILS_TENSOR_UTILS_H -#define MLIR_HLO_GML_ST_UTILS_TENSOR_UTILS_H - -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" - -namespace mlir { -namespace gml_st { - -// TODO(vuson): maybe overload this function instead of templating it. -// Check if the reshape operation is only expanding into/collapsing of -// unit-dimension. -template -bool isDegenerateReshapeOp(TensorReshapeOp reshapeOp) { - constexpr bool isExpanding = - std::is_same::value; - llvm::ArrayRef expandedShape = - (isExpanding ? reshapeOp.getResultType().getShape() - : reshapeOp.getSrcType().getShape()); - for (auto& indices : reshapeOp.getReassociationIndices()) { - // For each reassociation indices, a degenerate reshape op only has at most - // 1 non-unit-dimension, i.e. number of unit-dimensions is greater or equal - // to the indices size - 1. - if (static_cast( - llvm::count_if(indices, [&expandedShape](int64_t idx) { - return expandedShape[idx] == 1; - })) < indices.size() - 1) - return false; - } - return true; -} - -// Returns ids of size-1 dims that were expanded or collapsed by -// tensor.expand_shape/tensor.collapse_shape. -SmallVector getPreservedDimensions( - ArrayRef shape, - ArrayRef reassociationIndices); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_GML_ST_UTILS_TENSOR_UTILS_H diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 438013e4b3bcc1..1bd60985e1a0d3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -107,26 +107,6 @@ add_mlir_library(MhloPasses StablehloBroadcastUtils ) -add_mlir_library(MhloToThloConversion - legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc - - DEPENDS - MLIRMhloPassIncGen - THLODialect - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MhloDialect - MhloToArithmeticConversion - MhloTypeConversion - THLODialect - MLIRIR - MLIRMhloUtils - MLIRPass - MLIRTransformUtils -) add_mlir_library(MhloToMemrefConversion hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -316,21 +296,7 @@ target_link_libraries(AllMhloPasses INTERFACE MhloToStandard HloToLinalgUtils MhloToLinalg - MhloToThloConversion MhloShapeOpsToStandard MhloToStablehlo StablehloToMhlo ) - -add_library(AllGmlStPasses INTERFACE) -target_link_libraries(AllGmlStPasses INTERFACE - GmlStPasses - GmlStTestPasses - MLIRFuncDialect - MLIRPass -) - -add_library(AllThloPasses INTERFACE) -target_link_libraries(AllThloPasses INTERFACE - ThloPasses -) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc deleted file mode 100644 index f40b19ffb5fb55..00000000000000 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc +++ /dev/null @@ -1,465 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mhlo/utils/legalize_to_linalg_utils.h" -#include "mhlo/utils/mhlo_scatter_gather_utils.h" -#include "mhlo/utils/type_conversion.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace mhlo { - -#define GEN_PASS_DEF_LEGALIZEMHLOTOTHLOPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -Value castToIndex(OpBuilder& b, Location loc, TensorType originalType, - Value value) { - Type elementTy = originalType.getElementType(); - if (elementTy.isIndex()) return value; - - Type indexType = b.getIndexType(); - Value emptyTensor = b.create( - loc, tensor::getMixedSizes(b, loc, value), indexType); - - auto map = b.create( - loc, value, emptyTensor, - [&](OpBuilder& nestedB, Location loc, ValueRange args) { - Value elem = args.front(); - Value res = - elementTy.isUnsignedInteger() - ? nestedB.create(loc, indexType, elem) - .getResult() - : nestedB.create(loc, indexType, elem) - .getResult(); - - b.create(loc, res); - }); - return map->getResult(0); -} - -struct ConcatenateOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ConcatenateOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - const int64_t concatDim = op.getDimension(); - const Location loc = op.getLoc(); - const Value anyOperand = adaptor.getVal().front(); - - auto resultTy = typeConverter->convertType(op.getResult().getType()) - .cast(); - const ArrayRef resultShape = resultTy.getShape(); - const int64_t rank = resultTy.getRank(); - - // Determine empty tensor size. - SmallVector staticInitSizes(resultShape.begin(), - resultShape.end()); - SmallVector dynamicInitSizes; - for (int64_t i = 0; i < rank; ++i) { - // No need to materialize anything for static dimensions. - if (staticInitSizes[i] != ShapedType::kDynamic) { - continue; - } - - // For all dimensions other than the concatenation dimension, we can copy - // the size from any operand. - if (i != static_cast(concatDim)) { - dynamicInitSizes.push_back( - rewriter.create(loc, anyOperand, i)); - continue; - } - - // For the concatenation dimensions, sum up the sizes of all operands in - // that dimension. - int64_t staticSum = 0; - Value dynamicSum; - for (const Value operand : adaptor.getVal()) { - auto operandTy = operand.getType().cast(); - if (operandTy.getDimSize(concatDim) == ShapedType::kDynamic) { - const Value dynamicSummand = - rewriter.create(loc, operand, concatDim); - if (dynamicSum) { - dynamicSum = - rewriter.create(loc, dynamicSum, dynamicSummand); - } else { - dynamicSum = dynamicSummand; - } - } else { - staticSum += operandTy.getDimSize(concatDim); - } - } - assert(dynamicSum && "expect at least one dynamic summand in this case"); - if (staticSum != 0) { - dynamicSum = rewriter.create( - loc, dynamicSum, - rewriter.create(loc, staticSum)); - } - dynamicInitSizes.push_back(dynamicSum); - } - - // Create empty tensor and the new concat op. - auto emptyTensor = rewriter.create( - loc, staticInitSizes, resultTy.getElementType(), dynamicInitSizes); - rewriter.replaceOpWithNewOp( - op, resultTy, adaptor.getVal(), emptyTensor, - rewriter.getIndexAttr(concatDim)); - return success(); - } -}; - -struct DynamicBroadcastInDimOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - auto loc = op.getLoc(); - Value outputDimensions = adaptor.getOutputDimensions(); - auto operandTy = adaptor.getOperand().getType().cast(); - auto resultTy = - typeConverter->convertType(op.getType()).cast(); - - // Only apply to broadcasts that cannot be lowered to linalg, i.e. those - // for which we do not know their expansion behavior at compile time. - int64_t countKnownExpansionBehavior = 0; - if (auto expandingDims = op.getKnownExpandingDimensions()) { - countKnownExpansionBehavior += expandingDims->size(); - } - if (auto nonexpandingDims = op.getKnownNonexpandingDimensions()) { - countKnownExpansionBehavior += nonexpandingDims->size(); - } - if (operandTy.getRank() == countKnownExpansionBehavior) return failure(); - - // Create empty tensor as none of the operands are reusable/updatable. - SmallVector dynamicDims; - SmallVector staticShapeInfo; - for (int i = 0; i < resultTy.getRank(); i++) { - dynamicDims.push_back(rewriter.create( - loc, outputDimensions, - ValueRange{rewriter.create(loc, i)})); - staticShapeInfo.push_back(ShapedType::kDynamic); - } - auto emptyTensor = rewriter.create( - loc, staticShapeInfo, resultTy.getElementType(), dynamicDims); - - auto broadcastDims = rewriter.getDenseI64ArrayAttr( - llvm::to_vector(op.getBroadcastDimensions().getValues())); - - DenseI64ArrayAttr knownExpandingDims; - if (op.getKnownExpandingDimensions().has_value()) { - knownExpandingDims = rewriter.getDenseI64ArrayAttr(llvm::to_vector( - op.getKnownExpandingDimensionsAttr().getValues())); - } - DenseI64ArrayAttr knownNonexpandingDims; - if (op.getKnownNonexpandingDimensions().has_value()) { - knownNonexpandingDims = rewriter.getDenseI64ArrayAttr(llvm::to_vector( - op.getKnownNonexpandingDimensionsAttr().getValues())); - } - - rewriter.replaceOpWithNewOp( - op, resultTy, adaptor.getOperand(), emptyTensor, broadcastDims, - knownExpandingDims, knownNonexpandingDims); - return success(); - } -}; - -// Rewrites simple gather patterns (as checked below). -struct GatherPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::GatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - if (!isCanonicalGather(op)) return failure(); - auto startIndicesType = - adaptor.getStartIndices().getType().dyn_cast(); - auto operandType = - adaptor.getOperand().getType().dyn_cast(); - - if (!startIndicesType || !operandType) return failure(); - - auto resultType = - typeConverter->convertType(op.getType()).cast(); - SmallVector sizes; - sizes.reserve(resultType.getRank()); - if (resultType.getDimSize(0) != ShapedType::kDynamic) { - sizes.push_back(rewriter.getI64IntegerAttr(resultType.getDimSize(0))); - } else { - sizes.push_back( - rewriter - .create(op.getLoc(), adaptor.getStartIndices(), 0) - .getResult()); - } - llvm::copy(op.getSliceSizes().getValues(), - std::back_inserter(sizes)); - - auto emptyTensor = rewriter.create( - op.getLoc(), sizes, resultType.getElementType()); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getOperand(), - castToIndex(rewriter, op.getLoc(), op.getStartIndices().getType(), - adaptor.getStartIndices()), - emptyTensor); - return success(); - } -}; - -bool isInBodyOfThloOp(Operation* op) { - auto* parentOp = op->getParentRegion()->getParentOp(); - return isa(*parentOp) || isa(*parentOp); -} - -// Rewrites a mhlo::ReturnOp inside a thlo::ReductionOp to thlo::YieldOp. -struct ThloRegionReturnOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - if (!isInBodyOfThloOp(op)) return failure(); - SmallVector operands(adaptor.getOperands()); - auto loc = op.getLoc(); - for (size_t i = 0; i < operands.size(); ++i) { - if (operands[i].getType().isa()) { - operands[i] = rewriter.create(loc, operands[i]); - } - } - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; - -struct ScatterPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - // Only canonicalized single-result scatter ops are supported. - if (!isCanonicalScatter(op) || op.getNumResults() != 1) return failure(); - - auto opType = - typeConverter->convertType(op.getType(0)).dyn_cast(); - if (!opType) return failure(); - - Location loc = op.getLoc(); - auto thloScatter = rewriter.create( - loc, opType, - castToIndex(rewriter, loc, op.getScatterIndices().getType(), - adaptor.getScatterIndices()), - adaptor.getUpdates().front(), adaptor.getInputs().front()); - - Region& region = thloScatter.getUpdateComputation(); - rewriter.inlineRegionBefore(op.getRegion(), region, region.end()); - - // Convert the signature of the body by inserting - // tensor.from_elements/tensor.extract. - TypeConverter::SignatureConversion signatureConverter(2); - for (const auto& [idx, val] : llvm::enumerate( - thloScatter.getUpdateComputation().getArgumentTypes())) { - signatureConverter.addInputs( - 1 - idx, typeConverter->convertType( - val.cast().getElementType())); - } - rewriter.applySignatureConversion(®ion, signatureConverter, - getTypeConverter()); - - rewriter.replaceOp(op, thloScatter.getResults()); - return success(); - } -}; - -struct SortPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::SortOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - Location loc = op.getLoc(); - - SmallVector outputs; - SmallVector operandTypes; - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - for (auto [operand, resultType] : - llvm::zip(adaptor.getInputs(), resultTypes)) { - RankedTensorType operandType = - operand.getType().dyn_cast(); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expects known-rank operands"); - operandTypes.push_back(operandType); - auto tensorResultType = resultType.cast(); - - Value emptyTensor = rewriter.create( - loc, tensorResultType.getShape(), tensorResultType.getElementType()); - - outputs.push_back(emptyTensor); - } - - int64_t dimension = op.getDimension(); - // TODO(bchetioui): MHLO accepts dimensions in the range [-rank, rank), - // while THLO accepts only dimensions in the range [0, rank). Ideally, they - // should agree on the range of acceptable arguments, but while it is not - // the case, this is a (reliable) workaround. - if (dimension < 0) dimension = dimension + operandTypes.front().getRank(); - bool isStable = op.getIsStable(); - - auto thloSort = rewriter.create( - loc, resultTypes, adaptor.getInputs(), outputs, - rewriter.getIndexAttr(dimension), rewriter.getBoolAttr(isStable)); - - Region& region = thloSort.getComparator(); - rewriter.inlineRegionBefore(op.getComparator(), region, region.end()); - - assert(thloSort.getNumDpsInputs() == thloSort.getNumDpsInits()); - - // Convert the signature of the comparator. - TypeConverter::SignatureConversion signatureConverter( - thloSort.getNumDpsInputs() * 2); - for (const auto& [idx, val] : llvm::enumerate(operandTypes)) { - signatureConverter.addInputs( - /*origInputNo=*/2 * idx, - typeConverter->convertType(val.getElementType())); - signatureConverter.addInputs( - /*origInputNo=*/2 * idx + 1, - typeConverter->convertType(val.getElementType())); - } - - rewriter.applySignatureConversion(®ion, signatureConverter, - getTypeConverter()); - - rewriter.replaceOp(op, thloSort.getResults()); - return success(); - } -}; - -struct ReversePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::ReverseOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - auto reverseDimensions = - llvm::to_vector(op.getDimensions().getValues()); - Type resultType = typeConverter->convertType(op->getResultTypes()[0]); - if (!resultType) - return rewriter.notifyMatchFailure(op, "failed to convert result type"); - Location loc = op.getLoc(); - auto operandType = - adaptor.getOperand().getType().dyn_cast(); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expects known-rank operand"); - auto tensorResultType = resultType.cast(); - SmallVector dynShape = - tensor::createDynamicDimValues(rewriter, loc, adaptor.getOperand()); - Value initTensor = rewriter.create( - loc, tensorResultType.getShape(), tensorResultType.getElementType(), - dynShape); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getOperand(), initTensor, reverseDimensions); - return success(); - } -}; - -class LegalizeMHLOToTHLOPass - : public impl::LegalizeMHLOToTHLOPassBase { - public: - explicit LegalizeMHLOToTHLOPass(bool enableExperimentalOps) { - enableExperimental = enableExperimentalOps; - } - - private: - void runOnOperation() final { - MLIRContext* ctx = &getContext(); - RewritePatternSet patterns(ctx); - ConversionTarget target(*ctx); - // clang-format off - target.addLegalDialect< - arith::ArithDialect, - complex::ComplexDialect, - linalg::LinalgDialect, - math::MathDialect, - shape::ShapeDialect, - tensor::TensorDialect, - thlo::THLODialect>(); - // clang-format on - target.addLegalOp(); - - auto typeConverter = std::make_unique(); - - populateScalarHloToArithmeticConversionPatterns( - ctx, *typeConverter, &patterns, - [](Operation* op) { return isInBodyOfThloOp(op); }); - - // List of patterns. - // clang-format off - patterns.insert< - ConcatenateOpPattern, - GatherPattern, - ReversePattern, - ScatterPattern, - SortPattern, - ThloRegionReturnOpConversion>(*typeConverter, ctx); - // clang-format on - - if (enableExperimental) { - patterns.insert(*typeConverter, ctx); - } - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> createLegalizeMHLOToTHLOPass( - bool enableExperimentalOps) { - return std::make_unique(enableExperimentalOps); -} - -} // namespace mhlo -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 32c5822b52de82..e29f17382437d3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -146,22 +146,6 @@ def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "func::FuncOp"> { "transpose) when possible, instead of linalg.generic">]; } -def LegalizeMHLOToTHLOPass : Pass<"legalize-mhlo-to-thlo", "mlir::func::FuncOp"> { - let summary = "Legalize from HLO dialect to tHLO dialect."; - let constructor = "::mlir::mhlo::createLegalizeMHLOToTHLOPass()"; - let options = - [Option<"enableExperimental", "enable-experimental", "bool", - /*default=*/"false", - "Enable conversion to operations that are still under " - "developement and might not be working in some pipelines. For " - "example, thlo.map and thlo.transpose.">]; - let dependentDialects = [ - "arith::ArithDialect", "complex::ComplexDialect", - "linalg::LinalgDialect", "math::MathDialect", "shape::ShapeDialect", - "tensor::TensorDialect", "thlo::THLODialect" - ]; -} - def HloLegalizeShapeComputationsPass : Pass<"hlo-legalize-shape-computations", "func::FuncOp"> { let summary = "Legalize HLOs shape operations to core-mlir operations."; let constructor = "createLegalizeShapeComputationsPass()"; diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir deleted file mode 100644 index cc3690fea32e95..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: mlir-hlo-opt %s --add-debug-info --mlir-print-debuginfo | FileCheck %s - -builtin.module { - func.func @foo() { - return - } -} - -// CHECK: module -// CHECK: func.func @[[SUBPROGRAM_NAME:.*]]() { -// CHECK: return loc(#[[RET_LOC:.*]]) -// CHECK: } loc(#[[FUSED_SUBPROGRAM_LOC:.*]]) -// CHECK: } loc(#[[MODULE_LOC:.*]]) -// CHECK: #di_basic_type = #llvm.di_basic_type -// CHECK: #di_file = #llvm.di_file<"[[FILE_NAME:.*]]" in "[[DIR_NAME:.*]]"> -// CHECK: #[[MODULE_LOC]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE:]]:1) -// CHECK: #[[SUBPROGRAM_LOC:.*]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE+1]]:3) -// CHECK: #[[RET_LOC]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE+2]]:5) -// CHECK: #di_compile_unit = #llvm.di_compile_unit -// CHECK: #di_subroutine_type = #llvm.di_subroutine_type -// CHECK: #di_subprogram = #llvm.di_subprogram -// CHECK: #[[FUSED_SUBPROGRAM_LOC]] = loc(fused<#di_subprogram>[#[[SUBPROGRAM_LOC]]]) \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir deleted file mode 100644 index 9c4fd72e90b0df..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir +++ /dev/null @@ -1,303 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-rewrite-from-elements-ops \ -// RUN: -eliminate-empty-tensors -empty-tensor-to-alloc-tensor \ -// RUN: -hlo-one-shot-bufferize -canonicalize -cse -canonicalize \ -// RUN: -split-input-file | FileCheck %s - -func.func @set_tile(%input: tensor) -> tensor<2x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %dim_0 = tensor.dim %input, %c0 : tensor - %dim_1 = tensor.dim %input, %c1 : tensor - - %slice = tensor.extract_slice %input[0, 1][2, 4][1, 1] - : tensor to tensor<2x4xf32> - - return %slice : tensor<2x4xf32> -} -// CHECK-LABEL: func @set_tile( -// CHECK-SAME: %[[ARG:.*]]: memref) -// CHECK-NEXT: %[[VIEW:.*]] = memref.subview %[[ARG]][0, 1] [2, 4] [1, 1] -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<2x4xf32> -// CHECK-NEXT: memref.copy %[[VIEW]], %[[ALLOC]] -// CHECK-NEXT: return %[[ALLOC]] : memref<2x4xf32> - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @parallel_with_tiles(%lhs: tensor, %rhs: tensor, - %out : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %dim_0 = tensor.dim %lhs, %c0 : tensor - %dim_1 = tensor.dim %lhs, %c1 : tensor - - %result = scf.forall (%i, %j) = (%c0, %c0) to (%dim_0, %dim_1) - step (%c4, %c1) shared_outs (%out_ = %out) -> (tensor) { - %7 = arith.addi %i, %c4 : index - %8 = arith.cmpi sgt, %7, %dim_0 : index - %9 = arith.subi %dim_0, %i : index - %size_0 = arith.select %8, %9, %c4 : index - - %lhs_tile = tensor.extract_slice %lhs[%i, %j] [%size_0, 1] [1, 1] - : tensor to tensor - %rhs_tile = tensor.extract_slice %rhs[%i, %j] [%size_0, 1] [1, 1] - : tensor to tensor - %init_tile = tensor.extract_slice %out_[%i, %j] [%size_0, 1] [1, 1] - : tensor to tensor - %sum = linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%lhs_tile, %rhs_tile : tensor, tensor) - outs(%init_tile : tensor) { - ^bb0(%l: f32, %r: f32, %o: f32): - %add = arith.addf %l, %r : f32 - linalg.yield %add : f32 - } -> tensor - scf.forall.in_parallel { - tensor.parallel_insert_slice %sum into %out_[%i, %j] [%size_0, 1] [1, 1] - : tensor into tensor - } - } - return %result : tensor -} -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: func @parallel_with_tiles( -// CHECK-SAME: %[[LHS:.*]]: memref, %[[RHS:.*]]: memref, -// CHECK-SAME: %[[OUT:.*]]: memref) -> memref { - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[DIM_0:.*]] = memref.dim %[[LHS]], %[[C0]] : memref -// CHECK: %[[DIM_1:.*]] = memref.dim %[[LHS]], %[[C1]] : memref - -// CHECK: scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) -// CHECK-SAME: to (%[[DIM_0]], %[[DIM_1]]) step (4, 1) { - -// CHECK-DAG: %[[LHS_SUB:.*]] = memref.subview %[[LHS]][%[[I]], %[[J]]] -// CHECK-SAME: : memref to memref> -// CHECK-DAG: %[[RHS_SUB:.*]] = memref.subview %[[RHS]][%[[I]], %[[J]]] -// CHECK-SAME: : memref to memref> -// CHECK-DAG: %[[OUT_SUB:.*]] = memref.subview %[[OUT]][%[[I]], %[[J]]] -// CHECK-SAME: : memref to memref> - -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]] -// CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : memref> -// CHECK-SAME: outs(%[[OUT_SUB]] : memref>) -// CHECK: } -// CHECK: return %[[OUT]] : memref - -// ----- - -func.func @materialize_and_yield_with_constants( - %in: tensor<8x2xf32>, %out: tensor<8x2xf32>) -> tensor<8x2xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - - %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c8, %c2) step (%c1, %c1) - shared_outs (%out_ = %out) -> (tensor<8x2xf32>) { - %2 = tensor.extract_slice %in[%i, %j] [1, 1] [1, 1] - : tensor<8x2xf32> to tensor<1x1xf32> - %3 = tensor.extract %2[%c0, %c0] : tensor<1x1xf32> - %4 = math.absf %3: f32 - %5 = tensor.from_elements %4 : tensor - scf.forall.in_parallel { - tensor.parallel_insert_slice %5 into %out_[%i, %j] [1, 1] [1, 1] - : tensor into tensor<8x2xf32> - } - } - return %1 : tensor<8x2xf32> -} -// CHECK-LABEL: func @materialize_and_yield_with_constants -// CHECK-SAME: %[[IN:.*]]: memref<8x2xf32>, %[[OUT:.*]]: memref<8x2xf32>) - -// CHECK: scf.forall (%[[I:.*]], %[[J:.*]]) in (8, 2) -// CHECK-NEXT: %[[SLICE:.*]] = memref.subview %[[IN]][%[[I]], %[[J]]] -// CHECK-NEXT: %[[ELEM:.*]] = memref.load %[[SLICE]] -// CHECK-NEXT: %[[ABS:.*]] = math.absf %[[ELEM]] : f32 -// CHECK-NEXT: %[[OUT_SLICE:.*]] = memref.subview %[[OUT]] -// CHECK-SAME: [%[[I]], %[[J]]] [1, 1] [1, 1] -// CHECK-NEXT: memref.store %[[ABS]], %[[OUT_SLICE]][] - -// ----- - -func.func @same_enclosing_repetitive_region(%2: tensor<320xf32>, - %3: tensor<320x10240xf32>) - -> tensor<320xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant -0.000000e+00 : f32 - %c320 = arith.constant 320 : index - %4 = scf.forall (%i) = (%c0) to (%c320) step (%c1) - shared_outs(%arg1 = %2) -> (tensor<320xf32>) { - %5 = tensor.extract_slice %3[%i, 0] [1, 10240] [1, 1] : tensor<320x10240xf32> to tensor<1x10240xf32> - %6 = tensor.extract_slice %arg1[%i] [1] [1] : tensor<320xf32> to tensor<1xf32> - %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32> - %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32> - - scf.forall.in_parallel { - tensor.parallel_insert_slice %8 into %arg1[%i] [1] [1] - : tensor<1xf32> into tensor<320xf32> - } - } - return %4 : tensor<320xf32> -} -// CHECK-LABEL: @same_enclosing_repetitive_region -// CHECK-NOT: memref.alloc - -// ----- - -// CHECK-LABEL: func @scf.forall_private_var( -// CHECK-SAME: %[[t:.*]]: memref<10xf32 -func.func @scf.forall_private_var(%t: tensor<10xf32>) -> f32 { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c5 = arith.constant 5 : index - - // A copy is inserted for the uses of %t in the loop. - // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32> - // CHECK: memref.copy %[[t]], %[[t_copy]] - - // CHECK: scf.forall - - // Load from the copy and store into the shared output. - // CHECK: %[[subview:.*]] = memref.subview %[[t_copy]] - // CHECK: memref.load %[[t]] - // CHECK: memref.store %{{.*}}, %[[subview]] - %0 = scf.forall (%tid) = (%c0) to (%c2) step (%c1) - shared_outs (%o = %t) -> (tensor<10xf32>) { - %offset = arith.muli %c5, %tid : index - %slice = tensor.extract_slice %o[%offset] [5] [1] - : tensor<10xf32> to tensor<5xf32> - %r2 = tensor.extract %t[%tid] : tensor<10xf32> - %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32> - - scf.forall.in_parallel { - tensor.parallel_insert_slice %i into %o[%offset][5][1] - : tensor<5xf32> into tensor<10xf32> - } - } - %r = tensor.extract %0[%c2] : tensor<10xf32> - return %r : f32 -} - -// ----- - -func.func @gml_st_fusion(%arg0: tensor, - %init: tensor) -> tensor { - %0 = gml_st.fusion ins(%a0 = %arg0 : tensor) - inits(%in = %init : tensor) { - %res = linalg.map { math.exp } - ins(%a0 : tensor) - outs(%in : tensor) - gml_st.yield %res : tensor - } : tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @gml_st_fusion -// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref -// CHECK: gml_st.fusion -// CHECK-SAME: ins(%[[ARG0_:.*]] = %[[ARG0]]: memref) -// CHECK-SAME: inits(%[[ARG1_:.*]] = %[[ARG1]]: memref) -// CHECK: linalg.map { math.exp } -// CHECK-SAME: ins(%[[ARG0_]] : memref) -// CHECK-SAME: outs(%[[ARG1_]] : memref) -// CHECK: gml_st.yield %[[ARG1_]] : memref -// CHECK: return %[[ARG1]] : memref - -// ----- - -func.func @gml_st_fusion_temp_tensor( - %arg0: tensor, %arg1: tensor) -> tensor { - %c0 = arith.constant 0 : index - %dim0 = tensor.dim %arg0, %c0 : tensor - %init = tensor.empty(%dim0) : tensor - %0 = gml_st.fusion ins(%arg0_ = %arg0 : tensor, - %arg1_ = %arg1 : tensor) - inits(%init_ = %init : tensor) { - %c0_ = arith.constant 0 : index - %dim0_ = tensor.dim %arg0_, %c0_ : tensor - %temp = tensor.empty(%dim0_) : tensor - %map0 = linalg.map { math.exp } - ins(%arg0_ : tensor) - outs(%temp : tensor) - %map1 = linalg.map { arith.mulf } - ins(%map0, %arg1_ : tensor, tensor) - outs(%init_ : tensor) - gml_st.yield %map1 : tensor - } : tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @gml_st_fusion_temp_tensor -// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) -// CHECK: gml_st.fusion -// CHECK-SAME: ins(%[[ARG0_:.*]] = %[[ARG0]]: memref, -// CHECK-SAME: %[[ARG1_:.*]] = %[[ARG1]]: memref) -// CHECK-SAME: inits(%[[INIT_:.*]] = %[[ALLOC]]: memref) -// CHECK-DAG: %[[C0_:.*]] = arith.constant 0 : index -// CHECK: %[[DIM_:.*]] = memref.dim %[[ARG0_]], %[[C0_]] -// CHECK: %[[ALLOC_:.*]] = memref.alloc(%[[DIM_]]) -// CHECK: linalg.map { math.exp } -// CHECK-SAME: ins(%[[ARG0_]] -// CHECK-SAME: outs(%[[ALLOC_]] -// CHECK: linalg.map { arith.mulf } -// CHECK-SAME: ins(%[[ALLOC_]], %[[ARG1_]] -// CHECK-SAME: outs(%[[INIT_]] -// CHECK: gml_st.yield %[[INIT_]] : memref -// CHECK: return %[[ALLOC]] : memref - -// ----- - -func.func @gml_st_fusion_scalar_scf_for(%arg0: tensor) -> tensor { - %0 = tensor.empty() : tensor - %1 = gml_st.fusion - ins(%arg1 = %arg0: tensor) - inits(%arg2 = %0: tensor) { - %c1_i64 = arith.constant 1 : i64 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim = tensor.dim %arg1, %c0 : tensor - %2 = scf.for %arg3 = %c0 to %dim step %c1 - iter_args(%arg4 = %c1_i64) -> (i64) { - %extracted = tensor.extract %arg1[%arg3] : tensor - %3 = arith.muli %arg4, %extracted : i64 - scf.yield %3 : i64 - } - %from_elements = tensor.from_elements %2 : tensor - gml_st.yield %from_elements : tensor - } : tensor - return %1 : tensor -} - -// CHECK-LABEL: func.func @gml_st_fusion_scalar_scf_for -// CHECK-SAME: (%[[ARG0:.*]]: memref) -// CHECK: %[[ALLOC:.*]] = memref.alloc() -// CHECK: gml_st.fusion -// CHECK-SAME: ins(%[[ARG0_:.*]] = %[[ARG0]]: memref) -// CHECK-SAME: inits(%[[ALLOC_:.*]] = %[[ALLOC]]: memref) -// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[ARG0_]], %[[C0]] -// CHECK: %[[FOR:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[DIM]] -// CHECK-SAME: step %[[C1]] iter_args(%[[ARG4:.*]] = %[[C1_I64]]) -// CHECK: %[[LOAD:.*]] = memref.load %[[ARG0_]][%[[ARG3]]] -// CHECK: %[[MULI:.*]] = arith.muli %[[ARG4]], %[[LOAD]] -// CHECK: scf.yield %[[MULI]] : i64 -// CHECK: %[[ALLOC_0:.*]] = memref.alloc() -// CHECK: memref.store %[[FOR]], %[[ALLOC_0]][] -// CHECK: memref.copy %[[ALLOC_0]], %[[ALLOC_]] -// CHECK: gml_st.yield %[[ALLOC_]] : memref -// CHECK: return %[[ALLOC]] : memref diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir deleted file mode 100644 index a67378fb603e8d..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir +++ /dev/null @@ -1,288 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-collapse-shape | FileCheck %s - -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-collapse-shape="retain-trailing-dims=1" | \ -// RUN: FileCheck %s --check-prefix=CHECK-1 - -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-collapse-shape="retain-trailing-dims=2" | \ -// RUN: FileCheck %s --check-prefix=CHECK-2 - -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-collapse-shape="retain-trailing-dims=3" | \ -// RUN: FileCheck %s --check-prefix=CHECK-3 - -func.func @bcast(%arg0: tensor<2x4x2048xf32>) -> tensor<2x4x2048x4096xf32> { - %0 = tensor.empty() : tensor<2x4x2048x4096xf32> - %1 = linalg.broadcast - ins(%arg0 : tensor<2x4x2048xf32>) - outs(%0 : tensor<2x4x2048x4096xf32>) - dimensions = [3] - return %1 : tensor<2x4x2048x4096xf32> -} - -// CHECK: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) -// CHECK-NOT: collapse_shape -// CHECK-NOT: expand_shape - -// CHECK-1: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) -// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-1-SAME: [0, 1, 2]] -// CHECK-1: %[[EMPTY:.*]] = tensor.empty() -// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-1: ins(%[[COLLAPSED]] : tensor<16384xf32>) -// CHECK-1: outs(%[[EMPTY]] : tensor<16384x4096xf32>) -// CHECK-1: dimensions = [1] -// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: return %[[EXPANDED]] - -// CHECK-2: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) -// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-2-SAME: [0, 1], [2]] -// CHECK-2: %[[EMPTY:.*]] = tensor.empty() -// CHECK-2: %[[BROADCASTED:.*]] = linalg.broadcast -// CHECK-2-SAME: ins(%[[COLLAPSED]] : tensor<8x2048xf32>) -// CHECK-2-SAME: outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) -// CHECK-2: dimensions = [2] -// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCASTED]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: return %[[EXPANDED]] - -// CHECK-3: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) -// CHECK-3-NOT: collapse_shape -// CHECK-3-NOT: expand_shape - -// ----- - -func.func @bcast_from_scalar() -> tensor<2x4x2048x4096xf32> { - %0 = tensor.empty() : tensor<2x4x2048x4096xf32> - %cst = arith.constant 0xFF800000 : f32 - %1 = tensor.empty() : tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = linalg.broadcast - ins(%2 : tensor) - outs(%0 : tensor<2x4x2048x4096xf32>) - dimensions = [0, 1, 2, 3] - return %3 : tensor<2x4x2048x4096xf32> -} - -// CHECK: func.func @bcast_from_scalar() -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<67108864xf32> -// CHECK: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK: ins(%{{.*}} : tensor) -// CHECK: outs(%[[EMPTY]] : tensor<67108864xf32>) -// CHECK: dimensions = [0] -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ -// CHECK-SAME: 0, 1, 2, 3]] -// CHECK: return %[[EXPANDED]] - -// CHECK-1: func.func @bcast_from_scalar() -// CHECK-1: %[[EMPTY:.*]] = tensor.empty() : tensor<16384x4096xf32> -// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-1-SAME: ins(%{{.*}} : tensor) -// CHECK-1-SAME: outs(%[[EMPTY]] : tensor<16384x4096xf32>) -// CHECK-1-SAME: dimensions = [1, 0] -// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: return %[[EXPANDED]] - -// CHECK-2: func.func @bcast_from_scalar() -// CHECK-2: %[[EMPTY:.*]] = tensor.empty() : tensor<8x2048x4096xf32> -// CHECK-2: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-2-SAME: ins(%{{.*}} : tensor -// CHECK-2-SAME: outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) -// CHECK-2-SAME: dimensions = [1, 2, 0] -// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: return %[[EXPANDED]] - -// CHECK-3: func.func @bcast_from_scalar() -// CHECK-3-NOT: collapse_shape -// CHECK-3-NOT: expand_shape - -// ----- - -func.func @reduction(%arg0: tensor<2x4x2048x4096xf32>) -> tensor<2x4x2048xf32> { - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<2x4x2048xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4x2048xf32>) - -> tensor<2x4x2048xf32> - %2 = linalg.reduce { arith.maximumf } - ins(%arg0 : tensor<2x4x2048x4096xf32>) - outs(%1 : tensor<2x4x2048xf32>) - dimensions = [3] - return %2 : tensor<2x4x2048xf32> -} - -// CHECK: func.func @reduction(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-NOT: collapse_shape -// CHECK-NOT: expand_shape - -// CHECK-1: func.func @reduction(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-1-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: %[[EMPTY:.*]] = tensor.empty() -// CHECK-1: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<16384xf32>) -// CHECK-1: %[[REDUCED:.*]] = linalg.reduce { arith.maximumf } -// CHECK-1-SAME: ins(%[[COLLAPSED]] : tensor<16384x4096xf32>) -// CHECK-1-SAME: outs(%[[FILL]] : tensor<16384xf32>) -// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[REDUCED]] [ -// CHECK-1-SAME: [0, 1, 2]] -// CHECK-1: return %[[EXPANDED]] - - -// ----- - -func.func @cwise(%arg0: tensor<2x4x2048x4096xf32>, - %arg1: tensor<2x4x2048x4096xf32>) -> tensor<2x4x2048x4096xf32> { - %0 = tensor.empty() : tensor<2x4x2048x4096xf32> - %1 = linalg.map { arith.subf } - ins(%arg0, %arg1 : tensor<2x4x2048x4096xf32>, tensor<2x4x2048x4096xf32>) - outs(%0 : tensor<2x4x2048x4096xf32>) - return %1 : tensor<2x4x2048x4096xf32> -} - -// CHECK: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-SAME: [0, 1, 2, 3]] -// CHECK: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ -// CHECK-SAME: [0, 1, 2, 3]] -// CHECK: %[[EMPTY:.*]] = tensor.empty() -// CHECK: %[[MAP:.*]] = linalg.map { arith.subf } -// CHECK: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<67108864xf32>, tensor<67108864xf32>) -// CHECK: outs(%[[EMPTY]] : tensor<67108864xf32>) -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ -// CHECK-SAME: [0, 1, 2, 3]] -// CHECK: return %[[EXPANDED]] - -// CHECK-1: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: %[[EMPTY:.*]] = tensor.empty() -// CHECK-1: %[[MAP:.*]] = linalg.map { arith.subf } -// CHECK-1-SAME: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<16384x4096xf32>, tensor<16384x4096xf32>) -// CHECK-1-SAME outs(%[[EMPTY]] : tensor<16384x4096xf32>) -// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: return %[[EXPANDED]] - -// CHECK-2: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: %[[EMPTY:.*]] = tensor.empty() -// CHECK-2: %[[MAP:.*]] = linalg.map { arith.subf } -// CHECK-2-SAME: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<8x2048x4096xf32>, tensor<8x2048x4096xf32>) -// CHECK-2-SAME outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) -// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: return %[[EXPANDED]] - -// CHECK-3: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-3-NOT: collapse_shape -// CHECK-3-NOT: expand_shape - -// ----- - -func.func @partial_softmax(%arg0: tensor<2x4x2048x4096xf32>) - -> tensor<2x4x2048x4096xf32> { - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<2x4x2048xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4x2048xf32>) - -> tensor<2x4x2048xf32> - %2 = linalg.reduce { arith.maximumf } - ins(%arg0 : tensor<2x4x2048x4096xf32>) - outs(%1 : tensor<2x4x2048xf32>) - dimensions = [3] - %3 = tensor.empty() : tensor<2x4x2048x4096xf32> - %4 = linalg.broadcast - ins(%2 : tensor<2x4x2048xf32>) - outs(%3 : tensor<2x4x2048x4096xf32>) - dimensions = [3] - %5 = linalg.map { arith.subf } - ins(%arg0, %4 : tensor<2x4x2048x4096xf32>, tensor<2x4x2048x4096xf32>) - outs(%3 : tensor<2x4x2048x4096xf32>) - return %5 : tensor<2x4x2048x4096xf32> -} - -// CHECK-1: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-1-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: %[[EMPTY:.*]] = tensor.empty() -// CHECK-1: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<16384xf32>) -// CHECK-1: %[[REDUCE:.*]] = linalg.reduce { arith.maximumf } -// CHECK-1-SAME: ins(%[[COLLAPSED]] : tensor<16384x4096xf32>) -// CHECK-1-SAME: outs(%[[FILL]] : tensor<16384xf32>) -// CHECK-1-SAME: dimensions = [1] -// CHECK-1: %[[EMPTY_0:.*]] = tensor.empty() -// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-1-SAME: ins(%[[REDUCE]] : tensor<16384xf32>) -// CHECK-1-SAME: outs(%[[EMPTY_0]] : tensor<16384x4096xf32>) -// CHECK-1-SAME: dimensions = [1] -// CHECK-1: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: %[[EMPTY_1:.*]] = tensor.empty() -// CHECK-1: %[[MAP:.*]] = linalg.map { arith.subf } -// CHECK-1-SAME: ins(%[[COLLAPSED_0]], %[[BROADCAST]] : tensor<16384x4096xf32>, tensor<16384x4096xf32>) -// CHECK-1-SAME: outs(%[[EMPTY_1]] : tensor<16384x4096xf32>) -// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ -// CHECK-1-SAME: [0, 1, 2], [3]] -// CHECK-1: return %[[EXPANDED]] - -// CHECK-2: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-2-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: %[[EMPTY:.*]] = tensor.empty() -// CHECK-2: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<8x2048xf32>) -// CHECK-2: %[[REDUCE:.*]] = linalg.reduce { arith.maximumf } -// CHECK-2-SAME: ins(%[[COLLAPSED]] : tensor<8x2048x4096xf32>) -// CHECK-2-SAME: outs(%[[FILL]] : tensor<8x2048xf32>) -// CHECK-2-SAME: dimensions = [2] -// CHECK-2: %[[EMPTY_0:.*]] = tensor.empty() -// CHECK-2: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-2-SAME: ins(%[[REDUCE]] : tensor<8x2048xf32>) -// CHECK-2-SAME: outs(%[[EMPTY_0]] : tensor<8x2048x4096xf32>) -// CHECK-2-SAME: dimensions = [2] -// CHECK-2: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG0]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: %[[EMPTY_1:.*]] = tensor.empty() -// CHECK-2: %[[MAP:.*]] = linalg.map { arith.subf } -// CHECK-2-SAME: ins(%[[COLLAPSED_0]], %[[BROADCAST]] : tensor<8x2048x4096xf32>, tensor<8x2048x4096xf32>) -// CHECK-2-SAME: outs(%[[EMPTY_1]] : tensor<8x2048x4096xf32>) -// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ -// CHECK-2-SAME: [0, 1], [2], [3]] -// CHECK-2: return %[[EXPANDED]] - -// CHECK-3: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) -// CHECK-3-NOT: collapse_shape -// CHECK-3-NOT: expand_shape - -// ----- - - -func.func @collapse_shape_of_cwise(%arg0: tensor<2x4xf32>) -> tensor<8xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %1 = linalg.map { arith.negf } - ins(%arg0 : tensor<2x4xf32>) - outs(%0 : tensor<2x4xf32>) - %3 = tensor.collapse_shape %1 [[0, 1]] : tensor<2x4xf32> into tensor<8xf32> - return %3 : tensor<8xf32> -} - -// CHECK: func.func @collapse_shape_of_cwise -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape {{.*}} [ -// CHECK-SAME: [0, 1]] : tensor<2x4xf32> into tensor<8xf32> -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK: ins(%[[COLLAPSED]] : tensor<8xf32>) - -// CHECK-1: func.func @collapse_shape_of_cwise -// CHECK-2: func.func @collapse_shape_of_cwise -// CHECK-3: func.func @collapse_shape_of_cwise - diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collect_stats.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collect_stats.mlir deleted file mode 100644 index fd4e1dd18a56f4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/collect_stats.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline=stats-detail-level=1 | \ -// RUN: FileCheck %s --check-prefix=CHECK-1 - -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline=stats-detail-level=2 | \ -// RUN: FileCheck %s --check-prefix=CHECK-2 - -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline=stats-detail-level=3 | \ -// RUN: FileCheck %s --check-prefix=CHECK-3 - -func.func @foo(%arg0: tensor<2x4xf32>, - %arg1: tensor<8x8xf32>, - %arg2: tensor<128xf32>) -> tensor<4x2xf32> { - %cst = arith.constant 0.0 : f32 - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c32 = arith.constant 32 : index - %0 = tensor.empty() : tensor<2x4xf32> - %1 = linalg.map { arith.negf } - ins(%arg0 : tensor<2x4xf32>) - outs(%0 : tensor<2x4xf32>) - %3 = tensor.collapse_shape %1 [[0, 1]] : tensor<2x4xf32> into tensor<8xf32> - %4 = tensor.empty() : tensor<8xf32> - %17 = scf.for %arg13 = %c0 to %c8 step %c32 iter_args(%arg14 = %4) - -> (tensor<8xf32>) { - %extracted_slice = tensor.extract_slice %arg2[%arg13] [32] [1] : - tensor<128xf32> to tensor<32xf32> - %expanded_17 = tensor.expand_shape %extracted_slice [[0, 1]] : - tensor<32xf32> into tensor<4x8xf32> - %reduced_18 = linalg.reduce { arith.addf } - ins(%expanded_17 : tensor<4x8xf32>) - outs(%arg14 : tensor<8xf32>) dimensions = [0] - scf.yield %reduced_18 : tensor<8xf32> - } - %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8xf32>) - -> tensor<8xf32> - %6 = linalg.vecmat ins(%3, %arg1 : tensor<8xf32>, tensor<8x8xf32>) - outs(%5 : tensor<8xf32>) -> tensor<8xf32> - %7 = tensor.expand_shape %6 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32> - %8 = tensor.collapse_shape %7 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32> - %9 = linalg.matvec ins(%arg1, %8 : tensor<8x8xf32>, tensor<8xf32>) - outs(%5 : tensor<8xf32>) -> tensor<8xf32> - %10 = linalg.map { arith.addf } - ins(%17, %9 : tensor<8xf32>, tensor<8xf32>) - outs(%5 : tensor<8xf32>) - %11 = tensor.expand_shape %10 [[0, 1]] : tensor<8xf32> into tensor<4x2xf32> - return %11 : tensor<4x2xf32> -} - -// CHECK-1: *** Tileable ops stats (detail level 1) *** -// CHECK-1-DAG: 1x linalg.fill -// CHECK-1-DAG: 2x linalg.map -// CHECK-1-DAG: 1x linalg.matvec -// CHECK-1-DAG: 1x linalg.reduce -// CHECK-1-DAG: 1x linalg.vecmat -// CHECK-1-DAG: 1x tensor.collapse_shape (degenerate) -// CHECK-1-DAG: 1x tensor.collapse_shape (non-degenerate) -// CHECK-1-DAG: 3x tensor.expand_shape - -// CHECK-2: *** Tileable ops stats (detail level 2) *** -// CHECK-2: 1x linalg.fill -// CHECK-2-NEXT: 1. %{{.*}} = linalg.fill ins({{.*}}) outs({{.*}}) - -// CHECK-3: *** Tileable ops stats (detail level 3) *** -// CHECK-3: 2x linalg.map -// CHECK-3-DAG: %{{.*}} = linalg.map { arith.negf } ins({{.*}}) outs({{.*}}) -// CHECK-3-NEXT: Producers: -// CHECK-3-NEXT: {{.*}} index: 0 -// CHECK-3-NEXT: tensor.empty -// CHECK-3-NEXT: Consumers: -// CHECK-3-NEXT: tensor.collapse_shape -// CHECK-3-DAG: %{{.*}} = linalg.map { arith.addf } ins({{.*}}) outs({{.*}}) -// CHECK-3-NEXT: Producers: -// CHECK-3-NEXT: scf.for -// CHECK-3-NEXT: linalg.matvec -// CHECK-3-NEXT: linalg.fill -// CHECK-3-NEXT: Consumers: -// CHECK-3-NEXT: tensor.expand_shape diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir deleted file mode 100644 index c9616edb586e15..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-compose-extract-insert-slice --split-input-file \ -// RUN: | FileCheck %s - -func.func @compose_slices(%arg: tensor, %i: index, %j: index, - %k: index, %n: index, %a: index, %b: index) -> tensor<4x?xf32> { - %4 = tensor.extract_slice %arg[%i, %j] [4, 128] [2, %a] - : tensor to tensor<4x128xf32> - %5 = tensor.extract_slice %4[0, %k] [4, %n] [1, %b] - : tensor<4x128xf32> to tensor<4x?xf32> - return %5 : tensor<4x?xf32> -} -// CHECK-LABEL: @compose_slices -// CHECK-SAME: %[[ARG:[a-z0-9]+]]: tensor, %[[I:[a-z0-9]+]]: index, -// CHECK-SAME: %[[J:[a-z0-9]+]]: index, %[[K:[a-z0-9]+]]: index, -// CHECK-SAME: %[[N:[a-z0-9]+]]: index, %[[A:[a-z0-9]+]]: index, -// CHECK-SAME: %[[B:[a-z0-9]+]]: index) - -// CHECK-DAG: %[[J_PLUS_AK:.*]] = affine.apply -// CHECK-DAG: %[[AB:.*]] = affine.apply -// CHECK-NEXT: %[[RES:.*]] = tensor.extract_slice %[[ARG]] -// CHECK-SAME: [%[[I]], %[[J_PLUS_AK]]] [4, %[[N]]] [2, %[[AB]]] -// CHECK-SAME: : tensor - -// ----- - -func.func @compose_extract_of_slice(%arg: tensor, %i: index, %j: index, - %k: index, %l: index) -> f32 { - %slice = tensor.extract_slice %arg[%i, %j] [4, 128] [2, %l] - : tensor to tensor<4x128xf32> - %c1 = arith.constant 1 : index - %pt = tensor.extract %slice[%c1, %k] : tensor<4x128xf32> - return %pt : f32 -} -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 + 2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)> - -// CHECK-LABEL: func.func @compose_extract_of_slice -// CHECK-SAME: (%[[ARG:.*]]: tensor, -// CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index, -// CHECK-SAME: %[[L:.*]]: index) -> f32 { - -// CHECK: %[[X:.*]] = affine.apply #[[$MAP0]]()[%[[I]]] -// CHECK: %[[Y:.*]] = affine.apply #[[$MAP1]]()[%[[K]], %[[L]], %[[J]]] -// CHECK: tensor.extract %[[ARG]][%[[X]], %[[Y]]] : tensor - diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/batch_matmul.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/batch_matmul.mlir deleted file mode 100644 index 1bbad52432eaa8..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/batch_matmul.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --gml-st-cpu-tiling-pipeline=matmul-tile-sizes=4,4,4 \ -// RUN: | FileCheck %s - -func.func @batch_matmul(%lhs: tensor<8x64x32xf32>, - %rhs: tensor<8x32x64xf32>) -> tensor<8x64x64xf32> { - %37 = tensor.empty() : tensor<8x64x64xf32> - %cst_75 = arith.constant 0.000000e+00 : f32 - %38 = linalg.fill ins(%cst_75 : f32) outs(%37 : tensor<8x64x64xf32>) - -> tensor<8x64x64xf32> - %39 = linalg.batch_matmul ins(%lhs, %rhs : tensor<8x64x32xf32>, - tensor<8x32x64xf32>) outs(%38 : tensor<8x64x64xf32>) -> tensor<8x64x64xf32> - - func.return %39 : tensor<8x64x64xf32> -} -// CHECK-LABEL: @batch_matmul - -// CHECK: scf.for -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x64x32xf32> into tensor<64x32xf32> -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x32x64xf32> into tensor<32x64xf32> -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x64x64xf32> into tensor<64x64xf32> -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.contract -// CHECK-SAME: : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> -// CHECK: scf.yield %{{.*}} : vector<4x4xf32> -// CHECK: scf.yield %{{.*}} : tensor<64x64xf32> -// CHECK: scf.yield %{{.*}} : tensor<64x64xf32> -// CHECK: %expanded = tensor.expand_shape -// CHECK: : tensor<64x64xf32> into tensor<1x64x64xf32> -// CHECK: scf.yield %inserted_slice : tensor<8x64x64xf32> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/conv_2d_nhwc_hwcf.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/conv_2d_nhwc_hwcf.mlir deleted file mode 100644 index f365853dbba0a7..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/conv_2d_nhwc_hwcf.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-st-cpu-tiling-pipeline=matmul-tile-sizes=4,4,4 \ -// RUN: | FileCheck %s - -func.func @conv_is_matmul(%input: tensor<1x41x140x1xf32>, - %kernel: tensor<1x140x1x128xf32>) -> tensor<1x41x1x128xf32> { - %empty = tensor.empty() : tensor<1x41x1x128xf32> - - %c0 = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill ins(%c0 : f32) - outs(%empty: tensor<1x41x1x128xf32>) -> tensor<1x41x1x128xf32> - - %conv = linalg.conv_2d_nhwc_hwcf - {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %kernel : tensor<1x41x140x1xf32>, tensor<1x140x1x128xf32>) - outs(%fill : tensor<1x41x1x128xf32>) -> tensor<1x41x1x128xf32> - - func.return %conv : tensor<1x41x1x128xf32> -} -// CHECK-LABEL: @conv_is_matmul -// CHECK: scf.for -// CHECK: scf.yield %{{.*}} : tensor<41x128xf32> - -// ----- - -func.func @conv_is_matmul_after_tiling(%input: tensor<1x45x140x1xf32>, - %kernel: tensor<5x140x1x128xf32>) -> tensor<1x41x1x128xf32> { - %empty = tensor.empty() : tensor<1x41x1x128xf32> - - %c0 = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill ins(%c0 : f32) - outs(%empty: tensor<1x41x1x128xf32>) -> tensor<1x41x1x128xf32> - - %conv = linalg.conv_2d_nhwc_hwcf - {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %kernel : tensor<1x45x140x1xf32>, tensor<5x140x1x128xf32>) - outs(%fill : tensor<1x41x1x128xf32>) -> tensor<1x41x1x128xf32> - - func.return %conv : tensor<1x41x1x128xf32> -} -// CHECK-LABEL: @conv_is_matmul_after_tiling -// CHECK: scf.for -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x41x140x1xf32> into tensor<41x140xf32> -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x140x1x128xf32> into tensor<140x128xf32> -// CHECK-DAG: tensor.collapse_shape -// CHECK-SAME: : tensor<1x41x1x128xf32> into tensor<41x128xf32> -// CHECK: scf.for -// CHECK: scf.yield %{{.*}} : tensor<41x128xf32> -// CHECK: tensor.expand_shape -// CHECK-SAME: : tensor<41x128xf32> into tensor<1x41x1x128xf32> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/dot.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/dot.mlir deleted file mode 100644 index a8546b4e21e15a..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/dot.mlir +++ /dev/null @@ -1,260 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-st-cpu-tiling-pipeline="matmul-tile-sizes=4,5,6 \ -// RUN: vectorization-size-threshold=1" |\ -// RUN: FileCheck %s - -func.func @matvec(%lhs: tensor<33x17xf32>, %rhs: tensor<17xf32>, - %output: tensor<33xf32>) -> tensor<33xf32> { - %2 = linalg.matvec ins(%lhs, %rhs : tensor<33x17xf32>, tensor<17xf32>) - outs(%output : tensor<33xf32>) -> tensor<33xf32> - return %2 : tensor<33xf32> -} - -// CHECK-LABEL: @matvec -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index -// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index -// CHECK: scf.for {{.*}} %[[C0]] to %[[C32]] step %[[C4]] -// CHECK: scf.for {{.*}} %[[C0]] to %[[C12]] step %[[C6]] -// CHECK: vector.contract {{.*}} vector<4x6xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> -// CHECK: vector.contract -// CHECK: vector.transfer_write -// CHECK: scf.for {{.*}} %[[C0]] to %[[C17]] step %[[C6]] -// CHECK: linalg.matvec - -// ----- - -func.func @large_matvec(%lhs: tensor<33x1024xf32>, %rhs: tensor<1024xf32>, - %output: tensor<33xf32>) -> tensor<33xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill ins(%cst : f32) - outs(%output : tensor<33xf32>) -> tensor<33xf32> - %matvec = linalg.matvec ins(%lhs, %rhs : tensor<33x1024xf32>, tensor<1024xf32>) - outs(%fill : tensor<33xf32>) -> tensor<33xf32> - return %matvec : tensor<33xf32> -} -// CHECK-LABEL: @large_matvec - -// CHECK: scf.for -// CHECK: tensor.collapse_shape -// CHECK-SAME: : tensor<1x1024xf32> into tensor<1024xf32> -// CHECK: scf.for -// CHECK: arith.mulf %{{.*}} : vector<32xf32> -// CHECK: vector.multi_reduction -// CHECK: scf.yield %{{.*}} : vector<8xf32> -// CHECK: vector.multi_reduction -// CHECK: scf.yield %{{.*}} : tensor<33xf32> - -// ----- - -func.func @vecmat(%lhs: tensor<17xf32>, %rhs: tensor<17x33xf32>, - %output: tensor<33xf32>) -> tensor<33xf32> { - %2 = linalg.vecmat ins(%lhs, %rhs : tensor<17xf32>, tensor<17x33xf32>) - outs(%output : tensor<33xf32>) -> tensor<33xf32> - return %2 : tensor<33xf32> -} - -// CHECK-LABEL: @vecmat -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK: scf.for {{.*}} %[[C0]] to %[[C30]] step %[[C5]] -// CHECK: scf.for {{.*}} %[[C0]] to %[[C12]] step %[[C6]] -// CHECK: vector.contract {{.*}} vector<6x5xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<5xf32> -// CHECK: vector.contract -// CHECK: vector.transfer_write -// CHECK: scf.for {{.*}} %[[C0]] to %[[C17]] step %[[C6]] -// CHECK: linalg.vecmat - -// ----- - -func.func @dot(%lhs: tensor<19xf32>, %rhs: tensor<19xf32>, - %output: tensor) -> tensor { - %2 = linalg.dot ins(%lhs, %rhs : tensor<19xf32>, tensor<19xf32>) - outs(%output : tensor) -> tensor - return %2 : tensor -} - -// CHECK-LABEL: @dot -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index -// CHECK: scf.for {{.*}} %[[C0]] to %[[C18]] step %[[C6]] -// CHECK: vector.contract {{.*}} vector<6xf32> -// CHECK-NEXT: vector.broadcast -// CHECK-NEXT: scf.yield %{{.*}} : vector -// CHECK: arith.mulf -// CHECK: arith.addf - -// ----- - -func.func @large_dot(%lhs: tensor<128xf32>, %rhs: tensor<128xf32>, - %output: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill ins(%cst : f32) - outs(%output : tensor) -> tensor - %dot = linalg.dot ins(%lhs, %rhs : tensor<128xf32>, tensor<128xf32>) - outs(%fill : tensor) -> tensor - return %dot : tensor -} -// CHECK-LABEL: @large_dot - -// CHECK: scf.for -// CHECK: arith.mulf {{.*}} : vector<32xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<4x8xf32> to vector<8xf32> -// CHECK: scf.yield %{{.*}} : vector<8xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<8xf32> to f32 - - -// ----- - -func.func @matvec_to_vecmat(%rhs: tensor<2xi32>, - %output: tensor<3xi32>) -> tensor<3xi32> { - %cst = arith.constant dense<[[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32> - %2 = linalg.matvec ins(%cst, %rhs : tensor<3x2xi32>, tensor<2xi32>) - outs(%output : tensor<3xi32>) -> tensor<3xi32> - return %2 : tensor<3xi32> -} - -// CHECK-LABEL: @matvec_to_vecmat -// CHECK: arith.constant dense<{{\[}}[0, 2, 4], [1, 3, 5]]> : tensor<2x3xi32> -// CHECK: vector.contract {{.*}} : vector<2xi32>, vector<2x3xi32> into vector<3xi32> - -// ----- - -func.func @matvec_addf(%lhs: tensor<33x17xf32>, %rhs: tensor<17xf32>, - %add: tensor<33xf32>) -> tensor<33xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<33xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<33xf32>) -> tensor<33xf32> - %2 = linalg.matvec ins(%lhs, %rhs : tensor<33x17xf32>, tensor<17xf32>) - outs(%1 : tensor<33xf32>) -> tensor<33xf32> - %3 = linalg.map { arith.addf } ins(%2, %add : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - %4 = linalg.map { arith.addf } ins(%3, %add : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - return %4 : tensor<33xf32> -} - -// CHECK-LABEL: @matvec_addf -// CHECK-SAME: (%{{.*}}: {{.*}}, %{{.*}}: {{.*}}, %[[ARG_INIT:.*]]: tensor<33xf32>) -// CHECK: scf.for {{.*}} iter_args(%[[ARG:.*]] = %[[ARG_INIT]] -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]] -// CHECK: %[[READ_INIT:.*]] = vector.transfer_read %[[SLICE]] -// CHECK: %[[FOR:.*]] = scf.for {{.*}} iter_args(%[[ARG_FOR:.*]] = %[[READ_INIT]] -// CHECK: vector.contract {{.*}} %[[ARG_FOR]] : -// CHECK-NEXT: scf.yield -// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}} %[[FOR]] : -// CHECK: vector.transfer_write %[[CONTRACT]], %[[ARG]] -// CHECK: scf.yield -// CHECK: scf.for -// CHECK: linalg.matvec -// CHECK: scf.for -// CHECK: arith.addf -// CHECK-NOT: arith.addf -// CHECK: scf.yield -// CHECK: arith.addf -// CHECK-NOT: arith.addf - -// ----- - -func.func @matvec_no_dominate_addf(%lhs: tensor<33x17xf32>, %rhs: tensor<17xf32>) -> tensor<33xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %cst1 = arith.constant 1.000000e+00 : f32 - %0 = tensor.empty() : tensor<33xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<33xf32>) -> tensor<33xf32> - %2 = linalg.matvec ins(%lhs, %rhs : tensor<33x17xf32>, tensor<17xf32>) - outs(%1 : tensor<33xf32>) -> tensor<33xf32> - %3 = tensor.empty() : tensor<33xf32> - %4 = linalg.fill ins(%cst1 : f32) outs(%0 : tensor<33xf32>) -> tensor<33xf32> - %5 = linalg.map { arith.addf } ins(%2, %4 : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - return %5 : tensor<33xf32> -} - -// CHECK-LABEL: @matvec_no_dominate_addf -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.contract -// CHECK-NEXT: scf.yield -// CHECK: vector.contract -// CHECK: vector.transfer_write -// CHECK: scf.yield -// CHECK: scf.for -// CHECK: linalg.matvec -// CHECK: scf.for -// CHECK: arith.addf -// CHECK: scf.yield -// CHECK: arith.addf - -// ----- - -func.func @vecmat_addf(%lhs: tensor<17xf32>, %rhs: tensor<17x33xf32>, - %add: tensor<33xf32>) -> tensor<33xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<33xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<33xf32>) -> tensor<33xf32> - %2 = linalg.vecmat ins(%lhs, %rhs : tensor<17xf32>, tensor<17x33xf32>) - outs(%1 : tensor<33xf32>) -> tensor<33xf32> - %3 = linalg.map { arith.addf } ins(%add, %2 : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - %4 = linalg.map { arith.addf } ins(%3, %add : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - return %4 : tensor<33xf32> -} - -// CHECK-LABEL: @vecmat_addf -// CHECK-SAME: (%{{.*}}: {{.*}}, %{{.*}}: {{.*}}, %[[ARG_INIT:.*]]: tensor<33xf32>) -// CHECK: scf.for {{.*}} iter_args(%[[ARG:.*]] = %[[ARG_INIT]] -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]] -// CHECK: %[[READ_INIT:.*]] = vector.transfer_read %[[SLICE]] -// CHECK: %[[FOR:.*]] = scf.for {{.*}} iter_args(%[[ARG_FOR:.*]] = %[[READ_INIT]] -// CHECK: vector.contract {{.*}} %[[ARG_FOR]] : -// CHECK-NEXT: scf.yield -// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}} %[[FOR]] : -// CHECK: vector.transfer_write %[[CONTRACT]], %[[ARG]] -// CHECK: scf.yield -// CHECK: scf.for -// CHECK: linalg.vecmat -// CHECK: scf.for -// CHECK: arith.addf -// CHECK-NOT: arith.addf -// CHECK: scf.yield -// CHECK: arith.addf -// CHECK-NOT: arith.addf - -// ----- - -func.func @vecmat_multiple_uses_addf(%lhs: tensor<17xf32>, %rhs: tensor<17x33xf32>, - %add: tensor<33xf32>) -> tensor<33xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<33xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<33xf32>) -> tensor<33xf32> - %2 = linalg.vecmat ins(%lhs, %rhs : tensor<17xf32>, tensor<17x33xf32>) - outs(%1 : tensor<33xf32>) -> tensor<33xf32> - %3 = linalg.map { arith.addf } ins(%add, %2 : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - %4 = linalg.map { arith.addf } ins(%2, %3 : tensor<33xf32>, tensor<33xf32>) outs(%0 : tensor<33xf32>) - return %4 : tensor<33xf32> -} - -// CHECK-LABEL: @vecmat_multiple_uses_addf -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.contract -// CHECK-NEXT: scf.yield -// CHECK: vector.contract -// CHECK: vector.transfer_write -// CHECK: scf.yield -// CHECK: scf.for -// CHECK: linalg.vecmat -// CHECK: scf.for -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: scf.yield -// CHECK: arith.addf -// CHECK: arith.addf diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/duplicate_fusions.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/duplicate_fusions.mlir deleted file mode 100644 index 7b7d998ad07aa1..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/duplicate_fusions.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --gml-fusion-outlining --duplicate-function-elimination | \ -// RUN: FileCheck %s - -func.func @double_bcast_map_reduce(%arg : tensor, - %init_3d : tensor, %init_1d : tensor) -> tensor { - - // Bcast, map, reduce. - %0 = gml_st.fusion ins(%arg_ = %arg : tensor, - %init_3d_ = %init_3d : tensor) - inits(%init_1d_ = %init_1d : tensor) { - %broadcasted = linalg.broadcast ins(%arg_ : tensor) - outs(%init_3d_ : tensor) dimensions = [1, 2] - %mapped = linalg.map { math.absf } ins(%broadcasted : tensor) - outs(%init_3d_ : tensor) - %reduced = linalg.reduce { arith.addf } ins(%mapped : tensor) - outs(%init_1d_ : tensor) dimensions = [1, 2] - gml_st.yield %reduced : tensor - } : tensor - - // And again... - %1 = gml_st.fusion ins(%arg_ = %0 : tensor, - %init_3d_ = %init_3d : tensor) - inits(%init_1d_ = %init_1d : tensor) { - %broadcasted = linalg.broadcast ins(%arg_ : tensor) - outs(%init_3d_ : tensor) dimensions = [1, 2] - %mapped = linalg.map { math.absf } ins(%broadcasted : tensor) - outs(%init_3d_ : tensor) - %reduced = linalg.reduce { arith.addf } ins(%mapped : tensor) - outs(%init_1d_ : tensor) dimensions = [1, 2] - gml_st.yield %reduced : tensor - } : tensor - - return %1 : tensor -} - -// CHECK: @[[UNIQUE_OUTLINED_FUSION_FUNC:double_bcast_map_reduce_fusion(_[0-9]+)?]] -// CHECK-SAME: %{{.*}}: tensor, %{{.*}}: tensor, %{{.*}}: tensor -// CHECK-SAME: attributes {fusion} - -// CHECK: @double_bcast_map_reduce -// CHECK-SAME: %[[ARG:.*]]: tensor, %[[INIT_3D:.*]]: tensor, %[[INIT_1D:.*]]: tensor -// CHECK: %[[CALL_0:.*]] = call @[[UNIQUE_OUTLINED_FUSION_FUNC]](%[[ARG]], %[[INIT_3D]], %[[INIT_1D]]) -// CHECK: %[[CALL_1:.*]] = call @[[UNIQUE_OUTLINED_FUSION_FUNC]](%[[CALL_0]], %[[INIT_3D]], %[[INIT_1D]]) -// CHECK: return %[[CALL_1]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fibonacci.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fibonacci.mlir deleted file mode 100644 index 35b5f84dfda885..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fibonacci.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline | FileCheck %s - -func.func @fuse_fibonacci(%init : tensor) -> tensor { - %c0 = arith.constant 0 : i64 - %c1 = arith.constant 1 : i64 - - %0 = linalg.fill ins(%c0 : i64) outs(%init : tensor) -> tensor - %1 = linalg.fill ins(%c1 : i64) outs(%init : tensor) -> tensor - %2 = linalg.map { arith.addi } ins(%0, %1 : tensor, tensor) outs(%init : tensor) - %3 = linalg.map { arith.addi } ins(%1, %2 : tensor, tensor) outs(%init : tensor) - %4 = linalg.map { arith.addi } ins(%2, %3 : tensor, tensor) outs(%init : tensor) - %5 = linalg.map { arith.addi } ins(%3, %4 : tensor, tensor) outs(%init : tensor) - %6 = linalg.map { arith.addi } ins(%4, %5 : tensor, tensor) outs(%init : tensor) - %7 = linalg.map { arith.addi } ins(%5, %6 : tensor, tensor) outs(%init : tensor) - %8 = linalg.map { arith.addi } ins(%6, %7 : tensor, tensor) outs(%init : tensor) - %9 = linalg.map { arith.addi } ins(%7, %8 : tensor, tensor) outs(%init : tensor) - %10 = linalg.map { arith.addi } ins(%8, %9 : tensor, tensor) outs(%init : tensor) - %11 = linalg.map { arith.addi } ins(%9, %10 : tensor, tensor) outs(%init : tensor) - %12 = linalg.map { arith.addi } ins(%10, %11 : tensor, tensor) outs(%init : tensor) - %13 = linalg.map { arith.addi } ins(%11, %12 : tensor, tensor) outs(%init : tensor) - %14 = linalg.map { arith.addi } ins(%12, %13 : tensor, tensor) outs(%init : tensor) - %15 = linalg.map { arith.addi } ins(%13, %14 : tensor, tensor) outs(%init : tensor) - %16 = linalg.map { arith.addi } ins(%14, %15 : tensor, tensor) outs(%init : tensor) - %17 = linalg.map { arith.addi } ins(%15, %16 : tensor, tensor) outs(%init : tensor) - %18 = linalg.map { arith.addi } ins(%16, %17 : tensor, tensor) outs(%init : tensor) - %19 = linalg.map { arith.addi } ins(%17, %18 : tensor, tensor) outs(%init : tensor) - %20 = linalg.map { arith.addi } ins(%18, %19 : tensor, tensor) outs(%init : tensor) - %21 = linalg.map { arith.addi } ins(%19, %20 : tensor, tensor) outs(%init : tensor) - %22 = linalg.map { arith.addi } ins(%20, %21 : tensor, tensor) outs(%init : tensor) - %23 = linalg.map { arith.addi } ins(%21, %22 : tensor, tensor) outs(%init : tensor) - %24 = linalg.map { arith.addi } ins(%22, %23 : tensor, tensor) outs(%init : tensor) - %25 = linalg.map { arith.addi } ins(%23, %24 : tensor, tensor) outs(%init : tensor) - %26 = linalg.map { arith.addi } ins(%24, %25 : tensor, tensor) outs(%init : tensor) - %27 = linalg.map { arith.addi } ins(%25, %26 : tensor, tensor) outs(%init : tensor) - %28 = linalg.map { arith.addi } ins(%26, %27 : tensor, tensor) outs(%init : tensor) - %29 = linalg.map { arith.addi } ins(%27, %28 : tensor, tensor) outs(%init : tensor) - %30 = linalg.map { arith.addi } ins(%28, %29 : tensor, tensor) outs(%init : tensor) - %31 = linalg.map { arith.addi } ins(%29, %30 : tensor, tensor) outs(%init : tensor) - %32 = linalg.map { arith.addi } ins(%30, %31 : tensor, tensor) outs(%init : tensor) - %33 = linalg.map { arith.addi } ins(%31, %32 : tensor, tensor) outs(%init : tensor) - %34 = linalg.map { arith.addi } ins(%32, %33 : tensor, tensor) outs(%init : tensor) - %35 = linalg.map { arith.addi } ins(%33, %34 : tensor, tensor) outs(%init : tensor) - %36 = linalg.map { arith.addi } ins(%34, %35 : tensor, tensor) outs(%init : tensor) - %37 = linalg.map { arith.addi } ins(%35, %36 : tensor, tensor) outs(%init : tensor) - %38 = linalg.map { arith.addi } ins(%36, %37 : tensor, tensor) outs(%init : tensor) - %39 = linalg.map { arith.addi } ins(%37, %38 : tensor, tensor) outs(%init : tensor) - func.return %39 : tensor -} -// CHECK-LABEL: @fuse_fibonacci -// CHECK-DAG: %[[SCALAR_RESULT:.*]] = arith.constant 63245986 : i64 -// CHECK-DAG: %[[VECTOR_RESULT:.*]] = arith.constant dense<63245986> : vector<8xi64> - -// CHECK: scf.for -// CHECK: %[[VEC:.*]] = vector.transfer_write %[[VECTOR_RESULT]] -// CHECK: scf.yield %[[VEC]] - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: tensor.insert %[[SCALAR_RESULT]] -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_outlining.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_outlining.mlir deleted file mode 100644 index 8a85edce3edac4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_outlining.mlir +++ /dev/null @@ -1,130 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-fusion-outlining | \ -// RUN: FileCheck %s - -func.func @map_fusion(%arg0: tensor, %arg1: tensor) - -> tensor { - %0 = gml_st.fusion ins(%arg2 = %arg0: tensor) - inits(%arg3 = %arg1: tensor) { - %mapped = linalg.map { math.exp } ins(%arg2 : tensor) - outs(%arg3 : tensor) - %mapped_0 = linalg.map { arith.mulf } - ins(%mapped, %mapped : tensor, tensor) - outs(%arg3 : tensor) - %mapped_1 = linalg.map { math.absf } ins(%mapped_0 : tensor) - outs(%arg3 : tensor) - gml_st.yield %mapped_1 : tensor - } : tensor - return %0 : tensor -} - -// CHECK-LABEL: @map_fusion_fusion_0 -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor -// CHECK-SAME: attributes {fusion} -// CHECK: %[[FUSION:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[ARG2:.*]] = %[[ARG0]]: tensor) -// CHECK-SAME: inits(%[[ARG3:.*]] = %[[ARG1]]: tensor) -// CHECK: %[[MAPPED:.*]] = linalg.map { math.exp } ins(%[[ARG2]] : tensor) outs(%[[ARG3]] : tensor) -// CHECK: %[[MAPPED_0:.*]] = linalg.map { arith.mulf } ins(%[[MAPPED]], %[[MAPPED]] : tensor, tensor) outs(%[[ARG3]] : tensor) -// CHECK: %[[MAPPED_1:.*]] = linalg.map { math.absf } ins(%[[MAPPED_0]] : tensor) outs(%[[ARG3]] : tensor) -// CHECK: gml_st.yield %[[MAPPED_1]] -// CHECK: return %[[FUSION]] -// CHECK: @map_fusion(%[[ARG0_0:.*]]: tensor, %[[ARG1_0:.*]]: tensor) -// CHECK: %[[VAL:.*]] = call @map_fusion_fusion_0(%[[ARG0_0]], %[[ARG1_0]]) -// CHECK: return %[[VAL]] - -// ----- - -func.func @multiple_fusions(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %0 = gml_st.fusion ins(%arg3 = %arg0: tensor) - inits(%arg4 = %arg1: tensor) { - %sorted0 = thlo.sort ins(%arg3 : tensor) - outs(%arg4 : tensor) dimension = 0 is_stable = false - (%lhs0: f32, %rhs0: f32) { - %2 = arith.cmpf ogt, %lhs0, %rhs0 : f32 - thlo.yield %2 : i1 - } - gml_st.yield %sorted0 : tensor - } : tensor - %1 = gml_st.fusion ins(%arg3 = %0: tensor) - inits(%arg4 = %arg2: tensor) { - %reduced = linalg.reduce { arith.addf } ins(%arg3 : tensor) - outs(%arg4 : tensor) dimensions = [0] - %mapped = linalg.map { math.exp } ins(%reduced : tensor) - outs(%arg4 : tensor) - gml_st.yield %mapped : tensor - } : tensor - return %1 : tensor -} - -// CHECK-LABEL: @multiple_fusions_fusion_0 -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor -// CHECK-SAME: attributes {fusion} -// CHECK: %[[FUSION:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[ARG2:.*]] = %[[ARG0]]: tensor) -// CHECK-SAME: inits(%[[ARG3:.*]] = %[[ARG1]]: tensor) -// CHECK: %[[SORTED0:.*]] = thlo.sort ins(%[[ARG2]] : tensor) outs(%[[ARG3]] : tensor) dimension = 0 is_stable = false -// CHECK: (%[[LHS0:.*]]: f32, %[[RHS0:.*]]: f32) -// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[LHS0]], %[[RHS0]] : f32 -// CHECK: thlo.yield %[[CMPF]] : i1 -// CHECK: gml_st.yield %[[SORTED0]] -// CHECK: return %[[FUSION]] -// CHECK: @multiple_fusions_fusion_1 -// CHECK-SAME: %[[ARG0_0:.*]]: tensor, %[[ARG1_0:.*]]: tensor -// CHECK-SAME: attributes {fusion} -// CHECK: %[[FUSION_0:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[ARG2_0:.*]] = %[[ARG0_0]]: tensor) -// CHECK-SAME: inits(%[[ARG3_0:.*]] = %[[ARG1_0]]: tensor) -// CHECK: %[[REDUCED:.*]] = linalg.reduce { arith.addf } ins(%[[ARG2_0]] : tensor) outs(%[[ARG3_0]] : tensor) dimensions = [0] -// CHECK: %[[MAPPED:.*]] = linalg.map { math.exp } ins(%[[REDUCED]] : tensor) outs(%[[ARG3_0]] : tensor) -// CHECK: gml_st.yield %[[MAPPED]] -// CHECK: return %[[FUSION_0]] -// CHECK: @multiple_fusions -// CHECK-SAME: %[[ARG0_1:.*]]: tensor, %[[ARG1_1:.*]]: tensor, %[[ARG2_1:.*]]: tensor -// CHECK: %[[VAL:.*]] = call @multiple_fusions_fusion_0(%[[ARG0_1]], %[[ARG1_1]]) -// CHECK: %[[VAL_0:.*]] = call @multiple_fusions_fusion_1(%[[VAL]], %[[ARG2_1]]) -// CHECK: return %[[VAL_0]] - -// ----- - -func.func @cst_defined_above() -> tensor<1x10xf32> { - %0 = tensor.empty() : tensor<1x10xf32> - %1 = gml_st.fusion inits(%arg3 = %0 : tensor<1x10xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %2 = linalg.fill ins(%cst : f32) outs(%arg3 : tensor<1x10xf32>) -> tensor<1x10xf32> - gml_st.yield %2 : tensor<1x10xf32> - } { some_attr = 123 } : tensor<1x10xf32> - return %1 : tensor<1x10xf32> -} - -// CHECK-LABEL: @cst_defined_above_fusion_0 -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x10xf32> -// CHECK-SAME: attributes {fusion} -// CHECK: %[[FUSION:.*]] = gml_st.fusion inits(%[[ARG1:.*]] = %[[ARG0]]: tensor<1x10xf32>) { -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG1]] : tensor<1x10xf32>) -// CHECK: gml_st.yield %[[FILL]] -// CHECK: } {some_attr = 123 : i64} -// CHECK: return %[[FUSION]] -// CHECK: @cst_defined_above -// CHECK: %[[EMPTY:.*]] = tensor.empty() -// CHECK: %[[VAL:.*]] = call @cst_defined_above_fusion_0(%[[EMPTY]]) -// CHECK: return %[[VAL]] - -// ----- - -func.func @reduce_wo_init(%arg0: tensor<2xf64>, %arg1: tensor) - -> tensor { - %0 = gml_st.fusion ins(%arg3 = %arg0: tensor<2xf64>) - inits(%arg4 = %arg1: tensor) { - %reduced = linalg.reduce { arith.maximumf } - ins(%arg3 : tensor<2xf64>) - outs(%arg4 : tensor) - dimensions = [0] - gml_st.yield %reduced : tensor - } : tensor - return %0 : tensor -} - -// CHECK: @reduce_wo_init_fusion_0 -// CHECK: @reduce_wo_init diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_planning_for_cpu.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_planning_for_cpu.mlir deleted file mode 100644 index 27b3dd4e096f27..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/fusion_planning_for_cpu.mlir +++ /dev/null @@ -1,480 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-fusion-planning \ -// RUN: --split-input-file \ -// RUN: | FileCheck %s - -func.func @reverse_reduce_map(%input: tensor, %init0: tensor, - %init1: tensor) -> tensor { - %sorted = thlo.sort - ins(%input: tensor) - outs(%init0: tensor) - dimension = 0 - is_stable = false - (%lhs: f32, %rhs: f32) { - %gt = arith.cmpf ogt, %lhs, %rhs: f32 - thlo.yield %gt : i1 - } - %reduced = linalg.reduce { arith.addf } - ins(%sorted: tensor) - outs(%init1: tensor) - dimensions = [0] - %result = linalg.map { math.exp } - ins(%reduced: tensor) - outs(%init1: tensor) - func.return %result : tensor -} - -// CHECK-LABEL: @reverse_reduce_map -// CHECK-SAME: (%[[INPUT:.*]]: tensor, %[[INIT0:.*]]: tensor -// CHECK-SAME: %[[INIT1:.*]]: tensor - -// CHECK: %[[FUSION0:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[BB_INPUT:.*]] = %[[INPUT]]: tensor -// CHECK-SAME: inits(%[[BB_INIT0:.*]] = %[[INIT0]]: tensor -// CHECK-NEXT: %[[SORTED:.*]] = thlo.sort -// CHECK-SAME: ins(%[[BB_INPUT]] -// CHECK-SAME: outs(%[[BB_INIT0]] -// CHECK: gml_st.yield %[[SORTED]] - -// CHECK: %[[FUSION1:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[BB_INPUT:.*]] = %[[FUSION0]]: tensor -// CHECK-SAME: inits(%[[BB_INIT1:.*]] = %[[INIT1]]: tensor -// CHECK: %[[REDUCED:.*]] = linalg.reduce -// CHECK-SAME: ins(%[[BB_INPUT]] -// CHECK-SAME: outs(%[[BB_INIT1]] -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: ins(%[[REDUCED]] -// CHECK-SAME: outs(%[[BB_INIT1]] -// CHECK: gml_st.yield %[[MAPPED]] - -// CHECK: return %[[FUSION1]] - -// ----- - -func.func @scatter(%indices: tensor<1x1xindex>, - %updates: tensor<1x1x3x4xi64>, - %init: tensor<3x3x4xi64>) -> tensor<3x3x4xi64> { - %res = thlo.scatter ins(%indices : tensor<1x1xindex>, - %updates : tensor<1x1x3x4xi64>) - outs(%init : tensor<3x3x4xi64>) - (%arg5: i64, %arg6: i64) { - thlo.yield %arg5 : i64 - } - func.return %res : tensor<3x3x4xi64> -} - -// CHECK-LABEL: func @scatter -// CHECK: gml_st.fusion -// CHECK: thlo.scatter -// CHECK: gml_st.yield - -// ----- - -func.func @sort(%input: tensor, %init: tensor) - -> tensor { - %res = thlo.sort - ins(%input: tensor) - outs(%init: tensor) - dimension = 0 - is_stable = true - (%lhs: f32, %rhs: f32) { - %0 = arith.cmpf ogt, %lhs, %rhs : f32 - thlo.yield %0 : i1 - } - func.return %res : tensor -} - -// CHECK-LABEL: func @sort -// CHECK: gml_st.fusion -// CHECK: thlo.sort -// CHECK: gml_st.yield - -// ----- - -func.func @reverse(%input: tensor, %init: tensor) - -> tensor { - %res = thlo.reverse - ins(%input: tensor) - outs(%init: tensor) - reverse_dimensions = [0, 1] - func.return %res : tensor -} - -// CHECK-LABEL: func @reverse -// CHECK: gml_st.fusion -// CHECK: thlo.reverse -// CHECK: gml_st.yield -// CHECK: parallel_tile_sizes = array -// CHECK-SAME: reduction_tile_sizes = array - -// ----- - -func.func @transpose(%input: tensor, %init: tensor) - -> tensor { - %res = linalg.transpose - ins(%input: tensor) - outs(%init: tensor) - permutation = [1, 0] - func.return %res : tensor -} - -// CHECK-LABEL: func @transpose -// CHECK: gml_st.fusion -// CHECK: linalg.transpose -// CHECK: gml_st.yield - -// ----- - -func.func @map(%input: tensor, %init: tensor) - -> tensor { - %abs = linalg.map { math.absf } ins(%input:tensor) outs(%init:tensor) - func.return %abs : tensor -} - -// CHECK-LABEL: func @map -// CHECK: gml_st.fusion -// CHECK: linalg.map -// CHECK: gml_st.yield -// CHECK: parallel_tile_sizes = array -// CHECK-SAME: reduction_tile_sizes = array - -// ----- - -func.func @map_non_unique_users(%arg: tensor, - %init: tensor) -> tensor { - %exp = linalg.map { math.exp } - ins(%arg: tensor) - outs(%init: tensor) - %mul = linalg.map { arith.mulf } - ins(%exp, %exp: tensor, tensor) - outs(%init: tensor) - %abs = linalg.map { math.absf } - ins(%mul: tensor) - outs(%init: tensor) - func.return %abs : tensor -} - -// CHECK-LABEL: func @map_non_unique_users -// CHECK: gml_st.fusion -// CHECK-COUNT-3: linalg.map -// CHECK: gml_st.yield -// CHECK: parallel_tile_sizes = array -// CHECK-SAME: reduction_tile_sizes = array - -// ----- - -func.func @matmul(%input1: tensor<4x8xf32>, %input2: tensor<8x16xf32>, - %init: tensor<4x16xf32>) -> tensor<4x16xf32> { - %res = linalg.matmul - ins(%input1, %input2 : tensor<4x8xf32>, tensor<8x16xf32>) - outs(%init : tensor<4x16xf32>) -> tensor<4x16xf32> - func.return %res : tensor<4x16xf32> -} - -// CHECK-LABEL: func @matmul -// CHECK: gml_st.fusion -// CHECK: linalg.matmul -// CHECK: gml_st.yield - -// ----- - -func.func @reduce(%input: tensor<100x10xf32>, - %output: tensor<10xf32>) -> tensor<10xf32> { - %res = linalg.reduce { arith.addf } - ins(%input: tensor<100x10xf32>) - outs(%output: tensor<10xf32>) - dimensions = [0] - return %res : tensor<10xf32> -} - -// CHECK-LABEL: func @reduce -// CHECK: gml_st.fusion -// CHECK: linalg.reduce -// CHECK: gml_st.yield - -// ----- - -func.func @fused_matmul(%arg0: tensor<1x32xf32>, %arg1: tensor<32x10xf32>, - %arg2: tensor<10xf32>) -> tensor<1x10xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<1x10xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x10xf32>) -> tensor<1x10xf32> - %2 = linalg.matmul - ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<32x10xf32>) - outs(%1 : tensor<1x10xf32>) -> tensor<1x10xf32> - %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<10xf32> into tensor<1x10xf32> - %mapped = linalg.map { arith.addf } - ins(%2, %expanded : tensor<1x10xf32>, tensor<1x10xf32>) - outs(%0 : tensor<1x10xf32>) - return %mapped : tensor<1x10xf32> -} - -// CHECK-LABEL: func @fused_matmul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<32x10xf32> -// CHECK-SAME: %[[ARG2:.*]]: tensor<10xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() -// CHECK: gml_st.fusion -// CHECK-SAME: ins(%[[ARG2_:.*]] = %[[ARG2]]: tensor<10xf32> -// CHECK-SAME: %[[ARG0_:.*]] = %[[ARG0]]: tensor<1x32xf32> -// CHECK-SAME: %[[ARG1_:.*]] = %[[ARG1]]: tensor<32x10xf32> -// CHECK-SAME: inits(%[[EMPTY_:.*]] = %[[EMPTY]]: tensor<1x10xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG2_]] -// CHECK: %[[TMP:.*]] = tensor.empty -// CHECK: %[[FILLED:.*]] = linalg.fill -// CHECK-SAME: ins(%[[C0]] : f32) -// CHECK-SAME: outs(%[[TMP]] : tensor<1x10xf32> -// CHECK: %[[MATMUL:.*]] = linalg.matmul -// CHECK-SAME: ins(%[[ARG0_]], %[[ARG1_]] -// CHECK-SAME: outs(%[[FILLED]] -// CHECK: %[[MAP:.*]] = linalg.map -// CHECK-SAME: ins(%[[MATMUL]], %[[EXPANDED]] -// CHECK-SAME: outs(%[[EMPTY_]] -// CHECK: gml_st.yield %[[MAP]] - -// ----- - -func.func @value_used_in_op_region(%arg0: tensor, - %arg1: tensor, %arg2: tensor, %init: tensor) - -> tensor { - %extracted = tensor.extract %arg0[] : tensor - %mapped = linalg.map - ins(%arg1, %arg2 : tensor, tensor) - outs(%init : tensor) - (%in: i64, %in_1: i64) { - %3 = arith.select %extracted, %in, %in_1 : i64 - linalg.yield %3 : i64 - } - return %mapped : tensor -} - -// CHECK-LABEL: func @value_used_in_op_region -// CHECK-SAME: (%[[ARG0:.*]]: tensor -// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG0]] -// CHECK: gml_st.fusion -// CHECK-SAME: %[[EXTRACTED_:[a-zA-Z0-9]*]] = %[[EXTRACTED]]: i1 -// CHECK: linalg.map -// CHECK: arith.select %[[EXTRACTED_]] -// CHECK: parallel_tile_sizes = array -// CHECK-SAME: reduction_tile_sizes = array - -// ----- - -func.func @variadic_fusion(%input1: tensor<16x32x64xf32>, - %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>, - %init2: tensor<16x64xi64>) -> (tensor<16x64xf32>, tensor<16x64xi64>) { - %reduce, %reduce2 = linalg.reduce - ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xi64>) - outs(%init1, %init2 : tensor<16x64xf32>, tensor<16x64xi64>) - dimensions = [1] - (%in1: f32, %in2: i64, %out1: f32, %out2: i64) { - %0 = arith.addf %in1, %out1: f32 - %1 = arith.addi %in2, %out2: i64 - linalg.yield %0, %1: f32, i64 - } - func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<16x64xi64> -} - -// CHECK-LABEL: func @variadic_fusion -// CHECK: %[[FUSION_RESULT:.*]]:2 = gml_st.fusion -// CHECK: %[[REDUCE_RESULT:.*]]:2 = linalg.reduce -// CHECK: gml_st.yield %[[REDUCE_RESULT]]#0, %[[REDUCE_RESULT]]#1 -// CHECK: return %[[FUSION_RESULT]]#0, %[[FUSION_RESULT]]#1 - -// ----- - -func.func @tensor_empty_init(%input: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %d0 = tensor.dim %input, %c0 : tensor - %init = tensor.empty(%d0) : tensor - - %mapped = linalg.map { math.exp } - ins(%input: tensor) - outs(%init: tensor) - - %result = linalg.map { math.exp } - ins(%mapped: tensor) - outs(%init: tensor) - - func.return %result : tensor -} - -// CHECK-LABEL: func @tensor_empty_init -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim -// CHECK: %[[EMPTY:.*]] = tensor.empty -// CHECK: gml_st.fusion -// CHECK-SAME: ins(%[[DIM_:.*]] = %[[DIM]]: index -// CHECK-SAME: %[[ARG0_:.*]] = %[[ARG0]]: tensor -// CHECK-SAME: inits(%[[EMPTY_:.*]] = %[[EMPTY]] -// CHECK: %[[TMP:.*]] = tensor.empty(%[[DIM_]]) -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: outs(%[[TMP]] -// CHECK: %[[MAPPED0:.*]] = linalg.map -// CHECK-SAME: outs(%[[EMPTY_]] -// CHECK: parallel_tile_sizes = array -// CHECK-SAME: reduction_tile_sizes = array - -// ----- - -func.func @shared_tensor_empty_static(%input: tensor<8xf32>) - -> (tensor<8xf32>, tensor<8xf32>) { - %init = tensor.empty() : tensor<8xf32> - - %exp = linalg.map { math.exp } - ins(%input: tensor<8xf32>) - outs(%init: tensor<8xf32>) - %res0 = linalg.map { math.absf } - ins(%exp: tensor<8xf32>) - outs(%init: tensor<8xf32>) - - %abs = linalg.map { math.absf } - ins(%input: tensor<8xf32>) - outs(%init: tensor<8xf32>) - %res1 = linalg.map { math.exp } - ins(%abs: tensor<8xf32>) - outs(%init: tensor<8xf32>) - - func.return %res0, %res1 : tensor<8xf32>, tensor<8xf32> -} - -// CHECK-LABEL: func @shared_tensor_empty_static -// CHECK-COUNT-2: tensor.empty -// CHECK: gml_st.fusion -// CHECK: tensor.empty -// CHECK: linalg.map { math.exp } -// CHECK: linalg.map { math.absf } -// CHECK: gml_st.yield - -// CHECK: gml_st.fusion -// CHECK: tensor.empty -// CHECK: linalg.map { math.absf } -// CHECK: linalg.map { math.exp } -// CHECK: gml_st.yield - -// ----- - -func.func @shared_tensor_empty_dynamic(%input: tensor, %size : index) - -> tensor { - %init1 = tensor.empty(%size) : tensor - %exp = linalg.map { math.exp } - ins(%input: tensor) - outs(%init1: tensor) - - %init2 = tensor.empty(%size) : tensor - %res = linalg.map { math.absf } - ins(%exp: tensor) - outs(%init2: tensor) - return %res : tensor -} - -// CHECK-LABEL: func @shared_tensor_empty_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[SIZE:.*]]: index -// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[SIZE]]) : tensor -// CHECK: %[[FUSION:.*]] = gml_st.fusion -// CHECK-SAME: ins(%[[ARG2:.*]] = %[[SIZE]]: index, -// CHECK-SAME: %[[ARG3:.*]] = %[[ARG0]]: tensor -// CHECK-SAME: inits(%[[ARG4:.*]] = %[[EMPTY]]: tensor -// CHECK: %[[EMPTY_0:.*]] = tensor.empty(%[[ARG2]]) : tensor -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: ins(%[[ARG3]] {{.*}} outs(%[[EMPTY_0]] -// CHECK: %[[MAPPED_0:.*]] = linalg.map -// CHECK-SAME: ins(%[[MAPPED]] {{.*}} outs(%[[ARG4]] -// CHECK: gml_st.yield %[[MAPPED_0]] : tensor -// CHECK: return %[[FUSION]] : tensor - -// ----- - -func.func @shared_linalg_fill_dynamic(%input: tensor, %size : index) - -> (tensor, tensor) { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty(%size) : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - %res0 = linalg.reduce { arith.addf } - ins(%input: tensor) - outs(%fill: tensor) - dimensions = [0] - - %res1 = linalg.map { math.absf } - ins(%fill : tensor) - outs(%init: tensor) - return %res0, %res1 : tensor, tensor -} - -// CHECK-LABEL: func.func @shared_linalg_fill_dynamic -// CHECK: %[[EMPTY:.*]] = tensor.empty -// CHECK: %[[FUSION:.*]] = gml_st.fusion -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill -// CHECK: %[[REDUCED:.*]] = linalg.reduce {{.*}} outs(%[[FILL]] -// CHECK: %[[FUSION_0:.*]] = gml_st.fusion -// CHECK: %[[EMPTY_0:.*]] = tensor.empty -// CHECK: %[[FILL_0:.*]] = linalg.fill {{.*}} outs(%[[EMPTY_0]] -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK: gml_st.yield %[[MAPPED]] : tensor -// CHECK: return %[[FUSION]], %[[FUSION_0]] - -// ----- - -func.func @multiple_users_linalg_fill(%arg0: tensor<2xf64>) - -> (tensor, tensor) { - %cst = arith.constant 0x7FF0000000000000 : f64 - %0 = tensor.empty() : tensor - %1 = linalg.fill ins(%cst : f64) outs(%0 : tensor) -> tensor - %reduced = linalg.reduce { arith.minimumf } - ins(%arg0 : tensor<2xf64>) - outs(%1 : tensor) - dimensions = [0] - return %1, %reduced : tensor, tensor -} - -// CHECK-LABEL: func @multiple_users_linalg_fill -// CHECK: %[[FILL0:.*]] = linalg.fill -// CHECK: %[[RESULT:.*]] = gml_st.fusion -// CHECK: %[[FILL1:.*]] = linalg.fill -// CHECK: linalg.reduce -// CHECK-SAME: outs(%[[FILL1]] -// CHECK: return %[[FILL0]], %[[RESULT]] - -// ----- - -func.func @map_for_matmuls(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %init: tensor) - -> tensor { - %matmul0 = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) - outs(%init : tensor) -> tensor - %matmul1 = linalg.matmul - ins(%arg0, %arg2 : tensor, tensor) - outs(%init : tensor) -> tensor - - %res = linalg.map { arith.addf } - ins(%matmul0, %matmul1 : tensor, tensor) - outs(%init : tensor) - func.return %res : tensor -} - -// CHECK-LABEL: func @map_for_matmuls -// CHECK: gml_st.fusion -// CHECK: linalg.matmul -// CHECK: gml_st.fusion -// CHECK: linalg.matmul -// CHECK: linalg.map - -// ----- - -func.func @do_not_fuse_unsupported_op(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %init = tensor.empty() : tensor<10xf32> - %negated = "mhlo.negate"(%arg0) : (tensor<10xf32>) -> tensor<10xf32> - %mapped = linalg.map { math.exp } - ins(%negated : tensor<10xf32>) - outs(%init : tensor<10xf32>) - return %mapped : tensor<10xf32> -} - -// CHECK-LABEL: func @do_not_fuse_unsupported_op -// CHECK: tensor.empty -// CHECK: mhlo.negate -// CHECK: gml_st.fusion -// CHECK: linalg.map diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/inline_fusion_clusters.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/inline_fusion_clusters.mlir deleted file mode 100644 index 8f51d6c89c407c..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/inline_fusion_clusters.mlir +++ /dev/null @@ -1,74 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-inline-fusion-clusters \ -// RUN: --split-input-file \ -// RUN: | FileCheck %s - -func.func @two_clusters_tensors( - %arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { - %0 = gml_st.fusion ins(%arg3 = %arg0: tensor) - inits(%arg4 = %arg1: tensor) { - %sorted0 = thlo.sort - ins(%arg3 : tensor) - outs(%arg4 : tensor) - dimension = 0 - is_stable = false - (%lhs0: f32, %rhs0: f32) { - %2 = arith.cmpf ogt, %lhs0, %rhs0 : f32 - thlo.yield %2 : i1 - } - gml_st.yield %sorted0 : tensor - } : tensor - %1 = gml_st.fusion ins(%arg3 = %0: tensor) - inits(%arg4 = %arg2: tensor) { - %reduced = linalg.reduce { arith.addf } ins(%arg3 : tensor) outs(%arg4 : tensor) dimensions = [0] - %mapped = linalg.map { math.exp } ins(%reduced : tensor) outs(%arg4 : tensor) - gml_st.yield %mapped : tensor - } : tensor - return %1 : tensor -} - -// CHECK-LABEL: @two_clusters_tensors -// CHECK-NOT: gml_st.fusion -// CHECK: thlo.sort -// CHECK-NOT: gml_st.fusion -// CHECK: linalg.reduce -// CHECK: linalg.map - -// ----- - -func.func @two_clusters_memrefs( - %arg0: memref, %arg1: memref, %arg2: memref) - -> memref { - gml_st.fusion ins(%arg3 = %arg0: memref) - inits(%arg4 = %arg1: memref) { - thlo.sort - ins(%arg3 : memref) - outs(%arg4 : memref) - dimension = 0 - is_stable = false - (%lhs0: f32, %rhs0: f32) { - %2 = arith.cmpf ogt, %lhs0, %rhs0 : f32 - thlo.yield %2 : i1 - } - gml_st.yield %arg4 : memref - } - gml_st.fusion ins(%arg3 = %arg1: memref) - inits(%arg4 = %arg2: memref) { - linalg.reduce { arith.addf } - ins(%arg3 : memref) - outs(%arg4 : memref) - dimensions = [0] - linalg.map { math.exp } - ins(%arg4 : memref) - outs(%arg4 : memref) - gml_st.yield %arg4 : memref - } - return %arg2 : memref -} - -// CHECK-LABEL: @two_clusters_memrefs -// CHECK-NOT: gml_st.fusion -// CHECK: thlo.sort -// CHECK-NOT: gml_st.fusion -// CHECK: linalg.reduce -// CHECK: linalg.map diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir deleted file mode 100644 index 54862c3051829c..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline \ -// RUN: | FileCheck %s - -func.func @map_bcast_map(%arg0: tensor, %arg1: tensor, - %init0: tensor, - %init1: tensor) -> tensor { - %abs = linalg.map { math.absf } - ins(%arg0:tensor) - outs(%init0:tensor) - - %bcast = linalg.broadcast - ins(%abs : tensor) - outs(%init1 : tensor) - dimensions = [1, 2] - - %mapped = linalg.map { arith.addf } - ins(%bcast, %arg1 : tensor, tensor) - outs(%init1:tensor) - func.return %mapped : tensor -} - -// CHECK-LABEL: func.func @map_bcast_map - -// CHECK: scf.for -// CHECK: math.absf %{{.*}} : vector<8xf32> -// CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x8x8xf32> -// CHECK: vector.transpose %{{.*}}, [2, 0, 1] -// CHECK-SAME: : vector<1x8x8xf32> to vector<8x1x8xf32> -// CHECK: arith.addf %{{.*}} : vector<8x1x8xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: math.absf %{{.*}} : f32 -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: tensor.insert -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir deleted file mode 100644 index 7da607be9cd40a..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-st-cpu-tiling-pipeline=matmul-tile-sizes=4,4,4 \ -// RUN: | FileCheck %s - -func.func @map_matmul(%lhs0: tensor<16x16xf32>, %rhs0: tensor<16x16xf32>, - %lhs1: tensor<16x32xf32>, %rhs1: tensor<32x16xf32>) -> tensor<16x16xf32> { - %init = tensor.empty() : tensor<16x16xf32> - - %cst = arith.constant 0.000000e+00 : f32 - %filled = linalg.fill ins(%cst : f32) - outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32> - - %4 = linalg.matmul ins(%lhs0, %rhs0 : tensor<16x16xf32>, tensor<16x16xf32>) - outs(%filled : tensor<16x16xf32>) -> tensor<16x16xf32> - %5 = linalg.matmul ins(%lhs1, %rhs1 : tensor<16x32xf32>, tensor<32x16xf32>) - outs(%filled : tensor<16x16xf32>) -> tensor<16x16xf32> - %6 = linalg.map { math.absf } - ins(%5 : tensor<16x16xf32>) - outs(%init : tensor<16x16xf32>) - - %result = linalg.map { arith.addf } - ins(%4, %6 : tensor<16x16xf32>, tensor<16x16xf32>) - outs(%init : tensor<16x16xf32>) - return %result : tensor<16x16xf32> -} - -// CHECK-LABEL: @map_matmul - -// Fuse this linalg.fill. - -// CHECK-NOT: linalg.fill -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-2: vector.transfer_read -// CHECK: vector.contract -// CHECK: scf.yield -// CHECK: scf.for -// CHECK-COUNT-2: vector.transfer_read -// CHECK: vector.contract -// CHECK: scf.yield -// CHECK: math.absf %{{.*}} : vector<4x4xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : vector<1x8xf32> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir deleted file mode 100644 index e795e00442d3c4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir +++ /dev/null @@ -1,113 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline |\ -// RUN: FileCheck %s - -func.func @row_reduce_map_fuse_map(%arg0: tensor, - %arg1: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim0 = tensor.dim %arg1, %c0 : tensor - %dim1 = tensor.dim %arg1, %c1 : tensor - - %empty_2D = tensor.empty(%dim0, %dim1) : tensor - %reduce_init = tensor.empty(%dim0) : tensor - %mapped = linalg.map { arith.addf } - ins(%arg0, %arg1 : tensor, tensor) - outs(%empty_2D : tensor) - - %c0_f32 = arith.constant 0.0 : f32 - %empty_1D = tensor.empty(%dim1) : tensor - %fill = linalg.fill ins(%c0_f32: f32) - outs(%empty_1D: tensor) -> tensor - - %reduce = linalg.reduce { arith.addf } - ins(%mapped: tensor) - outs(%fill: tensor) - dimensions = [1] - - %res = linalg.map { math.absf } - ins(%reduce: tensor) - outs(%empty_1D : tensor) - return %res : tensor -} -// CHECK-LABEL: @row_reduce_map_fuse_map - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : vector<4x4xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<4x4xf32> to vector<4xf32> -// CHECK: scf.yield %{{.*}} : vector<4xf32> -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : vector<4x1xf32> -// CHECK: arith.addf %{{.*}} : vector<4xf32> -// CHECK: scf.yield %{{.*}} : vector<4xf32> -// CHECK: scf.yield %{{.*}} : vector<4xf32> -// CHECK: math.absf %{{.*}} : vector<4xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: math.absf %{{.*}} : f32 -// CHECK: tensor.insert - -// ----- - -func.func @col_reduce_map_fuse_map(%arg0: tensor, - %arg1: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim0 = tensor.dim %arg1, %c0 : tensor - %dim1 = tensor.dim %arg1, %c1 : tensor - - %empty_2D = tensor.empty(%dim0, %dim1) : tensor - %mapped = linalg.map { arith.addf } - ins(%arg0, %arg1 : tensor, tensor) - outs(%empty_2D : tensor) - - %c0_f32 = arith.constant 0.0 : f32 - %empty_1D = tensor.empty(%dim1) : tensor - %fill = linalg.fill ins(%c0_f32: f32) - outs(%empty_1D: tensor) -> tensor - - %reduce = linalg.reduce { arith.addf } - ins(%mapped: tensor) - outs(%fill: tensor) - dimensions = [0] - - %res = linalg.map { math.absf } - ins(%reduce: tensor) - outs(%empty_1D : tensor) - return %res : tensor -} -// CHECK-LABEL: @col_reduce_map_fuse_map - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : vector<4x4xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<4x4xf32> to vector<4xf32> -// CHECK: scf.yield %{{.*}} : vector<4xf32> -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : tensor<4xf32> -// CHECK: scf.yield %{{.*}} : tensor<4xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: arith.addf %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: math.absf %{{.*}} : f32 -// CHECK: tensor.insert -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reshape_map.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reshape_map.mlir deleted file mode 100644 index f0ff033d2794d3..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reshape_map.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --gml-st-cpu-tiling-pipeline="fuse-degenerate-reshapes=true" \ -// RUN: | FileCheck %s - -func.func @fuse_reshape_map(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %0 = tensor.empty() : tensor<10x1x1x1xf32> - %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor<10x1x1x1xf32> into tensor<10x1xf32> - - %empty= tensor.empty() : tensor<10x1x4x4x1xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3, 4]] : - tensor<10x16xf32> into tensor<10x1x4x4x1xf32> - %neg = linalg.map { arith.negf } - ins(%expanded: tensor<10x1x4x4x1xf32>) - outs(%empty: tensor<10x1x4x4x1xf32>) - %collapsed = tensor.collapse_shape %neg [[0, 1], [2, 3, 4]] : - tensor<10x1x4x4x1xf32> into tensor<10x16xf32> - - %empty_3D = tensor.empty() : tensor<10x1x16xf32> - %expanded0 = tensor.expand_shape %collapsed [[0], [1, 2]] : - tensor<10x16xf32> into tensor<10x1x16xf32> - %abs0 = linalg.map { math.absf } - ins(%expanded0: tensor<10x1x16xf32>) - outs(%empty_3D : tensor<10x1x16xf32>) - %collapsed0 = tensor.collapse_shape %abs0 [[0], [1, 2]] : - tensor<10x1x16xf32> into tensor<10x16xf32> - - %empty_5D = tensor.empty() : tensor<10x16x1x1x1xf32> - %expanded1 = tensor.expand_shape %collapsed0 [[0], [1, 2, 3, 4]] : - tensor<10x16xf32> into tensor<10x16x1x1x1xf32> - %abs1 = linalg.map { math.absf } - ins(%expanded1: tensor<10x16x1x1x1xf32>) - outs(%empty_5D : tensor<10x16x1x1x1xf32>) - %collapsed1 = tensor.collapse_shape %abs1 [[0], [1, 2, 3, 4]] : - tensor<10x16x1x1x1xf32> into tensor<10x16xf32> - - %empty_4D = tensor.empty() : tensor<10x1x16x1xf32> - %expanded2 = tensor.expand_shape %collapsed1 [[0, 1], [2, 3]] : - tensor<10x16xf32> into tensor<10x1x16x1xf32> - %abs2 = linalg.map { math.absf } - ins(%expanded2: tensor<10x1x16x1xf32>) - outs(%empty_4D : tensor<10x1x16x1xf32>) - %collapsed2 = tensor.collapse_shape %abs2 [[0, 1], [2, 3]] : - tensor<10x1x16x1xf32> into tensor<10x16xf32> - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed2, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - return %add : tensor<10x16xf32> -} - -// CHECK: @fuse_reshape_map(%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<10x16xf32> into tensor<10x1x4x4x1xf32> -// CHECK: %[[RES:.*]] = scf.for {{.*}} (tensor<10x1x4x4x1xf32>) { -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[EXPAND]] -// CHECK: arith.negf -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[RES]] {{.*}} tensor<10x1x4x4x1xf32> into tensor<10x16xf32> - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[EXTRACT0:.*]] = tensor.extract_slice %[[COLLAPSE]] -// CHECK: %[[EXPAND0:.*]] = tensor.expand_shape %[[EXTRACT0]] {{.*}} tensor<1x8xf32> into tensor<1x1x8xf32> -// CHECK: %[[READ0:.*]] = vector.transfer_read %[[EXPAND0]] -// CHECK: %[[ABS0:.*]] = math.absf %[[READ0]] -// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[ABS0]] -// CHECK: %[[COLLAPSE0:.*]] = tensor.collapse_shape %[[WRITE0]] {{.*}} tensor<1x1x8xf32> into tensor<1x8xf32> - -// CHECK: %[[EXPAND1:.*]] = tensor.expand_shape %[[COLLAPSE0]] {{.*}} tensor<1x8xf32> into tensor<1x8x1x1x1xf32> -// CHECK: %[[READ1:.*]] = vector.transfer_read %[[EXPAND1]] -// CHECK: %[[ABS1:.*]] = math.absf %[[READ1]] -// CHECK: %[[WRITE1:.*]] = vector.transfer_write %[[ABS1]] -// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[WRITE1]] {{.*}} tensor<1x8x1x1x1xf32> into tensor<1x8xf32> - -// CHECK: %[[EXPAND2:.*]] = tensor.expand_shape %[[COLLAPSE1]] {{.*}} tensor<1x8xf32> into tensor<1x1x8x1xf32> -// CHECK: %[[READ2:.*]] = vector.transfer_read %[[EXPAND2]] -// CHECK: %[[ABS2:.*]] = math.absf %[[READ2]] -// CHECK: %[[WRITE2:.*]] = vector.transfer_write %[[ABS2]] -// CHECK: %[[COLLAPSE2:.*]] = tensor.collapse_shape %[[WRITE2]] {{.*}} tensor<1x1x8x1xf32> into tensor<1x8xf32> - -// CHECK: %[[READ1:.*]] = vector.transfer_read %[[COLLAPSE2]] -// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG1]] -// CHECK: %[[ADD:.*]] = arith.addf %[[READ1]], %[[READ2]] -// CHECK: vector.transfer_write %[[ADD]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir deleted file mode 100644 index 7a9863ea96c8ed..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir +++ /dev/null @@ -1,180 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-st-cpu-tiling-pipeline=matmul-tile-sizes=4,5,6 | FileCheck %s -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline="lower-to-mmt4d=true" | \ -// RUN: FileCheck %s --check-prefixes=PACKED - -func.func @matmul_static(%lhs: tensor<128x16xf32>, %rhs: tensor<16x64xf32>, - %output: tensor<128x64xf32>) -> tensor<128x64xf32> { - %0 = linalg.matmul ins(%lhs, %rhs : tensor<128x16xf32>, tensor<16x64xf32>) - outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32> - return %0 : tensor<128x64xf32> -} - -// CHECK-LABEL: @matmul_static - -// CHECK: scf.for -// CHECK: vector.transfer_read -// CHECK-NEXT: scf.for -// CHECK-COUNT-2: vector.transfer_read -// CHECK: vector.contract {{.*}} vector<4x6xf32>, vector<6x5xf32> -// CHECK: scf.yield {{.*}} : vector<4x5xf32> -// CHECK: vector.transfer_write - -// PACKED-LABEL: @matmul_static - -// PACKED: tensor.empty() : tensor<16x16x8x1xf32> -// PACKED-COUNT-2: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<16x16x8x1xf32> -// PACKED: scf.yield %{{.*}} : tensor<16x16x8x1xf32> - -// PACKED: tensor.empty() : tensor<8x16x8x1xf32> -// PACKED-COUNT-2: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> -// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> - -// PACKED: tensor.empty() : tensor<16x8x8x8xf32> -// PACKED-COUNT-2: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield -// PACKED: scf.yield - -// PACKED-COUNT-2: scf.for -// PACKED: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_read -// PACKED: vector.contract -// PACKED: scf.yield -// PACKED: scf.yield -// PACKED: scf.yield - -// PACKED: tensor.empty() : tensor<128x64xf32> -// PACKED-COUNT-2: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<128x64xf32> -// PACKED: scf.yield %{{.*}} : tensor<128x64xf32> - -// ----- - -func.func @matmul(%lhs: tensor, - %rhs: tensor) -> tensor { - %c0 = arith.constant 0 : index - %0 = tensor.dim %lhs, %c0 : tensor - %c1 = arith.constant 1 : index - %1 = tensor.dim %rhs, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %cst = arith.constant 0.000000e+00 : f32 - %3 = linalg.fill ins(%cst : f32) - outs(%2 : tensor) -> tensor - %4 = linalg.matmul ins(%lhs, %rhs : tensor, tensor) - outs(%3 : tensor) -> tensor - return %4 : tensor -} -// CHECK-LABEL: @matmul - -// CHECK: scf.for -// CHECK: scf.for -// CHECK-COUNT-2: vector.transfer_read -// CHECK: vector.contract -// CHECK-NEXT: scf.yield %{{.*}} : vector<4x5xf32> -// CHECK: vector.transfer_write - -// CHECK-NEXT: scf.for -// CHECK: linalg.matmul {{.*}} -> tensor<4x5xf32> -// CHECK: scf.yield {{.*}} : tensor<4x5xf32> -// CHECK: tensor.insert_slice - -// CHECK: scf.for -// CHECK: linalg.fill -// CHECK: scf.for -// CHECK: linalg.matmul {{.*}} -> tensor<4x?xf32> -// CHECK: scf.yield {{.*}} : tensor<4x?xf32> -// CHECK: tensor.insert_slice - -// CHECK: scf.for -// CHECK: linalg.fill -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: scf.yield {{.*}} : tensor -// CHECK: tensor.insert_slice - -// ----- - -func.func @matmul_narrow_static(%lhs: tensor<2x16xf32>, %rhs: tensor<16x64xf32>, - %output: tensor<2x64xf32>) -> tensor<2x64xf32> { - %0 = linalg.matmul ins(%lhs, %rhs : tensor<2x16xf32>, tensor<16x64xf32>) - outs(%output : tensor<2x64xf32>) -> tensor<2x64xf32> - return %0 : tensor<2x64xf32> -} -// CHECK-LABEL: @matmul_narrow_static - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: scf.yield {{.*}} : tensor<2x5xf32> - -// PACKED-LABEL: @matmul_narrow_static - -// PACKED: tensor.empty() : tensor<1x16x2x1xf32> -// PACKED: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<1x16x2x1xf32> -// PACKED: } - -// PACKED: tensor.empty() : tensor<8x16x8x1xf32> -// PACKED-COUNT: scf.for -// PACKED: vector.transpose -// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> -// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> - -// PACKED: tensor.empty() : tensor<1x8x2x8xf32> -// PACKED: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<1x8x2x8xf32> -// PACKED: scf.for -// PACKED: scf.for -// PACKED: vector.contract -// PACKED: scf.yield %{{.*}} : vector<1x1x2x8xf32> -// PACKED: scf.yield - -// PACKED: tensor.empty() : tensor<2x64xf32> -// PACKED: scf.for -// PACKED: vector.transfer_read -// PACKED: vector.transfer_write -// PACKED: scf.yield %{{.*}} : tensor<2x64xf32> - -// ----- - -func.func @matmul_small_static_peeling(%lhs: tensor<2x4xf32>, - %arg1: tensor<4x6xf32>, %output: tensor<2x6xf32>) -> tensor<2x6xf32> { - %0 = linalg.matmul ins(%lhs, %arg1 : tensor<2x4xf32>, tensor<4x6xf32>) - outs(%output : tensor<2x6xf32>) -> tensor<2x6xf32> - return %0 : tensor<2x6xf32> -} -// CHECK-LABEL: @matmul_small_static_peeling - -// CHECK-NOT: scf.for -// CHECK-NOT: scf.for -// CHECK: vector.contract - -// ----- - -func.func @matvec_static(%lhs: tensor<1x16xf32>, %arg1: tensor<16x64xf32>, - %output: tensor<1x64xf32>) -> tensor<1x64xf32> { - %0 = linalg.matmul ins(%lhs, %arg1 : tensor<1x16xf32>, tensor<16x64xf32>) - outs(%output : tensor<1x64xf32>) -> tensor<1x64xf32> - return %0 : tensor<1x64xf32> -} -// CHECK-LABEL: @matvec_static - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: scf.yield {{.*}} : tensor<1x5xf32> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir deleted file mode 100644 index 8953e8816cc532..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mlir-print-ir-after-all\ -// RUN: --gml-st-cpu-tiling-pipeline="reduction-1d-tile-size=32 reduction-1d-split-ratio=8" \ -// RUN: | FileCheck %s - -func.func @reduce_1d_static(%arg0: tensor<100xf32>) -> tensor { - %1 = tensor.empty() : tensor - %cst = arith.constant 0.0 : f32 - %init = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %res = linalg.reduce { arith.addf } - ins(%arg0: tensor<100xf32>) outs(%init: tensor) dimensions = [0] - return %res : tensor -} -// CHECK-LABEL: @reduce_1d_static( -// CHECK-SAME: %[[ARG:.*]]: tensor<100xf32> - -// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor -// CHECK: %[[LHS:.*]] = vector.transfer_read %[[ARG]] -// CHECK: %[[RHS:.*]] = vector.transfer_read %[[CST]][] -// CHECK: %[[EXTRACT:.*]] = vector.extractelement %[[RHS]][] -// CHECK: %[[REDUCTION:.*]] = vector.multi_reduction , %[[LHS]], %[[EXTRACT]] -// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[REDUCTION]] -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[BROADCAST]], %[[CST]][] -// CHECK: return %[[WRITE]] - -// ----- - -func.func @reduce_1d_dynamic(%arg0: tensor) -> tensor { - %1 = tensor.empty() : tensor - %cst = arith.constant 0.0 : f32 - %init = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %res = linalg.reduce { arith.addf } - ins(%arg0: tensor) outs(%init: tensor) dimensions = [0] - return %res : tensor -} -// CHECK-LABEL: func @reduce_1d_dynamic - -// CHECK: arith.constant dense<0.000000e+00> : vector<8xf32> - -// CHECK: scf.for -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<4x8xf32> to vector<8xf32> -// CHECK: scf.yield %{{.*}} : vector<8xf32> - -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<8xf32> to f32 - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d_map.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d_map.mlir deleted file mode 100644 index 079d1df0c1dd47..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d_map.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --gml-st-cpu-tiling-pipeline="reduction-1d-tile-size=32 reduction-1d-split-ratio=8" \ -// RUN: | FileCheck %s -func.func @reduce_1d_map_aka_dot(%lhs: tensor, - %rhs: tensor) -> tensor { - %c0 = arith.constant 0 : index - %size = tensor.dim %lhs, %c0 : tensor - %init_1d = tensor.empty(%size) : tensor - - %map = linalg.map { arith.mulf } - ins(%lhs, %rhs: tensor, tensor) outs(%init_1d: tensor) - %cst = arith.constant 0.0 : f32 - %init_0d = tensor.empty() : tensor - - %fill = linalg.fill - ins(%cst : f32) outs(%init_0d : tensor) -> tensor - %res = linalg.reduce { arith.addf } - ins(%map: tensor) outs(%fill: tensor) dimensions = [0] - return %res : tensor -} -// CHECK-LABEL: func.func @reduce_1d_map_aka_dot -// CHECK: scf.for -// CHECK: arith.mulf {{.*}} : vector<32xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<4x8xf32> to vector<8xf32> -// CHECK: scf.yield %{{.*}} : vector<8xf32> -// CHECK: vector.multi_reduction -// CHECK: : vector<8xf32> to f32 -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.mulf {{.*}} : f32 -// CHECK: arith.addf {{.*}} : f32 -// CHECK: scf.yield {{.*}} : f32 -// CHECK: scf.yield {{.*}} : f32 - diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir deleted file mode 100644 index 7af0bdef6b492f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ -// RUN: | FileCheck %s - -func.func @col_reduce_static(%input: tensor<100x10xf32>, - %output: tensor<10xf32>) -> tensor<10xf32> { - %res = linalg.reduce { arith.addf } - ins(%input: tensor<100x10xf32>) - outs(%output: tensor<10xf32>) - dimensions = [0] - return %res : tensor<10xf32> -} -// CHECK-LABEL: @col_reduce_static - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<4x4xf32> to vector<4xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> -// CHECK: vector.transfer_write - -// ----- - -func.func @row_reduce_dynamic(%input: tensor, - %output: tensor) -> tensor { - %c0 = arith.constant 0 : index - %0 = tensor.dim %output, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %cst = arith.constant 0.000000e+00 : f32 - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %res = linalg.reduce { arith.mulf } - ins(%input: tensor) - outs(%2: tensor) - dimensions = [1] - return %res : tensor -} -// CHECK-LABEL: @row_reduce_dynamic - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<4x4xf32> to vector<4xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> - -// CHECK: scf.for -// CHECK: arith.mulf -// CHECK-SAME: : vector<4xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.mulf %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: tensor.insert -// CHECK: tensor.insert_slice - -// ----- - -func.func @col_reduce_dynamic(%input: tensor, - %output: tensor) -> tensor { - %c0 = arith.constant 0 : index - %0 = tensor.dim %output, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %cst = arith.constant 0.000000e+00 : f32 - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %res = linalg.reduce { arith.mulf } - ins(%input: tensor) - outs(%2: tensor) - dimensions = [0] - return %res : tensor -} -// CHECK-LABEL: @col_reduce_dynamic - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<4x4xf32> to vector<4xf32> -// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> - -// CHECK: scf.for -// CHECK: arith.mulf %{{.*}} : f32 -// CHECK-NEXT: scf.yield %{{.*}} : f32 -// CHECK: tensor.insert - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.mulf %{{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: tensor.insert -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_window.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_window.mlir deleted file mode 100644 index 5b39efe87e8f17..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_window.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline -// TODO(b/270534416): Re-enable. -// | FileCheck %s - -func.func @reduce_window(%input: tensor<1xf32>, %window: tensor<32xf32>, - %output: tensor<1x8xf32>) -> tensor<1x8xf32> { - %bcast_init = tensor.empty() : tensor<1x256xf32> - %bcast = linalg.broadcast - ins(%input : tensor<1xf32>) - outs(%bcast_init : tensor<1x256xf32>) - dimensions = [1] - - %abs_init = tensor.empty() : tensor<32xf32> - %abs = linalg.map { math.absf } - ins(%window: tensor<32xf32>) - outs(%abs_init: tensor<32xf32>) - - %cst = arith.constant 0.000000e+00 : f32 - %init = tensor.empty() : tensor<1x8xf32> - %fill = linalg.fill - ins(%cst : f32) outs(%init : tensor<1x8xf32>) -> tensor<1x8xf32> - - %reduce_window = linalg.generic { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d1 * 32 + d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%bcast, %abs : tensor<1x256xf32>, tensor<32xf32>) - outs(%fill : tensor<1x8xf32>) { - ^bb0(%in: f32, %win: f32, %out: f32): - %add = arith.addf %in, %out : f32 - linalg.yield %add : f32 - } -> tensor<1x8xf32> - - - %exp = linalg.map { math.exp } - ins(%reduce_window: tensor<1x8xf32>) - outs(%init: tensor<1x8xf32>) - - func.return %exp : tensor<1x8xf32> -} -// CHECK-LABEL: @reduce_window - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: arith.addf {{.*}} : f32 -// CHECK: scf.yield %{{.*}} : f32 -// CHECK: math.exp %{{.*}} : f32 -// CHECK: tensor.parallel_insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir deleted file mode 100644 index 9cc96474589e87..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ -// RUN: | FileCheck %s - -func.func @reverse_static_perfect_tiles( - %input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { - %res = thlo.reverse - ins(%input: tensor<64xf32>) - outs(%init: tensor<64xf32>) - reverse_dimensions = [0] - func.return %res : tensor<64xf32> -} - -// CHECK-LABEL: @reverse_static_perfect_tiles - -// CHECK: scf.for -// CHECK: vector.transfer_read -// CHECK: vector.shuffle -// CHECK: vector.transfer_write - -// ----- - -func.func @reverse_dynamic( - %input: tensor, %init: tensor) -> tensor { - %res = thlo.reverse - ins(%input: tensor) - outs(%init: tensor) - reverse_dimensions = [0, 1] - func.return %res : tensor -} - -// CHECK-LABEL: @reverse_dynamic - -// CHECK: scf.for -// CHECK: vector.shuffle -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: tensor.extract_slice -// CHECK: tensor.insert_slice - -// ----- - -func.func @reverse_dynamic_not_last_dim( - %input: tensor, %init: tensor) -> tensor { - %res = thlo.reverse - ins(%input: tensor) - outs(%init: tensor) - reverse_dimensions = [0] - func.return %res : tensor -} - -// CHECK-LABEL: @reverse_dynamic - -// CHECK: scf.for -// CHECK: tensor.extract_slice {{.*}} [1, 8] [1, 1] - -// CHECK: scf.for -// CHECK: %[[REM_SIZE:.*]] = affine.apply -// CHECK: tensor.extract_slice {{.*}} [1, %[[REM_SIZE]]] [1, 1] -// CHECK: tensor.insert_slice {{.*}} tensor<1x?xf32> into tensor \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir deleted file mode 100644 index d31b073f350909..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline | \ -// RUN: FileCheck %s - -func.func @scatter_fusion(%indices: tensor, - %updates: tensor, %init: tensor) -> tensor { - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %dim0 = tensor.dim %updates, %c0 : tensor - %dim1 = tensor.dim %updates, %c1 : tensor - %dim2 = tensor.dim %updates, %c2 : tensor - %init0 = tensor.empty(%dim0, %dim1, %dim2) : tensor - %abs = linalg.map { math.absf } - ins(%updates:tensor) - outs(%init0:tensor) - - %result = thlo.scatter - ins (%indices: tensor, %abs: tensor) - outs (%init: tensor) - (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 - thlo.yield %0: f32 - } - return %result : tensor -} -// CHECK-LABEL: @scatter_fusion - -// CHECK: scf.for -// CHECK: scf.if -// CHECK: scf.for -// CHECK: math.absf -// CHECK: scf.for -// CHECK: math.absf -// CHECK: linalg.reduce -// CHECK: scf.yield {{.*}} : tensor - -// ----- - -func.func @scatter_fusion_overwrite(%indices: tensor, - %updates: tensor, %init: tensor) -> tensor { - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %dim0 = tensor.dim %updates, %c0 : tensor - %dim1 = tensor.dim %updates, %c1 : tensor - %dim2 = tensor.dim %updates, %c2 : tensor - %init0 = tensor.empty(%dim0, %dim1, %dim2) : tensor - %abs = linalg.map { math.absf } - ins(%updates:tensor) - outs(%init0:tensor) - - %result = thlo.scatter - ins (%indices: tensor, %abs: tensor) - outs (%init: tensor) - (%in: f32, %out: f32) { - thlo.yield %in: f32 - } - return %result : tensor -} -// CHECK-LABEL: @scatter_fusion_overwrite - -// CHECK: scf.for -// CHECK: scf.if -// CHECK: scf.for -// CHECK: math.absf -// CHECK: scf.for -// CHECK: math.absf -// CHECK: scf.yield {{.*}} : tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir deleted file mode 100644 index e2c289a221aef6..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline --canonicalize \ -// RUN: | FileCheck %s - -func.func @sort(%input1: tensor<64x8x4xf32>, %input2: tensor<64x8x4xf32>, - %init1: tensor<64x8x4xf32>, %init2: tensor<64x8x4xf32>) - -> (tensor<64x8x4xf32>, tensor<64x8x4xf32>) { - %res0, %res1 = thlo.sort - ins(%input1: tensor<64x8x4xf32>, %input2: tensor<64x8x4xf32>) - outs(%init1: tensor<64x8x4xf32>, %init2: tensor<64x8x4xf32>) - dimension = 1 - is_stable = true - (%e11: f32, %e12: f32, %e21: f32, %e22: f32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %res0, %res1: tensor<64x8x4xf32>, tensor<64x8x4xf32> -} -// CHECK-LABEL: func.func @sort( - -// CHECK: scf.for -// CHECK: thlo.sort -// CHECK-SAME: ins(%{{.*}} : tensor<1x8x1xf32>, %{{.*}} : tensor<1x8x1xf32>) -// CHECK-SAME: dimension = 1 -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir deleted file mode 100644 index 065d983139275b..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline \ -// RUN: | FileCheck %s - -func.func @transpose(%input: tensor<16x32x64xf32>, - %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { - %transpose = linalg.transpose - ins(%input:tensor<16x32x64xf32>) - outs(%init:tensor<32x64x16xf32>) - permutation = [1, 2, 0] - func.return %transpose : tensor<32x64x16xf32> -} -// CHECK-LABEL: func.func @transpose - -// CHECK: scf.for -// CHECK: vector.transpose -// CHECK-SAME: [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> -// CHECK: vector.transfer_write - -// ----- - -func.func @peel_transpose(%input: tensor<16x32x65xf32>, - %init: tensor<32x65x16xf32>) -> tensor<32x65x16xf32> { - %transpose = linalg.transpose - ins(%input:tensor<16x32x65xf32>) - outs(%init:tensor<32x65x16xf32>) - permutation = [1, 2, 0] - func.return %transpose : tensor<32x65x16xf32> -} - -// CHECK-LABEL: @peel_transpose - -// CHECK: scf.for -// CHECK: vector.transpose -// CHECK-SAME: [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> -// CHECK: vector.transfer_write - -// CHECK: scf.for -// CHECK: scf.for -// CHECK: tensor.extract -// CHECK: tensor.insert -// CHECK: tensor.insert_slice diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir deleted file mode 100644 index f54c4ed7174932..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir +++ /dev/null @@ -1,490 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --test-hlo-transform-dialect-interpreter --canonicalize -cse \ -// RUN: --test-gml-st-greedy-fusion | FileCheck %s - -// CHECK-LABEL: func @fuse_broadcast_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<16xf32>, %[[ARG1:.*]]: tensor<16x32xf32>) -func.func @fuse_broadcast_map(%arg0: tensor<16xf32>, %arg1: tensor<16x32xf32>) - -> tensor<16x32xf32> { - %init = tensor.empty() : tensor<16x32xf32> - %bcast = linalg.broadcast - ins(%arg0 : tensor<16xf32>) - outs(%init : tensor<16x32xf32>) - dimensions = [1] - - %result = linalg.map { arith.addf } - ins(%bcast, %arg1 : tensor<16x32xf32>, tensor<16x32xf32>) - outs(%init : tensor<16x32xf32>) - func.return %result : tensor<16x32xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %forall_op, %tiled_op = transform.structured.tile_using_forall %0 num_threads [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[RESULT:.*]] = scf.forall -// CHECK-SAME: shared_outs(%[[INIT_:.*]] = %[[INIT]]) -// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT]] -// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[BCAST:.*]] = linalg.broadcast -// CHECK-SAME: ins(%[[ARG0_SLICE]] -// CHECK-SAME: outs(%[[INIT_SLICE]] -// CHECK: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]] -// CHECK-DAG: %[[INIT_SLICE_:.*]] = tensor.extract_slice %[[INIT_]] -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: ins(%[[BCAST]], %[[ARG1_SLICE]] -// CHECK-SAME: outs(%[[INIT_SLICE_]] -// CHECK: tensor.parallel_insert_slice %[[MAPPED]] -// CHECK: return %[[RESULT]] - -// ----- - -// CHECK-LABEL: func @do_not_fuse_multiple_uses -func.func @do_not_fuse_multiple_uses(%arg0: tensor, - %arg1: tensor) -> (tensor, tensor) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim0 = tensor.dim %arg1, %c0 : tensor - %dim1 = tensor.dim %arg1, %c1 : tensor - %init = tensor.empty(%dim0, %dim1) : tensor - %bcast = linalg.broadcast - ins(%arg0 : tensor) - outs(%init : tensor) - dimensions = [1] - - %result = linalg.map { arith.addf } - ins(%bcast, %arg1 : tensor, tensor) - outs(%init : tensor) - { op_label = "root" } - func.return %result, %bcast : tensor, tensor -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK: tensor.empty -// CHECK: %[[BCAST:.*]] = linalg.broadcast -// CHECK: %[[RESULT:.*]] = scf.forall -// CHECK: linalg.map -// CHECK: scf.forall.in_parallel -// CHECK: return %[[RESULT]], %[[BCAST]] - -// ----- - -// CHECK-LABEL: func @do_not_fuse_map_reduce -// CHECK-SAME: (%[[ARG0:.*]]: tensor<16x32xf32>, %[[ARG1:.*]]: tensor<16xf32>) -func.func @do_not_fuse_map_reduce(%arg0: tensor<16x32xf32>, %arg1: tensor<16xf32>) - -> tensor<16xf32> { - %init = tensor.empty() : tensor<16xf32> - %reduce = linalg.reduce { arith.addf } - ins(%arg0 : tensor<16x32xf32>) - outs(%init : tensor<16xf32>) - dimensions = [1] - - %result = linalg.map { arith.addf } - ins(%reduce, %arg1 : tensor<16xf32>, tensor<16xf32>) - outs(%init : tensor<16xf32>) - func.return %result : tensor<16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[REDUCE:.*]] = linalg.reduce -// CHECK: %[[RESULT:.*]] = scf.forall -// CHECK-SAME: shared_outs(%[[INIT_:.*]] = %[[INIT]]) -// CHECK-DAG: %[[REDUCE_SLICE:.*]] = tensor.extract_slice %[[REDUCE]] -// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]] -// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT_]] -// CHECK: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: ins(%[[REDUCE_SLICE]], %[[ARG1_SLICE]] -// CHECK-SAME: outs(%[[INIT_SLICE]] -// CHECK: tensor.parallel_insert_slice %[[MAPPED]] -// CHECK: return %[[RESULT]] - -// ----- - -// Only basic checks that all maps and fills were fused into scf.forall. -// This test verified that ops are fused in correct order. If something is -// broken, the test will take exponential time and/or memory to finish. -// CHECK-LABEL: func @fuse_fibonacci -// CHECK-NOT: linalg.fill -// CHECK-NOT: linalg.map -// CHECK: scf.forall -// CHECK-COUNT-2: linalg.fill -// CHECK-COUNT-38: linalg.map -// CHECK-NOT: linalg.fill -// CHECK-NOT: linalg.map -// CHECK: tensor.parallel_insert_slice -// CHECK: return -func.func @fuse_fibonacci(%init : tensor) -> tensor { - %c0 = arith.constant 0 : i64 - %c1 = arith.constant 1 : i64 - - %0 = linalg.fill ins(%c0 : i64) outs(%init : tensor) -> tensor - %1 = linalg.fill ins(%c1 : i64) outs(%init : tensor) -> tensor - %2 = linalg.map { arith.addi } ins(%0, %1 : tensor, tensor) outs(%init : tensor) - %3 = linalg.map { arith.addi } ins(%1, %2 : tensor, tensor) outs(%init : tensor) - %4 = linalg.map { arith.addi } ins(%2, %3 : tensor, tensor) outs(%init : tensor) - %5 = linalg.map { arith.addi } ins(%3, %4 : tensor, tensor) outs(%init : tensor) - %6 = linalg.map { arith.addi } ins(%4, %5 : tensor, tensor) outs(%init : tensor) - %7 = linalg.map { arith.addi } ins(%5, %6 : tensor, tensor) outs(%init : tensor) - %8 = linalg.map { arith.addi } ins(%6, %7 : tensor, tensor) outs(%init : tensor) - %9 = linalg.map { arith.addi } ins(%7, %8 : tensor, tensor) outs(%init : tensor) - %10 = linalg.map { arith.addi } ins(%8, %9 : tensor, tensor) outs(%init : tensor) - %11 = linalg.map { arith.addi } ins(%9, %10 : tensor, tensor) outs(%init : tensor) - %12 = linalg.map { arith.addi } ins(%10, %11 : tensor, tensor) outs(%init : tensor) - %13 = linalg.map { arith.addi } ins(%11, %12 : tensor, tensor) outs(%init : tensor) - %14 = linalg.map { arith.addi } ins(%12, %13 : tensor, tensor) outs(%init : tensor) - %15 = linalg.map { arith.addi } ins(%13, %14 : tensor, tensor) outs(%init : tensor) - %16 = linalg.map { arith.addi } ins(%14, %15 : tensor, tensor) outs(%init : tensor) - %17 = linalg.map { arith.addi } ins(%15, %16 : tensor, tensor) outs(%init : tensor) - %18 = linalg.map { arith.addi } ins(%16, %17 : tensor, tensor) outs(%init : tensor) - %19 = linalg.map { arith.addi } ins(%17, %18 : tensor, tensor) outs(%init : tensor) - %20 = linalg.map { arith.addi } ins(%18, %19 : tensor, tensor) outs(%init : tensor) - %21 = linalg.map { arith.addi } ins(%19, %20 : tensor, tensor) outs(%init : tensor) - %22 = linalg.map { arith.addi } ins(%20, %21 : tensor, tensor) outs(%init : tensor) - %23 = linalg.map { arith.addi } ins(%21, %22 : tensor, tensor) outs(%init : tensor) - %24 = linalg.map { arith.addi } ins(%22, %23 : tensor, tensor) outs(%init : tensor) - %25 = linalg.map { arith.addi } ins(%23, %24 : tensor, tensor) outs(%init : tensor) - %26 = linalg.map { arith.addi } ins(%24, %25 : tensor, tensor) outs(%init : tensor) - %27 = linalg.map { arith.addi } ins(%25, %26 : tensor, tensor) outs(%init : tensor) - %28 = linalg.map { arith.addi } ins(%26, %27 : tensor, tensor) outs(%init : tensor) - %29 = linalg.map { arith.addi } ins(%27, %28 : tensor, tensor) outs(%init : tensor) - %30 = linalg.map { arith.addi } ins(%28, %29 : tensor, tensor) outs(%init : tensor) - %31 = linalg.map { arith.addi } ins(%29, %30 : tensor, tensor) outs(%init : tensor) - %32 = linalg.map { arith.addi } ins(%30, %31 : tensor, tensor) outs(%init : tensor) - %33 = linalg.map { arith.addi } ins(%31, %32 : tensor, tensor) outs(%init : tensor) - %34 = linalg.map { arith.addi } ins(%32, %33 : tensor, tensor) outs(%init : tensor) - %35 = linalg.map { arith.addi } ins(%33, %34 : tensor, tensor) outs(%init : tensor) - %36 = linalg.map { arith.addi } ins(%34, %35 : tensor, tensor) outs(%init : tensor) - %37 = linalg.map { arith.addi } ins(%35, %36 : tensor, tensor) outs(%init : tensor) - %38 = linalg.map { arith.addi } ins(%36, %37 : tensor, tensor) outs(%init : tensor) - %39 = linalg.map { arith.addi } ins(%37, %38 : tensor, tensor) outs(%init : tensor) - {op_label="root"} - func.return %39 : tensor -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// ----- - -func.func @fuse_reshape_middle_unit_dim_map(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty_3D = tensor.empty() : tensor<10x1x16xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : - tensor<10x16xf32> into tensor<10x1x16xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<10x1x16xf32>) - outs(%empty_3D : tensor<10x1x16xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %collapsed = tensor.collapse_shape %abs [[0], [1, 2]] : - tensor<10x1x16xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @fuse_reshape_middle_unit_dim_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK-NOT: tensor.expand_shape -// CHECK-NOT: tensor.collapse_shape -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: linalg.map { arith.addf } ins(%[[COLLAPSE]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -// ----- - -func.func @fuse_reshape_trailing_unit_dim_map(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty_5D = tensor.empty() : tensor<10x16x1x1x1xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2, 3, 4]] : - tensor<10x16xf32> into tensor<10x16x1x1x1xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<10x16x1x1x1xf32>) - outs(%empty_5D : tensor<10x16x1x1x1xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %collapsed = tensor.collapse_shape %abs [[0], [1, 2, 3, 4]] : - tensor<10x16x1x1x1xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @fuse_reshape_trailing_unit_dim_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK-NOT: tensor.expand_shape -// CHECK-NOT: tensor.collapse_shape -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: linalg.map { arith.addf } ins(%[[COLLAPSE]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -// ----- - -func.func @fuse_reshape_leading_unit_dim_map(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty_5D = tensor.empty() : tensor<1x1x1x10x16xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4]] : - tensor<10x16xf32> into tensor<1x1x1x10x16xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<1x1x1x10x16xf32>) - outs(%empty_5D : tensor<1x1x1x10x16xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %collapsed = tensor.collapse_shape %abs [[0, 1, 2, 3], [4]] : - tensor<1x1x1x10x16xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @fuse_reshape_leading_unit_dim_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK-NOT: tensor.expand_shape -// CHECK-NOT: tensor.collapse_shape -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: linalg.map { arith.addf } ins(%[[COLLAPSE]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -// ----- - -func.func @fuse_reshape_multiple_unit_dims_map(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty_4D = tensor.empty() : tensor<10x1x16x1xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : - tensor<10x16xf32> into tensor<10x1x16x1xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<10x1x16x1xf32>) - outs(%empty_4D : tensor<10x1x16x1xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %collapsed = tensor.collapse_shape %abs [[0, 1], [2, 3]] : - tensor<10x1x16x1xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @fuse_reshape_multiple_unit_dims_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK-NOT: tensor.expand_shape -// CHECK-NOT: tensor.collapse_shape -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: linalg.map { arith.addf } ins(%[[COLLAPSE]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -// ----- - -func.func @fuse_reshape_reassoc_only_unit_dims_map(%arg0: tensor<10x16xf32>) - -> tensor<10x16x1xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty_5D = tensor.empty() : tensor<10x1x16x1x1xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3, 4]] : - tensor<10x16xf32> into tensor<10x1x16x1x1xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<10x1x16x1x1xf32>) - outs(%empty_5D : tensor<10x1x16x1x1xf32>) - - %empty_3D = tensor.empty() : tensor<10x16x1xf32> - %collapsed = tensor.collapse_shape %abs [[0, 1], [2], [3, 4]] : - tensor<10x1x16x1x1xf32> into tensor<10x16x1xf32> - %neg = linalg.map { arith.negf } - ins(%collapsed : tensor<10x16x1xf32>) - outs(%empty_3D : tensor<10x16x1xf32>) - {op_label="root"} - return %neg : tensor<10x16x1xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @fuse_reshape_reassoc_only_unit_dims_map -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>) -// CHECK-NOT: tensor.expand_shape -// CHECK-NOT: tensor.collapse_shape -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: linalg.map { arith.negf } ins(%[[COLLAPSE]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -// ----- - -func.func @do_not_fuse_collapse_shape(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty = tensor.empty() : tensor<10x1x4x4x1xf32> - %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3, 4]] : - tensor<10x16xf32> into tensor<10x1x4x4x1xf32> - %abs = linalg.map { math.absf } - ins(%expanded: tensor<10x1x4x4x1xf32>) - outs(%empty: tensor<10x1x4x4x1xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %collapsed = tensor.collapse_shape %abs [[0, 1], [2, 3, 4]] : - tensor<10x1x4x4x1xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%collapsed, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @do_not_fuse_collapse_shape -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG0]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[EXPAND]] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[COLLAPSE]] -// CHECK: linalg.map { arith.addf } ins(%[[EXTRACT]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return - -//%test = tensor.collapse_shape %abs [[0, 1], [2]] : -// tensor<10x16x1xf32> into tensor<160x1xf32> - -// ----- - -func.func @do_not_fuse_expand_shape(%arg0: tensor<10x16xf32>, - %arg1: tensor<10x16xf32>) -> tensor<10x16xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %empty = tensor.empty() : tensor<160xf32> - %collapsed = tensor.collapse_shape %arg0 [[0, 1]] : - tensor<10x16xf32> into tensor<160xf32> - %abs = linalg.map { math.absf } - ins(%collapsed: tensor<160xf32>) - outs(%empty: tensor<160xf32>) - - %empty_2D = tensor.empty() : tensor<10x16xf32> - %expanded = tensor.expand_shape %abs [[0, 1]] : - tensor<160xf32> into tensor<10x16xf32> - %add = linalg.map { arith.addf } - ins(%expanded, %arg1 : tensor<10x16xf32>, tensor<10x16xf32>) - outs(%empty_2D : tensor<10x16xf32>) - {op_label="root"} - return %add : tensor<10x16xf32> -} -transform.sequence failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.map"]} - attributes{op_label="root"} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop, %1 = transform.structured.tile_using_forall %0 tile_sizes [1, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -} - -// CHECK-LABEL: func @do_not_fuse_expand_shape -// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x16xf32>, %[[ARG1:.*]]: tensor<10x16xf32>) -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ARG0]] -// CHECK: %[[ABS:.*]] = linalg.map { math.absf } ins(%[[COLLAPSE]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[ABS]] -// CHECK: scf.forall -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[EXPAND]] -// CHECK: linalg.map { arith.addf } ins(%[[EXTRACT]] -// CHECK: tensor.parallel_insert_slice -// CHECK: return diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir deleted file mode 100644 index 31e27a2a74ea63..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: mlir-hlo-opt %s -split-input-file -verify-diagnostics - -func.func @fusion_cluster_not_isolated(%arg0: tensor, - %arg1: tensor, %init: tensor) -> tensor { - %map0 = linalg.map { math.exp } - ins(%arg0 : tensor) - outs(%init : tensor) - // expected-note@+1 {{required by region isolation constraints}} - %0 = gml_st.fusion ins(%a1 = %arg1 : tensor) - inits(%in = %init : tensor) { - // expected-error@+1 {{op using value defined outside the region}} - %map1 = linalg.map { arith.mulf } - ins(%map0, %a1 : tensor, tensor) - outs(%in : tensor) - gml_st.yield %map1 : tensor - } : tensor - func.return %0 : tensor -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir deleted file mode 100644 index 78fd35133dc6c0..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir +++ /dev/null @@ -1,219 +0,0 @@ -// RUN: mlir-hlo-opt %s --lower-vectors --split-input-file | FileCheck %s -// RUN: mlir-hlo-opt %s --lower-vectors="flatten=true" --split-input-file | FileCheck %s --check-prefix=FLATTEN - -// CHECK-LABEL: func @vector_row -func.func @vector_row(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> - func.return %0 : vector<2xf32> -} -// CHECK-COUNT-4: arith.mulf - -// ----- - -// CHECK-LABEL: func @vector_col -func.func @vector_col(%arg0: vector<2x4xf32>, %acc: vector<4xf32>) -> vector<4xf32> { - %0 = vector.multi_reduction , %arg0, %acc [0] : vector<2x4xf32> to vector<4xf32> - func.return %0 : vector<4xf32> -} -// CHECK: arith.mulf -// CHECK: arith.mulf - -// ----- - -// CHECK-LABEL: func @vector_1d -func.func @vector_1d(%arg0: vector<4xf32>, %acc: f32) -> f32 { - %0 = vector.multi_reduction , %arg0, %acc [0] : vector<4xf32> to f32 - func.return %0 : f32 -} - -// ----- - -// CHECK: vector.reduction -func.func @lower_vector_contract(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) - -> tensor<8x8xf32> { - %c0 = arith.constant 0 : index - %cst_0 = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<8x8xf32> - %2 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [true, true]} - : tensor<8x8xf32>, vector<8x8xf32> - %3 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [true, true]} - : tensor<8x8xf32>, vector<8x8xf32> - %4 = vector.transfer_read %0[%c0, %c0], %cst_0 {in_bounds = [true, true]} - : tensor<8x8xf32>, vector<8x8xf32> - %5 = vector.contract { - indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], - kind = #vector.kind - } %2, %3, %4 : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> - %6 = vector.transfer_write %5, %0[%c0, %c0] {in_bounds = [true, true]} : vector<8x8xf32>, tensor<8x8xf32> - return %6 : tensor<8x8xf32> -} - -// ----- - -// CHECK-LABEL: func @lower_vector_contract -// CHECK-COUNT-8: vector.outerproduct - -// ----- - -func.func @lower_vector_contract_4d(%arg0: tensor<1x1x8x1xf32>, - %arg1: tensor<1x1x8x1xf32>) - -> tensor<1x1x8x8xf32> { - %c0 = arith.constant 0 : index - %4 = tensor.empty() : tensor<1x1x8x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - %20 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, - vector<1x1x8x1xf32> - %21 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, - vector<1x1x8x1xf32> - %22 = vector.transfer_read %4[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x8x8xf32>, - vector<1x1x8x8xf32> - %23 = vector.contract {indexing_maps = - [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], - iterator_types = ["parallel", "parallel", "reduction", - "parallel", "parallel", "reduction"], - kind = #vector.kind} - %20, %21, %22 : vector<1x1x8x1xf32>, vector<1x1x8x1xf32> - into vector<1x1x8x8xf32> - %14 = vector.transfer_write %23, %4[%c0, %c0, %c0, %c0] - {in_bounds = [true, true, true, true]} : vector<1x1x8x8xf32>, - tensor<1x1x8x8xf32> - return %14 : tensor<1x1x8x8xf32> -} - -// CHECK-LABEL: func @lower_vector_contract_4d -// CHECK: vector.outerproduct - -// ----- - -func.func @lower_vector_contract_4d_matvec(%arg0: tensor<1x1x1x1xf32>, - %arg1: tensor<1x1x8x1xf32>) - -> tensor<1x1x1x8xf32> { - %c0 = arith.constant 0 : index - %4 = tensor.empty() : tensor<1x1x1x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - %20 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x1x1xf32>, - vector<1x1x1x1xf32> - %21 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, - vector<1x1x8x1xf32> - %22 = vector.transfer_read %4[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true, true, true]} : tensor<1x1x1x8xf32>, - vector<1x1x1x8xf32> - %23 = vector.contract {indexing_maps = - [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], - iterator_types = ["parallel", "parallel", "reduction", - "parallel", "parallel", "reduction"], - kind = #vector.kind} - %20, %21, %22 : vector<1x1x1x1xf32>, vector<1x1x8x1xf32> - into vector<1x1x1x8xf32> - %14 = vector.transfer_write %23, %4[%c0, %c0, %c0, %c0] - {in_bounds = [true, true, true, true]} : vector<1x1x1x8xf32>, - tensor<1x1x1x8xf32> - return %14 : tensor<1x1x1x8xf32> -} - -// CHECK-LABEL: func @lower_vector_contract_4d_matvec -// CHECK: vector.outerproduct - -// ----- - -#map = affine_map<(d0) -> (d0 * 8)> -func.func @optimize_pack_with_transpose(%arg0: memref<1024x1024xf32>) -> - memref<128x1024x8x1xf32> { - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f32 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x1024x8x1xf32> - scf.for %arg2 = %c0 to %c128 step %c1 { - scf.for %arg3 = %c0 to %c1024 step %c1 { - %0 = affine.apply #map(%arg2) - %1 = vector.transfer_read %arg0[%arg3, %0], %cst {in_bounds = [true]} : - memref<1024x1024xf32>, vector<8xf32> - %2 = vector.broadcast %1 : vector<8xf32> to vector<1x8xf32> - %3 = vector.transpose %2, [1, 0] : vector<1x8xf32> to vector<8x1xf32> - vector.transfer_write %3, %alloc_0[%arg2, %arg3, %c0, %c0] - {in_bounds = [true, true]} : - vector<8x1xf32>, memref<128x1024x8x1xf32> - } - } - return %alloc_0 : memref<128x1024x8x1xf32> -} - -// FLATTEN-LABEL: func @optimize_pack_with_transpose( -// FLATTEN-SAME: %[[INPUT:.*]]: memref<1024x1024xf32>) - -// FLATTEN: %[[ALLOC:.*]] = memref.alloc -// FLATTEN: %[[READ:.*]] = vector.transfer_read %[[INPUT]] -// FLATTEN-NOT: vector.broadcast -// FLATTEN-NOT: vector.transpose -// FLATTEN: %[[COLLAPSE:.*]] = memref.collapse_shape %[[ALLOC]] -// FLATTEN-SAME: memref<128x1024x8x1xf32> into memref<128x1024x8xf32> -// FLATTEN: vector.transfer_write %[[READ]], %[[COLLAPSE]] - -// ----- - -#map = affine_map<(d0) -> (d0 * 8)> -func.func @optimize_pack(%arg0: memref<1024x1024xf32>) -> - memref<128x1024x8x1xf32> { - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f32 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x1024x8x1xf32> - scf.for %arg2 = %c0 to %c128 step %c1 { - scf.for %arg3 = %c0 to %c1024 step %c1 { - %0 = affine.apply #map(%arg2) - %1 = vector.transfer_read %arg0[%0, %arg3], %cst - {in_bounds = [true, true]} : - memref<1024x1024xf32>, vector<8x1xf32> - vector.transfer_write %1, %alloc_0[%arg2, %arg3, %c0, %c0] - {in_bounds = [true, true]} : - vector<8x1xf32>, memref<128x1024x8x1xf32> - } - } - return %alloc_0 : memref<128x1024x8x1xf32> -} - -// FLATTEN-LABEL: func @optimize_pack( -// FLATTEN-SAME: %[[INPUT:.*]]: memref<1024x1024xf32>) - -// FLATTEN: %[[ALLOC:.*]] = memref.alloc -// FLATTEN: %[[READ:.*]] = vector.transfer_read %[[INPUT]] -// FLATTEN: %[[COLLAPSE:.*]] = memref.collapse_shape %[[ALLOC]] -// FLATTEN-SAME: memref<128x1024x8x1xf32> into memref<128x1024x8xf32> -// FLATTEN: %[[SHAPE_CAST:.*]] = vector.shape_cast -// FLATTEN-SAME: vector<8x1xf32> to vector<8xf32> -// FLATTEN: vector.transfer_write %[[SHAPE_CAST]], %[[COLLAPSE]] - -// ----- - -func.func @no_flatten(%arg0: memref<2x9x10x2xf64>) -> - memref<2x9x10x2xf64> { - %cst = arith.constant 0.000000e+00 : f64 - %c0 = arith.constant 0 : index - %alloca = memref.alloca() : memref<2x9x10x2xf64> - %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<2x9x10x2xf64>, vector<2x9x10x2xf64> - vector.transfer_write %1, %alloca[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<2x9x10x2xf64>, memref<2x9x10x2xf64> - return %alloca : memref<2x9x10x2xf64> -} - - -// CHECK-LABEL: func @no_flatten( - -// CHECK-NOT: memref.collapse_shape -// CHECK-COUNT-180: vector.transfer_read {{.*}} memref<2x9x10x2xf64>, vector<2xf64> -// CHECK-COUNT-180: vector.transfer_write {{.*}} vector<2xf64>, memref<2x9x10x2xf64> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir deleted file mode 100644 index a5bac2b79bb4f6..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir +++ /dev/null @@ -1,110 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-tiling-softmax="tile-sizes=8,16" --canonicalize --cse \ -// RUN: --gml-tiling-softmax="tile-sizes=1,1" --canonicalize --cse | \ -// RUN: FileCheck %s - -func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - %cst = arith.constant -0.000000e+00 : f32 - %cst_0 = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<64xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.reduce ins(%arg0 : tensor<64x128xf32>) - outs(%1 : tensor<64xf32>) dimensions = [1] - (%arg1: f32, %arg2: f32) { - %11 = arith.maximumf %arg1, %arg2 : f32 - linalg.yield %11 : f32 - } - %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.broadcast - ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %5 = linalg.map ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - (%arg1: f32, %arg2: f32) { - %11 = arith.subf %arg1, %arg2 : f32 - linalg.yield %11 : f32 - } - %6 = linalg.map ins(%5 : tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - (%arg1: f32) { - %11 = math.exp %arg1 : f32 - linalg.yield %11 : f32 - } - %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %8 = linalg.reduce ins(%6 : tensor<64x128xf32>) - outs(%7 : tensor<64xf32>) dimensions = [1] - (%arg1: f32, %arg2: f32) { - %11 = arith.addf %arg2, %arg1 : f32 - linalg.yield %11 : f32 - } - %9 = linalg.broadcast - ins(%8 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %10 = linalg.map ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - (%arg1: f32, %arg2: f32) { - %11 = arith.divf %arg1, %arg2 : f32 - linalg.yield %11 : f32 - } - return %10 : tensor<64x128xf32> -} -// CHECK-LABEL: @softmax -// CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32> -// CHECK-DAG: %[[CST:.*]] = arith.constant -0.000000e+00 -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST_0]] : f32) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<64xf32>) -// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<64x128xf32> -// CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CST]] : f32) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<64xf32>) - -// CHECK: %[[PARALLEL:.*]] = scf.forall (%[[ARG1:.*]]) = (0) to (64) step (8) -// CHECK-SAME: shared_outs(%[[EMPTY_:.*]] = %[[EMPTY_0]]) -// CHECK-DAG: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] -// CHECK-DAG: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] -// CHECK-DAG: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[EMPTY_0]][%[[ARG1]], 0] [8, 128] [1, 1] -// CHECK-DAG: %[[MATERIALIZE_3:.*]] = tensor.extract_slice %[[FILL_0]][%[[ARG1]]] [8] [1] -// CHECK-DAG: %[[EMPTY_SUB:.*]] = tensor.extract_slice %[[EMPTY_]] - -// CHECK: %[[PARALLEL_0:.*]] = scf.forall (%[[ARG2:.*]]) in (8) -// CHECK-SAME: shared_outs(%[[EMPTY_SUB_:.*]] = %[[EMPTY_SUB]]) -// CHECK-NEXT: %[[MATERIALIZE_4:.*]] = tensor.extract_slice %[[MATERIALIZE]][%[[ARG2]], 0] [1, 128] [1, 1] -// CHECK-NEXT: %[[MATERIALIZE_5:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][%[[ARG2]]] [1] [1] -// CHECK-NEXT: %[[REDUCE:.*]] = linalg.reduce -// CHECK-SAME: ins(%[[MATERIALIZE_4]] : tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_5]] : tensor<1xf32>) -// CHECK-SAME: dimensions = [1] - -// CHECK: %[[MATERIALIZE_6:.*]] = tensor.extract_slice %[[MATERIALIZE_1]][%[[ARG2]], 0] [1, 128] [1, 1] -// CHECK-NEXT: %[[BROADCAST:.*]] = linalg.broadcast -// CHECK-SAME: ins(%[[REDUCE]] : tensor<1xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) -// CHECK-SAME: dimensions = [1] - -// CHECK: %[[MAP:.*]] = linalg.map -// CHECK-SAME: ins(%[[MATERIALIZE_4]], %[[BROADCAST]] : tensor<1x128xf32>, tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) - -// CHECK: %[[MAP_0:.*]] = linalg.map -// CHECK-SAME: ins(%[[MAP]] : tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) - -// CHECK: %[[MATERIALIZE_8:.*]] = tensor.extract_slice %[[MATERIALIZE_3]][%[[ARG2]]] [1] [1] -// CHECK-NEXT: %[[REDUCE_0:.*]] = linalg.reduce -// CHECK-SAME: ins(%[[MAP_0]] : tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_8]] : tensor<1xf32>) - -// CHECK: %[[BROADCAST_0:.*]] = linalg.broadcast -// CHECK-SAME: ins(%[[REDUCE_0]] : tensor<1xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) - -// CHECK-NEXT: %[[MATERIALIZE_7:.*]] = tensor.extract_slice %[[EMPTY_SUB_]] -// CHECK: %[[MAP_1:.*]] = linalg.map -// CHECK-SAME: ins(%[[MAP_0]], %[[BROADCAST_0]] : tensor<1x128xf32>, tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_7]] : tensor<1x128xf32>) -// CHECK: tensor.parallel_insert_slice %[[MAP_1]] into %[[EMPTY_SUB_]][%[[ARG2]], 0] [1, 128] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[PARALLEL_0]] into %[[EMPTY_]][%[[ARG1]], 0] [8, 128] [1, 1] -// CHECK: return %[[PARALLEL]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir deleted file mode 100644 index 4718693639cdf2..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --allow-unregistered-dialect | \ -// RUN: mlir-hlo-opt --verify-diagnostics --split-input-file \ -// RUN: --allow-unregistered-dialect | \ -// RUN: FileCheck %s - -func.func @fusion_cluster(%arg0: tensor, %arg1: tensor, - %init: tensor) -> tensor { - %0 = gml_st.fusion ins(%a0 = %arg0 : tensor, - %a1 = %arg1 : tensor) - inits(%in = %init : tensor) { - %map0 = linalg.map { math.exp } - ins(%a0 : tensor) - outs(%in : tensor) - %map1 = linalg.map { arith.mulf } - ins(%map0, %a1 : tensor, tensor) - outs(%in : tensor) - gml_st.yield %map1 : tensor - } { "some_attr" = 1 } : tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @fusion_cluster -// CHECK: gml_st.fusion -// CHECK: linalg.map -// CHECK: linalg.map -// CHECK: gml_st.yield diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/optimize_linalg_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/optimize_linalg_ops.mlir deleted file mode 100644 index 074642bc9c6f8b..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/optimize_linalg_ops.mlir +++ /dev/null @@ -1,183 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-optimize-linalg-ops-pass \ -// RUN: --split-input-file \ -// RUN: | FileCheck %s - -func.func @map_no_inputs(%arg: tensor<32xf32>) -> tensor<32xf32> { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor<32xf32> - - %res = linalg.map - outs(%init: tensor<32xf32>) - () { - linalg.yield %c0 : f32 - } - func.return %res : tensor<32xf32> -} - -// CHECK-LABEL: @map_no_inputs -// CHECK-DAG: %[[CST:.*]] = arith.constant -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.fill -// CHECK-SAME: ins(%[[CST]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @map_dense_constant_operand(%arg: tensor<32xf32>) -> tensor<32xf32> { - %c0 = arith.constant dense<0.0> : tensor<32xf32> - %init = tensor.empty() : tensor<32xf32> - - %res = linalg.map { arith.maximumf } - ins(%arg, %c0: tensor<32xf32>, tensor<32xf32>) - outs(%init: tensor<32xf32>) - func.return %res : tensor<32xf32> -} - -// CHECK-LABEL: @map_dense_constant_operand -// CHECK-SAME: (%[[ARG:.*]]: tensor<32xf32>) -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.0 -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.map -// CHECK-SAME: ins(%[[ARG]] -// CHECK-SAME: outs(%[[INIT]] -// CHECK-NEXT: (%[[BBARG:.*]]: f32) -// CHECK-NEXT: arith.maximumf %[[BBARG]], %[[CST]] - -// ----- - -func.func @map_dense_constant_operand_complex(%arg: tensor<32xcomplex>) - -> tensor<32xcomplex> { - %c0 = arith.constant dense<(1.0000e+00,0.0000e+00)> : tensor<32xcomplex> - %init = tensor.empty() : tensor<32xcomplex> - - %res = linalg.map { complex.add } - ins(%arg, %c0: tensor<32xcomplex>, tensor<32xcomplex>) - outs(%init: tensor<32xcomplex>) - func.return %res : tensor<32xcomplex> -} - -// CHECK-LABEL: @map_dense_constant_operand_complex -// CHECK-SAME: (%[[ARG:.*]]: tensor<32xcomplex>) -// CHECK-DAG: %[[CST:.*]] = complex.constant -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.map -// CHECK-SAME: ins(%[[ARG]] -// CHECK-SAME: outs(%[[INIT]] -// CHECK-NEXT: (%[[BBARG:.*]]: complex) -// CHECK-NEXT: complex.add %[[BBARG]], %[[CST]] - -// ----- - -func.func @map_fill_operand(%arg: tensor<32xf32>) -> tensor<32xf32> { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor<32xf32> - - %filled = linalg.fill ins(%c0 : f32) - outs(%init: tensor<32xf32>) -> tensor<32xf32> - - %res = linalg.map { arith.maximumf } - ins(%arg, %filled: tensor<32xf32>, tensor<32xf32>) - outs(%init: tensor<32xf32>) - func.return %res : tensor<32xf32> -} - -// CHECK-LABEL: @map_fill_operand -// CHECK-SAME: (%[[ARG:.*]]: tensor<32xf32>) -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.0 -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.map -// CHECK-SAME: ins(%[[ARG]] -// CHECK-SAME: outs(%[[INIT]] -// CHECK-NEXT: (%[[BBARG:.*]]: f32) -// CHECK-NEXT: arith.maximumf %[[BBARG]], %[[CST]] - -// ----- - -func.func @map_all_constant_operand(%select: i1) -> tensor<32xf32> { - %c0 = arith.constant dense<0.0> : tensor<32xf32> - %c1 = arith.constant 1.0 : f32 - %init = tensor.empty() : tensor<32xf32> - - %filled = linalg.fill ins(%c1 : f32) - outs(%init: tensor<32xf32>) -> tensor<32xf32> - - %res = linalg.map - ins(%c0, %filled: tensor<32xf32>, tensor<32xf32>) - outs(%init: tensor<32xf32>) - (%lhs : f32, %rhs : f32) { - %0 = arith.select %select, %lhs, %rhs : f32 - linalg.yield %0 : f32 - } - func.return %res : tensor<32xf32> -} - -// CHECK-LABEL: @map_all_constant_operand -// CHECK-DAG: %[[C0:.*]] = arith.constant 0.0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1.0 -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK-DAG: %[[VAL:.*]] = arith.select -// CHECK: linalg.fill -// CHECK-SAME: ins(%[[VAL]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @broadcast_of_splat() -> tensor<32x64xf32> { - %c0 = arith.constant dense<0.0> : tensor<32xf32> - %init = tensor.empty() : tensor<32x64xf32> - - %bcast = linalg.broadcast - ins(%c0: tensor<32xf32>) - outs(%init: tensor<32x64xf32>) - dimensions = [1] - func.return %bcast : tensor<32x64xf32> -} -// CHECK-LABEL: @broadcast_of_splat -// CHECK-DAG: %[[CST:.*]] = arith.constant -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.fill -// CHECK-SAME: ins(%[[CST]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @broadcast_of_single_element_tensor(%arg: tensor) - -> tensor<32xf32> { - %init = tensor.empty() : tensor<32xf32> - %bcast = linalg.broadcast - ins(%arg: tensor) - outs(%init: tensor<32xf32>) - dimensions = [0] - func.return %bcast : tensor<32xf32> -} -// CHECK-LABEL: @broadcast_of_single_element_tensor -// CHECK-SAME: (%[[ARG:.*]]: tensor) - -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK-DAG: %[[EXTRACT:.*]] = tensor.extract %[[ARG]] -// CHECK: linalg.fill -// CHECK-SAME: ins(%[[EXTRACT]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @slice_of_map(%arg: tensor<32xf32>) -> tensor<8xf32> { - %c0 = arith.constant dense<0.0> : tensor<32xf32> - %init = tensor.empty() : tensor<32xf32> - - %map = linalg.map { arith.maximumf } - ins(%arg, %c0: tensor<32xf32>, tensor<32xf32>) - outs(%init: tensor<32xf32>) - %slice = tensor.extract_slice %map[0] [8] [1] - : tensor<32xf32> to tensor<8xf32> - func.return %slice : tensor<8xf32> -} -// CHECK-LABEL: @slice_of_map -// CHECK-SAME: (%[[ARG:.*]]: tensor<32xf32>) - -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.0 -// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][0] [8] [1] -// CHECK-DAG: %[[INIT:.*]] = tensor.empty -// CHECK: linalg.map -// CHECK-SAME: ins(%[[SLICE]] -// CHECK-SAME: outs(%[[INIT]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_forall_to_for.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_forall_to_for.mlir deleted file mode 100644 index ed52346d7a9572..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_forall_to_for.mlir +++ /dev/null @@ -1,62 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-rewrite-forall-ops --split-input-file \ -// RUN: | FileCheck %s - -func.func @add(%in: tensor<3x3xi32>, %out: tensor<3x3xi32>) -> tensor<3x3xi32> { - %c3 = arith.constant 3 : index - - %result = scf.forall (%i, %j) in (%c3, %c3) - shared_outs(%o = %out) -> tensor<3x3xi32> { - %addend = tensor.extract_slice %in[%i, %j][1, 1][1, 1] - : tensor<3x3xi32> to tensor - %augend = tensor.extract_slice %out[%i, %j][1, 1][1, 1] - : tensor<3x3xi32> to tensor - %sum = mhlo.add %augend, %addend : tensor - scf.forall.in_parallel { - tensor.parallel_insert_slice %sum into %o[%i, %j][1, 1][1, 1] - : tensor into tensor<3x3xi32> - } - } {some_attr = "attr_value"} - - return %result : tensor<3x3xi32> -} - -// CHECK-LABEL: @add -// CHECK: %[[RESULT:.*]] = scf.for -// CHECK-NEXT: %[[INNER:.*]] = scf.for -// CHECK-NEXT: tensor.extract_slice -// CHECK-NEXT: tensor.extract_slice -// CHECK-NEXT: mhlo.add -// CHECK-NEXT: %[[INSERTED:.*]] = tensor.insert_slice -// CHECK-NEXT: scf.yield %[[INSERTED]] -// CHECK-NEXT: } {some_attr = "attr_value"} -// CHECK-NEXT: scf.yield %[[INNER]] -// CHECK-NEXT: } {some_attr = "attr_value"} -// CHECK-NEXT: return %[[RESULT]] - -// ----- - -func.func @bufferized_add() -> memref<3xi32> { - %c3 = arith.constant 3 : index - %in = arith.constant dense<[1, 2, 3]> : tensor<3xi32> - %out = arith.constant dense<[4, 5, 6]> : memref<3xi32> - - scf.forall (%i) in (%c3) { - %addend = tensor.extract %in[%i] : tensor<3xi32> - %augend = memref.load %out[%i] : memref<3xi32> - %sum = arith.addi %augend, %addend : i32 - memref.store %sum, %out[%i] : memref<3xi32> - } - - return %out : memref<3xi32> -} - -// CHECK-LABEL: @bufferized_add -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK: scf.for {{.*}} = %[[C0]] to %[[C3]] step %[[C1]] -// CHECK-NEXT: tensor.extract -// CHECK-NEXT: memref.load -// CHECK-NEXT: arith.addi -// CHECK-NEXT: memref.store -// CHECK-NEXT: } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tile_by_one.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tile_by_one.mlir deleted file mode 100644 index b5f9876045ef61..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tile_by_one.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-tile-by-one | FileCheck %s - -func.func @reverse_dynamic(%arg0: tensor, %arg1: tensor) - -> tensor { - %reversed = thlo.reverse ins(%arg0 : tensor) - outs(%arg1 : tensor) reverse_dimensions = [0, 1] - return %reversed : tensor -} - -// CHECK: @reverse_dynamic -// CHECK: scf.for -// CHECK: scf.for -// CHECK: tensor.extract_slice -// CHECK-SAME: to tensor<1x1xf32> - -// ----- - -func.func @map(%arg0: tensor, %arg1: tensor) - -> tensor { - %mapped = linalg.map { math.absf } ins(%arg0 : tensor) - outs(%arg1 : tensor) - return %mapped : tensor -} - -// CHECK: @map -// CHECK: scf.for -// CHECK: scf.for -// CHECK: tensor.extract_slice -// CHECK-SAME: to tensor<1x1xf32> -// CHECK: linalg.map { math.absf } -// CHECK-SAME: tensor<1x1xf32> - -// ----- - -func.func @dont_tile_scalarlike_map(%arg0: tensor<1x1xf32>, - %arg1: tensor<1x1xf32>) -> tensor<1x1xf32> { - %mapped = linalg.map { math.absf } ins(%arg0 : tensor<1x1xf32>) - outs(%arg1 : tensor<1x1xf32>) - return %mapped : tensor<1x1xf32> -} - -// CHECK: @dont_tile_scalarlike_map -// CHECK-NOT: scf.for -// CHECK-NOT: scf.parallel -// CHECK: linalg.map -// CHECK-SAME: tensor<1x1xf32> -// CHECK-NOT: scf.for -// CHECK-NOT: scf.parallel - -// ----- - -func.func @concat(%init : tensor, %a: tensor, - %b: tensor, %c: tensor) -> tensor { - %concat = thlo.concatenate - ins(%a : tensor, %b : tensor, %c : tensor) - outs(%init : tensor) dimension = 1 - func.return %concat : tensor -} - -// CHECK-LABEL: @concat -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.if -// CHECK: tensor.extract_slice %[[ARG1]] -// CHECK: scf.yield -// CHECK: else -// CHECK: scf.if -// CHECK: tensor.extract_slice %[[ARG2]] -// CHECK: scf.yield -// CHECK: else -// CHECK: tensor.extract_slice %[[ARG3]] -// CHECK: scf.yield -// CHECK: scf.yield -// CHECK: scf.yield -// CHECK: scf.yield -// CHECK: return diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir deleted file mode 100644 index 69856f4d89d103..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir +++ /dev/null @@ -1,172 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-tiling-softmax="tile-sizes=8,16" --canonicalize --cse | \ -// RUN: FileCheck %s - -// CHECK-LABEL: @partial_softmax -// CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32> -func.func @partial_softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[CST:.*]] = arith.constant 0xFF800000 - // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<64xf32> - // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) - // CHECK: %[[INIT_0:.*]] = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[PARALLEL:.*]] = scf.forall - // CHECK-SAME: (%[[ARG1:.*]]) = (0) to (64) step (8) - // CHECK-SAME: shared_outs(%[[INIT_0_:.*]] = %[[INIT_0]]) - // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] - // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maximumf } - // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[BROADCAST:.*]] = linalg.broadcast - // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[INIT_0_SUB:.*]] = tensor.extract_slice %[[INIT_0_]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) - // CHECK-SAME: outs(%[[INIT_0_SUB]] : tensor<8x128xf32>) - - // CHECK: return %[[PARALLEL]] - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<64xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.reduce { arith.maximumf } - ins(%arg0 : tensor<64x128xf32>) - outs(%1 : tensor<64xf32>) - dimensions = [1] - %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.broadcast - ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %5 = linalg.map { arith.subf } - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - return %5 : tensor<64x128xf32> -} - -// ----- - -// CHECK-LABEL: @partial_softmax_fusion -// CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32>, %[[ARG1:.*]]: index -func.func @partial_softmax_fusion(%arg0: tensor<64x128xf32>, %arg1: index) - -> tensor<8x128xf32> { - // CHECK-DAG: %[[CST:.*]] = arith.constant 0xFF800000 - // CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<64xf32> - // CHECK-DAG: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) - // CHECK-DAG: %[[INIT_0:.*]] = tensor.empty() : tensor<64x128xf32> - // CHECK-DAG: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK-DAG: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] - // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maximumf } - // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK-DAG: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[BROADCAST:.*]] = linalg.broadcast - // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: return %[[MAP]] - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<64xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.reduce { arith.maximumf } - ins(%arg0 : tensor<64x128xf32>) - outs(%1 : tensor<64xf32>) - dimensions = [1] - %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.broadcast - ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %5 = linalg.map { arith.subf } - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - %8 = tensor.extract_slice %5[%arg1, 0] [8, 128] [1, 1] - : tensor<64x128xf32> to tensor<8x128xf32> - return %8 : tensor<8x128xf32> -} - -// ----- - -// CHECK-LABEL: @softmax -// CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32> -func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK-DAG: %[[CST:.*]] = arith.constant -0.000000e+00 - // CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 - // CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<64xf32> - // CHECK-DAG: %[[FILL:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[INIT]] : tensor<64xf32>) - // CHECK-DAG: %[[INIT_0:.*]] = tensor.empty() : tensor<64x128xf32> - // CHECK-DAG: %[[FILL_0:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) - // CHECK: %[[PARALLEL:.*]] = scf.forall - // CHECK-SAME: (%[[ARG1:.*]]) = (0) to (64) step (8) - // CHECK-SAME: shared_outs(%[[INIT_0_:.*]] = %[[INIT_0]]) - // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] - // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maximumf } - // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[BROADCAST:.*]] = linalg.broadcast - // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: %[[MAP_0:.*]] = linalg.map { math.exp } - // CHECK-SAME: ins(%[[MAP]] : tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: %[[MATERIALIZE_3:.*]] = tensor.extract_slice %[[FILL_0]][%[[ARG1]]] [8] [1] - // CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.addf } - // CHECK-SAME: ins(%[[MAP_0]] : tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_3]] : tensor<8xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[BROADCAST_0:.*]] = linalg.broadcast - // CHECK-SAME: ins(%[[REDUCE_0]] : tensor<8xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK-SAME: dimensions = [1] - // CHECK: %[[INIT_0_SUB:.*]] = tensor.extract_slice %[[INIT_0_]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[MAP_1:.*]] = linalg.map { arith.divf } - // CHECK-SAME: ins(%[[MAP_0]], %[[BROADCAST_0]] : tensor<8x128xf32>, tensor<8x128xf32>) - // CHECK-SAME: outs(%[[INIT_0_SUB]] : tensor<8x128xf32>) - // CHECK: tensor.parallel_insert_slice %[[MAP_1]] into %[[INIT_0_]][%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: return %[[PARALLEL]] - %cst = arith.constant -0.000000e+00 : f32 - %cst_0 = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<64xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.reduce { arith.maximumf } - ins(%arg0 : tensor<64x128xf32>) - outs(%1 : tensor<64xf32>) dimensions = [1] - %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.broadcast - ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %5 = linalg.map { arith.subf } - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - %6 = linalg.map { math.exp } - ins(%5 : tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %8 = linalg.reduce { arith.addf } - ins(%6 : tensor<64x128xf32>) - outs(%7 : tensor<64xf32>) - dimensions = [1] - %9 = linalg.broadcast - ins(%8 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) - dimensions = [1] - %10 = linalg.map { arith.divf } - ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) - return %10 : tensor<64x128xf32> -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir deleted file mode 100644 index ad51ae89b20f06..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir +++ /dev/null @@ -1,395 +0,0 @@ -// RUN: mlir-hlo-opt %s --vectorize-for-cpu --split-input-file |\ -// RUN: FileCheck %s - - -func.func @vectorize_tiled_matmul(%lhs: tensor<8x16xf32>, - %rhs: tensor<16x4xf32>, %fill: tensor<8x4xf32>) -> tensor<8x4xf32> { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c16 = arith.constant 16 : index - - %7 = scf.for %i = %c0 to %c16 step %c2 - iter_args (%arg6 = %fill) -> (tensor<8x4xf32>) { - %9 = tensor.extract_slice %lhs[0, %i] [8, 2] [1, 1] : - tensor<8x16xf32> to tensor<8x2xf32> - - %11 = tensor.extract_slice %rhs[%i, 0] [2, 4] [1, 1] : - tensor<16x4xf32> to tensor<2x4xf32> - - %13 = tensor.extract_slice %arg6[0, 0] [8, 4] [1, 1] : - tensor<8x4xf32> to tensor<8x4xf32> - - %14 = linalg.matmul ins(%9, %11 : tensor<8x2xf32>, tensor<2x4xf32>) - outs(%13 : tensor<8x4xf32>) -> tensor<8x4xf32> - - %12 = tensor.insert_slice %14 into %arg6 [0, 0] [8, 4] [1, 1] - : tensor<8x4xf32> into tensor<8x4xf32> - - scf.yield %14 : tensor<8x4xf32> - } {__perfectly_tileable_loop_label__} - return %7 : tensor<8x4xf32> -} - -// CHECK-LABEL: func @vectorize_tiled_matmul - -// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT:.*]] -// CHECK: %[[FOR:.*]]:2 = scf.for {{.*}} iter_args(%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = -// CHECK: %[[LHS:.*]] = vector.transfer_read -// CHECK-SAME: : tensor<8x16xf32>, vector<8x2xf32> -// CHECK: %[[RHS:.*]] = vector.transfer_read -// CHECK-SAME: : tensor<16x4xf32>, vector<2x4xf32> -// CHECK: %[[CONTRACT:.*]] = vector.contract -// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ARG1]] -// CHECK: scf.yield %[[ARG0]], %[[CONTRACT]] -// CHECK: vector.transfer_write %[[FOR]]#1, %[[FOR]]#0 - -// ----- - -func.func @vectorize_static_matmul(%lhs: tensor<128x16xf32>, - %rhs: tensor<16x64xf32>, %fill: tensor<128x64xf32>) -> tensor<128x64xf32> { - %c2 = arith.constant 2 : index - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - %c64 = arith.constant 64 : index - %0 = scf.forall (%i, %j) = (%c0, %c0) to (%c128, %c64) step (%c8, %c4) - shared_outs (%out_ = %fill) -> (tensor<128x64xf32>) { - %2 = tensor.extract_slice %lhs[%i, 0] [8, 16] [1, 1] : - tensor<128x16xf32> to tensor<8x16xf32> - %4 = tensor.extract_slice %rhs[0, %j] [16, 4] [1, 1] : - tensor<16x64xf32> to tensor<16x4xf32> - %6 = tensor.extract_slice %fill[%i, %j] [8, 4] [1, 1] : - tensor<128x64xf32> to tensor<8x4xf32> - %7 = scf.for %k = %c0 to %c16 step %c2 iter_args (%arg6 = %6) -> (tensor<8x4xf32>) { - %9 = tensor.extract_slice %2[0, %k] [8, 2] [1, 1] : - tensor<8x16xf32> to tensor<8x2xf32> - %11 = tensor.extract_slice %4[%k, 0] [2, 4] [1, 1] : - tensor<16x4xf32> to tensor<2x4xf32> - %13 = tensor.extract_slice %arg6[0, 0] [8, 4] [1, 1] : - tensor<8x4xf32> to tensor<8x4xf32> - %14 = linalg.matmul ins(%9, %11 : tensor<8x2xf32>, tensor<2x4xf32>) - outs(%13 : tensor<8x4xf32>) -> tensor<8x4xf32> - scf.yield %14 : tensor<8x4xf32> - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %7 into %out_[%i, %j] [8, 4] [1, 1] : - tensor<8x4xf32> into tensor<128x64xf32> - } - } - return %0 : tensor<128x64xf32> -} -// CHECK-LABEL: func @vectorize_static_matmul - -// CHECK: %[[OUT_READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4xf32>, vector<8x4xf32> -// CHECK: %[[FOR:.*]]:2 = scf.for {{.*}} iter_args(%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %[[OUT_READ]] -// CHECK-NOT: linalg.matmul -// CHECK: %[[LHS:.*]] = vector.transfer_read {{.*}} : tensor<128x16xf32>, vector<8x2xf32> -// CHECK: %[[RHS:.*]] = vector.transfer_read {{.*}} : tensor<16x64xf32>, vector<2x4xf32> -// CHECK-NOT: vector.transfer_read -// CHECK: %[[CONTRACT:.*]] = vector.contract {{{.*}}} %[[LHS]], %[[RHS]], %[[ARG1]] -// CHECK: scf.yield %[[ARG0]], %[[CONTRACT]] -// CHECK: vector.transfer_write %[[FOR]]#1, %[[FOR]]#0 - -// ----- - -func.func @transpose(%input: tensor<4x5x6xf32>, - %init: tensor<5x6x4xf32>) -> tensor<5x6x4xf32> { - %transpose = linalg.transpose - ins(%input:tensor<4x5x6xf32>) - outs(%init:tensor<5x6x4xf32>) - permutation = [1, 2, 0] - func.return %transpose : tensor<5x6x4xf32> -} - -// CHECK-LABEL: func @transpose( -// CHECK-SAME: %[[INPUT:.*]]: tensor<4x5x6xf32> -// CHECK-SAME: %[[INIT:.*]]: tensor<5x6x4xf32> - -// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]] -// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[READ]], [1, 2, 0] -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TRANSPOSE]], %[[INIT]] -// CHECK: return %[[WRITE]] - -// ----- - -func.func @simplify_identity_transpose(%input: tensor<1x1xf32>, - %init: tensor<1x1xf32>) -> tensor<1x1xf32> { - %transpose = linalg.transpose - ins(%input:tensor<1x1xf32>) - outs(%init:tensor<1x1xf32>) - permutation = [0, 1] - func.return %transpose : tensor<1x1xf32> -} - -// CHECK-LABEL: func @simplify_identity_transpose( - -// CHECK-NOT: linalg.transpose -// CHECK: return - -// ----- - -func.func @do_not_simplify_transpose(%input: tensor<1x1xf32>, - %init: tensor<1x1xf32>) -> tensor<1x1xf32> { - %transpose = linalg.transpose - ins(%input:tensor<1x1xf32>) - outs(%init:tensor<1x1xf32>) - permutation = [1, 0] - func.return %transpose : tensor<1x1xf32> -} - -// CHECK-LABEL: func @do_not_simplify_transpose( - -// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK: return %[[TRANSPOSE]] - -// ----- - -func.func @perfectly_tiled_reverse_1d(%input: tensor<8xf32>, - %init: tensor<8xf32>) -> tensor<8xf32> { - %res = thlo.reverse - ins(%input: tensor<8xf32>) - outs(%init: tensor<8xf32>) - reverse_dimensions = [0] - func.return %res : tensor<8xf32> -} - -// CHECK-LABEL: func @perfectly_tiled_reverse_1d( -// CHECK-SAME: %[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] -// CHECK: return %[[WRITE]] - -// ----- - -func.func @perfectly_tiled_reverse_2d(%input: tensor<1x8xf32>, - %init: tensor<1x8xf32>) -> tensor<1x8xf32> { - %res = thlo.reverse - ins(%input: tensor<1x8xf32>) - outs(%init: tensor<1x8xf32>) - reverse_dimensions = [1] - func.return %res : tensor<1x8xf32> -} - -// CHECK-LABEL: func @perfectly_tiled_reverse_2d( -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x8xf32>, %[[ARG1:.*]]: tensor<1x8xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] -// CHECK-SAME: : tensor<1x8xf32>, vector<8xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] -// CHECK-SAME: : vector<8xf32>, tensor<1x8xf32> -// CHECK: return %[[WRITE]] - -// ----- - -func.func @perfectly_tiled_reverse_4d(%input: tensor<1x1x1x8xf32>, - %init: tensor<1x1x1x8xf32>) -> tensor<1x1x1x8xf32> { - %res = thlo.reverse - ins(%input: tensor<1x1x1x8xf32>) - outs(%init: tensor<1x1x1x8xf32>) - reverse_dimensions = [3] - func.return %res : tensor<1x1x1x8xf32> -} - -// CHECK-LABEL: func @perfectly_tiled_reverse_4d( -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8xf32>, %[[ARG1:.*]]: tensor<1x1x1x8xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] -// CHECK-SAME: : tensor<1x1x1x8xf32>, vector<8xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] -// CHECK-SAME: : vector<8xf32>, tensor<1x1x1x8xf32> -// CHECK: return %[[WRITE]] - -// ----- - -func.func @matvec(%lhs: tensor<33x17xf32>, %rhs: tensor<17xf32>, - %output: tensor<33xf32>) -> tensor<33xf32> { - %2 = linalg.matvec ins(%lhs, %rhs : tensor<33x17xf32>, tensor<17xf32>) - outs(%output : tensor<33xf32>) -> tensor<33xf32> - return %2 : tensor<33xf32> -} - -// CHECK-LABEL: @matvec -// CHECK-SAME: %[[LHS:.*]]: tensor<33x17xf32>, %[[RHS:.*]]: tensor<17xf32>, %[[OUT:.*]]: tensor<33xf32> -// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]] -// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]] -// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]] -// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]], %[[OUT_READ]] -// CHECK: vector.transfer_write %[[CONTRACT]], %[[OUT]] - -// ----- - -func.func @vecmat(%lhs: tensor<17xf32>, %rhs: tensor<17x33xf32>, - %output: tensor<33xf32>) -> tensor<33xf32> { - %2 = linalg.vecmat ins(%lhs, %rhs : tensor<17xf32>, tensor<17x33xf32>) - outs(%output : tensor<33xf32>) -> tensor<33xf32> - return %2 : tensor<33xf32> -} - -// CHECK-LABEL: @vecmat -// CHECK-SAME: %[[LHS:.*]]: tensor<17xf32>, %[[RHS:.*]]: tensor<17x33xf32>, %[[OUT:.*]]: tensor<33xf32> -// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]] -// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]] -// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]] -// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]], %[[OUT_READ]] -// CHECK: vector.transfer_write %[[CONTRACT]], %[[OUT]] - -// ----- - -func.func @dot(%lhs: tensor<17xf32>, %rhs: tensor<17xf32>, - %output: tensor) -> tensor { - %2 = linalg.dot ins(%lhs, %rhs : tensor<17xf32>, tensor<17xf32>) - outs(%output : tensor) -> tensor - return %2 : tensor -} - -// CHECK-LABEL: @dot -// CHECK-SAME: %[[LHS:.*]]: tensor<17xf32>, %[[RHS:.*]]: tensor<17xf32>, %[[OUT:.*]]: tensor -// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]] -// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]] -// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]] -// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]] -// CHECK: vector.transfer_write {{.*}}, %[[OUT]] - -// ----- - -func.func @dont_vectorize_any_ite(%arg0: i1, %arg1: tensor<8x1xf32>, - %arg2: tensor<8x1xf32>) -> tensor<8x1xf32> { - %0 = scf.if %arg0 -> (tensor<8x1xf32>) { - scf.yield %arg1 : tensor<8x1xf32> - } else { - scf.yield %arg2 : tensor<8x1xf32> - } - return %0 : tensor<8x1xf32> -} - -// CHECK-LABEL: @dont_vectorize_any_ite -// CHECK: scf.if %{{.*}} -> (tensor<8x1xf32>) - -// ----- - -func.func @vectorize_ite_w_vector_producers(%arg0: i1, %arg1: vector<8x1xf32>, - %arg2: vector<8x1xf32>) -> tensor<8x1xf32> { - %c0 = arith.constant 0 : index - %0 = tensor.empty() : tensor<8x1xf32> - %1 = scf.if %arg0 -> (tensor<8x1xf32>) { - %2 = vector.transfer_write %arg1, %0[%c0, %c0] {in_bounds = [true, true]} - : vector<8x1xf32>, tensor<8x1xf32> - scf.yield %2 : tensor<8x1xf32> - } else { - %2 = vector.transfer_write %arg2, %0[%c0, %c0] {in_bounds = [true, true]} - : vector<8x1xf32>, tensor<8x1xf32> - scf.yield %2 : tensor<8x1xf32> - } - return %1 : tensor<8x1xf32> -} - -// CHECK-LABEL: @vectorize_ite_w_vector_producers -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: vector<8x1xf32>, %[[ARG2:.*]]: vector<8x1xf32> -// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (vector<8x1xf32>) -// CHECK: scf.yield %[[ARG1]] -// CHECK: else -// CHECK: scf.yield %[[ARG2]] -// CHECK: %[[TRANSFER:.*]] = vector.transfer_write %[[IF]] -// CHECK: return %[[TRANSFER]] - -// ----- - -func.func @vectorize_ite_w_vector_users(%arg0: i1, %arg1: tensor<8x1xf32>, - %arg2: tensor<8x1xf32>) -> vector<8x1xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = scf.if %arg0 -> (tensor<8x1xf32>) { - scf.yield %arg1 : tensor<8x1xf32> - } else { - scf.yield %arg2 : tensor<8x1xf32> - } - %1 = vector.transfer_read %0[%c0, %c0], %cst {in_bounds = [true, true]} - : tensor<8x1xf32>, vector<8x1xf32> - return %1 : vector<8x1xf32> -} - -// CHECK-LABEL: @vectorize_ite_w_vector_users -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<8x1xf32>, %[[ARG2:.*]]: tensor<8x1xf32> -// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (vector<8x1xf32>) -// CHECK: %[[TRANSFER:.*]] = vector.transfer_read -// CHECK: scf.yield %[[TRANSFER]] : vector<8x1xf32> -// CHECK: else -// CHECK: %[[TRANSFER_0:.*]] = vector.transfer_read -// CHECK: scf.yield %[[TRANSFER_0]] : vector<8x1xf32> -// CHECK: return %[[IF]] - -// ----- - -func.func @dont_vectorize_complex_ite(%arg0: i1, - %arg1: tensor<8x1xcomplex>, %arg2: tensor<8x1xcomplex>) - -> tensor<8x1xcomplex> { - %0 = scf.if %arg0 -> (tensor<8x1xcomplex>) { - scf.yield %arg1 : tensor<8x1xcomplex> - } else { - scf.yield %arg2 : tensor<8x1xcomplex> - } - return %0 : tensor<8x1xcomplex> -} - -// CHECK-LABEL: @dont_vectorize_complex_ite -// CHECK: scf.if %{{.*}} -> (tensor<8x1xcomplex>) - -// ----- - -func.func @vectorize_ite_w_scalar(%arg0: i1, %arg1: tensor<8x1xf32>, %arg2: f32, - %arg3: tensor<8x1xf32>, %arg4: f32) -> (vector<8x1xf32>, f32) { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0:2 = scf.if %arg0 -> (tensor<8x1xf32>, f32) { - scf.yield %arg1, %arg2 : tensor<8x1xf32>, f32 - } else { - scf.yield %arg3, %arg4 : tensor<8x1xf32>, f32 - } - %1 = vector.transfer_read %0#0[%c0, %c0], %cst {in_bounds = [true, true]} - : tensor<8x1xf32>, vector<8x1xf32> - return %1, %0#1 : vector<8x1xf32>, f32 -} - -// CHECK-LABEL: @vectorize_ite_w_scalar -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<8x1xf32>, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: tensor<8x1xf32>, %[[ARG4:.*]]: f32 -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG0]] -> (vector<8x1xf32>, f32) -// CHECK: %[[TRANSFER:.*]] = vector.transfer_read %[[ARG1]] -// CHECK: scf.yield %[[TRANSFER]], %[[ARG2]] -// CHECK: else -// CHECK: %[[TRANSFER_0:.*]] = vector.transfer_read %[[ARG3]] -// CHECK: scf.yield %[[TRANSFER_0]], %[[ARG4]] -// CHECK: return %[[IF]]#0, %[[IF]]#1 - -// ----- - -func.func @vectorize_ite_w_casts(%arg0: i1, %arg1: tensor<8x1xf32>, - %arg2: tensor<8x1xf32>) -> vector<8x1xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = scf.if %arg0 -> (tensor) { - %cast_0 = tensor.cast %arg1 : tensor<8x1xf32> to tensor - scf.yield %cast_0 : tensor - } else { - %cast_0 = tensor.cast %arg2 : tensor<8x1xf32> to tensor - scf.yield %cast_0 : tensor - } - %cast = tensor.cast %0 : tensor to tensor<8x1xf32> - %1 = vector.transfer_read %cast[%c0, %c0], %cst {in_bounds = [true, true]} - : tensor<8x1xf32>, vector<8x1xf32> - return %1 : vector<8x1xf32> -} - -// ----- - -// CHECK-LABEL: @vectorize_ite_w_casts -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<8x1xf32>, %[[ARG2:.*]]: tensor<8x1xf32> -// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (vector<8x1xf32>) -// CHECK: %[[TRANSFER:.*]] = vector.transfer_read %[[ARG1]] -// CHECK: scf.yield %[[TRANSFER]] -// CHECK: else -// CHECK: %[[TRANSFER_0:.*]] = vector.transfer_read %[[ARG2]] -// CHECK: scf.yield %[[TRANSFER_0]] -// CHECK: return %[[IF]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir deleted file mode 100644 index 18b71ea6e12bd5..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir +++ /dev/null @@ -1,314 +0,0 @@ -// RUN: mlir-hlo-opt %s --legalize-mhlo-to-thlo=enable-experimental=true | FileCheck %s - -// CHECK-LABEL: @dynamic_broadcast_in_dim -// CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<3xindex> -func.func @dynamic_broadcast_in_dim(%arg : tensor, %shape : tensor<3xindex>) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 - // CHECK-DAG: %[[SHAPE_D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] - // CHECK-DAG: %[[SHAPE_D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] - // CHECK-DAG: %[[SHAPE_D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] - // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[SHAPE_D0]], %[[SHAPE_D1]], %[[SHAPE_D2]]) : tensor - // CHECK-NEXT: %[[BCAST:.*]] = thlo.dynamic_broadcast_in_dim - // CHECK-SAME: ins(%[[ARG]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: broadcast_dimensions = [0, 2] - // CHECK: return %[[BCAST]] - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) - { broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> } - : (tensor, tensor<3xindex>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @dynamic_broadcast_in_dim_expansion_behavior_known -// CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<3xindex> -func.func @dynamic_broadcast_in_dim_expansion_behavior_known( - %arg : tensor, %shape : tensor<3xindex>) -> tensor { - // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]]) - // CHECK: return %[[BCAST]] - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { - broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>, - known_expanding_dimensions = dense<[0]> : tensor<1xi64>, - known_nonexpanding_dimensions = dense<[1]> : tensor<1xi64> } - : (tensor, tensor<3xindex>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @dynamic_broadcast_in_dim_with_known_expanding -// CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<4xindex> -func.func @dynamic_broadcast_in_dim_with_known_expanding(%arg : tensor, %shape : tensor<4xindex>) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 - // CHECK-DAG: %[[SHAPE_D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] - // CHECK-DAG: %[[SHAPE_D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] - // CHECK-DAG: %[[SHAPE_D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] - // CHECK-DAG: %[[SHAPE_D3:.*]] = tensor.extract %[[SHAPE]][%[[C3]]] - // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[SHAPE_D0]], %[[SHAPE_D1]], %[[SHAPE_D2]], %[[SHAPE_D3]]) : tensor - // CHECK-NEXT: %[[BCAST:.*]] = thlo.dynamic_broadcast_in_dim - // CHECK-SAME: ins(%[[ARG]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: broadcast_dimensions = [0, 2, 3] - // CHECK-SAME: {known_expanding_dimensions = array, known_nonexpanding_dimensions = array} - // CHECK: return %[[BCAST]] - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { - broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, - known_expanding_dimensions = dense<[0]> : tensor<1xi64>, - known_nonexpanding_dimensions = dense<[2]> : tensor<1xi64> } - : (tensor, tensor<4xindex>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @concatenate -// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor -func.func @concatenate(%a: tensor, %b: tensor, %c: tensor) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[A]], %[[C0]] - // CHECK-DAG: %[[CONCAT_DIM_A:.*]] = tensor.dim %[[A]], %[[C1]] - // CHECK-DAG: %[[CONCAT_DIM_B:.*]] = tensor.dim %[[B]], %[[C1]] - // CHECK-DAG: %[[CONCAT_DIM_C:.*]] = tensor.dim %[[C]], %[[C1]] - // CHECK-DAG: %[[CONCAT_DIM_AB:.*]] = arith.addi %[[CONCAT_DIM_A]], %[[CONCAT_DIM_B]] - // CHECK-DAG: %[[CONCAT_DIM_ABC:.*]] = arith.addi %[[CONCAT_DIM_AB]], %[[CONCAT_DIM_C]] - // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[D0]], %[[CONCAT_DIM_ABC]]) - // CHECK: %[[CONCATENATE:.*]] = thlo.concatenate - // CHECK-SAME: ins(%[[A]] : tensor, %[[B]] : tensor, %[[C]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: dimension = 1 - // CHECK: return %[[CONCATENATE]] - %concat = "mhlo.concatenate"(%a, %b, %c) { dimension = 1 } : (tensor, tensor, tensor) -> tensor - func.return %concat : tensor -} - -// CHECK-LABEL: @concatenate_with_static_info -// CHECK-SAME: %[[A:.*]]: tensor<64x32xi32>, %[[B:.*]]: tensor<64x16xi32>, %[[C:.*]]: tensor<64x?xi32> -func.func @concatenate_with_static_info(%a: tensor<64x32xi32>, %b: tensor<64x16xi32>, %c: tensor<64x?xi32>) -> tensor<64x?xi32> { - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C48:.*]] = arith.constant 48 - // CHECK-DAG: %[[CONCAT_DIM_C:.*]] = tensor.dim %[[C]], %[[C1]] - // CHECK-DAG: %[[CONCAT_DIM_SUM:.*]] = arith.addi %[[CONCAT_DIM_C]], %[[C48]] - // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[CONCAT_DIM_SUM]]) - // CHECK: %[[CONCAT:.*]] = thlo.concatenate - // CHECK-SAME: ins(%[[A]] : tensor<64x32xi32>, %[[B]] : tensor<64x16xi32>, %[[C]] : tensor<64x?xi32>) - // CHECK-SAME: outs(%[[INIT]] : tensor<64x?xi32>) - // CHECK-SAME: dimension = 1 - // CHECK: return %[[CONCAT]] - %concat = "mhlo.concatenate"(%a, %b, %c) { dimension = 1 } : (tensor<64x32xi32>, tensor<64x16xi32>, tensor<64x?xi32>) -> tensor<64x?xi32> - func.return %concat : tensor<64x?xi32> -} - -func.func @simple_gather(%operand : tensor<3x3xf32>, - %indices: tensor<3x2xi64>) -> tensor<3x1x1xf32> { - %0 = "mhlo.gather"(%operand, %indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 1]> : tensor<2xi64> - } : (tensor<3x3xf32>, tensor<3x2xi64>) -> tensor<3x1x1xf32> - func.return %0 : tensor<3x1x1xf32> -} - -// CHECK-LABEL: @simple_gather -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<3x1x1xf32> -// CHECK: %[[CAST_INIT:.*]] = tensor.empty() : tensor<3x2xindex> -// CHECK: %[[CAST:.*]] = linalg.map { arith.index_cast } -// CHECK: %[[GATHER:.*]] = thlo.gather -// CHECK-SAME: ins(%{{.*}} : tensor<3x3xf32>, %[[CAST]] : tensor<3x2xindex>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x1x1xf32>) -// CHECK: return %[[GATHER]] - -func.func @simple_gather_unsigned( - %operand : tensor<3x3xui32>, %indices: tensor<3x2xui64>) -> tensor<3x1x1xui32> { - %0 = "mhlo.gather"(%operand, %indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 1]> : tensor<2xi64> - } : (tensor<3x3xui32>, tensor<3x2xui64>) -> tensor<3x1x1xui32> - func.return %0 : tensor<3x1x1xui32> -} -// CHECK-LABEL: @simple_gather_unsigned -// CHECK-DAG: %[[CAST:.*]] = builtin.unrealized_conversion_cast {{.*}} : tensor<3x3xui32> to tensor<3x3xi32> -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x1x1xi32> -// CHECK: %[[INDEX_CAST_INIT:.*]] = tensor.empty() : tensor<3x2xindex> -// CHECK: %[[INDEX_CAST:.*]] = linalg.map { arith.index_castui } -// CHECK: %[[GATHER:.*]] = thlo.gather -// CHECK-SAME: ins(%[[CAST]] : tensor<3x3xi32>, %[[INDEX_CAST]] : tensor<3x2xindex>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x1x1xi32>) -// CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[GATHER]] : tensor<3x1x1xi32> to tensor<3x1x1xui32> -// CHECK: return %[[CAST2]] - -func.func @gather_with_slices( - %operand : tensor<300x300xi32>, %indices: tensor<3x2xi64>) -> tensor<3x101x102xi32> { - %0 = "mhlo.gather"(%operand, %indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[101, 102]> : tensor<2xi64> - } : (tensor<300x300xi32>, tensor<3x2xi64>) -> tensor<3x101x102xi32> - func.return %0 : tensor<3x101x102xi32> -} -// CHECK-LABEL: @gather_with_slices -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<3x101x102xi32> -// CHECK: thlo.gather -// CHECK-SAME: outs(%[[INIT]] : tensor<3x101x102xi32>) - -func.func @gather_dynamic( - %operand : tensor<300xi32>, %indices: tensor) -> tensor { - %0 = "mhlo.gather"(%operand, %indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = dense<[42]> : tensor<1xi64> - } : (tensor<300xi32>, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: @gather_dynamic -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim {{.*}} %[[C0]] : tensor -// CHECK: %[[INIT:.*]] = tensor.empty(%dim) : tensor -// CHECK: thlo.gather -// CHECK-SAME: outs(%[[INIT]] : tensor) - -func.func @unsupported_gather(%operand: tensor<3x3xf32>, - %indices: tensor<3x2xi64>) -> tensor<3xf32> { - %0 = "mhlo.gather"(%operand, %indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 1, - offset_dims = [], - start_index_map = [1, 0] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 1]> : tensor<2xi64> - } : (tensor<3x3xf32>, tensor<3x2xi64>) -> tensor<3xf32> - func.return %0 : tensor<3xf32> -} - -// CHECK-LABEL: @unsupported_gather -// CHECK: mhlo.gather - -func.func @simple_scatter(%dst: tensor<3x3xf32>, %indices: tensor<2x2xi32>, - %update: tensor<2x1x3xf32>) -> tensor<3x3xf32> { - %0 = "mhlo.scatter"(%dst, %indices, %update) ({ - ^bb0(%in: tensor, %out: tensor): - %sum = mhlo.add %in, %out : tensor - "mhlo.return"(%sum) : (tensor) -> () - }) { - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1, 2], - inserted_window_dims = [], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 1, - >, - unique_indices = false, - indices_are_sorted = false - } : (tensor<3x3xf32>, tensor<2x2xi32>, tensor<2x1x3xf32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} - -// CHECK-LABEL: @simple_scatter -// CHECK-SAME: (%[[DST:.*]]: tensor<3x3xf32>, %[[INDICES:.*]]: tensor<2x2xi32>, -// CHECK-SAME: %[[UPDATE:.*]]: tensor<2x1x3xf32>) -// CHECK: %[[CAST_INIT:.*]] = tensor.empty() : tensor<2x2xindex> -// CHECK: %[[CAST:.*]] = linalg.map { arith.index_cast } -// CHECK: thlo.scatter -// CHECK-SAME: ins(%[[CAST]] : tensor<2x2xindex>, -// CHECK-SAME: %[[UPDATE]] : tensor<2x1x3xf32>) -// CHECK-SAME: outs(%[[DST]] : tensor<3x3xf32>) -// CHECK-NEXT: (%[[UPD:.*]]: f32, %[[CUR:.*]]: f32) { -// CHECK-NEXT: %[[CUR_T:.*]] = tensor.from_elements %[[CUR]] : tensor -// CHECK-NEXT: %[[UPD_T:.*]] = tensor.from_elements %[[UPD]] : tensor -// CHECK-NEXT: %[[CUR:.*]] = tensor.extract %[[CUR_T]][] : tensor -// CHECK-NEXT: %[[UPD:.*]] = tensor.extract %[[UPD_T]][] : tensor -// CHECK-NEXT: arith.addf %[[CUR]], %[[UPD]] : f32 -// CHECK-NEXT: tensor.from_elements -// CHECK-NEXT: tensor.extract - -// ----- - -// CHECK-LABEL: func @sort -// CHECK-SAME: (%[[IN0:.*]]: tensor<16x16xf32>, %[[IN1:.*]]: tensor<16x16xi32>) -func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - %0:2 = "mhlo.sort"(%input0, %input1) ({ - ^bb0(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor): - %7 = "mhlo.compare"(%arg0, %arg1) - {comparison_direction = #mhlo} - : (tensor, tensor) -> tensor - "mhlo.return"(%7) : (tensor) -> () - }) {dimension = -1 : i64, is_stable = true} - : (tensor<16x16xf32>, tensor<16x16xi32>) - -> (tensor<16x16xf32>, tensor<16x16xi32>) - func.return -} -// CHECK-DAG: %[[INIT0:.*]] = tensor.empty() : tensor<16x16xf32> -// CHECK-DAG: %[[INIT1:.*]] = tensor.empty() : tensor<16x16xi32> -// CHECK: thlo.sort -// CHECK-SAME: ins(%[[IN0]] : tensor<16x16xf32>, %[[IN1]] : tensor<16x16xi32>) -// CHECK-SAME: outs(%[[INIT0]] : tensor<16x16xf32>, %[[INIT1]] : tensor<16x16xi32>) -// CHECK-SAME: dimension = 1 -// CHECK-SAME: is_stable = true -// CHECK: (%[[FLOAT0:.*]]: f32, %[[FLOAT1:.*]]: f32, %[[INT0:.*]]: i32, %[[INT1:.*]]: i32) -// CHECK-DAG: %[[TENSOR0:.*]] = tensor.from_elements %[[FLOAT0]] : tensor -// CHECK-DAG: %[[TENSOR1:.*]] = tensor.from_elements %[[FLOAT1]] : tensor -// CHECK-DAG: %[[EXTRACTED0:.*]] = tensor.extract %[[TENSOR0]][] : tensor -// CHECK-DAG: %[[EXTRACTED1:.*]] = tensor.extract %[[TENSOR1]][] : tensor -// CHECK: %[[CMPRESULT:.*]] = arith.cmpf ogt, %[[EXTRACTED0]], %[[EXTRACTED1]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = tensor.from_elements %[[CMPRESULT]] : tensor -// CHECK-NEXT: %[[EXTRACTED_RESULT:.*]] = tensor.extract %[[RESULT]][] : tensor -// CHECK-NEXT: thlo.yield %[[EXTRACTED_RESULT]] : i1 - -func.func @reverse_static(%input: tensor<100xf32>) - -> tensor<100xf32> { - %res = "mhlo.reverse"(%input) {dimensions = dense<[0]> : tensor<1xi64>} : - (tensor<100xf32>) -> tensor<100xf32> - func.return %res : tensor<100xf32> -} - -// CHECK-LABEL: func @reverse_static -// CHECK-SAME: (%[[ARG0:.*]]: tensor<100xf32>) -> tensor<100xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty -// CHECK: %[[REVERSED:.*]] = thlo.reverse -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[EMPTY]] -// CHECK-SAME: reverse_dimensions = [0] -// CHECK-NEXT: return %[[REVERSED]] - -func.func @reverse_dynamic(%input: tensor) - -> tensor { - %res = "mhlo.reverse"(%input) {dimensions = dense<[0, 1]> : tensor<2xi64>} : - (tensor) -> tensor - func.return %res : tensor -} - -// CHECK-LABEL: func @reverse_dynamic -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor -// CHECK: %[[C0:.*]] = arith.constant -// CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[C1:.*]] = arith.constant -// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) -// CHECK: %[[REVERSED:.*]] = thlo.reverse -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[EMPTY]] -// CHECK-SAME: reverse_dimensions = [0, 1] -// CHECK-NEXT: return %[[REVERSED]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir deleted file mode 100644 index 5e88a2dd320019..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --computeop-and-func-bufferize \ -// RUN: --allow-unregistered-dialect --final-bufferize=alignment=128 | \ -// RUN: FileCheck %s - -func.func @sort(%input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 1 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// CHECK-LABEL: func.func @sort -// CHECK-SAME: (%[[INPUT1:[A-Za-z_0-9]*]]: memref, -// CHECK-SAME: %[[INPUT2:[A-Za-z_0-9]*]]: memref, -// CHECK-SAME: %[[INIT1:[A-Za-z_0-9]*]]: memref, -// CHECK-SAME: %[[INIT2:[A-Za-z_0-9]*]]: memref) -// CHECK-SAME: -> (memref, memref) -// CHECK-DAG: %[[OUTPUT1:.*]] = memref.alloc -// CHECK-DAG: memref.copy %[[INIT1]], %[[OUTPUT1]] -// CHECK-DAG: %[[OUTPUT2:.*]] = memref.alloc -// CHECK-DAG: memref.copy %[[INIT2]], %[[OUTPUT2]] -// CHECK: thlo.sort -// CHECK-SAME: ins(%[[INPUT1]] : memref, -// CHECK-SAME: %[[INPUT2]] : memref) -// CHECK-SAME: outs(%[[OUTPUT1]] : memref, -// CHECK-SAME: %[[OUTPUT2]] : memref) -// CHECK-SAME: dimension = 1 -// CHECK-SAME: is_stable = true -// CHECK-NEXT: (%[[FLOAT1:[A-Za-z_0-9]*]]: f32, %[[FLOAT2:.*]]: f32, -// CHECK-SAME: %[[INT1:[A-Za-z_0-9]*]]: i32, %[[INT2:.*]]: i32) -// CHECK: %[[RESULT:.*]] = arith.cmpf ogt, %[[FLOAT1]], %[[FLOAT2]] : f32 -// CHECK: thlo.yield %[[RESULT]] : i1 -// CHECK: return %[[OUTPUT1]], %[[OUTPUT2]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir deleted file mode 100644 index aaec8cc5b1efd4..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -func.func @reverse_dynamic_fold(%input: tensor<1x?xf32>, %init: tensor<1x?xf32>) - -> tensor<1x?xf32> { - %res = thlo.reverse - ins(%input: tensor<1x?xf32>) - outs(%init: tensor<1x?xf32>) - reverse_dimensions = [0] - func.return %res : tensor<1x?xf32> -} - -// CHECK-LABEL: func @reverse_dynamic_fold -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>, %[[ARG1:.*]]: tensor<1x?xf32> -// CHECK: return %[[ARG0]] \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir deleted file mode 100644 index 8582b40adabda1..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir +++ /dev/null @@ -1,352 +0,0 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file - -func.func @concatenate(%arg1: tensor, - %arg2: tensor, - %dst: tensor) -> tensor { - // expected-error @+1 {{thlo.concatenate' op expected element type of input 'i32' to match output element type 'f32'}} - %cat = thlo.concatenate - ins(%arg1: tensor, %arg2: tensor) - outs(%dst: tensor) - dimension = 0 - func.return %cat : tensor -} - -// ----- - -func.func @concatenate_mismatch_rank(%arg1: tensor, - %arg2: tensor, - %dst: tensor) -> tensor { - // expected-error @+1 {{thlo.concatenate' op expected all args to be rank 2, got 3 in arg 1}} - %cat = thlo.concatenate - ins(%arg1: tensor, %arg2: tensor) - outs(%dst: tensor) - dimension = 0 - func.return %cat : tensor -} - -// ----- - -func.func @concatenate_mismatch_shape(%arg1: tensor, - %arg2: tensor, - %dst: tensor) -> tensor { - // expected-error @+1 {{thlo.concatenate' op shape of input arg 1: 'tensor' doesn't match expected shape 'tensor'}} - %cat = thlo.concatenate - ins(%arg1: tensor, %arg2: tensor) - outs(%dst: tensor) - dimension = 0 - func.return %cat : tensor -} - -// ----- - -func.func @yield_op_inside_mhlo_reduce( - %arg0: tensor<5x4xf32>, %arg1: tensor) -> tensor<5xf32> { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%init: tensor, %arg3: tensor): - %1 = mhlo.add %init, %arg3 : tensor - // expected-error @+1{{'thlo.yield' op expects parent op to be one of}} - thlo.yield %1: tensor - }) {dimensions = dense<1> : tensor<1xi64>} : - (tensor<5x4xf32>, tensor) -> tensor<5xf32> - func.return %0 : tensor<5xf32> -} - -// ----- - -func.func @scatter_indices_wrong_rank(%indices: tensor<2x2x2xindex>, - %updates: tensor<2x1x3xf32>, %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - // expected-error@+1{{expected `indices` to be a 2D tensor}} - %0 = thlo.scatter ins(%indices : tensor<2x2x2xindex>, - %updates : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} - -// ----- - -func.func @scatter_updates_indices_major_dim_mismatch( - %indices: tensor<2x2xindex>, %updates: tensor<3x1x3xf32>, - %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - // expected-error@+1{{expected major dimension of `indices` to match major dimension of `updates`}} - %0 = thlo.scatter ins(%indices : tensor<2x2xindex>, - %updates : tensor<3x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} - -// ----- - -func.func @scatter_indices_dynamic_index_vector_dim( - %indices: tensor<2x?xindex>, %updates: tensor<2x1x3xf32>, - %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - // expected-error@+1{{expected index vector dimension size to be static}} - %0 = thlo.scatter ins(%indices : tensor<2x?xindex>, - %updates : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} - -// ----- - -func.func @scatter_indices_index_vector_dim_too_big( - %indices: tensor<2x9xindex>, %updates: tensor<2x1x3xf32>, - %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - // expected-error@+1{{expected index vector dimension size = 9 to be smaller or equal than `init` rank = 2}} - %0 = thlo.scatter ins(%indices : tensor<2x9xindex>, - %updates : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} - -// ----- - -func.func @scatter_updates_init_rank_mismatch(%indices: tensor<2x2xindex>, - %updates: tensor<2x3xf32>, %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - // expected-error@+1{{expected `updates` rank + 1 to match `init` rank}} - %0 = thlo.scatter ins(%indices : tensor<2x2xindex>, - %updates : tensor<2x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} - -// ----- - -func.func @scatter_updates_init_element_type_mismatch( - %indices: tensor<2x2xindex>, %updates: tensor<2x1x3xf32>, - %init: tensor<3x3xi32>) -> tensor<3x3xi32> { - // expected-error@+1{{expected `updates` element type to match `init` element type}} - %0 = thlo.scatter ins(%indices : tensor<2x2xindex>, - %updates : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xi32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xi32> -} - -// ----- - -func.func @gather_output_result_mismatch( - %arg: tensor<100xf32>, %indices: tensor<42x1xindex>, %dst: tensor<42xf32>) - -> tensor<42xf64> { - // expected-error@+1{{'thlo.gather' op expected type of operand #2 ('tensor<42xf32>') to match type of corresponding result ('tensor<42xf64>')}} - %gather = "thlo.gather"(%arg, %indices, %dst) : - (tensor<100xf32>, tensor<42x1xindex>, tensor<42xf32>) -> (tensor<42xf64>) - func.return %gather : tensor<42xf64> -} - -// ----- - -func.func @gather_invalid_dynamic_indices( - %arg: tensor<100xf32>, %indices: tensor<42x?xindex>, %dst: tensor<42xf32>) - -> tensor<42xf64> { - // expected-error@+1{{'thlo.gather' op expected type of operand #2 ('tensor<42xf32>') to match type of corresponding result ('tensor<42xf64>')}} - %gather = "thlo.gather"(%arg, %indices, %dst) : - (tensor<100xf32>, tensor<42x?xindex>, tensor<42xf32>) -> (tensor<42xf64>) - func.return %gather : tensor<42xf64> -} - -// ----- - -func.func @gather_invalid_indices_shape( - %arg: tensor<100xf32>, %indices: tensor<42xindex>, %dst: tensor<42xf32>) - -> tensor<42xf64> { - // expected-error@+1{{'thlo.gather' op expected `indices` to be a 2D tensor}} - %gather = "thlo.gather"(%arg, %indices, %dst) : - (tensor<100xf32>, tensor<42xindex>, tensor<42xf32>) -> (tensor<42xf64>) - func.return %gather : tensor<42xf64> -} - -// ----- - -func.func @gather_indices_dst_mismatch( - %arg: tensor<100xf32>, %indices: tensor<42x1xindex>, %dst: tensor<43xf32>) - -> tensor<43xf64> { - // expected-error@+1{{'thlo.gather' op expected major dimension of `startIndices` to match major dimension of `init`}} - %gather = "thlo.gather"(%arg, %indices, %dst) : - (tensor<100xf32>, tensor<42x1xindex>, tensor<43xf32>) -> (tensor<43xf64>) - func.return %gather : tensor<43xf64> -} - -// ----- - -func.func @gather_invalid_dst_shape( - %arg: tensor<100xf32>, %indices: tensor<42x1xindex>, %dst: tensor<42x?xf32>) - -> tensor<42x?xf64> { - // expected-error@+1{{'thlo.gather' op only the major dimenion of `init` may be dynamic}} - %gather = "thlo.gather"(%arg, %indices, %dst) : - (tensor<100xf32>, tensor<42x1xindex>, tensor<42x?xf32>) -> (tensor<42x?xf64>) - func.return %gather : tensor<42x?xf64> -} - -// ----- - -func.func @sort_mismatched_number_of_inputs_and_outputs( - %input1: tensor, %input2: tensor, - %init1: tensor) - -> tensor { - // expected-error@+1{{'thlo.sort' op expected the number of inputs 2 to match the number of outputs 1}} - %sorted = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted : tensor -} - -// ----- - -func.func @sort_mismatched_number_of_inputs_and_comparator_arguments( - %input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - // expected-error@+1{{'thlo.sort' op expected the number of block arguments 3 to be twice the number of inputs (2*2)}} - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// ----- - -func.func @sort_mismatched_input_and_comparator_type( - %input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - // expected-error@+1{{'thlo.sort' op expected element type of input 1 to match type of the corresponding arguments to the comparison function but got 'i32' and ('i32', 'f32')}} - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: f32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// ----- - -func.func @sort_comparator_yields_different_than_one_output( - %input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - // expected-error@+1{{'thlo.yield' op expects number of tensor output args = 1 to match the number of yield operands = 2}} - thlo.yield %gt, %gt : i1, i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// ----- - -func.func @sort_comparator_yields_non_boolean( - %input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - // expected-error@+1{{'thlo.yield' op expects yield operand 0 with type = 'f32' to match output arg element type = 'i1'}} - thlo.yield %e11 : f32 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// ----- - -func.func @sort_inputs_have_different_shapes( - %input1: tensor<64x32xf32>, %input2: tensor<32x32xi32>, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - // expected-error@+1{{'thlo.sort' op expected all inputs to have the same shape (64, 32) but input 1 has shape (32, 32)}} - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor<64x32xf32>, %input2: tensor<32x32xi32>) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -// ----- - -func.func @sort_output_has_different_shape_from_inputs( - %input1: tensor<64x32xf32>, %input2: tensor<64x32xi32>, - %init1: tensor<32x64xf32>, %init2: tensor) - -> (tensor<32x64xf32>, tensor) { - // expected-error@+1{{'thlo.sort' op expected outputs to have shape (64, 32) but output 0 has shape (32, 64)}} - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor<64x32xf32>, %input2: tensor<64x32xi32>) - outs(%init1: tensor<32x64xf32>, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor<32x64xf32>, tensor -} - -// ----- - -func.func @sort_dimension_is_incompatible_with_rank_of_inputs( - %input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - // expected-error@+1{{'thlo.sort' op sorting dimension must be in range [0, 2) but got 2}} - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 2 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir deleted file mode 100644 index d4f3709d718a22..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir +++ /dev/null @@ -1,203 +0,0 @@ -// RUN: mlir-hlo-opt -thlo-legalize-sort -canonicalize %s | FileCheck %s - -func.func @sort(%input1: memref, %input2: memref, - %init1: memref, %init2: memref) { - thlo.sort - ins(%input1: memref, %input2: memref) - outs(%init1: memref, %init2: memref) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return -} - -// CHECK-LABEL: func.func @sort( -// CHECK-SAME: %[[INPUT1:[A-Za-z0-9]*]]: memref, -// CHECK-SAME: %[[INPUT2:[A-Za-z0-9]*]]: memref, -// CHECK-SAME: %[[INIT1:[A-Za-z0-9]*]]: memref, -// CHECK-SAME: %[[INIT2:[A-Za-z0-9]*]]: memref) { -// CHECK: %[[CTRUE:.*]] = arith.constant true -// CHECK: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[CFALSE:.*]] = arith.constant false -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[SORT_DIM:.*]] = memref.dim %[[INPUT1]], %[[C0]] -// CHECK: %[[DYN_DIM0:.*]] = memref.dim %[[INPUT1]], %[[C0]] -// CHECK: %[[DYN_DIM1:.*]] = memref.dim %[[INPUT1]], %[[C1]] -// CHECK: %[[SCRATCH1:.*]] = memref.alloc(%[[DYN_DIM0]], %[[DYN_DIM1]]) -// CHECK: %[[SCRATCH2:.*]] = memref.alloc(%[[DYN_DIM0]], %[[DYN_DIM1]]) -// CHECK: %[[BATCH_DIM_SIZE:.*]] = memref.dim %[[INPUT1]], %[[C1]] -// CHECK: %[[PARITY:.*]] = scf.for -// CHECK-SAME: %[[SUBVIEW_INDEX:.*]] = %[[C0]] to %[[BATCH_DIM_SIZE]] -// CHECK-SAME: step %[[C1]] -// CHECK-SAME: iter_args(%[[ARG5:.*]] = %[[CFALSE]]) -> (i1) { -// CHECK: %[[SUBVIEW_INPUT1:.*]] = memref.subview -// CHECK-SAME: %[[INPUT1]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// CHECK: %[[SUBVIEW_INPUT2:.*]] = memref.subview -// CHECK-SAME: %[[INPUT2]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// CHECK: %[[SUBVIEW_INIT1:.*]] = memref.subview -// CHECK-SAME: %[[INIT1]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// CHECK: %[[SUBVIEW_INIT2:.*]] = memref.subview -// CHECK-SAME: %[[INIT2]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// CHECK: %[[SUBVIEW_SCRATCH1:.*]] = memref.subview -// CHECK-SAME: %[[SCRATCH1]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// CHECK: %[[SUBVIEW_SCRATCH2:.*]] = memref.subview -// CHECK-SAME: %[[SCRATCH2]][0, %[[SUBVIEW_INDEX]]] -// CHECK-SAME [%[[SORT_DIM]], 1] [1, 1] -// COM: // We first sort ELEMs in groups of 16 using an -// COM: // insertion sort. -// CHECK: scf.for %[[LO:.*]] = %[[C0]] to %[[SORT_DIM]] -// CHECK-SAME: step %[[C16]] { -// CHECK: %[[UPPER_BOUND:.*]] = arith.addi %[[LO]], %[[C16]] -// CHECK: %[[END:.*]] = arith.minsi %[[UPPER_BOUND]], %[[SORT_DIM]] -// CHECK: %[[LO_IN1:.*]] = memref.load %[[SUBVIEW_INPUT1]][%[[LO]]] -// CHECK: %[[LO_IN2:.*]] = memref.load %[[SUBVIEW_INPUT2]][%[[LO]]] -// CHECK: memref.store %[[LO_IN1]], %[[SUBVIEW_INIT1]][%[[LO]]] -// CHECK: memref.store %[[LO_IN2]], %[[SUBVIEW_INIT2]][%[[LO]]] -// CHECK: %[[LO_PLUS_1:.*]] = arith.addi %[[LO]], %[[C1]] -// CHECK: scf.for %[[START:.*]] = %[[LO_PLUS_1]] to %[[END]] -// CHECK-SAME: step %[[C1]] { -// CHECK: %[[PIVOT1:.*]] = memref.load %[[SUBVIEW_INPUT1]][%[[START]]] -// CHECK: %[[PIVOT2:.*]] = memref.load %[[SUBVIEW_INPUT2]][%[[START]]] -// COM: // Binary search of the insertion point. -// CHECK: %[[LR:.*]]:2 = scf.while -// CHECK-SAME: (%[[LEFT:.*]] = %[[LO]], %[[RIGHT:.*]] = %[[START]]) -// CHECK-SAME: : (index, index) -> (index, index) { -// CHECK: %[[L_LT_R:.*]] = arith.cmpi slt, %[[LEFT]], %[[RIGHT]] -// CHECK: scf.condition(%[[L_LT_R]]) %[[LEFT]], %[[RIGHT]] -// CHECK: } do { -// CHECK: ^bb0(%[[LEFT_:.*]]: index, %[[RIGHT_:.*]]: index): -// CHECK: %[[SUM_LR:.*]] = arith.addi %[[LEFT_]], %[[RIGHT_]] -// CHECK: %[[MID:.*]] = arith.shrui %[[SUM_LR]], %[[C1]] -// CHECK: %[[MID_PLUS_1:.*]] = arith.addi %[[MID]], %[[C1]] -// CHECK: %[[MEDIAN:.*]] = memref.load %[[SUBVIEW_INIT1]][%[[MID]]] -// CHECK: %[[CMP_PIVOT_MEDIAN:.*]] = arith.cmpf ogt, %[[PIVOT1]], %[[MEDIAN]] : f32 -// CHECK: %[[NEW_LEFT:.*]] = arith.select %[[CMP_PIVOT_MEDIAN]], %[[LEFT_]], %[[MID_PLUS_1]] -// CHECK: %[[NEW_RIGHT:.*]] = arith.select %[[CMP_PIVOT_MEDIAN]], %[[MID]], %[[RIGHT_]] -// CHECK: scf.yield %[[NEW_LEFT]], %[[NEW_RIGHT]] -// CHECK: } -// COM: // Move the n ELEMs that are larger than the pivot -// COM: // once to the right. -// CHECK: %[[N:.*]] = arith.subi %[[START]], %[[LR:.*]]#0 -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[N]] step %[[C1]] { -// CHECK: %[[CUR_IX:.*]] = arith.subi %[[START]], %[[I]] -// CHECK: %[[CUR_IX_MINUS_1:.*]] = arith.subi %[[CUR_IX]], %[[C1]] : index -// CHECK: %[[ELEM_TO_MOVE1:.*]] = memref.load %[[SUBVIEW_INIT1]][%[[CUR_IX_MINUS_1]]] -// CHECK: %[[ELEM_TO_MOVE2:.*]] = memref.load %[[SUBVIEW_INIT2]][%[[CUR_IX_MINUS_1]]] -// CHECK: memref.store %[[ELEM_TO_MOVE1]], %[[SUBVIEW_INIT1]][%[[CUR_IX]]] -// CHECK: memref.store %[[ELEM_TO_MOVE2]], %[[SUBVIEW_INIT2]][%[[CUR_IX]]] -// CHECK: } -// CHECK: memref.store %[[PIVOT1]], %[[SUBVIEW_INIT1]][%[[LR]]#0] -// CHECK: memref.store %[[PIVOT2]], %[[SUBVIEW_INIT2]][%[[LR]]#0] -// CHECK: } -// CHECK: } -// COM: // Merge subarrays of each input together until the final -// COM: // sorted array is computed. -// CHECK: %[[MERGE_RESULTS:.*]]:2 = scf.while -// CHECK-SAME: (%[[SUBARRAY_SIZE:[A-Za-z0-9]*]] = %[[C16]], -// CHECK-SAME: %[[PARITY_:[A-Za-z0-9]*]] = %[[CFALSE]]) -// CHECK: %[[ARE_ALL_SUBARRAYS_MERGED:.*]] = arith.cmpi slt, %[[SUBARRAY_SIZE]], %[[SORT_DIM]] -// CHECK: scf.condition(%[[ARE_ALL_SUBARRAYS_MERGED]]) %[[SUBARRAY_SIZE]], %[[PARITY_]] -// CHECK: } do { -// CHECK: ^bb0(%[[SUBARRAY_SIZE_:[A-Za-z0-9]*]]: index, -// CHECK-SAME: %[[PARITY__:[A-Za-z0-9]*]]: i1): -// CHECK: %[[DOUBLE_SUBARRAY_SIZE:.*]] = arith.addi %[[SUBARRAY_SIZE_]], %[[SUBARRAY_SIZE_]] -// COM: // Merge all successive pairs of subarrays of maximum -// COM: // size SUBARRAY_SIZE. -// CHECK: scf.if %[[PARITY_]] { -// CHECK: scf.for -// CHECK-SAME: %[[DOUBLE_SUBARRAY_START:.*]] = %[[C0]] to %[[SORT_DIM]] -// CHECK-SAME: step %[[DOUBLE_SUBARRAY_SIZE]] { -// CHECK: %[[SUBARRAY1_UPPER_BOUND:.*]] = arith.addi %[[DOUBLE_SUBARRAY_START]], %[[SUBARRAY_SIZE_]] -// CHECK: %[[SUBARRAY1_END:.*]] = arith.minsi %[[SORT_DIM]], %[[SUBARRAY1_UPPER_BOUND]] -// CHECK: %[[SUBARRAY2_UPPER_BOUND:.*]] = arith.addi %[[DOUBLE_SUBARRAY_START]], %[[DOUBLE_SUBARRAY_SIZE]] -// CHECK: %[[SUBARRAY2_END:.*]] = arith.minsi %[[SORT_DIM]], %[[SUBARRAY2_UPPER_BOUND]] -// COM: // Merge two subarrays together. -// CHECK: %[[POST_MERGE_INDICES:.*]]:3 = scf.while -// CHECK-SAME: (%[[OUTPUT_INDEX:[A-Za-z0-9]*]] = %[[DOUBLE_SUBARRAY_START]], -// CHECK-SAME: %[[SUBARRAY1_INDEX:[A-Za-z0-9]*]] = %[[DOUBLE_SUBARRAY_START]], -// CHECK-SAME: %[[SUBARRAY2_INDEX:[A-Za-z0-9]*]] = %[[SUBARRAY1_END]]) -// CHECK: %[[SUBARRAY1_IS_CONSUMED:.*]] = arith.cmpi slt, %[[SUBARRAY1_INDEX]], %[[SUBARRAY1_END]] -// CHECK: %[[SUBARRAY2_IS_CONSUMED:.*]] = arith.cmpi slt, %[[SUBARRAY2_INDEX]], %[[SUBARRAY2_END]] -// CHECK: %[[IS_MERGE_OVER:.*]] = arith.andi %[[SUBARRAY1_IS_CONSUMED]], %[[SUBARRAY2_IS_CONSUMED]] : i1 -// CHECK: scf.condition(%[[IS_MERGE_OVER]]) %[[OUTPUT_INDEX]], %[[SUBARRAY1_INDEX]], %[[SUBARRAY2_INDEX]] -// CHECK: } do { -// CHECK: ^bb0(%[[OUTPUT_INDEX_:[A-Za-z0-9]*]]: index, -// CHECK-SAME: %[[SUBARRAY1_INDEX_:[A-Za-z0-9]*]]: index, -// CHECK-SAME: %[[SUBARRAY2_INDEX_:[A-Za-z0-9]*]]: index): -// CHECK: %[[RHS_ELEM1:.*]] = memref.load %[[SUBVIEW_INIT1]][%[[SUBARRAY1_INDEX_]]] -// CHECK: %[[RHS_ELEM2:.*]] = memref.load %[[SUBVIEW_INIT2]][%[[SUBARRAY1_INDEX_]]] -// CHECK: %[[LHS_ELEM1:.*]] = memref.load %[[SUBVIEW_INIT1]][%[[SUBARRAY2_INDEX_]]] -// CHECK: %[[LHS_ELEM2:.*]] = memref.load %[[SUBVIEW_INIT2]][%[[SUBARRAY2_INDEX_]]] -// CHECK: %[[COMPARATOR_RESULT:.*]] = arith.cmpf ogt, %[[LHS_ELEM1]], %[[RHS_ELEM1]] : f32 -// CHECK: %[[LEFT_ELEM1:.*]] = arith.select %[[COMPARATOR_RESULT]], %[[LHS_ELEM1]], %[[RHS_ELEM1]] : f32 -// CHECK: %[[LEFT_ELEM2:.*]] = arith.select %[[COMPARATOR_RESULT]], %[[LHS_ELEM2]], %[[RHS_ELEM2]] : i32 -// CHECK: memref.store %[[LEFT_ELEM1]], %[[SUBVIEW_SCRATCH1]][%[[OUTPUT_INDEX_]]] -// CHECK: memref.store %[[LEFT_ELEM2]], %[[SUBVIEW_SCRATCH2]][%[[OUTPUT_INDEX_]]] -// CHECK: %[[SUBARRAY1_INDEX__PLUS_1:.*]] = arith.addi %[[SUBARRAY1_INDEX_]], %[[C1]] -// CHECK: %[[NEW_SUBARRAY1_INDEX:.*]] = arith.select %[[COMPARATOR_RESULT]], %[[SUBARRAY1_INDEX_]], %[[SUBARRAY1_INDEX__PLUS_1]] -// CHECK: %[[SUBARRAY2_INDEX__PLUS_1:.*]] = arith.addi %[[SUBARRAY2_INDEX_]], %[[C1]] -// CHECK: %[[NEW_SUBARRAY2_INDEX:.*]] = arith.select %[[COMPARATOR_RESULT]], %[[SUBARRAY2_INDEX__PLUS_1]], %[[SUBARRAY2_INDEX_]] -// CHECK: %[[NEW_OUTPUT_INDEX:.*]] = arith.addi %[[OUTPUT_INDEX_]], %[[C1]] -// CHECK: scf.yield %[[NEW_OUTPUT_INDEX]], %[[NEW_SUBARRAY1_INDEX]], %[[NEW_SUBARRAY2_INDEX]] -// CHECK: } -// COM: // After the merge, exactly one of the two subarrays -// COM: // contains unprocessed (and sorted) ELEMs. This -// COM: // appends the corresponding ELEMs to the result -// COM: // array. -// CHECK: %[[IS_SUBARRAY1_CONSUMED:.*]] = arith.cmpi slt, %[[POST_MERGE_INDICES]]#1, %[[SUBARRAY1_END]] -// CHECK: %[[INDEX_TO_UNPROCESSED_ELEMS:.*]] = arith.select %[[IS_SUBARRAY1_CONSUMED]], %[[POST_MERGE_INDICES]]#1, %[[POST_MERGE_INDICES]]#2 -// CHECK: %[[UNPROCESSED_SUBARRAY_END:.*]] = arith.select %[[IS_SUBARRAY1_CONSUMED]], %[[SUBARRAY1_END]], %[[SUBARRAY2_END]] -// CHECK: %[[NUMBER_OF_UNPROCESSED_ELEMS:.*]] = arith.subi %[[UNPROCESSED_SUBARRAY_END]], %[[INDEX_TO_UNPROCESSED_ELEMS]] -// CHECK: scf.for -// CHECK-SAME: %[[I_:.*]] = %[[C0]] to %[[NUMBER_OF_UNPROCESSED_ELEMS]] -// CHECK-SAME: step %[[C1]] { -// CHECK: %[[UNPROCESSED_ELEM_INDEX:.*]] = arith.addi %[[INDEX_TO_UNPROCESSED_ELEMS]], %[[I_]] -// CHECK: %[[OUTPUT_INDEX__:.*]] = arith.addi %[[POST_MERGE_INDICES]]#0, %[[I_]] -// CHECK: %[[UNPROCESSED_ELEM1:.*]] = memref.load %[[SUBVIEW_INIT1]][%[[UNPROCESSED_ELEM_INDEX]]] -// CHECK: %[[UNPROCESSED_ELEM2:.*]] = memref.load %[[SUBVIEW_INIT2]][%[[UNPROCESSED_ELEM_INDEX]]] -// CHECK: memref.store %[[UNPROCESSED_ELEM1]], %[[SUBVIEW_SCRATCH1]][%[[OUTPUT_INDEX__]]] -// CHECK: memref.store %[[UNPROCESSED_ELEM2]], %[[SUBVIEW_SCRATCH2]][%[[OUTPUT_INDEX__]]] -// CHECK: } -// CHECK: } -// COM: // Else block as above, but with read and write buffers -// COM: // swapped. -// CHECK: } -// CHECK: %[[NEW_PARITY:.*]] = arith.subi %[[CTRUE]], %[[PARITY__]] : i1 -// CHECK: scf.yield %[[DOUBLE_SUBARRAY_SIZE]], %[[NEW_PARITY]] -// CHECK: } -// CHECK: scf.yield %[[MERGE_RESULTS]]#1 : i1 -// CHECK: } -// CHECK: scf.if %[[PARITY]] { -// CHECK: memref.copy %[[SCRATCH1]], %[[INIT1]] -// CHECK: memref.copy %[[SCRATCH2]], %[[INIT2]] -// CHECK: } -// CHECK: memref.dealloc %[[SCRATCH1]] -// CHECK: memref.dealloc %[[SCRATCH2]] -// CHECK: return -// CHECK: } - -// ----- - -// CHECK-LABEL: @sort_strided -func.func @sort_strided(%input: memref<47x1xf32, strided<[7, 1], offset: ?>>, - %init: memref<47x1xf32, strided<[1, 7], offset: ?>>) { - thlo.sort - ins(%input : memref<47x1xf32, strided<[7, 1], offset: ?>>) - outs(%init : memref<47x1xf32, strided<[1, 7], offset: ?>>) - dimension = 0 - is_stable = true - (%lhs: f32, %rhs: f32) { - %gt = arith.cmpf ogt, %lhs, %rhs: f32 - thlo.yield %gt : i1 - } - func.return -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir deleted file mode 100644 index cab86c18cdc569..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir +++ /dev/null @@ -1,179 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --allow-unregistered-dialect | \ -// RUN: mlir-hlo-opt --verify-diagnostics --split-input-file \ -// RUN: --allow-unregistered-dialect | \ -// RUN: FileCheck %s - -func.func @concatenate(%arg1: tensor, - %arg2: tensor, - %dst: tensor) -> tensor { - %cat = thlo.concatenate - ins(%arg1: tensor, %arg2: tensor) - outs(%dst: tensor) - dimension = 0 - func.return %cat : tensor -} -// CHECK-LABEL: func @concatenate - -// ----- - -func.func @concatenate_result_number(%dst: tensor) -> tensor { - %a:2 = "test.op"() : () -> (tensor, tensor) - %cat = thlo.concatenate - ins(%a#0: tensor, %a#1: tensor) - outs(%dst: tensor) - dimension = 0 - func.return %cat : tensor -} -// CHECK-LABEL: func @concatenate_result_number - -// ----- - -func.func @concatenate_memref(%arg1: memref, - %arg2: memref, - %dst: memref) { - thlo.concatenate - ins(%arg1: memref, %arg2: memref) - outs(%dst: memref) - dimension = 0 - func.return -} -// CHECK-LABEL: func @concatenate_memref - -// ----- - -func.func @dynamic_broadcast_in_dim(%arg: tensor, - %dst: tensor) { - %bcast = thlo.dynamic_broadcast_in_dim - ins(%arg: tensor) - outs(%dst: tensor) - broadcast_dimensions = [0, 2] - func.return -} -// CHECK-LABEL: func @dynamic_broadcast_in_dim - -// ----- - -func.func @dynamic_broadcast_in_dim_memref(%arg: memref, - %dst: memref) { - thlo.dynamic_broadcast_in_dim - ins(%arg: memref) - outs(%dst: memref) - broadcast_dimensions = [0, 2] - func.return -} -// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref - -// ----- - -func.func @gather(%arg: tensor<100xf32>, - %indices: tensor<42x1xindex>, - %dst: tensor<42xf32>) -> tensor<42xf32> { - %gather = thlo.gather - ins(%arg: tensor<100xf32>, %indices: tensor<42x1xindex>) - outs(%dst: tensor<42xf32>) - func.return %gather : tensor<42xf32> -} -// CHECK-LABEL: func @gather - -// ----- - -func.func @gather_memref(%arg: memref<100xf32>, - %indices: memref<42x1xindex>, - %dst: memref<42xf32>) { - thlo.gather - ins(%arg: memref<100xf32>, %indices: memref<42x1xindex>) - outs(%dst: memref<42xf32>) - func.return -} -// CHECK-LABEL: func @gather_memref - -// ----- - -func.func @scatter(%indices: tensor<2x2xindex>, %updates: tensor<2x1x3xf32>, - %init: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = thlo.scatter ins(%indices : tensor<2x2xindex>, - %updates : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - return %0 : tensor<3x3xf32> -} -// CHECK-LABEL: func @scatter - -// ----- - -func.func @scatter_memref(%indices: memref<2x2xindex>, - %updates: memref<2x1x3xf32>, %init: memref<3x3xf32>) { - thlo.scatter ins(%indices : memref<2x2xindex>, %updates : memref<2x1x3xf32>) - outs(%init : memref<3x3xf32>) - (%in: f32, %out: f32) { - %sum = arith.addf %in, %out : f32 - thlo.yield %sum : f32 - } - func.return -} -// CHECK-LABEL: func @scatter_memref - -// ----- - -func.func @sort(%input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} -// CHECK-LABEL: func @sort -// CHECK: %[[RES1:sorted0]], %[[RES2:sorted1]] = thlo.sort -// CHECK: %[[LHS0:lhs0: f32]], %[[RHS0:rhs0: f32]], -// CHECK-SAME: %[[LHS1:lhs1: i32]], %[[RHS1:rhs1: i32]] - -// ----- - -func.func @sort_memref(%input1: memref, %input2: memref, - %init1: memref, %init2: memref) { - thlo.sort - ins(%input1: memref, %input2: memref) - outs(%init1: memref, %init2: memref) - dimension = 0 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return -} -// CHECK-LABEL: func @sort_memref - -// ----- - -func.func @reverse_static(%input: tensor<100xf32>, %init: tensor<100xf32>) - -> tensor<100xf32> { - %res = thlo.reverse - ins(%input: tensor<100xf32>) - outs(%init: tensor<100xf32>) - reverse_dimensions = [0] - func.return %res : tensor<100xf32> -} -// CHECK-LABEL: func @reverse_static - -// ----- - -func.func @reverse_dynamic(%input: tensor, %init: tensor) - -> tensor { - %res = thlo.reverse - ins(%input: tensor) - outs(%init: tensor) - reverse_dimensions = [0, 1] - func.return %res : tensor -} -// CHECK-LABEL: func @reverse_dynamic \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/tiling.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/tiling.mlir deleted file mode 100644 index 73d0339b196211..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/thlo/tiling.mlir +++ /dev/null @@ -1,416 +0,0 @@ -// RUN: mlir-hlo-opt %s -test-hlo-transform-dialect-interpreter -cse \ -// RUN: -split-input-file | FileCheck %s - -func.func @dynamic_broadcast_in_dim_at_tile(%init : tensor, - %arg : tensor) -> tensor { - %bcast = thlo.dynamic_broadcast_in_dim ins(%arg: tensor) - outs(%init: tensor) broadcast_dimensions = [0, 2] - func.return %bcast : tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.dynamic_broadcast_in_dim"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_using_for %0 [256, 512] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: @dynamic_broadcast_in_dim_at_tile -// CHECK-SAME: %[[INIT:.*]]: tensor, %[[ARG:.*]]: tensor - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-DAG: %[[C256:.*]] = arith.constant 256 -// CHECK-DAG: %[[C512:.*]] = arith.constant 512 -// CHECK-DAG: %[[INIT_DIM_0:.*]] = tensor.dim %[[INIT]], %[[C0]] -// CHECK-DAG: %[[INIT_DIM_1:.*]] = tensor.dim %[[INIT]], %[[C1]] -// CHECK-DAG: %[[INIT_DIM_2:.*]] = tensor.dim %[[INIT]], %[[C2]] -// CHECK-DAG: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[INIT_DIM_0]] step %[[C256]] iter_args(%[[INIT_ARG0:.*]] = %[[INIT]]) -// CHECK-DAG: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[I]])[%[[INIT_DIM_0]]] -// CHECK: %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[INIT_DIM_1]] -// CHECK-SAME: step %[[C512]] -// CHECK-SAME: iter_args(%[[OUT:.*]] = %[[INIT_ARG0]]) -// CHECK: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[J]])[%[[INIT_DIM_1]]] -// CHECK: %[[ARG_DIM_0:.*]] = tensor.dim %[[ARG]], %[[C0]] -// CHECK: %[[ARG_DIM_1:.*]] = tensor.dim %[[ARG]], %[[C1]] -// CHECK: %[[ARG_DIM_2:.*]] = tensor.dim %[[OUT]], %[[C0]] -// CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[ARG_DIM_0]], %[[ARG_DIM_2]] -// CHECK: %[[ARG_DIM_3:.*]] = tensor.dim %[[OUT]], %[[C2]] -// CHECK: %[[CMPI_0:.*]] = arith.cmpi ne, %[[ARG_DIM_1]], %[[ARG_DIM_3]] -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[I]] -// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_0]], %[[C0]], %[[C0]] -// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI]], %[[C1]], %[[MIN]] -// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_0]], %[[C1]], %[[INIT_DIM_2]] -// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[OUT]] -// CHECK-SAME: [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] -// CHECK: %[[EXTRACT_0:.*]] = tensor.extract_slice %[[ARG]] -// CHECK-SAME: [%[[SELECT]], %[[SELECT_0]]] [%[[SELECT_1]], %[[SELECT_2]]] [1, 1] -// CHECK: %[[DYNAMIC:.*]] = thlo.dynamic_broadcast_in_dim -// CHECK-SAME: ins(%[[EXTRACT_0]] -// CHECK-SAME: outs(%[[EXTRACT]] -// CHECK-SAME: broadcast_dimensions = [0, 2] -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[DYNAMIC]] -// CHECK-SAME: into %[[OUT]][%[[I]], %[[J]], %[[C0]]] -// CHECK-SAME: [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK: scf.yield %[[INSERTED]] -// CHECK: scf.yield %[[INNER_FOR]] -// CHECK: return %[[FOR]] - -// ----- - -func.func @scatter_i64(%indices: tensor, - %updates: tensor, %init: tensor) -> tensor { - %result = thlo.scatter - ins (%indices: tensor, %updates: tensor) - outs (%init: tensor) - (%in: i64, %out: i64) { - %0 = arith.addi %in, %out: i64 - thlo.yield %0: i64 - } - return %result : tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.scatter"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loop = transform.structured.tile_using_for %0 [1] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: func.func @scatter_i64( -// CHECK-SAME: %[[INDICES:.*]]: tensor, -// CHECK-SAME: %[[UPDATES:.*]]: tensor, -// CHECK-SAME: %[[INIT:.*]]: tensor - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[INDICES_COUNT:.*]] = tensor.dim %[[INDICES]], %c0 - -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[INDICES_COUNT]] step %[[C1]] -// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[INIT]]) - -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[UPDATE_SUB:.*]] = tensor.extract_slice %[[UPDATES]][%[[I]] -// CHECK-SAME: : tensor -// CHECK: %[[INDICES_SUB:.*]] = tensor.extract_slice %[[INDICES]][%[[I]] -// CHECK-SAME: : tensor -// CHECK-DAG: %[[INIT_DIM_0:.*]] = tensor.dim %[[INIT_]], %[[C0]] -// CHECK-DAG: %[[INIT_DIM_1:.*]] = tensor.dim %[[INIT_]], %[[C1]] -// CHECK: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT_]][0, 0] -// CHECK-SAME: [%[[INIT_DIM_0]], %[[INIT_DIM_1]]] [1, 1] - -// CHECK: %[[SCATTER:.*]] = thlo.scatter -// CHECK-SAME: ins(%[[INDICES_SUB]] : tensor<1x2xindex>, -// CHECK-SAME: %[[UPDATE_SUB]] : tensor<1x?x?xi64>) -// CHECK-SAME: outs(%[[INIT_SUB]] : tensor) -// CHECK: arith.addi -// CHECK: thlo.yield -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[SCATTER]] -// CHECK-SAME: into %[[INIT_]][0, 0] -// CHECK: scf.yield %[[INSERTED:.*]] - -// ----- - -func.func @gather(%operand: tensor, %indices: tensor, - %init: tensor) -> tensor { - %result = thlo.gather - ins (%operand: tensor, %indices: tensor) - outs (%init: tensor) - return %result : tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.gather"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loop = transform.structured.tile_using_for %0 [1] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: @gather -// CHECK-SAME: %[[OPERAND:.*]]: tensor -// CHECK-SAME: %[[INDICES:.*]]: tensor -// CHECK-SAME: %[[INIT:.*]]: -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 -// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 -// CHECK: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[ZERO]] to -// CHECK-SAME: (%[[INIT_:[a-z0-9]+]] = %[[INIT]]) - -// CHECK: %[[INDEX_SLICE:.*]] = tensor.extract_slice %[[INDICES]] -// CHECK-SAME: [%[[I]], 0] [1, 4] [1, 1] - -// CHECK: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT_]] -// CHECK-SAME: [%[[I]], 0] [1, 10] [1, 1] -// CHECK: %[[GATHER_SLICE:.*]] = thlo.gather -// CHECK-SAME: ins(%[[OPERAND]] : tensor, -// CHECK-SAME: %[[INDEX_SLICE]] : tensor<1x4xindex>) -// CHECK-SAME: outs(%[[INIT_SLICE]] : tensor<1x10xf64>) -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[GATHER_SLICE]] -// CHECK-SAME: into %[[INIT_]][%[[I]], 0] [1, 10] -// CHECK: scf.yield %[[INSERTED]] - -// ----- - -func.func @concatenate_at_tile(%init : tensor, %a: tensor, - %b: tensor, %c: tensor) - -> tensor { - %concat = thlo.concatenate - ins(%a : tensor, %b : tensor, %c : tensor) - outs(%init : tensor) - dimension = 1 - func.return %concat : tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.concatenate"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_using_for %0 [256, 512] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: @concatenate_at_tile -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C256:.*]] = arith.constant 256 -// CHECK-DAG: %[[C512:.*]] = arith.constant 512 -// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[FOR:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[DIM]] step %[[C256]] iter_args(%[[INIT_:.*]] = %[[ARG0]]) -// CHECK-DAG: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[ARG4]])[%[[DIM]]] -// CHECK: %[[INNER_FOR:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[DIM_0]] step %[[C512]] -// CHECK-SAME: iter_args(%[[ARG6:.*]] = %[[INIT_]]) -// CHECK: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[ARG5]])[%[[DIM_0]]] -// CHECK: %[[DIM_4:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[MINUI:.*]] = arith.minui %[[ARG5]], %[[DIM_4]] -// CHECK: %[[SUBI:.*]] = arith.subi %[[DIM_4]], %[[MINUI]] -// CHECK: %[[MINUI_0:.*]] = arith.minui %[[SUBI]], %[[MIN_0]] -// CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG1]] -// CHECK-SAME: [%[[ARG4]], %[[MINUI]]] [%[[MIN]], %[[MINUI_0]]] [1, 1] -// CHECK: %[[CMPI:.*]] = arith.cmpi ule, %[[ARG5]], %[[DIM_4]] -// CHECK: %[[SUBI_0:.*]] = arith.subi %[[ARG5]], %[[DIM_4]] -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[SUBI_0]] -// CHECK: %[[DIM_5:.*]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[MINUI_1:.*]] = arith.minui %[[SELECT]], %[[DIM_5]] -// CHECK: %[[SUBI_1:.*]] = arith.subi %[[DIM_5]], %[[MINUI_1]] -// CHECK: %[[MINUI_2:.*]] = arith.minui %[[SUBI_1]], %[[MIN_0]] -// CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG2]] -// CHECK-SAME: [%[[ARG4]], %[[MINUI_1]]] [%[[MIN]], %[[MINUI_2]]] [1, 1] -// CHECK: %[[CMPI_0:.*]] = arith.cmpi ule, %[[SELECT]], %[[DIM_5]] -// CHECK: %[[SUBI_2:.*]] = arith.subi %[[SELECT]], %[[DIM_5]] -// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_0]], %[[C0]], %[[SUBI_2]] -// CHECK: %[[DIM_6:.*]] = tensor.dim %[[ARG3]], %[[C1]] -// CHECK: %[[MINUI_3:.*]] = arith.minui %[[SELECT_0]], %[[DIM_6]] -// CHECK: %[[SUBI_3:.*]] = arith.subi %[[DIM_6]], %[[MINUI_3]] -// CHECK: %[[MINUI_4:.*]] = arith.minui %[[SUBI_3]], %[[MIN_0]] -// CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[ARG3]] -// CHECK-SAME: [%[[ARG4]], %[[MINUI_3]]] [%[[MIN]], %[[MINUI_4]]] [1, 1] -// CHECK: %[[MATERIALIZE_2:.*]] = tensor.extract_slice %[[ARG6]] -// CHECK: [%[[ARG4]], %[[ARG5]]] [%[[MIN]], %[[MIN_0]]] [1, 1] -// CHECK: %[[CONCATENATE:.*]] = thlo.concatenate -// CHECK-SAME: ins(%[[MATERIALIZE]] : tensor, %[[MATERIALIZE_0]] : tensor, %[[MATERIALIZE_1]] : tensor) -// CHECK-SAME: outs(%[[MATERIALIZE_2]] : tensor) -// CHECK-SAME: dimension = 1 -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[CONCATENATE]] -// CHECK-SAME: into %[[ARG6]][%[[ARG4]], %[[ARG5]]] [%[[MIN]], %[[MIN_0]]] [1, 1] -// CHECK: scf.yield %[[INSERTED]] -// CHECK: return %[[FOR]] - -// CHECK-PARALLEL-LABEL: @concatenate_at_tile - -// ----- - -func.func @sort(%input1: tensor, %input2: tensor, - %init1: tensor, %init2: tensor) - -> (tensor, tensor) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 1 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return %sorted1, %sorted2 : tensor, tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.sort"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_using_for %0 [256, 512] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: func.func @sort -// CHECK-SAME: (%[[IN0:[a-zA-Z_0-9]*]]: tensor, -// CHECK-SAME: %[[IN1:[a-zA-Z_0-9]*]]: tensor, -// CHECK-SAME: %[[INIT0:[a-zA-Z_0-9]*]]: tensor, -// CHECK-SAME: %[[INIT1:[a-zA-Z_0-9]*]]: tensor) -// CHECK-DAG: %[[C0:[a-zA-Z_0-9]*]] = arith.constant 0 -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[INIT0]], %[[C0]] -// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[INIT0]], %[[C2]] -// CHECK: scf.for -// CHECK-SAME: %[[START0:.*]] = %[[C0]] to %[[DIM0]] -// CHECK-SAME: iter_args(%[[INIT0_OUTER:.*]] = %[[INIT0]], -// CHECK-SAME: %[[INIT1_OUTER:.*]] = %[[INIT1]]) -// CHECK-DAG: %[[TILE_SIZE0:.*]] = affine.min #map{{[0-9]*}}(%[[START0]])[%[[DIM0]]] -// CHECK: scf.for -// CHECK-SAME: %[[START2:.*]] = %[[C0]] to %[[DIM2]] -// CHECK-SAME: iter_args(%[[INIT0_:.*]] = %[[INIT0_OUTER]], -// CHECK-SAME: %[[INIT1_:.*]] = %[[INIT1_OUTER]]) -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[TILE_SIZE2:.*]] = affine.min #map{{[0-9]*}}(%[[START2]])[%[[DIM2]]] -// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[IN0]], %[[C1]] -// CHECK-DAG: %[[IN0_SUB:.*]] = tensor.extract_slice %[[IN0]] -// CHECK-SAME: [%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK-DAG: %[[IN1_SUB:.*]] = tensor.extract_slice %[[IN1]] -// CHECK-SAME: [%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK-DAG: %[[INIT0_SUB:.*]] = tensor.extract_slice %[[INIT0_]] -// CHECK-SAME: [%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK-DAG: %[[INIT1_SUB:.*]] = tensor.extract_slice %[[INIT1_]] -// CHECK-SAME: [%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK: %[[SORTED0:.*]], %[[SORTED1:.*]] = thlo.sort -// CHECK-SAME: ins(%[[IN0_SUB]] : tensor, %[[IN1_SUB]] : tensor) -// CHECK-SAME: outs(%[[INIT0_SUB]] : tensor, %[[INIT1_SUB]] : tensor) -// CHECK: %[[INSERTED0:.*]] = tensor.insert_slice %[[SORTED0]] -// CHECK-SAME: %[[INIT0_]][%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK: %[[INSERTED1:.*]] = tensor.insert_slice %[[SORTED1]] -// CHECK-SAME: %[[INIT1_]][%[[START0]], 0, %[[START2]]] -// CHECK-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] -// CHECK-SAME: [1, 1, 1] -// CHECK: scf.yield %[[INSERTED0]], %[[INSERTED1]] - -// ----- - -func.func @sort2(%input1: tensor<1024x2048x4096xf32>, - %input2: tensor<1024x2048x4096xi32>, - %init1: tensor<1024x2048x4096xf32>, - %init2: tensor<1024x2048x4096xi32>) - -> (tensor<1024x2048x4096xf32>, tensor<1024x2048x4096xi32>) { - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor<1024x2048x4096xf32>, - %input2: tensor<1024x2048x4096xi32>) - outs(%init1: tensor<1024x2048x4096xf32>, - %init2: tensor<1024x2048x4096xi32>) - dimension = 1 - is_stable = true - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } - func.return - %sorted1, %sorted2 : tensor<1024x2048x4096xf32>, tensor<1024x2048x4096xi32> -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.sort"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_using_for %0 [256, 512] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: func.func @sort2 - -// ----- - -func.func @reverse_static(%input: tensor<100xf32>, %init: tensor<100xf32>) - -> tensor<100xf32> { - %res = thlo.reverse - ins(%input: tensor<100xf32>) - outs(%init: tensor<100xf32>) - reverse_dimensions = [0] - func.return %res : tensor<100xf32> -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.reverse"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loop = transform.structured.tile_using_for %0 [10] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: func @reverse_static -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf32>, %[[ARG1:.*]]: tensor<100xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-DAG: %[[C100:.*]] = arith.constant 100 : index -// CHECK: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[ARG3:.*]] = %[[ARG1]]) -// CHECK: %[[TEMP_SUB_RES:.*]] = arith.subi %[[C100]], %[[I]] -// CHECK: %[[IN_TILE_DIM:.*]] = arith.subi %[[TEMP_SUB_RES]], %[[C10]] -// CHECK-DAG: %[[IN_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IN_TILE_DIM]]] -// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[ARG3]][%[[I]]] -// CHECK: %[[REVERSED:.*]] = thlo.reverse ins(%[[IN_SLICE]] -// CHECK: outs(%[[INIT_SLICE]] -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[REVERSED]] into %[[ARG3]][%[[I]] -// CHECK: scf.yield %[[INSERTED]] -// CHECK: return %[[FOR]] - -// ----- - -func.func @reverse_dynamic(%input: tensor, %init: tensor) - -> tensor { - %res = thlo.reverse - ins(%input: tensor) - outs(%init: tensor) - reverse_dimensions = [0, 1] - func.return %res : tensor -} - -transform.sequence failures(propagate) { - ^bb0(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["thlo.reverse"]} in %arg1 - : (!pdl.operation) -> !pdl.operation - %1, %loops:2 = transform.structured.tile_using_for %0 [256, 512] - : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) -} - -// CHECK-LABEL: func @reverse_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] -// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] -// CHECK-SAME: iter_args(%[[ARG4:.*]] = %[[ARG1]]) -// CHECK-DAG: %[[AFFINE_MIN1:.*]] = affine.min -// CHECK: %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[DIM0]] -// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[ARG4]]) -// CHECK-DAG: %[[AFFINE_MIN2:.*]] = affine.min -// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[TEMP_SUB_RES0:.*]] = arith.subi %[[DIM1]], %[[I]] -// CHECK-DAG: %[[IN_TILE_DIM0:.*]] = arith.subi %[[TEMP_SUB_RES0]], %[[AFFINE_MIN1]] -// CHECK-DAG: %[[TEMP_SUB_RES1:.*]] = arith.subi %[[DIM2]], %[[J]] -// CHECK-DAG: %[[IN_TILE_DIM1:.*]] = arith.subi %[[TEMP_SUB_RES1]], %[[AFFINE_MIN2]] -// CHECK-DAG: %[[IN_SLICE:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK-SAME: [%[[IN_TILE_DIM0]], %[[IN_TILE_DIM1]]] -// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT_]] -// CHECK-SAME: [%[[I]], %[[J]]] -// CHECK: %[[REVERSED:.*]] = thlo.reverse ins(%[[IN_SLICE]] -// CHECK-SAME: outs(%[[INIT_SLICE]] -// CHECK: %[[INSERTED:.*]] = tensor.insert_slice %[[REVERSED]] -// CHECK-SAME: into %[[INIT_]][%[[I]], %[[J]] -// CHECK: scf.yield %[[INSERTED]] -// CHECK: return %[[FOR]] diff --git a/third_party/xla/xla/mlir_hlo/tests/scalarization.mlir b/third_party/xla/xla/mlir_hlo/tests/scalarization.mlir deleted file mode 100644 index d2e644a9c16dc8..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/scalarization.mlir +++ /dev/null @@ -1,586 +0,0 @@ -// RUN: mlir-hlo-opt %s --scalarize --split-input-file | FileCheck %s - -#map = affine_map<() -> ()> - -func.func @zero_rank(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = tensor.empty() : tensor - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} - ins(%lhs, %rhs: tensor, tensor) - outs(%0: tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.addf %arg3, %arg4: f32 - linalg.yield %2: f32 - } -> tensor - return %1: tensor -} -// CHECK-LABEL: func @zero_rank -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -// CHECK-DAG: %[[LHS_VAL:.*]] = tensor.extract %[[LHS]] -// CHECK-DAG: %[[RHS_VAL:.*]] = tensor.extract %[[RHS]] -// CHECK: %[[RES:.*]] = arith.addf %[[LHS_VAL]], %[[RHS_VAL]] -// CHECK: %[[NEW_TENSOR_RES:.*]] = tensor.from_elements %[[RES]] -// CHECK: return %[[NEW_TENSOR_RES]] - -// ----- - -func.func @linalg_index(%arg0: tensor<1xf64>) -> tensor<1xf64> { - %0 = tensor.empty() : tensor<1xf64> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - outs(%0 : tensor<1xf64>) { - ^bb0(%arg1: f64): - %2 = linalg.index 0 : index - %3 = tensor.extract %arg0[%2] : tensor<1xf64> - linalg.yield %3 : f64 - } -> tensor<1xf64> - return %1 : tensor<1xf64> -} -// CHECK-LABEL: func @linalg_index -// CHECK-SAME: (%[[ARG:.*]]: tensor<1xf64>) -// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 -// CHECK-NEXT: %[[ELEM:.*]] = tensor.extract %[[ARG]][%[[C0]]] -// CHECK-NEXT: tensor.from_elements %[[ELEM]] - -// ----- - - -func.func @nonzero_rank(%lhs: tensor<1xf32>, %rhs: tensor<1x1xf32>) - -> tensor<1x1x1xf32> { - %0 = tensor.empty() : tensor<1x1x1xf32> - %1 = linalg.generic {indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0)>, - affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%lhs, %rhs: tensor<1xf32>, tensor<1x1xf32>) - outs(%0: tensor<1x1x1xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.addf %arg3, %arg4: f32 - linalg.yield %2: f32 - } -> tensor<1x1x1xf32> - return %1: tensor<1x1x1xf32> -} -// CHECK-LABEL: func @nonzero_rank -// CHECK-SAME: (%[[LHS:.*]]: tensor<1xf32>, %[[RHS:.*]]: tensor<1x1xf32>) -// CHECK-DAG: %[[LHS_VAL:.*]] = tensor.extract %[[LHS]] -// CHECK-DAG: %[[RHS_VAL:.*]] = tensor.extract %[[RHS]] -// CHECK: %[[RES:.*]] = arith.addf %[[LHS_VAL]], %[[RHS_VAL]] -// CHECK: %[[NEW_TENSOR_RES:.*]] = tensor.from_elements %[[RES]] -// CHECK: return %[[NEW_TENSOR_RES]] - -// ----- - -#map = affine_map<() -> ()> - -func.func @op_sequence(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = tensor.empty() : tensor - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} - ins(%lhs, %rhs: tensor, tensor) - outs(%0: tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.addf %arg3, %arg4: f32 - linalg.yield %2: f32 - } -> tensor - - %3 = tensor.empty() : tensor - %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} - ins(%lhs, %1: tensor, tensor) - outs(%3: tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %5 = arith.mulf %arg3, %arg4: f32 - linalg.yield %5: f32 - } -> tensor - - %6 = tensor.empty() : tensor - %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} - ins(%1, %4: tensor, tensor) - outs(%6: tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %5 = arith.divf %arg3, %arg4: f32 - linalg.yield %5: f32 - } -> tensor - - return %7: tensor -} -// CHECK-LABEL: func @op_sequence -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -// CHECK-DAG: %[[LHS_VAL:.*]] = tensor.extract %[[LHS]] -// CHECK-DAG: %[[RHS_VAL:.*]] = tensor.extract %[[RHS]] -// CHECK: %[[RES:.*]] = arith.addf %[[LHS_VAL]], %[[RHS_VAL]] -// CHECK-DAG: %[[LHS_VAL_:.*]] = tensor.extract %[[LHS]] -// CHECK: %[[RES2:.*]] = arith.mulf %[[LHS_VAL_]], %[[RES]] -// CHECK: %[[RES3:.*]] = arith.divf %[[RES]], %[[RES2]] -// CHECK: %[[NEW_TENSOR_RES:.*]] = tensor.from_elements %[[RES3]] -// CHECK: return %[[NEW_TENSOR_RES]] - -// ----- - -#map = affine_map<() -> ()> - -func.func @multiple_ops(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = tensor.empty() : tensor - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} - ins(%lhs, %rhs: tensor, tensor) - outs(%0: tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.addf %arg3, %arg4: f32 - %3 = arith.mulf %2, %arg4: f32 - linalg.yield %3: f32 - } -> tensor - return %1: tensor -} -// CHECK-LABEL: func @multiple_ops -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -// CHECK-DAG: %[[LHS_VAL:.*]] = tensor.extract %[[LHS]] -// CHECK-DAG: %[[RHS_VAL:.*]] = tensor.extract %[[RHS]] -// CHECK: %[[RES:.*]] = arith.addf %[[LHS_VAL]], %[[RHS_VAL]] -// CHECK: %[[RES2:.*]] = arith.mulf %[[RES]], %[[RHS_VAL]] -// CHECK: %[[NEW_TENSOR_RES:.*]] = tensor.from_elements %[[RES2]] -// CHECK: return %[[NEW_TENSOR_RES]] - -// ----- - -func.func @outside_yield() -> tensor<1x1xi1> { - %true = arith.constant true - %0 = tensor.empty() : tensor<1x1xi1> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - outs(%0 : tensor<1x1xi1>) { - ^bb0(%arg1: i1): - linalg.yield %true : i1 - } -> tensor<1x1xi1> - return %1: tensor<1x1xi1> -} - -// CHECK-LABEL: func @outside_yield -// CHECK: %[[CST:.*]] = arith.constant dense : tensor<1x1xi1> -// CHECK: return %[[CST]] - -// ----- - -#map0 = affine_map<(d0) -> ()> -#map1 = affine_map<(d0) -> (d0)> -func.func @extra_argument(%arg0: tensor<4xf64>, %arg2: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f64 - %0 = tensor.empty() : tensor - %1 = linalg.fill ins(%cst : f64) outs(%0 : tensor) -> tensor - %2 = linalg.generic { - indexing_maps = [affine_map<(d0) -> ()>, - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> ()>], - iterator_types = ["reduction"]} - ins(%arg2, %arg0 : tensor, tensor<4xf64>) outs(%1 : tensor) { - ^bb0(%arg3: i1, %arg4: f64, %arg5: f64): - %3 = arith.cmpf une, %arg4, %arg4 : f64 - %4 = arith.select %3, %cst, %arg4 : f64 - %5 = arith.select %arg3, %4, %cst : f64 - %6 = arith.addf %arg5, %5 : f64 - linalg.yield %6 : f64 - } -> tensor - return %2 : tensor -} - -// CHECK-LABEL: func @extra_argument - -// ----- - -func.func @scatter_f32_with_update_computation(%indices: tensor<1x2xindex>, - %updates: tensor<1x?x?xf32>, %init: tensor) -> tensor { - %0 = thlo.scatter ins(%indices: tensor<1x2xindex>, %updates: tensor<1x?x?xf32>) - outs(%init: tensor) - (%in: f32, %out: f32) { - %1 = arith.addf %in, %out: f32 - thlo.yield %1: f32 - } - return %0: tensor -} -// CHECK-LABEL: func.func @scatter_f32_with_update_computation( -// CHECK-SAME: %[[INDICES:.*]]: tensor<1x2xindex>, -// CHECK-SAME: %[[UPDATES:.*]]: tensor<1x?x?xf32>, -// CHECK-SAME: %[[INIT:.*]]: tensor) -> tensor { - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 - -// CHECK-DAG: %[[UPDATES_DIM_1:.*]] = tensor.dim %[[UPDATES]], %[[C1]] -// CHECK-DAG: %[[UPDATES_DIM_2:.*]] = tensor.dim %[[UPDATES]], %[[C2]] -// CHECK-DAG: %[[INIT_DIM_0:.*]] = tensor.dim %[[INIT]], %[[C0]] -// CHECK-DAG: %[[INIT_DIM_1:.*]] = tensor.dim %[[INIT]], %[[C1]] - -// Extract scatter indices from `indices` arg. -// CHECK-DAG: %[[INDEX_0:.*]] = tensor.extract %[[INDICES]][%[[C0]], -// CHECK-DAG: %[[INDEX_1:.*]] = tensor.extract %[[INDICES]][%[[C0]], - -// CHECK-COUNT-7: arith.andi - -// CHECK: scf.if -// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract_slice %[[INIT]][%[[INDEX_0]], -// CHECK-SAME: %[[INDEX_1]]] [%[[UPDATES_DIM_1]], %[[UPDATES_DIM_2]]] [%[[C1]], -// CHECK-SAME: %[[C1]]] : tensor to tensor - -// CHECK-NEXT: %[[UPDATES_SLICE:.*]] = tensor.extract_slice %[[UPDATES]] - -// CHECK-NEXT: %[[SUM:.*]] = linalg.reduce -// CHECK-SAME: ins(%[[UPDATES_SLICE]] : tensor<1x?x?xf32>) -// CHECK-SAME: outs(%[[EXTRACTED]] : tensor) dimensions = [0] -// CHECK-NEXT: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) { -// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 -// CHECK-NEXT: linalg.yield %[[ADD]] : f32 -// CHECK-NEXT: } - -// CHECK-NEXT: %[[INSERTED:.*]] = tensor.insert_slice %[[SUM]] into %[[INIT]] -// CHECK-SAME: [%[[INDEX_0]], %[[INDEX_1]]] [%[[UPDATES_DIM_1]], -// CHECK-SAME: %[[UPDATES_DIM_2]]] [%[[C1]], %[[C1]]] : tensor -// CHECK-SAME: into tensor -// CHECK-NEXT: scf.yield %[[INSERTED]] : tensor -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[INIT]] : tensor -// CHECK-NEXT: } -// CHECK-NEXT: return - -// ----- - -func.func @scatter_i64_no_update_computation(%indices: tensor<1x1xindex>, - %updates: tensor<1x1x3x4xi64>, - %init: tensor<3x3x4xi64>) -> tensor<3x3x4xi64> { - %0 = thlo.scatter ins(%indices : tensor<1x1xindex>, - %updates : tensor<1x1x3x4xi64>) - outs(%init : tensor<3x3x4xi64>) - (%arg5: i64, %arg6: i64) { - thlo.yield %arg5 : i64 - } - func.return %0 : tensor<3x3x4xi64> -} -// CHECK-LABEL: func.func @scatter_i64_no_update_computation( -// CHECK-SAME: %[[INDICES:.*]]: tensor<1x1xindex>, -// CHECK-SAME: %[[UPDATES:.*]]: tensor<1x1x3x4xi64>, -// CHECK-SAME: %[[INIT:.*]]: tensor<3x3x4xi64>) -> tensor<3x3x4xi64> { - -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C3:.*]] = arith.constant 3 : index - -// CHECK: %[[INDEX_0:.*]] = tensor.extract %[[INDICES]]{{\[}}%[[C0]], -// CHECK-SAME: %[[C0]]] : tensor<1x1xindex> - -// CHECK: scf.if -// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[UPDATES]] -// CHECK-SAME: [0, 1], [2], [3]] -// CHECK-SAME: : tensor<1x1x3x4xi64> into tensor<1x3x4xi64> -// CHECK-NEXT: %[[INSERTED:.*]] = tensor.insert_slice %[[COLLAPSED]] into -// CHECK-SAME: %[[INIT]][%[[INDEX_0]], %[[C0]], %[[C0]]] [1, 3, 4] -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]] -// CHECK-SAME: : tensor<1x3x4xi64> into tensor<3x3x4xi64> -// CHECK-NEXT: scf.yield %[[INSERTED]] : tensor<3x3x4xi64> -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[INIT]] : tensor<3x3x4xi64> -// CHECK-NEXT: } -// CHECK-NEXT: return - -// ----- - -func.func @gather(%indices: tensor<1x2xindex>, - %operand: tensor<5x6x7xi64>, - %init: tensor<1x3xi64>) -> tensor<1x3xi64> { - %0 = thlo.gather ins(%operand : tensor<5x6x7xi64>, - %indices : tensor<1x2xindex>) - outs(%init : tensor<1x3xi64>) - func.return %0 : tensor<1x3xi64> -} - -// CHECK-LABEL: func.func @gather( -// CHECK-SAME: %[[INDICES:.*]]: tensor<1x2xindex> -// CHECK-SAME: %[[OPERAND:.*]]: tensor<5x6x7xi64> -// CHECK-SAME: %[[INIT:.*]]: tensor<1x3xi64> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 -// CHECK-DAG: %[[INDEX0:.*]] = tensor.extract %[[INDICES]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[INDEX1:.*]] = tensor.extract %[[INDICES]][%[[C0]], %[[C1]]] -// CHECK-DAG: %[[CLAMPED_INDEX0:.*]] = arith.minsi %[[INDEX0]], %[[C2]] -// CHECK-DAG: %[[CLAMPED_INDEX0_:.*]] = arith.maxsi %[[CLAMPED_INDEX0]], %[[C0]] -// CHECK-DAG: %[[CLAMPED_INDEX1:.*]] = arith.minsi %[[INDEX1]], %[[C5]] -// CHECK-DAG: %[[CLAMPED_INDEX1_:.*]] = arith.maxsi %[[CLAMPED_INDEX1]], %[[C0]] -// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C3]] -// CHECK: %[[OFFSET_J:.*]] = arith.addi %[[J]], %[[CLAMPED_INDEX0_]] - -// CHECK: %[[VAL:.*]] = tensor.extract %[[OPERAND]] -// CHECK-SAME: [%[[OFFSET_J]], %[[CLAMPED_INDEX1_]], %[[C0]]] -// CHECK-NEXT: %[[UPDATED:.*]] = tensor.insert %[[VAL]] -// CHECK: scf.yield %[[UPDATED]] - -// ----- - -func.func @fold_from_elements_into_insert_slice(%elem: f32, - %out: tensor<8x2xf32>) -> tensor<8x2xf32> { - %elem_tensor = tensor.from_elements %elem : tensor<1x1xf32> - %updated = tensor.insert_slice %elem_tensor into %out[0, 1] [1, 1] [1, 1] - : tensor<1x1xf32> into tensor<8x2xf32> - - func.return %updated: tensor<8x2xf32> -} -// CHECK-LABEL: func @fold_from_elements_into_insert_slice -// CHECK-SAME: %[[ELEM:.*]]: f32, %[[OUT:.*]]: tensor<8x2xf32> - -// CHECK: %[[UPDATE:.*]] = tensor.insert %[[ELEM]] into %[[OUT]] -// CHECK-NEXT: return %[[UPDATE]] - -// ----- - -func.func @dynamic_broadcast_in_dim(%arg : tensor<1x1xf32>, - %init: tensor<1x1x1xf32>) - -> tensor<1x1x1xf32> { - %0 = thlo.dynamic_broadcast_in_dim ins(%arg : tensor<1x1xf32>) - outs(%init : tensor<1x1x1xf32>) - broadcast_dimensions = [0, 2] - func.return %0 : tensor<1x1x1xf32> -} -// CHECK-LABEL: @dynamic_broadcast_in_dim( -// CHECK-SAME: %[[ARG:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1x1xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[ELEM:.*]] = tensor.extract %[[ARG]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[UPDATED:.*]] = tensor.from_elements %[[ELEM]] - -// ----- - -func.func @concatenate( - %arg0: tensor, %arg1: tensor, - %arg2: tensor, %init: tensor) -> tensor { - %cat = thlo.concatenate - ins(%arg0: tensor, - %arg1: tensor, - %arg2: tensor) - outs(%init: tensor) - dimension = 1 - func.return %cat : tensor -} - -// CHECK-LABEL: func @concatenate( -// CHECK-SAME: %[[ARG_0:[0-9a-zA-Z]*]]: tensor, -// CHECK-SAME: %[[ARG_1:[0-9a-zA-Z]*]]: tensor, -// CHECK-SAME: %[[ARG_2:[0-9a-zA-Z]*]]: tensor, -// CHECK-SAME: %[[INIT:[0-9a-zA-Z]*]]: tensor) - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 - -// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] -// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] - - -// Extract elements from arg0 is it's not empty. -// CHECK-NEXT: %[[DIM_ARG_0:.*]] = tensor.dim %[[ARG_0]], %[[C1]] -// CHECK-NEXT: %[[CMP_0:.*]] = arith.cmpi ne, %[[DIM_ARG_0]], %[[C0]] -// CHECK: %[[RESULT:.*]] = scf.if %[[CMP_0]] -// CHECK: %[[MAT_0:.*]] = tensor.extract_slice %[[ARG_0]] -// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] -// CHECK: %[[RES_0:.*]] = tensor.insert_slice %[[MAT_0]] into %[[INIT]] -// CHECK-NEXT: scf.yield %[[RES_0]] -// CHECK-NEXT: } else { - -// Else check arg1 and extracts element if it's not empty. -// CHECK-NEXT: %[[DIM_ARG_1:.*]] = tensor.dim %[[ARG_1]], %[[C1]] -// CHECK-NEXT: %[[CMP_1:.*]] = arith.cmpi ne, %[[DIM_ARG_1]], %[[C0]] -// CHECK-NEXT: %[[RESULT_1:.*]] = scf.if %[[CMP_1]] -// CHECK-NEXT: %[[MAT_1:.*]] = tensor.extract_slice %[[ARG_1]] -// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] -// CHECK-NEXT: %[[RES_1:.*]] = tensor.insert_slice %[[MAT_1]] into %[[INIT]] -// CHECK-NEXT: scf.yield %[[RES_1]] -// CHECK-NEXT: } else { - -// Otherwise extract elements from arg2, because arg0 and arg1 are empty. -// CHECK-NEXT: %[[MAT_2:.*]] = tensor.extract_slice %[[ARG_2]] -// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] -// CHECK-NEXT: %[[RES_2:.*]] = tensor.insert_slice %[[MAT_2]] into %[[INIT]] -// CHECK-NEXT: scf.yield %[[RES_2]] -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[RESULT_1]] -// CHECK-NEXT: } - -// CHECK-NEXT: return %[[RESULT]] : tensor - -// ----- - -func.func @linalg_map(%lhs : tensor<1x1xf32>, - %rhs: tensor<1x1xf32>, - %init: tensor<1x1xf32>) - -> tensor<1x1xf32> { - %add = linalg.map - ins(%lhs, %rhs : tensor<1x1xf32>, tensor<1x1xf32>) - outs(%init: tensor<1x1xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { - %0 = arith.addf %lhs_elem, %rhs_elem: f32 - linalg.yield %0: f32 - } - func.return %add : tensor<1x1xf32> -} - -// CHECK-LABEL: @linalg_map( -// CHECK-SAME: %[[LHS:.*]]: tensor<1x1xf32>, %[[RHS:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[L_ELEM:.*]] = tensor.extract %[[LHS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[R_ELEM:.*]] = tensor.extract %[[RHS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[L_ELEM]], %[[R_ELEM]] -// CHECK-NEXT: tensor.from_elements %[[ADD]] - -// ----- - -func.func @linalg_reduce(%ins: tensor<1x1x1xf32>, - %outs: tensor<1x1xf32>) - -> tensor<1x1xf32> { - %reduce = linalg.reduce - ins(%ins: tensor<1x1x1xf32>) - outs(%outs: tensor<1x1xf32>) - dimensions = [1] - (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 - linalg.yield %0: f32 - } - func.return %reduce : tensor<1x1xf32> -} - -// CHECK-LABEL: @linalg_reduce( -// CHECK-SAME: %[[INS:.*]]: tensor<1x1x1xf32>, %[[OUTS:.*]]: tensor<1x1xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[L_ELEM:.*]] = tensor.extract %[[INS]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK-NEXT: %[[R_ELEM:.*]] = tensor.extract %[[OUTS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[L_ELEM]], %[[R_ELEM]] -// CHECK-NEXT: tensor.from_elements %[[ADD]] - -// ----- - -func.func @linalg_transpose(%ins: tensor<1x1xf32>, - %outs: tensor<1x1xf32>) - -> tensor<1x1xf32> { - %transpose = linalg.transpose - ins(%ins: tensor<1x1xf32>) - outs(%outs: tensor<1x1xf32>) - permutation = [1, 0] - func.return %transpose : tensor<1x1xf32> -} - -// CHECK-LABEL: @linalg_transpose( -// CHECK-SAME: %[[INS:.*]]: tensor<1x1xf32>, %[[OUTS:.*]]: tensor<1x1xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[INS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: tensor.from_elements %[[EXTRACTED]] - -// ----- - -func.func @linalg_matmul(%lhs: tensor<1x1xf32>, - %rhs: tensor<1x1xf32>, - %out : tensor<1x1xf32>) -> tensor<1x1xf32> { - %0 = linalg.matmul - ins(%lhs, %rhs : tensor<1x1xf32>, tensor<1x1xf32>) - outs(%out : tensor<1x1xf32>) -> tensor<1x1xf32> - return %0 : tensor<1x1xf32> -} - -// CHECK-LABEL: @linalg_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor<1x1xf32>, %[[RHS:.*]]: tensor<1x1xf32>, %[[OUT:.*]]: tensor<1x1xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[LHS_ELEM:.*]] = tensor.extract %[[LHS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[RHS_ELEM:.*]] = tensor.extract %[[RHS]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[OUT_ELEM:.*]] = tensor.extract %[[OUT]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[LHS_ELEM]], %[[RHS_ELEM]] -// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT_ELEM]], %[[MUL]] -// CHECK-NEXT: tensor.from_elements %[[ADD]] - -// ----- - -func.func @thlo_reverse(%arg : tensor<1x1xf32>, %init: tensor<1x1xf32>) - -> tensor<1x1xf32> { - %0 = thlo.reverse ins(%arg : tensor<1x1xf32>) - outs(%init : tensor<1x1xf32>) - reverse_dimensions = [0, 1] - func.return %0 : tensor<1x1xf32> -} - -// CHECK-LABEL: @thlo_reverse( -// CHECK-SAME: %[[ARG:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1xf32>) -// CHECK: return %[[ARG]] - -// ----- - -func.func @ite_1d(%arg0: i1, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) - -> tensor<1xf32> { - %0 = scf.if %arg0 -> (tensor<1xf32>) { - scf.yield %arg2 : tensor<1xf32> - } else { - scf.yield %arg1 : tensor<1xf32> - } - return %0 : tensor<1xf32> -} - -// CHECK: func.func @ite_1d(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<1xf32>, %[[ARG2:.*]]: tensor<1xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (f32) -// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG2]][%[[C0]]] -// CHECK: scf.yield %[[EXTRACTED]] : f32 -// CHECK: else -// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG1]][%[[C0]]] -// CHECK: scf.yield %[[EXTRACTED_0]] : f32 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[IF]] -// CHECK: return %[[FROM_ELEMENTS]] - -// ----- - -func.func @ite_2d(%arg0: i1, %arg1: tensor<1x1xf32>, %arg2: tensor<1x1xf32>) - -> tensor<1x1xf32> { - %0 = scf.if %arg0 -> (tensor<1x1xf32>) { - scf.yield %arg2 : tensor<1x1xf32> - } else { - scf.yield %arg1 : tensor<1x1xf32> - } - return %0 : tensor<1x1xf32> -} - -// CHECK: func.func @ite_2d(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<1x1xf32>, %[[ARG2:.*]]: tensor<1x1xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (f32) -// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG2]][%[[C0]], %[[C0]]] -// CHECK: scf.yield %[[EXTRACTED]] : f32 -// CHECK: else -// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG1]][%[[C0]], %[[C0]]] -// CHECK: scf.yield %[[EXTRACTED_0]] : f32 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[IF]] -// CHECK: return %[[FROM_ELEMENTS]] - - -// ----- - -func.func @scalarize_for_op(%initValue: f32, %input: tensor<10xf32>) -> f32 { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - - %initTensor = tensor.from_elements %initValue : tensor<1x1xf32> - - %sum = scf.for %i = %c0 to %c10 step %c1 - iter_args(%acc = %initTensor) -> (tensor<1x1xf32>) { - %input_elem = tensor.extract %input[%i] : tensor<10xf32> - - %acc_elem = tensor.extract %acc[%c0, %c0] : tensor<1x1xf32> - %add = arith.addf %acc_elem, %input_elem : f32 - %from_elements = tensor.from_elements %add : tensor<1x1xf32> - - scf.yield %from_elements : tensor<1x1xf32> - } - %sum_elem = tensor.extract %sum[%c0, %c0] : tensor<1x1xf32> - func.return %sum_elem : f32 -} -// CHECK-LABEL: @scalarize_for_op - -// CHECK: scf.for %[[I:[a-z0-9]+]] = -// CHECK-NEXT: %[[ELEM:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<10xf32> -// CHECK-NEXT: %[[ADD:.*]] = arith.addf %{{.*}}, %[[ELEM]] : f32 -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: return - - diff --git a/third_party/xla/xla/mlir_hlo/thlo/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/thlo/CMakeLists.txt deleted file mode 100644 index c3f701100ccab5..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_subdirectory(IR) -add_subdirectory(transforms) -add_subdirectory(interfaces) diff --git a/third_party/xla/xla/mlir_hlo/thlo/IR/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/thlo/IR/CMakeLists.txt deleted file mode 100644 index 5a4f76478c2e78..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/IR/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(LLVM_TARGET_DEFINITIONS thlo_ops.td) -mlir_tablegen(thlo_ops.h.inc -gen-op-decls) -mlir_tablegen(thlo_ops.cc.inc -gen-op-defs) -mlir_tablegen(thlo_dialect.h.inc -gen-dialect-decls) -mlir_tablegen(thlo_dialect.cc.inc -gen-dialect-defs) - -add_public_tablegen_target(MLIRthlo_opsIncGen) -add_dependencies(mlir-headers MLIRthlo_opsIncGen) - - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_dialect_library(THLODialect - thlo_ops.cc - - DEPENDS - MLIRthlo_opsIncGen - - LINK_LIBS PUBLIC - GmlStDialect - MLIRDestinationStyleOpInterface - MLIRIR - MLIRMemRefDialect - MLIRSideEffectInterfaces - MLIRSupport - MLIRTensorDialect -) diff --git a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.cc b/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.cc deleted file mode 100644 index ed83f7123fcc80..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.cc +++ /dev/null @@ -1,1363 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "thlo/IR/thlo_ops.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Transforms/InliningUtils.h" - -namespace mlir { -namespace { - -Value materializeSlice(OpBuilder &b, Location loc, Value valueToTile, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - return b.create(loc, valueToTile, offsets, sizes, - strides); -} - -Value materializeSlice(OpBuilder &b, Location loc, Value valueToTile, - ArrayRef offsets, - ArrayRef sizes) { - SmallVector strides(offsets.size(), b.getIndexAttr(1)); - return materializeSlice(b, loc, valueToTile, offsets, sizes, strides); -} - -//===----------------------------------------------------------------------===// -// Destination-style ops tools -//===----------------------------------------------------------------------===// - -LogicalResult verifyDestinationStyleOp(Operation *op) { - auto dstStyleOp = cast(*op); - if (dstStyleOp.hasBufferSemantics()) return success(op->getNumResults() == 0); - - if (!dstStyleOp.hasTensorSemantics()) - return op->emitOpError("expected either buffer or tensor semantics"); - - return success(); -} - -template -void printDstStyleOp( - DstOpTy op, OpAsmPrinter &p, - function_ref(DstOpTy op, OpAsmPrinter &)> - printAttrsFn = nullptr) { - if (op.getNumDpsInputs() != 0) { - p << " ins("; - llvm::interleaveComma( - op.getOperands().take_front(op.getNumDpsInputs()), p, - [&](Value input) { p << input << " : " << input.getType(); }); - p << ")"; - } - p << " outs("; - llvm::interleaveComma( - op.getOperands().take_back(op.getNumDpsInits()), p, - [&](Value output) { p << output << " : " << output.getType(); }); - p << ")"; - - // Print attributes with custom printing logic. - SmallVector elidedAttrs; - if (printAttrsFn) { - p << ' '; - elidedAttrs = printAttrsFn(op, p); - } - - p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); -} - -ParseResult parseKeywordOperandListWithTypes( - OpAsmParser &parser, OperationState &result, StringRef keyword, - SmallVectorImpl *operandTypes) { - SmallVector operands; - if (succeeded(parser.parseOptionalKeyword(keyword))) { - SMLoc operandsOperandsLoc = parser.getCurrentLocation(); - - if (parser.parseCommaSeparatedList( - AsmParser::Delimiter::Paren, [&]() -> ParseResult { - if (parser.parseOperand(operands.emplace_back(), - /*allowResultNumber=*/true) || - parser.parseColon() || - parser.parseType(operandTypes->emplace_back())) { - return failure(); - } - return success(); - })) - return failure(); - - if (parser.resolveOperands(operands, *operandTypes, operandsOperandsLoc, - result.operands)) - return failure(); - } - return success(); -} - -ParseResult parseDstStyleOp( - OpAsmParser &parser, OperationState &result, - function_ref parseAttrsFn = - nullptr) { - // Parse `ins` and `outs`. - SmallVector inputTypes, outputTypes; - if (parseKeywordOperandListWithTypes(parser, result, "ins", &inputTypes) || - parseKeywordOperandListWithTypes(parser, result, "outs", &outputTypes)) - return failure(); - - // Add result types. - for (Type outputType : outputTypes) { - if (outputType.isa()) result.addTypes(outputType); - } - - // Parse required attributes. - if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) - return failure(); - - // Parse optional attributes. - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - return success(); -} - -ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, - NamedAttrList &attributes, - StringRef attributeName) { - if (parser.parseKeyword(attributeName) || parser.parseEqual()) - return failure(); - - attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{})); - return success(); -} - -void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, - ArrayRef attributeValue) { - p << attributeName << " = [" << attributeValue << "] "; -} - -SmallVector getParallelIteratorTypes(int64_t dimCount) { - return SmallVector(dimCount, - utils::IteratorType::parallel); -} - -SmallVector getIterationDomainForTensor(OpBuilder &b, Location loc, - Value tensor, - int64_t dimCount = -1) { - auto dimValues = tensor::getMixedSizes(b, loc, tensor); - if (dimCount >= 0) dimValues.resize(dimCount); - return llvm::to_vector(llvm::map_range(dimValues, [&](OpFoldResult d) { - return Range{b.getIndexAttr(0), d, b.getIndexAttr(1)}; - })); -} - -static void getDstStyleOpEffectsImpl( - SmallVectorImpl> - &effects, - ValueRange results, ValueRange inputOperands, ValueRange outputOperands) { - for (auto operand : inputOperands) { - if (!operand.getType().isa()) continue; - effects.emplace_back(MemoryEffects::Read::get(), operand, - SideEffects::DefaultResource::get()); - } - for (auto operand : outputOperands) { - if (!operand.getType().isa()) continue; - effects.emplace_back(MemoryEffects::Read::get(), operand, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), operand, - SideEffects::DefaultResource::get()); - } -} - -} // namespace -} // namespace mlir - -//===----------------------------------------------------------------------===// -// THLO Dialect Interfaces -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace { - -struct THLOInlinerInterface : public mlir::DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - // Operations in THLO dialect are always legal to inline. - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - // Handle the given inlined terminator by replacing it with a new operation - // as necessary. Required when the region has only one block. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {} -}; - -} // namespace -} // namespace mlir - -//===----------------------------------------------------------------------===// -// THLODialect -//===----------------------------------------------------------------------===// - -// Generated dialect definitions. -#include "thlo/IR/thlo_dialect.cc.inc" - -namespace mlir { -namespace thlo { - -void THLODialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "thlo/IR/thlo_ops.cc.inc" - >(); - - addInterfaces(); -} - -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -LogicalResult checkYieldOutputs(YieldOp yieldOp, - TypeRange expectedElementTypes) { - uint64_t numOutputs = expectedElementTypes.size(); - if (yieldOp.getValues().size() != numOutputs) { - return yieldOp.emitOpError("expects number of tensor output args = ") - << numOutputs << " to match the number of yield operands = " - << yieldOp.getValues().size(); - } - - for (const auto &item : llvm::enumerate( - llvm::zip(expectedElementTypes, yieldOp.getOperandTypes()))) { - Type outputElementType, resultType; - unsigned index = item.index(); - std::tie(outputElementType, resultType) = item.value(); - if (outputElementType != resultType) - return yieldOp.emitOpError("expects yield operand ") - << index << " with type = " << resultType - << " to match output arg element type = " << outputElementType; - } - - return success(); -} - -LogicalResult YieldOp::verify() { return success(); } - -//===----------------------------------------------------------------------===// -// ConcatenateOp -//===----------------------------------------------------------------------===// - -SmallVector ConcatenateOp::getLoopIteratorTypes() { - return getParallelIteratorTypes(getInit().getType().getRank()); -} - -SmallVector ConcatenateOp::getIterationDomain(OpBuilder &b) { - return getIterationDomainForTensor(b, getLoc(), getInit()); -} - -namespace { - -Value getSingleOperandTiledImplementationForConcatRecursively( - OpBuilder &b, Location loc, int64_t concatDim, ValueRange remainingOperands, - SmallVector &remainingOffsets, ArrayRef sizes) { - assert(!remainingOperands.empty() && "expect at least one remaining operand"); - assert(sizes[concatDim].get().cast().getInt() == 1 && - "expect unit size in concat dim"); - - // Terminal case of exactly one operand. - Value leadingOperand = remainingOperands.front(); - if (remainingOperands.size() == 1) { - return materializeSlice(b, loc, leadingOperand, remainingOffsets, sizes); - } - - // For more than one operand, distinguish between the leading operand and the - // remainder. - assert(remainingOperands.size() > 1 && - "expect more than one operand at this point"); - Value leadingOperandSizeInConcatDim = - b.createOrFold(loc, leadingOperand, concatDim); - Value remainingOffsetInConcatDim = - getValueOrCreateConstantIndexOp(b, loc, remainingOffsets[concatDim]); - Value leadingOperandPredicate = b.createOrFold( - loc, arith::CmpIPredicate::ult, remainingOffsetInConcatDim, - leadingOperandSizeInConcatDim); - auto ifOp = b.create( - loc, leadingOperandPredicate, - [&](OpBuilder &b, Location loc) { - Value tiledConcat = - getSingleOperandTiledImplementationForConcatRecursively( - b, loc, concatDim, {leadingOperand}, remainingOffsets, sizes); - b.create(loc, tiledConcat); - }, - [&](OpBuilder &b, Location loc) { - remainingOffsets[concatDim] = getAsOpFoldResult( - b.createOrFold(loc, remainingOffsetInConcatDim, - leadingOperandSizeInConcatDim)); - Value tiledConcat = - getSingleOperandTiledImplementationForConcatRecursively( - b, loc, concatDim, remainingOperands.drop_front(), - remainingOffsets, sizes); - b.create(loc, tiledConcat); - }); - return ifOp.getResults().front(); -} - -Value getSingleOperandTiledImplementationForConcat( - ConcatenateOp op, OpBuilder &b, Location loc, - ArrayRef offsets, ArrayRef sizes) { - int64_t concatDim = op.getDimension().getSExtValue(); - SmallVector remainingOffsets(offsets); - return getSingleOperandTiledImplementationForConcatRecursively( - b, loc, concatDim, op.getInputs(), remainingOffsets, sizes); -} - -Value getGenericTiledImplementationForConcat(ConcatenateOp op, OpBuilder &b, - Location loc, - ArrayRef offsets, - ArrayRef sizes) { - // Create a basis for the tile offsets and sizes. These hold the shared values - // in all non-concat dimensions and are amended in the concat dimension to - // create the individual operand tiles. Also, create the shared tile strides, - // which are the exact same for every operand tile. - SmallVector operandTileOffsetsBase(offsets); - SmallVector operandTileSizesBase(sizes); - SmallVector operandTileStrides(sizes.size(), b.getIndexAttr(1)); - - // Some shared values. - Value zeroCst = b.create(loc, 0); - int64_t concatDim = op.getDimension().getSExtValue(); - Value concatDimCst = b.create(loc, concatDim); - Value maxTileSizeInConcatDim = - getValueOrCreateConstantIndexOp(b, loc, sizes[concatDim]); - - // The remaining tile offset in the concat dimension is subtracted by each - // operand's size in that dimension. We maintain the invariant - // remainingTileOffsetInConcatDim >= 0. - Value remainingTileOffsetInConcatDim = - getValueOrCreateConstantIndexOp(b, loc, offsets[concatDim]); - - // Create the relevant subsets per operand. These tiles can be empty at - // runtime. - SmallVector tiledOperands; - tiledOperands.reserve(op.getNumDpsInputs()); - for (Value operand : op.getInputs()) { - // Find the current operand's tile offset in the concat dimension. This is - // the remaining offset clamped into the bounds of the operand. Note that - // the remaining offset is always >= 0. - Value operandSizeInConcatDim = - b.createOrFold(loc, operand, concatDimCst); - Value operandTileOffsetInConcatDim = b.createOrFold( - loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim); - operandTileOffsetsBase[concatDim] = - getAsOpFoldResult(operandTileOffsetInConcatDim); - - // Find the current operand's tile size in the concat dimension. - Value remainingOperandSizeInConcatDim = b.createOrFold( - loc, operandSizeInConcatDim, operandTileOffsetInConcatDim); - operandTileSizesBase[concatDim] = - getAsOpFoldResult(b.createOrFold( - loc, remainingOperandSizeInConcatDim, maxTileSizeInConcatDim)); - - // Create the operand tile and materialize the subset for this operand. - tiledOperands.push_back( - materializeSlice(b, loc, operand, operandTileOffsetsBase, - operandTileSizesBase, operandTileStrides)); - - // Unless it is the last operand, update the remaining tile offset in the - // concat dimension. The remaining offset is subtracted by the operand's - // size but must remain >= 0. - if (operand != op.getInputs().back()) { - Value cmp = b.createOrFold(loc, arith::CmpIPredicate::ule, - remainingTileOffsetInConcatDim, - operandSizeInConcatDim); - Value sub = b.createOrFold( - loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim); - remainingTileOffsetInConcatDim = - b.createOrFold(loc, cmp, zeroCst, sub); - } - } - - // Create the tiled concat op. - Value tiledInit = materializeSlice(b, loc, op.getInit(), offsets, sizes); - auto tiledConcat = - b.create(loc, tiledInit.getType(), tiledOperands, - tiledInit, b.getIndexAttr(concatDim)); - return tiledConcat.getResults().front(); -} - -Value getTiledImplementationForConcat(ConcatenateOp op, OpBuilder &b, - Location loc, - ArrayRef offsets, - ArrayRef sizes) { - // If the tile is of unit size in the concatenation dimension, we can generate - // the tiled implementation based on a single operand. - int64_t concatDim = op.getDimension().getSExtValue(); - OpFoldResult tileSizeInConcatDim = sizes[concatDim]; - if (tileSizeInConcatDim.is() && - tileSizeInConcatDim.get().cast().getInt() == 1) { - return getSingleOperandTiledImplementationForConcat(op, b, loc, offsets, - sizes); - } - - // Otherwise, rely on the generic implementation. - return getGenericTiledImplementationForConcat(op, b, loc, offsets, sizes); -} - -} // namespace - -FailureOr ConcatenateOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - auto tiled = - getTiledImplementationForConcat(*this, b, getLoc(), offsets, sizes); - return TilingResult{{tiled.getDefiningOp()}, {tiled}}; -} - -LogicalResult ConcatenateOp::getResultTilePosition( - OpBuilder & /*b*/, unsigned /*resultNumber*/, - ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - return success(); -} - -FailureOr ConcatenateOp::generateResultTileValue( - OpBuilder &b, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes) { - assert(resultNumber == 0 && "expect unique result idx"); - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult.value(); -} - -LogicalResult ConcatenateOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - Location loc = getLoc(); - Value init = getInit(); - - // Assume unique result. - if (getNumResults() != 1) return failure(); - SmallVector &shape = reifiedReturnShapes.emplace_back(); - - // Derive shape from init operand. - int64_t rank = init.getType().cast().getRank(); - shape.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - shape.push_back(b.create(loc, init, i).getResult()); - } - - return success(); -} - -ParseResult ConcatenateOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { - int64_t dimension = 0; - if (parser.parseKeyword("dimension") || parser.parseEqual() || - parser.parseInteger(dimension)) - return failure(); - - attributes.set("dimension", - parser.getBuilder().getIndexAttr(dimension)); - return success(); - }); -} - -void ConcatenateOp::print(OpAsmPrinter &p) { - printDstStyleOp( - *this, p, - [](ConcatenateOp op, OpAsmPrinter &p) -> SmallVector { - p << op.getDimensionAttrName().str() << " = " << op.getDimension(); - - return {op.getDimensionAttrName()}; - }); -} - -LogicalResult ConcatenateOp::verify() { - int64_t concatDim = getDimension().getSExtValue(); - - ShapedType inputType = - getDpsInputOperand(0)->get().getType().cast(); - int64_t rank = inputType.getRank(); - auto inputShape = inputType.getShape(); - - Type outputElementType = - getDpsInitOperand(0)->get().getType().cast().getElementType(); - - for (const auto &en : llvm::enumerate(getInputs())) { - ShapedType inputArgShapedType = en.value().getType().cast(); - auto inputArgShape = inputArgShapedType.getShape(); - - if (inputArgShapedType.getElementType() != outputElementType) - return emitOpError() << "expected element type of input " - << inputArgShapedType.getElementType() - << " to match output element type " - << outputElementType; - - if (inputArgShapedType.getRank() != rank) - return emitOpError() << "expected all args to be rank " << rank - << ", got " << inputArgShapedType.getRank() - << " in arg " << en.index(); - - // Make sure that all dimensions, expect for concatenation dim, in the input - // arg are equal. - // TODO(shyshkov): Also check output dims once tiling is fixed for - // ConcatenateOp. - for (int64_t i = 0; i < rank; ++i) { - if (i == concatDim) continue; - - if (inputShape[i] != inputArgShape[i]) - return emitOpError() - << "shape of input arg " << en.index() << ": " - << inputArgShapedType << " doesn't match expected shape " - << inputType << " (all dims except concat dim(" << concatDim - << ") should match exactly)"; - } - } - - return verifyDestinationStyleOp(getOperation()); -} - -void ConcatenateOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -//===----------------------------------------------------------------------===// -// DynamicBroadcastInDimOp -//===----------------------------------------------------------------------===// - -ParseResult DynamicBroadcastInDimOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseDstStyleOp(parser, result, - [&](OpAsmParser &parser, NamedAttrList &attributes) { - return parseDenseI64ArrayAttr( - parser, attributes, "broadcast_dimensions"); - }); -} - -void DynamicBroadcastInDimOp::print(OpAsmPrinter &p) { - printDstStyleOp( - *this, p, - [](DynamicBroadcastInDimOp op, - OpAsmPrinter &p) -> SmallVector { - printDenseI64ArrayAttr(p, op.getBroadcastDimensionsAttrName(), - op.getBroadcastDimensions()); - return {op.getBroadcastDimensionsAttrName()}; - }); -} - -LogicalResult DynamicBroadcastInDimOp::verify() { - return verifyDestinationStyleOp(getOperation()); -} - -SmallVector -DynamicBroadcastInDimOp::getLoopIteratorTypes() { - return getParallelIteratorTypes(getInit().getType().getRank()); -} - -SmallVector DynamicBroadcastInDimOp::getIterationDomain(OpBuilder &b) { - return getIterationDomainForTensor(b, getLoc(), getInit()); -} - -FailureOr DynamicBroadcastInDimOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - // Create tile subset. - auto loc = getLoc(); - auto initRank = getInit().getType().cast().getRank(); - - DenseMap localIndexConstants; - - DenseSet dimensionsThatStay(getBroadcastDimensions().begin(), - getBroadcastDimensions().end()); - - // Materialize operand space. - auto operandTy = getOperand().getType().cast(); - auto dynamicDims = tensor::createDynamicDimValues(b, loc, getOperand()); - - // Materialize operand dimensions. - SmallVector operandDims; - int64_t dynamicDimsIdx = 0; - operandDims.reserve(operandTy.getRank()); - for (const auto &it : llvm::enumerate(operandTy.getShape())) { - int64_t d = it.value(); - Value dim = d == ShapedType::kDynamic - ? dynamicDims[dynamicDimsIdx++] - : b.create(loc, d); - operandDims.push_back(dim); - } - - // Find the expanding dimensions. If corresponding operand and result - // dimensions are different then the dimension is expanding. - // TODO(frgossen): Use info from known expanding and known non-expanding - // dimensions here. - SmallVector operandExpandingDims; - for (const auto &it : llvm::enumerate(getBroadcastDimensions())) { - auto operandDim = operandDims[it.index()]; - auto resultDim = b.create( - loc, getInit(), b.create(loc, it.value())); - operandExpandingDims.push_back(b.create( - loc, arith::CmpIPredicate::ne, operandDim, resultDim)); - } - - // Compute operand tile offsets. - auto tileOpOffsets = getValueOrCreateConstantIndexOp(b, loc, offsets); - int64_t operandRank = operandTy.getRank(); - auto staticOffsets = SmallVector(operandRank, ShapedType::kDynamic); - SmallVector operandOffsets; - Value zero = b.create(loc, 0); - for (int initId = 0, operandId = 0; initId < initRank; ++initId) { - if (!dimensionsThatStay.contains(initId)) continue; - Value isExpanding = operandExpandingDims[operandId++]; - Value collapsedSubsetOffset = tileOpOffsets[initId]; - operandOffsets.push_back(b.create(loc, isExpanding, zero, - collapsedSubsetOffset)); - } - - // Compute operand tile sizes. - auto staticTileSizes = - SmallVector(operandRank, ShapedType::kDynamic); - SmallVector tileSizes; - Value one = b.create(loc, 1); - auto tileOpSizes = getValueOrCreateConstantIndexOp(b, loc, sizes); - for (int initId = 0, operandId = 0; initId < initRank; ++initId) { - if (!dimensionsThatStay.contains(initId)) continue; - Value isExpanding = operandExpandingDims[operandId++]; - Value tileSize = tileOpSizes[initId]; - tileSizes.push_back( - b.create(loc, isExpanding, one, tileSize)); - } - - // Create operand tile. - auto staticTileStrides = SmallVector(operandRank, 1); - SmallVector tileStrides = {}; - - // Materialize operand tiles. - Value tiledInit = materializeSlice(b, loc, getInit(), offsets, sizes); - Value tiledOperand = materializeSlice( - b, loc, getOperand(), getMixedValues(staticOffsets, operandOffsets, b), - getMixedValues(staticTileSizes, tileSizes, b), - getMixedValues(staticTileStrides, tileStrides, b)); - - // Finally, materialize tiled broadcast. - auto resultTy = getType(0).cast(); - auto tiledResultTy = - RankedTensorType::get(tiledInit.getType().cast().getShape(), - resultTy.getElementType()); - auto tiledOp = b.create( - loc, TypeRange{tiledResultTy}, tiledOperand, tiledInit, - getBroadcastDimensionsAttr(), getKnownExpandingDimensionsAttr(), - getKnownNonexpandingDimensionsAttr()); - return TilingResult{{tiledOp}, {tiledOp.getResult()}}; -} - -LogicalResult DynamicBroadcastInDimOp::getResultTilePosition( - OpBuilder & /*b*/, unsigned /*resultNumber*/, - ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - return success(); -} - -FailureOr DynamicBroadcastInDimOp::generateResultTileValue( - OpBuilder &b, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes) { - assert(resultNumber == 0 && "expect unique result idx"); - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult.value(); -} - -void DynamicBroadcastInDimOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -//===----------------------------------------------------------------------===// -// ScatterOp -//===----------------------------------------------------------------------===// - -ParseResult ScatterOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseDstStyleOp(parser, result)) return failure(); - - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); - } - - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) return failure(); - - return success(); -} - -void ScatterOp::print(OpAsmPrinter &p) { - printDstStyleOp(*this, p); - - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(getUpdateComputation().getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; - - p.printRegion(getUpdateComputation(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); -} - -LogicalResult ScatterOp::verify() { - if (failed(verifyDestinationStyleOp(getOperation()))) return failure(); - - auto indicesType = getIndices().getType().cast(); - int64_t indicesRank = indicesType.getRank(); - - if (indicesRank != 2) - return emitOpError() << "expected `indices` to be a 2D tensor"; - - auto updatesType = getUpdates().getType(); - int64_t updatesRank = updatesType.getRank(); - - if (updatesType.getDimSize(0) != indicesType.getDimSize(0)) { - return emitOpError() << "expected major dimension of `indices` to match " - "major dimension of `updates`"; - } - - int64_t indexVectorDim = indicesType.getDimSize(1); - if (ShapedType::isDynamic(indexVectorDim)) - return emitOpError() << "expected index vector dimension size to be static"; - - auto initType = getInit().getType(); - int64_t initRank = initType.getRank(); - - if (indexVectorDim > initRank) { - return emitOpError() << "expected index vector dimension size = " - << indexVectorDim - << " to be smaller or equal than `init` rank = " - << initRank; - } - - if (updatesRank - 1 != initRank) - return emitOpError() << "expected `updates` rank + 1 to match `init` rank"; - - if (updatesType.getElementType() != initType.getElementType()) { - return emitOpError() - << "expected `updates` element type to match `init` element type"; - } - - // The update computation should yield exactly 1 result. - auto updateTerminator = cast(getBody()->getTerminator()); - Type outputElementType = - getDpsInitOperand(0)->get().getType().cast().getElementType(); - if (!succeeded(checkYieldOutputs(updateTerminator, outputElementType))) - return failure(); - - return success(); -} - -SmallVector ScatterOp::getLoopIteratorTypes() { - return {utils::IteratorType::reduction}; -} - -SmallVector ScatterOp::getIterationDomain(OpBuilder &b) { - Value indicesCount = b.create(getLoc(), getIndices(), 0); - return {Range{b.getIndexAttr(0), indicesCount, b.getIndexAttr(1)}}; -} - -FailureOr ScatterOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - Location loc = getLoc(); - IntegerAttr zeroAttr = b.getIndexAttr(0); - - OpFoldResult tileOffset = offsets.front(); - OpFoldResult tileSize = sizes.front(); - - // Tile outer dimension of updates. - Value update = this->getUpdates(); - auto updateType = update.getType().cast(); - - SmallVector updateOffsets(updateType.getRank(), zeroAttr); - updateOffsets.front() = tileOffset; - SmallVector updateSizes = tensor::getMixedSizes(b, loc, update); - updateSizes.front() = tileSize; - - Value updateSlice = - materializeSlice(b, loc, update, updateOffsets, updateSizes); - - // Tile outer dimension of indices. - Value indices = this->getIndices(); - - SmallVector indicesOffsets{offsets.front(), zeroAttr}; - indicesOffsets.front() = tileOffset; - SmallVector indicesSizes = - tensor::getMixedSizes(b, loc, indices); - indicesSizes.front() = tileSize; - - Value indicesSlice = - materializeSlice(b, loc, indices, indicesOffsets, indicesSizes); - - // Get full space of the `init` tensor. We use an extract_slice op because - // otherwise, tileUsingSCFForOp won't replace the arg with the bbarg. - int64_t initRank = getInit().getType().getRank(); - Value init = materializeSlice(b, loc, this->getInit(), - SmallVector(initRank, zeroAttr), - tensor::getMixedSizes(b, loc, this->getInit())); - - Operation *tiledOp = - mlir::clone(b, this->getOperation(), TypeRange{init.getType()}, - ValueRange{indicesSlice, updateSlice, init}); - return TilingResult{{tiledOp}, {tiledOp->getResult(0)}}; -} - -LogicalResult ScatterOp::getResultTilePosition( - OpBuilder &b, unsigned /*resultNumber*/, ArrayRef /*offsets*/, - ArrayRef /*sizes*/, SmallVector &resultOffsets, - SmallVector &resultSizes) { - ScatterOp scatterOp = cast(this->getOperation()); - auto init = scatterOp.getInit(); - resultOffsets = - SmallVector(init.getType().getRank(), b.getIndexAttr(0)); - resultSizes = tensor::getMixedSizes(b, scatterOp.getLoc(), init); - return success(); -} - -FailureOr ScatterOp::generateResultTileValue( - OpBuilder &b, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes) { - assert(resultNumber == 0 && "variadic scatter is not implemented"); - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult; -} - -void ScatterOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -//===----------------------------------------------------------------------===// -// GatherOp -//===----------------------------------------------------------------------===// - -ParseResult GatherOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDstStyleOp(parser, result); -} - -void GatherOp::print(OpAsmPrinter &p) { printDstStyleOp(*this, p); } - -LogicalResult GatherOp::verify() { - auto indicesType = getStartIndices().getType(); - int64_t indicesRank = indicesType.getRank(); - - if (indicesRank != 2) - return emitOpError() << "expected `indices` to be a 2D tensor"; - - auto initType = getInit().getType(); - if (indicesType.getDimSize(0) != getInit().getType().getDimSize(0)) { - return emitOpError() - << "expected major dimension of `startIndices` to match " - "major dimension of `init`"; - } - - if (initType.getNumDynamicDims() > 1 || - (initType.getNumDynamicDims() == 1 && !initType.isDynamicDim(0))) { - return emitOpError() << "only the major dimenion of `init` may be dynamic"; - } - - if (indicesType.isDynamic(1)) { - return emitOpError() - << "the minor dimensions of `startIndices` must be static"; - } - - return verifyDestinationStyleOp(getOperation()); -} - -SmallVector GatherOp::getLoopIteratorTypes() { - return {utils::IteratorType::parallel}; -} - -SmallVector GatherOp::getIterationDomain(OpBuilder &b) { - Value indicesCount = b.create(getLoc(), getStartIndices(), 0); - return {Range{b.getIndexAttr(0), indicesCount, b.getIndexAttr(1)}}; -} - -FailureOr GatherOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - SmallVector startIndexOffsets{offsets.front(), - b.getIndexAttr(0)}; - SmallVector startIndexSizes{ - sizes.front(), - b.getIndexAttr(getStartIndices().getType().getShape().back())}; - auto subStartIndices = materializeSlice(b, getLoc(), getStartIndices(), - startIndexOffsets, startIndexSizes); - - int64_t initRank = getInit().getType().getRank(); - SmallVector initOffsets(initRank, b.getIndexAttr(0)); - initOffsets[0] = offsets.front(); - auto initSizes = tensor::getMixedSizes(b, getLoc(), getInit()); - initSizes[0] = sizes.front(); - Value initSlice = - materializeSlice(b, getLoc(), getInit(), initOffsets, initSizes); - - auto gatherOp = - b.create(getLoc(), TypeRange{initSlice.getType()}, - ValueRange{getOperand(), subStartIndices, initSlice}); - return TilingResult{{gatherOp}, {gatherOp.getResult()}}; -} - -LogicalResult GatherOp::getResultTilePosition( - OpBuilder &b, unsigned /*resultNumber*/, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - GatherOp gatherOp = cast(this->getOperation()); - auto init = gatherOp.getInit(); - resultOffsets = - SmallVector(init.getType().getRank(), b.getIndexAttr(0)); - resultOffsets.front() = offsets.front(); - resultSizes = tensor::getMixedSizes(b, gatherOp.getLoc(), init); - resultSizes.front() = sizes.front(); - return success(); -} - -FailureOr GatherOp::generateResultTileValue( - OpBuilder &b, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes) { - assert(resultNumber == 0 && "resultNumber > 0 not implemented"); - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult.value(); -} - -void GatherOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -//===----------------------------------------------------------------------===// -// SortOp -//===----------------------------------------------------------------------===// - -void SortOp::getAsmResultNames(function_ref setNameFn) { - ResultRange results = getResults(); - for (size_t i = 0; i < results.size(); i++) { - setNameFn(results[i], "sorted" + std::to_string(i)); - } -} - -void SortOp::getAsmBlockArgumentNames(Region ®ion, - OpAsmSetValueNameFn setNameFn) { - for (int i = 0, e = region.getNumArguments(); i < e; i += 2) { - setNameFn(region.getArgument(i), "lhs" + std::to_string(i / 2)); - setNameFn(region.getArgument(i + 1), "rhs" + std::to_string(i / 2)); - } -} - -ParseResult SortOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { - int64_t dimension = 0; - int64_t isStable = 0; - if (parser.parseKeyword("dimension") || parser.parseEqual() || - parser.parseInteger(dimension) || - parser.parseKeyword("is_stable") || parser.parseEqual() || - parser.parseInteger(isStable)) - return failure(); - - auto b = parser.getBuilder(); - attributes.set("dimension", b.getIndexAttr(dimension)); - attributes.set("is_stable", b.getBoolAttr(isStable != 0)); - return success(); - })) - return failure(); - - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); - } - - Region *comparator = result.addRegion(); - if (parser.parseRegion(*comparator, regionArgs)) return failure(); - - return success(); -} - -void SortOp::print(OpAsmPrinter &p) { - printDstStyleOp( - *this, p, [](SortOp op, OpAsmPrinter &p) -> SmallVector { - p << op.getDimensionAttrName().str() << " = " << op.getDimension() - << ' ' << op.getIsStableAttrName().str() << " = " << op.getIsStable(); - return {op.getDimensionAttrName(), op.getIsStableAttrName()}; - }); - - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(getComparator().getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; - - p.printRegion(getComparator(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); -} - -LogicalResult SortOp::verify() { - auto *comparatorBlock = getBody(); - auto comparatorArgs = comparatorBlock->getArguments(); - - // Checks that the arity of the comparator is equal to twice the number of - // inputs. - int64_t numInputs = getNumDpsInputs(); - int64_t numOutputs = getNumDpsInits(); - if (getNumDpsInits() != numInputs) { - return emitOpError() << "expected the number of inputs " << numInputs - << " to match the number of outputs " << numOutputs; - } - if (static_cast(comparatorArgs.size()) != numInputs * 2) { - return emitOpError() << "expected the number of block arguments " - << comparatorArgs.size() << " to be twice the number " - << "of inputs (2*" << numInputs << ")"; - } - // Checks that the comparator's arguments match the element type of the - // inputs. - TypeRange inputTypes = TypeRange{getInputs()}; - TypeRange comparatorArgElementTypes = comparatorBlock->getArgumentTypes(); - for (size_t i = 0; i < getInputs().size(); ++i) { - Type inputArgElemType = inputTypes[i].cast().getElementType(), - comparatorArgElemType1 = comparatorArgElementTypes[2 * i], - comparatorArgElemType2 = comparatorArgElementTypes[2 * i + 1]; - if (comparatorArgElemType1 != inputArgElemType || - comparatorArgElemType2 != inputArgElemType) - return emitOpError() << "expected element type of input " << i - << " to match type of the corresponding " - "arguments to the comparison function but got " - << inputArgElemType << " and (" - << comparatorArgElemType1 << ", " - << comparatorArgElemType2 << ")"; - } - - // Checks that the comparator yields exactly one boolean output. - YieldOp comparatorTerminator = - cast(comparatorBlock->getTerminator()); - if (!succeeded( - checkYieldOutputs(comparatorTerminator, - TypeRange({IntegerType::get(getContext(), 1)})))) - return failure(); - - // Checks that the inputs all have the same shape. - ArrayRef referenceShape = - getInputs().front().getType().cast().getShape(); - - for (const auto &item : llvm::enumerate(TypeRange{getInputs()})) { - ArrayRef shape = item.value().cast().getShape(); - if (shape != referenceShape) { - return emitOpError() << "expected all inputs to have the same shape (" - << referenceShape << ") but input " << item.index() - << " has shape (" << shape << ")"; - } - } - - // Checks that the outputs have the same shape as the inputs. - for (const auto &item : llvm::enumerate(getInits())) { - ArrayRef shape = - item.value().getType().cast().getShape(); - if (shape != referenceShape) { - return emitOpError() << "expected outputs to have shape (" - << referenceShape << ") but output " << item.index() - << " has shape (" << shape << ")"; - } - } - - // Checks that the rank of the reference shape is larger than the absolute - // value of the sorting dimension. This is enough to ensure that the dimension - // is valid, since all inputs are known to have the same shape. `getDimension` - // returns an unsigned int, so no need to check for negative values. - size_t referenceRank = referenceShape.size(); - if (getDimension().getSExtValue() >= (int64_t)referenceRank) { - return emitOpError() << "sorting dimension must be in range [0, " - << referenceRank << ") but got " - << getDimension().getSExtValue(); - } - - return verifyDestinationStyleOp(getOperation()); -} - -SmallVector SortOp::getLoopIteratorTypes() { - return getParallelIteratorTypes(getType(0).cast().getRank() - 1); -} - -SmallVector SortOp::getIterationDomain(OpBuilder &b) { - Location loc = getLoc(); - auto oneInit = getInits().front(); - auto operandsRank = oneInit.getType().cast().getRank(); - - SmallVector iterationDomain(operandsRank - 1); - - IntegerAttr zero = b.getIndexAttr(0); - IntegerAttr one = b.getIndexAttr(1); - int64_t sortDimension = getDimension().getSExtValue(); - - for (auto axis : llvm::seq(0, operandsRank - 1)) { - int64_t operandAxis = (axis >= sortDimension) ? axis + 1 : axis; - iterationDomain[axis].offset = zero; - iterationDomain[axis].size = - b.createOrFold(loc, oneInit, operandAxis); - iterationDomain[axis].stride = one; - } - return iterationDomain; -} - -FailureOr SortOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - auto loc = getLoc(); - SmallVector tileOffsets = llvm::to_vector(offsets); - SmallVector tileSizes = llvm::to_vector(sizes); - - size_t numOutputs = getNumDpsInits(); - int64_t sortDimension = getDimension().getSExtValue(); - - Value oneInput = getInputs().front(); - - // Capture the entire sorting axis in each tile. - tileOffsets.insert(tileOffsets.begin() + sortDimension, b.getIndexAttr(0)); - - OpFoldResult sortDimensionSize = - b.createOrFold(loc, oneInput, sortDimension); - tileSizes.insert(tileSizes.begin() + sortDimension, sortDimensionSize); - - // Materialize the tile for each input and init. - SmallVector tiledInputsAndInits; - SmallVector tiledResultTypes; - tiledInputsAndInits.reserve(numOutputs * 2); - tiledResultTypes.reserve(numOutputs); - - for (const auto &input : getInputs()) { - tiledInputsAndInits.push_back( - materializeSlice(b, loc, input, tileOffsets, tileSizes)); - auto tileShape = - tiledInputsAndInits.back().getType().cast().getShape(); - tiledResultTypes.push_back(RankedTensorType::get( - tileShape, input.getType().cast().getElementType())); - } - - for (const auto &init : getInits()) { - tiledInputsAndInits.push_back( - materializeSlice(b, loc, init, tileOffsets, tileSizes)); - } - - Operation *tiledOp = mlir::clone(b, this->getOperation(), tiledResultTypes, - tiledInputsAndInits); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; -} - -LogicalResult SortOp::getResultTilePosition( - OpBuilder &b, unsigned /*resultNumber*/, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - SortOp sortOp = cast(this->getOperation()); - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - - int64_t sortDimIndex = sortOp.getDimension().getSExtValue(); - Value sortDimValue = b.create( - sortOp.getLoc(), sortOp.getInputs().front(), sortDimIndex); - resultOffsets.insert(resultOffsets.begin() + sortDimIndex, b.getIndexAttr(0)); - resultSizes.insert(resultSizes.begin() + sortDimIndex, sortDimValue); - return success(); -} - -FailureOr SortOp::generateResultTileValue( - OpBuilder &b, unsigned /*resultNumber*/, ArrayRef offsets, - ArrayRef sizes) { - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult.value(); -} - -void SortOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -ParseResult ReverseOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { - return parseDenseI64ArrayAttr(parser, attributes, "reverse_dimensions"); - }); -} - -void ReverseOp::print(OpAsmPrinter &p) { - printDstStyleOp( - *this, p, [](ReverseOp op, OpAsmPrinter &p) -> SmallVector { - printDenseI64ArrayAttr(p, op.getReverseDimensionsAttrName(), - op.getReverseDimensions()); - return {op.getReverseDimensionsAttrName()}; - }); -} - -LogicalResult ReverseOp::verify() { - return verifyDestinationStyleOp(getOperation()); -} - -void ReverseOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(getResult(), "reversed"); -} - -SmallVector ReverseOp::getLoopIteratorTypes() { - int64_t rank = getType().cast().getRank(); - return getParallelIteratorTypes(rank); -} - -SmallVector ReverseOp::getIterationDomain(OpBuilder &b) { - return getIterationDomainForTensor(b, getLoc(), getInit()); -} - -namespace { -SmallVector getInputTileOffsetsForReverse( - OpBuilder &b, Location loc, ArrayRef offsets, - ArrayRef tileSizes, ArrayRef reverseDimensions, - TypedValue &input) { - auto tileOpOffsets = getValueOrCreateConstantIndexOp(b, loc, offsets); - auto sizes = getValueOrCreateConstantIndexOp(b, loc, tileSizes); - SmallVector inputTileOffsets; - for (size_t i = 0; i < tileOpOffsets.size(); ++i) { - if (llvm::is_contained(reverseDimensions, i)) { - inputTileOffsets.push_back(OpFoldResult{b.createOrFold( - loc, - b.createOrFold( - loc, b.createOrFold(loc, input, i), - Value(tileOpOffsets[i])), - sizes[i])}); - } else { - inputTileOffsets.push_back(tileOpOffsets[i]); - } - } - - return inputTileOffsets; -} -} // namespace - -FailureOr ReverseOp::getTiledImplementation( - OpBuilder &b, ArrayRef offsets, - ArrayRef sizes) { - auto loc = getLoc(); - auto input = getInput(); - SmallVector inputTileOffsets = getInputTileOffsetsForReverse( - b, loc, offsets, sizes, getReverseDimensions(), input); - - // Materialize the tile for input and init. - SmallVector tiledInputsAndInits; - - tiledInputsAndInits.push_back( - materializeSlice(b, loc, input, inputTileOffsets, sizes)); - tiledInputsAndInits.push_back( - materializeSlice(b, loc, getInit(), offsets, sizes)); - auto tileShape = - tiledInputsAndInits.back().getType().cast().getShape(); - auto tiledResultType = RankedTensorType::get( - tileShape, input.getType().cast().getElementType()); - - Operation *tiledOp = mlir::clone(b, this->getOperation(), tiledResultType, - tiledInputsAndInits); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; -} - -LogicalResult ReverseOp::getResultTilePosition( - OpBuilder & /*b*/, unsigned /*resultNumber*/, - ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets = llvm::to_vector(offsets); - resultSizes = llvm::to_vector(sizes); - return success(); -} - -FailureOr ReverseOp::generateResultTileValue( - OpBuilder &b, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes) { - FailureOr tilingResult = - getTiledImplementation(b, offsets, sizes); - if (failed(tilingResult)) return failure(); - return tilingResult.value(); -} - -OpFoldResult ReverseOp::fold( - ReverseOpGenericAdaptor>) /*operands*/ { - auto inputType = getInput().getType(); - for (unsigned i = 0; i < getReverseDimensions().size(); ++i) { - if (inputType.getDimSize(getReverseDimensions()[i]) != 1) return nullptr; - } - return getInput(); -} - -void ReverseOp::getEffects( - SmallVectorImpl> - &effects) { - getDstStyleOpEffectsImpl(effects, getOperation()->getResults(), - getDpsInputs(), getDpsInits()); -} - -} // namespace thlo -} // namespace mlir - -// Generated op classes. -#define GET_OP_CLASSES -#include "thlo/IR/thlo_ops.cc.inc" diff --git a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.h b/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.h deleted file mode 100644 index 9f6ad3b7701389..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the THLO dialect. - -#ifndef MLIR_HLO_THLO_IR_THLO_OPS_H -#define MLIR_HLO_THLO_IR_THLO_OPS_H - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/TilingInterface.h" - -// Generated dialect declarations. -#include "thlo/IR/thlo_dialect.h.inc" - -// Generated operation classes. -#define GET_OP_CLASSES -#include "thlo/IR/thlo_ops.h.inc" - -#endif // MLIR_HLO_THLO_IR_THLO_OPS_H diff --git a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.td b/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.td deleted file mode 100644 index 5346d413842e51..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/IR/thlo_ops.td +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THLO_OPS -#define THLO_OPS - -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/TilingInterface.td" - -def TensorOrMemref : - AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; - -class TensorOrMemrefOf allowedTypes> : - AnyTypeOf<[MemRefOf, RankedTensorOf], - "", "::mlir::ShapedType">; - -def THLO_Dialect : Dialect { - let name = "thlo"; - let cppNamespace = "::mlir::thlo"; - let usePropertiesForAttributes = 0; -} - -class THLO_Op traits> : - Op { - let hasVerifier = 1; -} - -class THLO_DstStyleOp traits> : THLO_Op, - DestinationStyleOpInterface] # traits> { - let hasCustomAssemblyFormat = 1; -} - -def THLO_ConcatenateOp : THLO_DstStyleOp<"concatenate", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods - ]> { - let summary = "Destination-style twin for `mhlo.concatenate`"; - let description = [{ - tHLO ConcatenateOp composes a tensor or a memref from multiple tensors or - memrefs. - - Example: - ``` - %concat = thlo.concatenate - ins(%T1 : tensor<100x?xf32>, %T2 : tensor<300x?xf32>) - outs(%init : tensor<400x?xf32>) - dimension = 0 - ``` - - See https://www.tensorflow.org/xla/operation_semantics#concatenate - }]; - - let arguments = (ins - Variadic:$inputs, - TensorOrMemref:$init, - IndexAttr:$dimension - ); - let results = (outs Variadic:$result); - - let extraClassDeclaration = [{ - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitMutable(); - } - }]; -} - -def THLO_DynamicBroadcastInDimOp : THLO_DstStyleOp<"dynamic_broadcast_in_dim", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Destination-style twin for `mhlo.dynamic_broadcast_in_dim`"; - let description = [{ - tHLO DynamicBroadcastInDimOp specifies a map how to broadcast input - dimensions. It also supports broadcasting size-1 dimensions. - - Example: - ``` - %dyn_bcast = thlo.dynamic_broadcast_in_dim - ins(%input : tensor) - outs(%init : tensor) - broadcast_dimensions = [0, 2] - ``` - - See https://www.tensorflow.org/xla/operation_semantics#broadcastindim - }]; - - let arguments = (ins - // Input args - TensorOrMemref:$operand, - // Output arg - TensorOrMemref:$init, - - DenseI64ArrayAttr:$broadcast_dimensions, - OptionalAttr:$known_expanding_dimensions, - OptionalAttr:$known_nonexpanding_dimensions - ); - - let results = (outs Variadic:$result); - - let extraClassDeclaration = [{ - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitMutable(); - } - }]; -} - -def THLO_GatherOp : THLO_DstStyleOp<"gather", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Destination-style twin for `mhlo.gather`"; - let description = [{ - tHLO GatherOp corresponds to the canonicalized mHLO GatherOp, i.e. - - - start_indices is a two-dimensional tensor. - - index_vector_dim is 1 - - offset_dims is [1, 2, ...] - - collapsed_slice_dims is [] - - start_index_map is range(start_indices.shape[1]) - - Example: - ``` - %gathered = thlo.gather - ins(%input : tensor<100xf32>, %indices : tensor<42x1xindex>) - outs(%init : tensor<42xf32>) - ``` - - See https://www.tensorflow.org/xla/operation_semantics#gather. - }]; - let arguments = (ins - // Input args - TensorOrMemref:$operand, - TensorOrMemrefOf<[Index]>:$start_indices, - // Output arg - TensorOrMemref:$init - ); - let results = (outs Variadic:$result); - - let extraClassDeclaration = [{ - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitMutable(); - } - }]; -} - -def THLO_ScatterOp : THLO_DstStyleOp<"scatter", [ - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp"> - ]> { - let summary = "Destination-style twin for `mhlo.scatter`"; - let description = [{ - tHLO ScatterOp corresponds to the canonicalized mHLO ScatterOp, i.e. - - - update_window_dims is range(1, rank(update_window_dims)) - - inserted_window_dims is [] - - scatter_dims_to_operand_dims is range(0, rank(indices)) - - index_vector_dim is rank(indices) - 1 - - At the moment, the variadic case is not supported. - - Example: - ``` - %scattered = thlo.scatter - ins(%indices : tensor<2x2xindex>, %input : tensor<2x1x3xf32>) - outs(%init : tensor<3x3xf32>) - (%arg3: f32, %arg4: f32) { - %0 = arith.addf %arg3, %arg4 : f32 - thlo.yield %0 : f32 - } - ``` - - See https://www.tensorflow.org/xla/operation_semantics#scatter. - }]; - - let arguments = (ins - // Input args - TensorOrMemrefOf<[Index]>:$indices, - TensorOrMemref:$updates, - // Output arg - TensorOrMemref:$init - ); - - let results = (outs Variadic:$result); - - let regions = (region SizedRegion<1>:$update_computation); - - let extraClassDeclaration = [{ - // Returns index vector dimension size, which is always statically-known. - int64_t getIndexVectorDimSize() { - return getIndices().getType().getDimSize(1); - } - - // Returns the number of indices, i.e. number of scalar/tensor updates. - int64_t getIndicesCount() { return getIndices().getType().getDimSize(0); } - - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitMutable(); - } - }]; -} - -def THLO_SortOp : THLO_DstStyleOp<"sort", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - SameVariadicOperandSize, - SingleBlockImplicitTerminator<"YieldOp"> - ]> { - let summary = "Destination-style twin for the `mhlo.sort`"; - let description = [{ - Sorts the given `operands` along the given `dimension` using the given - `comparator`. - - Example: - ``` - %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - dimension = 0 - is_stable = true - (%lhs0: f32, %rhs0: f32, %lhs1: i32, %rhs1: i32) { - %0 = arith.cmpf ogt, %lhs0, %rhs0 : f32 - thlo.yield %0 : i1 - } - ``` - - See https://www.tensorflow.org/xla/operation_semantics#sort. - }]; - - let arguments = (ins - // Input args - Variadic:$inputs, - // Output args - Variadic:$inits, - - IndexAttr:$dimension, - BoolAttr:$is_stable - ); - - let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$comparator); - - let extraClassDeclaration = [{ - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitsMutable(); - } - }]; -} - -def THLO_ReverseOp : THLO_DstStyleOp<"reverse", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods,]> { - let summary = "Destination-style twin for the `mhlo.reverse`"; - let description = [{ - Reverses the specified dimensions of `input` according to the given - `dimensions`. - - See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. - }]; - - let arguments = (ins - TensorOrMemref:$input, - TensorOrMemref:$init, - DenseI64ArrayAttr:$reverse_dimensions - ); - - let results = (outs TensorOrMemref:$result); - - let hasFolder = 1; - - let extraClassDeclaration = [{ - // Implement method necessary for DestinationStyleOpInterface. - mlir::MutableOperandRange getDpsInitsMutable() { - return getInitMutable(); - } - }]; -} - -def THLO_YieldOp : THLO_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["ScatterOp", "SortOp"]>]>, - Arguments<(ins Variadic:$values)> { - let summary = "Yield operation for tHLO ops with regions."; - let assemblyFormat = "attr-dict $values `:` type($values)"; - let hasVerifier = 1; -} - -#endif // THLO_OPS diff --git a/third_party/xla/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt deleted file mode 100644 index 6ee9828d2a5a52..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_library(ThloBufferizableOpInterface - bufferizable_op_interface_impl.cc - - LINK_LIBS PUBLIC - THLODialect - MLIRBufferizationDialect - MLIRDestinationStyleOpInterface -) diff --git a/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc b/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc deleted file mode 100644 index 98ddebd8230f1c..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "thlo/interfaces/bufferizable_op_interface_impl.h" - -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "thlo/IR/thlo_ops.h" - -namespace mlir { -namespace thlo { -namespace { - -using mlir::bufferization::AliasingOpOperandList; -using mlir::bufferization::AliasingValueList; -using mlir::bufferization::AnalysisState; -using mlir::bufferization::BufferizableOpInterface; -using mlir::bufferization::BufferizationOptions; -using mlir::bufferization::BufferRelation; - -// We can reuse the upstream implementation when DestinationStyleOpInterface -// is moved out of linalg. -static LogicalResult bufferizeDestinationStyleOpInterface( - RewriterBase &rewriter, DestinationStyleOpInterface op, - const BufferizationOptions &options) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - - // Nothing to do. This op is already bufferized. - if (op.hasBufferSemantics()) return success(); - - if (!op.hasTensorSemantics()) - return op->emitError() << "expected either buffer or tensor semantics"; - - size_t numOutputs = op.getNumDpsInits(); - - // New operands for the cloned op. - SmallVector newOperands; - newOperands.reserve(op.getNumDpsInputs() + numOutputs); - - for (OpOperand *opOperand : op.getDpsInputOperands()) { - if (op.isScalar(opOperand)) { - newOperands.push_back(opOperand->get()); - continue; - } - FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); - if (failed(buffer)) return failure(); - newOperands.push_back(*buffer); - } - - // New output operands for the cloned op. - SmallVector newOutputs; - newOutputs.reserve(numOutputs); - - for (OpResult opResult : op->getOpResults()) { - OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); - FailureOr resultBuffer = - getBuffer(rewriter, opOperand->get(), options); - if (failed(resultBuffer)) return failure(); - newOutputs.push_back(*resultBuffer); - } - - newOperands.append(newOutputs.begin(), newOutputs.end()); - - // Set insertion point now that potential alloc/dealloc are introduced. - rewriter.setInsertionPoint(op); - - // Clone the op, but use the new operands. Move the existing block into the - // new op. Since the new op does not have any tensor results, it does not - // return anything. - auto newOp = cast(cloneWithoutRegions( - rewriter, op, /*resultTypes=*/TypeRange{}, newOperands)); - - assert(op->getNumRegions() <= 1); - if (op->getNumRegions() == 1) { - rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), - newOp->getRegion(0).begin()); - } - - // Replace the results of the old op with the new output buffers. - bufferization::replaceOpWithBufferizedValues(rewriter, op, newOutputs); - - return success(); -} - -struct ThloSortOpBufferizationModel - : public BufferizableOpInterface::ExternalModel< - ThloSortOpBufferizationModel, SortOp> { - bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/, - const AnalysisState & /*state*/) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState & /*state*/) const { - return cast(op).isDpsInit(&opOperand); - } - - AliasingOpOperandList getAliasingOpOperands( - Operation *op, Value value, const AnalysisState & /*state*/) const { - auto opResult = value.dyn_cast(); - if (!opResult) return {}; - auto dstStyleOp = cast(op); - - // The i-th OpResult aliases with the i-th "out" tensor. - return {{dstStyleOp.getDpsInitOperand(opResult.getResultNumber()), - BufferRelation::Equivalent}}; - } - - AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, - const AnalysisState & /*state*/) const { - auto dstStyleOp = cast(op); - - // The i-th "out" tensor aliases with the i-th OpResult. - if (dstStyleOp.isDpsInit(&opOperand)) - return { - {dstStyleOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}; - return {}; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - return bufferizeDestinationStyleOpInterface( - rewriter, cast(op), options); - } -}; - -} // namespace - -} // namespace thlo -} // namespace mlir - -void mlir::thlo::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, thlo::THLODialect * /*dialect*/) { - SortOp::attachInterface(*ctx); - }); -} diff --git a/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h b/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h deleted file mode 100644 index ee35b031ac5aea..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H -#define MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H - -namespace mlir { -class DialectRegistry; - -namespace thlo { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace thlo -} // namespace mlir - -#endif // MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H diff --git a/third_party/xla/xla/mlir_hlo/thlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/thlo/transforms/CMakeLists.txt deleted file mode 100644 index d582d86a724ee3..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/transforms/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set(LLVM_TARGET_DEFINITIONS thlo_passes.td) -mlir_tablegen(thlo_passes.h.inc -gen-pass-decls -name AllThlo) -add_public_tablegen_target(MLIRThloPassIncGen) - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_library(ThloPasses - legalize_sort/legalize_sort.cc - - DEPENDS - MLIRThloPassIncGen - - LINK_LIBS PUBLIC - MLIRArithDialect - MLIRArithUtils - MLIRFuncDialect - MLIRMemRefDialect - MLIRPass - MLIRSCFDialect - MLIRTransforms -) diff --git a/third_party/xla/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc b/third_party/xla/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc deleted file mode 100644 index 3f3d75fc5ad51d..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc +++ /dev/null @@ -1,561 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "thlo/IR/thlo_ops.h" -#include "thlo/transforms/passes.h" - -namespace mlir { -namespace thlo { - -#define GEN_PASS_DEF_THLOLEGALIZESORTPASS -#include "thlo/transforms/thlo_passes.h.inc" - -namespace { - -using ::mlir::arith::AddIOp; -using ::mlir::arith::MinSIOp; -using ::mlir::arith::SelectOp; - -constexpr uint64_t kInsertionSortSize = 16; - -// Inlines the `comparator` region (without terminator) at the current insertion -// point, replacing the arguments with the given values from `lhs` and `rhs`. -Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, - SmallVector& rhs, Region& comparator) { - assert(comparator.hasOneBlock() && "Comparator must have only one block."); - Block& block = comparator.front(); - assert(block.getTerminator()->getOperands().size() == 1 && - "Comparator must return a single value"); - - IRMapping mapping; - for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { - Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; - mapping.map(arg, value); - } - - for (Operation& op : block.without_terminator()) b.clone(op, mapping); - Value result = mapping.lookup(block.getTerminator()->getOperand(0)); - - return result; -} - -// Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range -// [`left`;`right`). `arrayMemrefs` must be sorted according to `comparator`. -Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, - SmallVector& pivots, ValueRange arrayMemrefs, - Region& comparator) { - SmallVector types{leftInit.getType(), rightInit.getType()}; - ArithBuilder arith(b, b.getLoc()); - - // while ( - auto whileOp = b.create( - types, SmallVector{leftInit, rightInit}, - [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { - // left < right) { - Value left = args[0], right = args[1]; - beforeBuilder.create(beforeLoc, - arith.slt(left, right), args); - }, - [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { - ImplicitLocOpBuilder impLocAfterBuilder = - ImplicitLocOpBuilder(afterLoc, afterBuilder); - Value left = args[0], right = args[1]; - // int mid = (left + right) >> 1; - Value one = impLocAfterBuilder.create(1); - Value mid = impLocAfterBuilder.create( - arith.add(left, right), one); - Value midPlusOne = impLocAfterBuilder.create(mid, one); - - auto arraysAtMid = llvm::to_vector( - llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { - return impLocAfterBuilder.create(arrayMemref, - mid); - })); - - Value cond = - emitComparison(impLocAfterBuilder, pivots, arraysAtMid, comparator); - // if (comparator(pivot, array[mid])) - // right = mid; - // else - // left = mid + 1; - Value newLeft = arith.select(cond, left, midPlusOne); - Value newRight = arith.select(cond, mid, right); - - // } - impLocAfterBuilder.create(ValueRange{newLeft, newRight}); - }); - - return whileOp.getResult(0); -} - -SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, - ValueRange memrefs, Value index) { - return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { - Type type = memref.getType().cast().getElementType(); - return b.create(type, memref, index); - })); -} - -void storeMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, - Value index, ValueRange values) { - for (auto [value, memref] : llvm::zip(values, memrefs)) { - b.create(value, memref, index); - } -} - -// Insertion sorts `inputMemrefs` in the range [`lo`; `hi`), storing the results -// in `outputMemrefs`. `inputMemrefs` and `outputMemrefs` must all be rank 1 and -// of identical size. -void emitInsertionSort(ImplicitLocOpBuilder& b, Value lo, Value hi, - ValueRange inputMemrefs, ValueRange outputMemrefs, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - Value zero = b.create(0); - Value one = b.create(1); - - // array[lo] = inputs[lo]; - storeMemrefElements(b, outputMemrefs, lo, - loadMemrefElements(b, inputMemrefs, lo)); - - // for (int start = lo + 1; start < hi; ++start) - { - auto forOp = b.create(arith.add(lo, one), hi, one); - OpBuilder::InsertionGuard outerGuard(b); - b.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getInductionVar(); - - // T pivot = inputs[start]; - auto pivots = loadMemrefElements(b, inputMemrefs, start); - - // int index = binarySearch(lo, start, pivot, array, comparator); - auto index = - emitBinarySearch(b, lo, start, pivots, outputMemrefs, comparator); - - // int n = start - index; // The number of elements to move - Value n = arith.sub(start, index); - - // memmove(&array[index + 1], &array[index], n * sizeof(T)) - // memref::CopyOp would be nice to use here, but: - // 1. It lowers to a quite inefficient library call in the general case - // (strides != 1). - // 2. It implements memcpy semantics, but we need memmove here. - // So we go with a loop instead. - auto copyForOp = b.create(zero, n, one); - { - OpBuilder::InsertionGuard innerGuard(b); - b.setInsertionPointToStart(copyForOp.getBody()); - Value copyLoopIndex = copyForOp.getInductionVar(); - - Value dstIndex = arith.sub(start, copyLoopIndex); - Value srcIndex = arith.sub(dstIndex, one); - storeMemrefElements(b, outputMemrefs, dstIndex, - loadMemrefElements(b, outputMemrefs, srcIndex)); - } - // array[index] = pivot; - storeMemrefElements(b, outputMemrefs, index, pivots); - } -} - -void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, - ValueRange readBufs, ValueRange writeBufs, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - // The while loop runs until we reach the end of either interval. It has three - // loop-carried variables: - // 1. current output index - // 2. current read index for interval 1 - // 3. current read index for interval 2 - SmallVector whileArgTypes{lo.getType(), lo.getType(), mid.getType()}; - SmallVector whileInitArgs{lo, lo, mid}; - SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); - - // while( - auto whileOp = b.create( - whileArgTypes, whileInitArgs, - [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { - Value i0 = args[1], i1 = args[2]; - - // i0 < mid && i1 < hi) { - Value inbounds0 = arith.slt(i0, mid); - Value inbounds1 = arith.slt(i1, hi); - beforeBuilder.create( - beforeLoc, arith._and(inbounds0, inbounds1), args); - }, - [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { - ImplicitLocOpBuilder impLocAfterBuilder(afterLoc, afterBuilder); - Value iOut = args[0], i0 = args[1], i1 = args[2]; - - // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; - SmallVector vals0 = - loadMemrefElements(impLocAfterBuilder, readBufs, i0); - SmallVector vals1 = - loadMemrefElements(impLocAfterBuilder, readBufs, i1); - - // writeBufs[iOut] = comparator(vals1, vals0) - // ? readBufs[i1++] : readBufs[i0++]; - Value cmp = - emitComparison(impLocAfterBuilder, vals1, vals0, comparator); - SmallVector pickedVals; - for (auto [val0, val1] : llvm::zip(vals0, vals1)) { - pickedVals.push_back( - impLocAfterBuilder.create(cmp, val1, val0)); - } - storeMemrefElements(impLocAfterBuilder, writeBufs, iOut, pickedVals); - Value one = impLocAfterBuilder.create(1); - Value nexti0 = - impLocAfterBuilder.create(cmp, i0, arith.add(i0, one)); - Value nexti1 = - impLocAfterBuilder.create(cmp, arith.add(i1, one), i1); - - // ++iOut; - Value nextIOut = impLocAfterBuilder.create(iOut, one); - impLocAfterBuilder.create( - ValueRange{nextIOut, nexti0, nexti1}); - }); - - // At this point, exactly one of the input ranges will have leftover elements. - Value iOut = whileOp->getResult(0); - Value i0 = whileOp->getResult(1); - Value i1 = whileOp->getResult(2); - - // We could use memref::CopyOp here, but typically, there aren't many leftover - // elements for randomly shuffled inputs. - Value leftoverIn0 = arith.slt(i0, mid); - Value start = arith.select(leftoverIn0, i0, i1); - Value end = arith.select(leftoverIn0, mid, hi); - Value n = arith.sub(end, start); - - Value zero = b.create(0); - Value one = b.create(1); - auto forOp = b.create(zero, n, one); - b.setInsertionPointToStart(forOp.getBody()); - Value copyIndex = forOp.getInductionVar(); - - Value srcIndex = arith.add(start, copyIndex); - Value dstIndex = arith.add(iOut, copyIndex); - storeMemrefElements(b, writeBufs, dstIndex, - loadMemrefElements(b, readBufs, srcIndex)); -} - -Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, - int64_t staticSortDimSize, ValueRange inputMemrefs, - ValueRange outputs0, ValueRange outputs1, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - Value size = arith.sub(hi, lo); - - Value zero = b.create(0); - Value insertionSortSize = - b.create(kInsertionSortSize); - - // Run insertion sort on blocks of size kInsertionSortSize. - { - auto forBody = [&](OpBuilder& ob, Location loc, Value start, ValueRange) { - ImplicitLocOpBuilder b = ImplicitLocOpBuilder(loc, ob); - Value end = arith.add( - b.create(arith.add(start, insertionSortSize), size), lo); - emitInsertionSort(b, start, end, inputMemrefs, outputs0, comparator); - b.create(ValueRange{}); - }; - b.create(/*lowerBound=*/zero, /*upperBound=*/size, - /*step=*/insertionSortSize, /*iterArgs=*/std::nullopt, - forBody); - } - - Value initParity = b.create(/*value=*/0, /*width=*/1); - if (staticSortDimSize >= 0 && - staticSortDimSize < static_cast(kInsertionSortSize)) { - return initParity; - } - - // The while arguments are: - // 1. the current size - // 2. a boolean stating whether we are reading from outputs0 or outputs1 - // - // 1 gets doubled each iteration, 2 gets negated. - // int currentSize = kInsertionSortSize; - SmallVector whileInitArgs{insertionSortSize, initParity}; - // First we read from `outputs0` (initialized by the insertion sort above). - - SmallVector whileArgTypes; - for (auto val : whileInitArgs) whileArgTypes.push_back(val.getType()); - - SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); - - // while ( - auto whileOp = b.create( - whileArgTypes, whileInitArgs, - [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { - // currentSize < totalSize) - Value currentSize = args[0]; - beforeBuilder.create( - beforeLoc, arith.slt(currentSize, size), args); - }, - [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { - ImplicitLocOpBuilder impLocAfterBuilder = - ImplicitLocOpBuilder(afterLoc, afterBuilder); - - // { - Value currentSize = args[0], parity = args[1]; - Value twoCurrentSize = arith.add(currentSize, currentSize); - - // emitMergeLoop(readBufs, writeBufs) { - // for (int start = 0; start < size; start += 2*currentSize) { - auto emitMergeLoop = [&](OpBuilder& builder, Location loc, - ValueRange readBufs, ValueRange writeBufs) { - ImplicitLocOpBuilder localImpLocBuilder(loc, builder); - ArithBuilder localArithBuilder(localImpLocBuilder, loc); - - auto forOp = - localImpLocBuilder.create(zero, size, twoCurrentSize); - OpBuilder::InsertionGuard guard(localImpLocBuilder); - localImpLocBuilder.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getInductionVar(); - - Value mid = localImpLocBuilder.create( - size, localArithBuilder.add(start, currentSize)); - Value end = localImpLocBuilder.create( - size, localArithBuilder.add(start, twoCurrentSize)); - emitMerge(localImpLocBuilder, start, mid, end, readBufs, writeBufs, - comparator); - return; - }; - // } - // } - - // if (parity) - // emitMergeLoop(outputs1, outputs0) - // else - // emitMergeLoop(outputs0, outputs1) - impLocAfterBuilder.create( - /*cond=*/parity, - /*thenBuilder=*/ - [&](OpBuilder& builder, Location loc) { - emitMergeLoop(builder, loc, outputs1, outputs0); - builder.create(loc, ValueRange{}); - }, - /*elseBuilder=*/ - [&](OpBuilder& builder, Location loc) { - emitMergeLoop(builder, loc, outputs0, outputs1); - builder.create(loc, ValueRange{}); - }); - - // parity = !parity; - Value one = impLocAfterBuilder.create(1, 1); - Value notParity = arith.sub(one, parity); - // currentSize *= 2; - SmallVector nextWhileArgs{twoCurrentSize, notParity}; - impLocAfterBuilder.create(nextWhileArgs); - }); - // } - - // The result is the parity bit. - return whileOp.getResult(1); -} - -struct Slicer { - Slicer(OpBuilder& b, int64_t sortDim, Value sortDimSize, - ValueRange inductionVariables) - : sizes(inductionVariables.size() + 1, b.getI64IntegerAttr(1)), - strides(inductionVariables.size() + 1, b.getI64IntegerAttr(1)) { - sizes[sortDim] = sortDimSize; - for (size_t i = 0; i < inductionVariables.size() + 1; ++i) { - if ((int64_t)i == sortDim) { - offsets.push_back(b.getI64IntegerAttr(0)); - } else { - offsets.push_back( - inductionVariables[i - static_cast((int64_t)i > sortDim)]); - } - } - } - - Value slice(ImplicitLocOpBuilder& b, Value input) { - auto ty = input.getType().cast(); - auto slicedType = - memref::SubViewOp::inferRankReducedResultType( - {ShapedType::kDynamic} /*1D output*/, ty, offsets, sizes, strides) - .cast(); - return b - .create(slicedType, input, offsets, sizes, strides) - .getResult(); - } - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -SmallVector sliceMemrefs(ImplicitLocOpBuilder& b, - SmallVector& inductionVariables, - Value sortDimSize, ValueRange memrefs, - SortOp op) { - if (inductionVariables.empty()) return memrefs; - - SmallVector slices; - Slicer slicer(b, op.getDimension().getSExtValue(), sortDimSize, - inductionVariables); - - for (Value out : memrefs) slices.push_back(slicer.slice(b, out)); - - return slices; -} - -struct SortOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SortOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Lowering thlo to our merge sort implementation necessarily happens after - // bufferization. - if (!op.hasBufferSemantics()) - return op->emitError() << "expected buffer semantics"; - - // Note: the output memrefs aren't necessarily the ones that we return - ValueRange outputMemrefs = op.getInits(); - SmallVector scratchMemrefs; - scratchMemrefs.reserve(outputMemrefs.size()); - - Value firstInput = op.getOperand(0); - auto firstInputType = firstInput.getType().cast(); - int64_t inputRank = firstInputType.getRank(); - - int64_t sortDim = op.getDimension().getSExtValue(); - Value sortDimSize = b.createOrFold( - firstInput, b.create(sortDim)); - int64_t staticSortDimSize = firstInputType.getDimSize(sortDim); - - SmallVector dynamicDims; - for (int i = 0; i < inputRank; ++i) { - if (!firstInputType.isDynamicDim(i)) continue; - auto index = b.createOrFold(i); - Value dimOp = b.create(firstInput, index); - dynamicDims.push_back(dimOp); - } - - // Allocate scratch memrefs. If the size of the sort dimension is - // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused - // and will be cleaned up later. - for (auto input : op.getInputs()) { - auto inputType = input.getType().cast(); - auto memRefType = - MemRefType::get(inputType.getShape(), inputType.getElementType()); - scratchMemrefs.emplace_back( - b.create(memRefType, dynamicDims)); - } - - b.setInsertionPoint(op); - Value zero = b.create(0); - Value one = b.create(1); - - Value forInitArg = b.create(/*value=*/0, /*width=*/1); - SmallVector forOps; - SmallVector inductionVariables; - forOps.reserve(inputRank - 1); - inductionVariables.reserve(inputRank - 1); - for (int64_t i = 0; i < inputRank; ++i) { - if (i != sortDim) { - Value dim = b.create(i); - Value upperBound = b.create(firstInput, dim); - scf::ForOp& forOp = forOps.emplace_back(b.create( - zero, upperBound, one, ValueRange{forInitArg})); - inductionVariables.push_back(forOp.getInductionVar()); - b.setInsertionPointToStart(forOp.SingleBlock::getBody()); - } - } - SmallVector inputs = - sliceMemrefs(b, inductionVariables, sortDimSize, op.getInputs(), op); - SmallVector outputs = - sliceMemrefs(b, inductionVariables, sortDimSize, outputMemrefs, op); - SmallVector scratches = - sliceMemrefs(b, inductionVariables, sortDimSize, scratchMemrefs, op); - - Value parity = - emitBottomUpMergeSort(b, zero, sortDimSize, staticSortDimSize, inputs, - outputs, scratches, op.getRegion()); - - // Pass the parity bit through the for loops. - for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { - b.setInsertionPointToEnd(&forOps[i].getRegion().front()); - b.create(ValueRange{parity}); - parity = forOps[i]->getResult(0); - } - b.setInsertionPoint(op); - - // If the results are in the scratch memrefs, copy them to the output - // memrefs. - auto thenBlock = [&](OpBuilder& ob, Location loc) { - ImplicitLocOpBuilder b = ImplicitLocOpBuilder(loc, ob); - for (auto [target, source] : llvm::zip(outputMemrefs, scratchMemrefs)) { - b.create(source, target); - } - b.create(ValueRange{}); - }; - - rewriter.replaceOpWithNewOp(op, /*cond=*/parity, - /*thenBuilder=*/thenBlock, - /*elseBuilder=*/nullptr); - - for (Value scratchMemref : scratchMemrefs) { - b.create(scratchMemref); - } - - return success(); - } -}; - -struct LegalizeSortPass - : public impl::ThloLegalizeSortPassBase { - // Perform the lowering to MLIR control flow. - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = f.getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - mlir::ConversionTarget target(*ctx); - target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); - target.addIllegalOp(); - - if (failed(applyPartialConversion(f, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -} // namespace thlo -} // namespace mlir - -std::unique_ptr> -mlir::thlo::createLegalizeSortPass() { - return std::make_unique(); -} diff --git a/third_party/xla/xla/mlir_hlo/thlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/thlo/transforms/passes.h deleted file mode 100644 index 7ac8499f714742..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/transforms/passes.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_THLO_TRANSFORMS_PASSES_H -#define MLIR_HLO_THLO_TRANSFORMS_PASSES_H - -#include - -#include "mlir/Pass/Pass.h" - -namespace mlir { - -template -class OperationPass; - -namespace func { -class FuncOp; -} // namespace func - -namespace thlo { - -#define GEN_PASS_DECL_THLOLEGALIZESORTPASS -#include "thlo/transforms/thlo_passes.h.inc" - -/// Lowers sort to Arith, MemRef, and SCF -std::unique_ptr> createLegalizeSortPass(); - -#define GEN_PASS_REGISTRATION -#include "thlo/transforms/thlo_passes.h.inc" - -} // namespace thlo -} // namespace mlir - -#endif // MLIR_HLO_THLO_TRANSFORMS_PASSES_H diff --git a/third_party/xla/xla/mlir_hlo/thlo/transforms/thlo_passes.td b/third_party/xla/xla/mlir_hlo/thlo/transforms/thlo_passes.td deleted file mode 100644 index be0bdf43816dbe..00000000000000 --- a/third_party/xla/xla/mlir_hlo/thlo/transforms/thlo_passes.td +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/Pass/PassBase.td" - -def ThloLegalizeSortPass : Pass<"thlo-legalize-sort", "func::FuncOp"> { - let summary = - "Legalize from THLO sort with buffer semantics to SCF control flow."; - let constructor = "createLegalizeSortPass()"; - let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect", - "scf::SCFDialect"]; -} \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt index 0e6b331876d738..0f20301d11c76d 100644 --- a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -22,13 +22,9 @@ set(LIBS ${extension_libs} MLIROptLib - AllGmlStPasses AllMhloPasses - AllThloPasses DeallocationDialect DeallocationPasses - GmlStDialect - GmlStPasses LmhloDialect LmhloGPUDialect LmhloPasses @@ -43,7 +39,6 @@ add_llvm_executable(mlir-hlo-opt mlir-hlo-opt.cc DEPENDS MLIRLmhloPassIncGen MLIRMhloPassIncGen - MLIRThloPassIncGen LMHLOTransformsPassIncGen LMHLOGPUTransformsPassIncGen ) diff --git a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc index 1f2890c164da53..db46cfdfd34f5f 100644 --- a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc +++ b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc @@ -15,9 +15,6 @@ limitations under the License. #include "deallocation/IR/deallocation_ops.h" #include "deallocation/transforms/passes.h" -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/transforms/passes.h" -#include "gml_st/transforms/test_passes.h" #include "lhlo/IR/lhlo_ops.h" #include "lhlo/transforms/passes.h" #include "lhlo_gpu/IR/lhlo_gpu_ops.h" @@ -28,8 +25,6 @@ limitations under the License. #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "stablehlo/dialect/Register.h" -#include "thlo/IR/thlo_ops.h" -#include "thlo/transforms/passes.h" #include "transforms/gpu_passes.h" #include "transforms/passes.h" @@ -38,25 +33,10 @@ using namespace mlir; int main(int argc, char** argv) { registerAllPasses(); deallocation::registerDeallocationPasses(); - gml_st::registerGmlStPasses(); - gml_st::registerGmlStTestPasses(); hlo::registerLMHLOTransformsPasses(); lmhlo::registerAllLmhloPasses(); mhlo::registerAllMhloPasses(); registerLMHLOGPUTransformsPasses(); - thlo::registerAllThloPasses(); - - PassPipelineRegistration - gmlStCpuTilingPipeline("gml-st-cpu-tiling-pipeline", - "Tiles, fuses, vectorizes tileable ops for CPU", - gml_st::addCPUTilingPipeline); - - PassPipelineRegistration<> defaultGmlStCpuTilingPipeline( - "default-gml-st-cpu-tiling-pipeline", - "Tiles, fuses, vectorizes tileable ops for CPU with default parameters", - [](OpPassManager& pm) { - gml_st::addDefaultCPUTilingPipeline(pm, /*cpuName=*/""); - }); DialectRegistry registry; registerAllDialects(registry); @@ -64,10 +44,6 @@ int main(int argc, char** argv) { mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); registry.insert(); - - registerTestHloTransformDialectEraseSchedulePass(); - registerTestHloTransformDialectInterpreterPass(); + lmhlo_gpu::LmhloGpuDialect>(); return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt index 5187d9707ea74e..c011ce3a724224 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/transforms/CMakeLists.txt @@ -48,8 +48,6 @@ add_mlir_library(MLIRBufferTransforms LINK_LIBS PUBLIC ChloOps - GmlStBufferizableOpInterface - GmlStDialect MLIRGPUDialect MLIRHLOAnalysis MLIRIR @@ -62,8 +60,6 @@ add_mlir_library(MLIRBufferTransforms MLIRX86VectorDialect MLIRX86VectorTransforms MhloDialect - THLODialect - ThloBufferizableOpInterface ) add_mlir_library(MLIRHLOGPUTransforms @@ -77,7 +73,6 @@ add_mlir_library(MLIRHLOGPUTransforms Core LINK_LIBS PUBLIC - GmlStPasses MLIRArithTransforms MLIRGPUDialect MLIRHLOAnalysis diff --git a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc index 9c2d3e4b9c69f9..057e7289ca3068 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc @@ -21,8 +21,6 @@ limitations under the License. #include #include -#include "gml_st/IR/gml_st_ops.h" -#include "gml_st/interfaces/bufferizable_op_interface_impl.h" #include "lhlo/IR/lhlo_ops.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" @@ -72,8 +70,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/ChloOps.h" -#include "thlo/IR/thlo_ops.h" -#include "thlo/interfaces/bufferizable_op_interface_impl.h" #include "transforms/passes.h" #include "transforms/rewriters.h" @@ -138,13 +134,12 @@ struct ComputeOpAndFuncBufferizePass : public impl::ComputeOpAndFuncBufferizePassBase< ComputeOpAndFuncBufferizePass> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry + .insert(); linalg::registerBufferizableOpInterfaceExternalModels(registry); mhlo::registerBufferizableOpInterfaceExternalModels(registry); - thlo::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } @@ -159,7 +154,7 @@ struct ComputeOpAndFuncBufferizePass options.opFilter.allowDialect(); + vector::VectorDialect>(); if (failed(bufferization::bufferizeOp(getOperation(), options))) { signalPassFailure(); @@ -176,11 +171,10 @@ struct ComputeOpAndFuncBufferizePass RewritePatternSet patterns(&getContext()); auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect< + affine::AffineDialect, arith::ArithDialect, complex::ComplexDialect, + func::FuncDialect, lmhlo::LmhloDialect, math::MathDialect, + memref::MemRefDialect, tensor::TensorDialect, vector::VectorDialect>(); target.addLegalOp(); target.addIllegalDialect(); @@ -219,20 +213,18 @@ struct OneShotBufferizePass : public impl::OneShotBufferizeBase { // TODO(b/173201243): Move to tablegen. void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry + .insert(); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); mhlo::registerBufferizableOpInterfaceExternalModels(registry); - gml_st::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); - thlo::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } @@ -270,16 +262,15 @@ struct FinalBufferizePass public: void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry + .insert(); arith::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); - thlo::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); if (dialectsCallback) dialectsCallback(registry); } @@ -305,7 +296,7 @@ struct FinalBufferizePass options.opFilter.allowDialect< arith::ArithDialect, bufferization::BufferizationDialect, linalg::LinalgDialect, func::FuncDialect, shape::ShapeDialect, - tensor::TensorDialect, thlo::THLODialect, vector::VectorDialect>(); + tensor::TensorDialect, vector::VectorDialect>(); if (failed(bufferization::bufferizeOp(getOperation(), options))) { signalPassFailure(); return; @@ -325,8 +316,7 @@ struct FinalBufferizePass cf::ControlFlowDialect, complex::ComplexDialect, memref::MemRefDialect, func::FuncDialect, scf::SCFDialect, tensor::TensorDialect, affine::AffineDialect, shape::ShapeDialect, lmhlo::LmhloDialect, - linalg::LinalgDialect, math::MathDialect, thlo::THLODialect, - vector::VectorDialect>(); + linalg::LinalgDialect, math::MathDialect, vector::VectorDialect>(); target.addLegalOp(); target.addIllegalDialect(); diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 9f30361fa77afe..336069e8c9baf5 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -243,7 +243,6 @@ cc_library( "//xla/mlir/runtime/transforms:jit_compiler", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:gml_st_passes", "//xla/mlir_hlo:lhlo", "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:transforms_passes", @@ -477,9 +476,7 @@ cc_library( "//xla/mlir/backends/cpu/transforms:passes", "//xla/mlir/runtime/transforms:compiler", "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:gml_st_bufferizable_op_interface", "//xla/mlir_hlo:mhlo_passes", - "//xla/mlir_hlo:thlo_bufferizable_op_interface", "//xla/mlir_hlo:transforms_passes", "//xla/runtime:compiler", "@llvm-project//mlir:ArithTransforms", diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 383437acf4eeef..5125dd46256764 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -259,8 +259,7 @@ void LoadMLIRDialects(mlir::MLIRContext& context) { xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( llvm::Triple target_triple, llvm::StringRef cpu_name) { xla::cpu::HloXlaRuntimePipelineOptions options; - options.enable_tiling_and_fusion = - xla::GetDebugOptionsFromFlags().xla_cpu_enable_mlir_tiling_and_fusion(); + options.enable_tiling_and_fusion = false; if (xla::GetDebugOptionsFromFlags().xla_cpu_enable_custom_matmul_tiling()) { options.matmul_tile_sizes = { xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_m_dim(), diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index fe402498d76d1e..5d96212cfba1d2 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -48,12 +48,8 @@ limitations under the License. #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir_hlo/deallocation/transforms/passes.h" -#include "xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h" -#include "xla/mlir_hlo/gml_st/transforms/passes.h" #include "xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h" -#include "xla/mlir_hlo/thlo/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" #include "xla/status.h" #include "tsl/platform/errors.h" @@ -100,7 +96,6 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, } // Sparsification set up. pm.addNestedPass(mlir::createLinalgGeneralizationPass()); - pm.addNestedPass(mlir::gml_st::createRewriteFromElementsOpPass()); pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass()); pm.addPass(mlir::createSparsificationAndBufferizationPass( GetBufferizationOptions(new_deallocator), sparsification_options, @@ -184,21 +179,13 @@ static Status CreateHloXlaPipeline( pm.addNestedPass(mlir::mhlo::createHloCanonicalizeDotPass()); pm.addNestedPass(mlir::mhlo::createGroupReductionDimensionsPass()); pm.addNestedPass( - mlir::mhlo::createLegalizeMHLOToTHLOPass()); - pm.addNestedPass( - mlir::mhlo::createLegalizeHloToLinalgPass( - options.enable_tiling_and_fusion)); + mlir::mhlo::createLegalizeHloToLinalgPass()); // Lower index cast on tensors to tensor.generate. pm.addNestedPass(mlir::createLowerIndexCastPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); - // Tile tHLO ops to 1. - if (!options.enable_tiling_and_fusion) { - pm.addNestedPass(mlir::gml_st::createTileByOnePass()); - } - // Lower shape dialect to standard to enable linalg canonicalizations (e.g. // use linalg inputs instead of outputs for memref.dim operations). pm.addNestedPass(mlir::mhlo::createShapeSimplification()); @@ -212,17 +199,7 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass( - mlir::gml_st::createOptimizeLinalgOpsPass()); - if (options.enable_tiling_and_fusion) { - mlir::gml_st::GmlStCPUTilingOptions opts = - mlir::gml_st::getDefaultCPUPipelineOptions(options.cpu_name); - opts.matmulTileSizes = options.matmul_tile_sizes; - opts.inlineFusionClusters = false; - mlir::gml_st::addCPUTilingPipeline(pm, opts); - } else { - pm.addNestedPass( - mlir::createLinalgElementwiseOpFusionPass()); - } + mlir::createLinalgElementwiseOpFusionPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); pm.addPass(mlir::createConvertTensorToLinalgPass()); @@ -236,8 +213,6 @@ static Status CreateHloXlaPipeline( return tsl::errors::Internal("Failed to set up detensorize pass."); } pm.addNestedPass(std::move(detensorize)); - pm.addNestedPass(mlir::gml_st::createScalarizationPass()); - pm.addNestedPass(mlir::gml_st::createRewriteFromElementsOpPass()); pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass()); pm.addNestedPass( mlir::bufferization::createEmptyTensorToAllocTensorPass()); @@ -260,18 +235,18 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::hlo::createOneShotBufferizePass()); } pm.addNestedPass(createRewriteReallocToAllocPass()); - - if (options.enable_fusion_outlining) { - pm.addPass(mlir::gml_st::createFusionOutliningPass()); - pm.addPass(mlir::func::createDuplicateFunctionEliminationPass()); - } - pm.addNestedPass(mlir::gml_st::createInlineFusionClustersPass()); - pm.addNestedPass(mlir::createVectorizeCopyPass()); pm.addNestedPass(mlir::createNaiveCopyRemovalPass()); - // Handle framework specific requirements for buffers and then insert - // deallocations for temporary buffers. - pm.addNestedPass(mlir::createConvertLinalgToLoopsPass()); + + // This should be unified. It exists, because the async runtime tests expect + // parallel loops. + if (options.sparse_bufferization) { + pm.addNestedPass( + mlir::createConvertLinalgToLoopsPass()); + } else { + pm.addNestedPass( + mlir::createConvertLinalgToParallelLoopsPass()); + } pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); mlir::bufferization::BufferResultsToOutParamsOptions out_params_options; @@ -305,7 +280,6 @@ static Status CreateHloXlaPipeline( xla::cpu::createRemoveCopiesToOutParamsPass()); } } - pm.addNestedPass(mlir::thlo::createLegalizeSortPass()); // Specialize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat, // and immediately canonicalize to clean up not taken branches. @@ -318,9 +292,6 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - - pm.addNestedPass( - mlir::gml_st::createLowerVectorsPass(options.enable_avx2)); pm.addNestedPass(xla::cpu::createLegalizeI1VectorTransferOpsPass()); pm.addNestedPass( xla::cpu::createConvertXlaCpuMemRefElementCastToLLVMPass()); @@ -342,12 +313,10 @@ void RegisterHloXlaRuntimePipelineDialects(mlir::DialectRegistry& dialects) { mlir::arith::registerBufferizableOpInterfaceExternalModels(dialects); mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( dialects); - mlir::gml_st::registerBufferizableOpInterfaceExternalModels(dialects); mlir::linalg::registerBufferizableOpInterfaceExternalModels(dialects); mlir::linalg::registerTilingInterfaceExternalModels(dialects); mlir::mhlo::registerBufferizableOpInterfaceExternalModels(dialects); mlir::scf::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::thlo::registerBufferizableOpInterfaceExternalModels(dialects); mlir::shape::registerBufferizableOpInterfaceExternalModels(dialects); mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(dialects); mlir::tensor::registerBufferizableOpInterfaceExternalModels(dialects); From 8d07dddc09f07f9ac2b319c4211465a360e5180a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 25 Nov 2023 00:44:04 -0800 Subject: [PATCH 069/381] Internal Code Change PiperOrigin-RevId: 585216825 --- tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index 03558438ac6f6b..90ab3af857c542 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -1,7 +1,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", "//tensorflow/core/tfrt:__subpackages__", From 3382bddb8569be8f34c5f36776b53bede542012f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 25 Nov 2023 01:02:02 -0800 Subject: [PATCH 070/381] compat: Update forward compatibility horizon to 2023-11-25 PiperOrigin-RevId: 585219170 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 85ae86d88d993d..480bfeb99d19a8 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 25) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 0a95eea143dfd14ef7d17c4a91dc75cf4c628b3a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 25 Nov 2023 01:02:10 -0800 Subject: [PATCH 071/381] Update GraphDef version to 1691. PiperOrigin-RevId: 585219200 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 7f560250ce2e94..f1c3542b1641d0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1690 // Updated: 2023/11/24 +#define TF_GRAPH_DEF_VERSION 1691 // Updated: 2023/11/25 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From d308eb690530df95e8fb10dc7fb109d62181fc1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 25 Nov 2023 21:55:13 -0800 Subject: [PATCH 072/381] Adds unique_id property to SparseCoreStackedTableTrackable PiperOrigin-RevId: 585365755 --- tensorflow/python/tpu/tpu_embedding_v3_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils.py b/tensorflow/python/tpu/tpu_embedding_v3_utils.py index 08ed796c54b492..276731051be54f 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_utils.py @@ -170,6 +170,12 @@ def __init__(self, stacked_layouts, table_to_config): shape=variable_shape, dtype=dtypes.float32, ) + # TODO(b/312743130): This is a workaround. During checkpoint restoration + # optimizer expects the trackable to provide a `_unique_id` or equivalent. + # Remove this when the bug is fixed. + @property + def _unique_id(self): + return self.vars[self._stacked_layouts[0].table_name]._unique_id def _serialize_to_tensors(self) -> Any: return { From 6a5aec429f0158a6837fb558f13e0a11ccfcfc1d Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Sun, 26 Nov 2023 00:04:56 -0800 Subject: [PATCH 073/381] [XLA:GPU] NFC: use matchers in PriorityFusionTest. PiperOrigin-RevId: 585379423 --- third_party/xla/xla/service/gpu/BUILD | 2 ++ .../xla/service/gpu/priority_fusion_test.cc | 34 +++++++++++-------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index f07b8878c36b6f..d39a1181e080db 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2143,6 +2143,8 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", ], ) diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index fc03003a983dbe..16cbcdf9e0cd65 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include +#include #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -33,9 +34,15 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" namespace m = ::xla::match; +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + namespace xla { namespace gpu { @@ -51,8 +58,8 @@ class PriorityFusionTest : public HloTestBase { std::vector RunAndGetFusionKinds( absl::string_view hlo) { auto module = ParseAndReturnVerifiedModule(hlo).value(); - EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); - TF_CHECK_OK(module->RemoveUnusedComputations()); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(module->RemoveUnusedComputations(), IsOk()); std::vector kinds; for (auto computation : module->computations()) { if (!computation->FusionInstruction()) continue; @@ -90,7 +97,7 @@ TEST_F(PriorityFusionTest, FuseWithSharedArgument) { })") .value(); - EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Fusion())); @@ -237,10 +244,10 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { } )"; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), - ::testing::UnorderedElementsAre( - HloFusionAnalysis::EmitterFusionKind::kLoop, - HloFusionAnalysis::EmitterFusionKind::kReduction)); + EXPECT_THAT( + RunAndGetFusionKinds(kHlo), + UnorderedElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop, + HloFusionAnalysis::EmitterFusionKind::kReduction)); RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY @@ -317,9 +324,9 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { using Kind = HloFusionAnalysis::EmitterFusionKind; EXPECT_THAT(RunAndGetFusionKinds(kHlo), - ::testing::UnorderedElementsAre( - Kind::kLoop, Kind::kReduction, Kind::kReduction, - Kind::kTranspose, Kind::kTranspose, Kind::kTranspose)); + UnorderedElementsAre(Kind::kLoop, Kind::kReduction, + Kind::kReduction, Kind::kTranspose, + Kind::kTranspose, Kind::kTranspose)); } TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) { @@ -486,8 +493,7 @@ TEST_F(PriorityFusionTest, SingleTransposeFusion) { })"; using Kind = HloFusionAnalysis::EmitterFusionKind; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), - ::testing::ElementsAre(Kind::kTranspose)); + EXPECT_THAT(RunAndGetFusionKinds(kHlo), ElementsAre(Kind::kTranspose)); } TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) { @@ -516,7 +522,7 @@ TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) { ROOT add = s32[3,3] add(scatter, scatter) })"); - EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* fusion = nullptr; @@ -645,7 +651,7 @@ TEST_F(PriorityFusionTest, EpilogueFusionFails) { ROOT fusion = f32[28672]{0} fusion(f,p1), kind=kLoop, calls=%fused_computation.2 })"); - EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } } // namespace gpu From 41223388f0bb5f5e0258a90ae26f4a707e519877 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Sun, 26 Nov 2023 00:50:45 -0800 Subject: [PATCH 074/381] [XLA:GPU] NFC: make IsReadCoalesced() slightly easier to read. PiperOrigin-RevId: 585386076 --- .../gpu/model/gpu_performance_model.cc | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 30b35ec03e1a4a..1abceb893dcc56 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -268,22 +268,26 @@ bool IsReadCoalesced(const std::optional& fusion_analysis, const HloInstruction* consumer = nullptr) { if (!config.consider_coalescing) return true; - bool coalesced = (fusion_analysis && - fusion_analysis->GetEmitterFusionKind() == - HloFusionAnalysis::EmitterFusionKind::kTranspose) || - (!TransposesMinorDimension(producer) && - !(consumer && TransposesMinorDimension(consumer))); - - if (consumer) { - // Fusing two row reductions breaks coalescing. - coalesced &= (fusion_analysis && - fusion_analysis->GetEmitterFusionKind() != - HloFusionAnalysis::EmitterFusionKind::kReduction) || - !IsInputFusibleReduction(*producer) || - !IsInputFusibleReduction(*consumer); - } - - return coalesced; + auto analyzed_kind_or_reduction = + fusion_analysis ? fusion_analysis->GetEmitterFusionKind() + : HloFusionAnalysis::EmitterFusionKind::kReduction; + + // Transposing minor dimension breaks coalescing. + if (analyzed_kind_or_reduction != + HloFusionAnalysis::EmitterFusionKind::kTranspose) { + if (TransposesMinorDimension(producer)) return false; + if (consumer && TransposesMinorDimension(consumer)) return false; + } + + // Fusing two row reductions breaks coalescing. + if (analyzed_kind_or_reduction == + HloFusionAnalysis::EmitterFusionKind::kReduction && + IsInputFusibleReduction(*producer) && consumer && + IsInputFusibleReduction(*consumer)) { + return false; + } + + return true; } } // namespace From 8efd58cc9ef1f575b69e0348251c111f70f3199c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 26 Nov 2023 01:02:11 -0800 Subject: [PATCH 075/381] compat: Update forward compatibility horizon to 2023-11-26 PiperOrigin-RevId: 585387882 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 480bfeb99d19a8..9c937609ef9c3b 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 25) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 26) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 339d1da0acb87d63ee085d588367292c5a37a467 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 26 Nov 2023 01:02:16 -0800 Subject: [PATCH 076/381] Update GraphDef version to 1692. PiperOrigin-RevId: 585387902 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index f1c3542b1641d0..d55e052e604047 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1691 // Updated: 2023/11/25 +#define TF_GRAPH_DEF_VERSION 1692 // Updated: 2023/11/26 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 7b8dd63247d2d62254c3021c1f1188a60eb5eb85 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Sun, 26 Nov 2023 19:50:57 -0800 Subject: [PATCH 077/381] (1/N) Support fused ops in quantize_composite_functions. - Adds static bias fusion. - Also merges test_matmul_ptq_model for testing stablehlo fused operations. Follow up cl's will cover the following: - activation fusion - convolution handling - dynamic shapes handling PiperOrigin-RevId: 585513542 --- .../mlir/quantization/stablehlo/BUILD | 1 + .../passes/quantize_composite_functions.cc | 150 +++++++++++++++--- .../integration_test/quantize_model_test.py | 2 +- .../tests/quantize_composite_functions.mlir | 143 +++++++++++------ .../utils/stablehlo_type_utils_test.cc | 21 ++- 5 files changed, 233 insertions(+), 84 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 08e9b917c4c5a5..f8babf008f107c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -92,6 +92,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index cf0c44f779a9ae..5d391b3e858c16 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -20,9 +20,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -50,6 +52,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#define DEBUG_TYPE "quantize-composite-functions" + namespace mlir::quant::stablehlo { #define GEN_PASS_DEF_QUANTIZECOMPOSITEFUNCTIONSPASS @@ -58,7 +62,9 @@ namespace mlir::quant::stablehlo { namespace { using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::mlir::stablehlo::AddOp; using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; using ::mlir::stablehlo::UniformQuantizeOp; using ::tensorflow::quantization::RunPassesOnModuleOp; @@ -85,15 +91,57 @@ bool IsQuantizedTensorType(const Type type) { type.cast().getElementType().isa(); } +// Returns true if an op has adjacent bias or activation that can be fused +// together into the quantization function. +// TODO: b/307620428 - Consider using matchAndRewrite to check and apply +// patterns at the same time. Also add check for fusible activation or +// fusible patterns with dynamic shape. +bool HasFusibleQuantizationPattern(Operation& op) { + if (isa(op.getNextNode())) { + return true; + } + return false; +} + +// Returns dynamically broadcasted user op of an input op. Returns null if +// the op is used multiple times or the user op is not dynamically broadcasted. +// Dynamic shapes usually has the following pattern. In the example below, +// the input operand would be stablehlo.dot_general op, and return value would +// be stablehlo.add op. +// +// ``` +// %2 = stablehlo.dot_general(%0, %1) +// %3 = shape.shape_of %2 +// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 +// %5 = stablehlo.add %2, %4 +// ``` +Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { + if (!op.hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Target op is used multiple times and will not be checked " + "for dynamic shape case.\n"); + return nullptr; + } + Operation& shapeof_op = *op.getNextNode(); + if (!isa(shapeof_op)) { + return nullptr; + } + Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); + if (!isa(broadcast_in_dims_op)) { + return nullptr; + } + return broadcast_in_dims_op.getNextNode(); +} + // Checks if all inputs and outputs are quantized. -bool HasQuantizedOperandOrOutput(Operation* call_op) { +bool HasQuantizedOperandOrOutput(Operation& call_op) { SmallVector arg_types; - for (const Value arg : call_op->getOperands()) { + for (const Value arg : call_op.getOperands()) { arg_types.push_back(arg.getType()); } SmallVector output_types; - for (const Value output : call_op->getResults()) { + for (const Value output : call_op.getResults()) { output_types.push_back(output.getType()); } @@ -116,7 +164,7 @@ std::string GetQuantizedFunctionName(const StringRef func_name) { // 3. It should also have the `kEntryFuncAttrName` attribute, which points to // the function that `xla_call_module_op` represents. bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { - return HasQuantizedOperandOrOutput(xla_call_module_op) && + return HasQuantizedOperandOrOutput(*xla_call_module_op) && xla_call_module_op->hasAttr(kQuantTraitAttrName) && xla_call_module_op->hasAttr(kEntryFuncAttrName); } @@ -124,7 +172,7 @@ bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { // Returns the entry function, i.e. the callee of `xla_call_module_op`. func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, SymbolTable symbol_table) { - auto entry_function_symbol_ref = + const auto entry_function_symbol_ref = xla_call_module_op->getAttrOfType(kEntryFuncAttrName); // Don't match if there are no DotGeneralOp. @@ -162,6 +210,17 @@ void SetQuantizedFunctionType(PatternRewriter& rewriter, } } +// Creates a UniformQuantize op and sets it as return op. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + // An interface representing patterns that quantizes an entry function's body. // The entry function's signatures should have already been quantized at the // point of rewriting. @@ -184,54 +243,93 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern(MLIRContext& ctx) : ctx_(&ctx) {} + // Match for all dot_general op and check for possible fusions. LogicalResult match(func::FuncOp entry_func_op) const override { + // function must have input, filter, and optionally bias. auto& operations = entry_func_op.getBody().front().getOperations(); - return success(operations.size() == 2 && - isa(operations.front())); + if (operations.size() != 2 && operations.size() != 3) { + return failure(); + } + if (!isa(operations.front())) { + return failure(); + } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { + LLVM_DEBUG(llvm::dbgs() + << "Currently dot_general quantization only supports static " + " shapes.\n"); + return failure(); + } + return success(); } void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { // Update the output type of the dot_general op. - auto dot_general_op = *entry_func_op.getOps().begin(); + DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const Type input_type = entry_func_op.getArgumentTypes()[0]; - const Type rhs_type = entry_func_op.getArgumentTypes()[1]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; const Type func_result_type = entry_func_op.getResultTypes()[0]; const double input_scale = getElementTypeOrSelf(input_type) .cast() .getScale(); - const double rhs_scale = - getElementTypeOrSelf(rhs_type).cast().getScale(); + const double filter_scale = getElementTypeOrSelf(filter_type) + .cast() + .getScale(); + const double result_scale = input_scale * filter_scale; // Define the intermediate output type, which is an i32 quantized type. // This is intermediate because the final output type of the entry_func_op // should be an i8 quantized type. - const UniformQuantizedType output_quantized_element_type = + const UniformQuantizedType dot_general_quantized_element_type = CreateI32F32UniformQuantizedType(dot_general_op->getLoc(), *ctx_, - input_scale * rhs_scale, + result_scale, /*zero_point=*/0); Value dot_general_op_result = dot_general_op->getResult(0); - const auto dot_general_op_result_type = + auto dot_general_op_result_type = dot_general_op_result.getType().cast(); - const ArrayRef shape = dot_general_op_result_type.getShape(); + const ArrayRef dot_general_shape = + dot_general_op_result_type.getShape(); const TensorType new_dot_general_op_result_type = - dot_general_op_result_type.cloneWith(shape, - output_quantized_element_type); + dot_general_op_result_type.cloneWith( + dot_general_shape, dot_general_quantized_element_type); dot_general_op_result.setType(new_dot_general_op_result_type); - // Add i32 -> i8 requantization. rewriter.setInsertionPointAfter(dot_general_op); - auto uniform_quant_op = rewriter.create( - dot_general_op->getLoc(), func_result_type, - dot_general_op->getResults()); - auto return_op = - cast(entry_func_op.getBody().front().getTerminator()); - return_op.setOperand(0, uniform_quant_op); + Operation& next_op = *dot_general_op->getNextNode(); + + // If an op is used multiple times, do not apply quantization of fused + // patterns to prevent removal of dependee ops. + const bool should_quantize_without_fusion = + HasFusibleQuantizationPattern(*dot_general_op.getOperation()) && + !dot_general_op->hasOneUse(); + + // TODO: b/307620428 - Add support for dynamic shapes. + if (should_quantize_without_fusion || !isa(next_op)) { + // no bias + CreateAndReturnUniformQuantizeOp(rewriter, *dot_general_op, entry_func_op, + func_result_type); + return; + } + // bias fusion + Value bias_op = next_op.getOperand(1); + Value add_op_result = next_op.getResult(0); + const auto add_op_result_type = + add_op_result.getType().cast(); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, dot_general_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = rewriter.create(dot_general_op->getLoc(), + dot_general_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); } private: @@ -264,7 +362,7 @@ void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( MLIRContext& ctx, PatternRewriter& rewriter, TF::XlaCallModuleOp xla_call_module_op, const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - auto module_op = xla_call_module_op->getParentOfType(); + ModuleOp module_op = xla_call_module_op->getParentOfType(); SymbolTable symbol_table(module_op); func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); @@ -298,7 +396,7 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { : OpRewritePattern(&ctx) {} LogicalResult match(TF::XlaCallModuleOp op) const override { - auto module_op = op->getParentOfType(); + ModuleOp module_op = op->getParentOfType(); SymbolTable symbol_table(module_op); // Ignore unquantized ops. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 8a894708e0e6ed..6071ded7ed39c5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -50,7 +50,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): @parameterized.parameters( parameter_combinations([{ 'activation_fn': [None], - 'has_bias': [False], + 'has_bias': [True, False], 'batch_sizes': [([], []), ([10], [10]), ([2, 3], [2, 3])], }]) ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir index 97ea1f30be81ba..b82d8177c68537 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir @@ -5,10 +5,10 @@ module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. // expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @quantize_dot_general(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @quantize_dot_general(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -16,99 +16,140 @@ module attributes {tf_saved_model.semantics} { // calls the quantized entry function. // CHECK-LABEL: func.func private @quantize_dot_general -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3x3xi8>} : () -> tensor<3x3x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x!quant.uniform>, tensor<3x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } // Checks that the entry function is quantized for dot_general. Quantized // dot_general outputs an i32 quantized tensor, followed by requantization to // i8 quantized tensor. -// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<3x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> } // ----- -// Tests error when there are no corresponding entry function to quantize -// (@composite_dot_general_fn). +// Tests that fused bias pattern is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. // expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @error_when_no_entry_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> -// expected-error @+2 {{Failed to find a valid entry function}} -// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @quantize_dot_general_with_bias(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } + +// CHECK-LABEL: func.func private @quantize_dot_general_with_bias +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK: func.func private @quantized_dot_general_with_bias_fn(%[[ARG_2:.*]]: tensor<1x2x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_4:.*]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_with_bias_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_4]] : tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + } // ----- -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. +// Tests that fused bias pattern with dynamic shape is not quantized. +// TODO: b/307620428 - Add support for fused bias with dynamic shapes. module attributes {tf_saved_model.semantics} { - func.func private @not_quantized_without_stats(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @quantize_dot_general_with_bias_dynamic(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + // expected-error@+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values, but got}} + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor } -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. - -// CHECK-LABEL: func.func private @not_quantized_without_stats -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<3x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<2xindex> + %2 = stablehlo.dynamic_broadcast_in_dim %arg2, %1, dims = [1] : (tensor<3xf32>, tensor<2xindex>) -> tensor + %3 = stablehlo.add %0, %2 : tensor + return %3 : tensor } -// Check that the composite_dot_general_fn is untouched. - -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3xf32>, %[[ARG_3:.*]]: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] -// CHECK: return %[[DOT_GENERAL]] } // ----- -// Tests that a fusion pattern for dot_general is not yet supported. Further op -// coverage will be provided in the future. -// TODO - b/307620428: Increase op coverage to cover this test case. +// Tests error when there are no corresponding entry function to quantize +// (@composite_dot_general_fn). module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. // expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @dot_general_fn_fusion_not_quantized(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} // expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } +} - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %1 : tensor<1x3xf32> } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK-LABEL: func.func private @not_quantized_without_stats +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Check that the composite_dot_general_fn is untouched. + +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] +// CHECK: return %[[DOT_GENERAL]] } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc index 4dcdb637e1b430..a864ee556ff5af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc @@ -24,14 +24,23 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -TEST(UtilsTest, IsStablehloOp) { - MLIRContext ctx; - OpBuilder b(&ctx); - ctx.loadDialect(); +using ::testing::Test; +class StablehloTypeUtilsTest : public Test { + protected: + StablehloTypeUtilsTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(StablehloTypeUtilsTest, ValidStablehloOpSucceeds) { mlir::stablehlo::ConstantOp constant_op = - b.create(b.getUnknownLoc(), - b.getI32IntegerAttr(0)); + builder_.create( + builder_.getUnknownLoc(), builder_.getI32IntegerAttr(0)); EXPECT_TRUE(IsStablehloOp(constant_op)); constant_op->erase(); } From 007c07220187e8def0959d67c44618958a5a6147 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Sun, 26 Nov 2023 23:14:10 -0800 Subject: [PATCH 078/381] Add broadcast_in_dim, gather and slice to GetStableHloQuantScaleSpec PiperOrigin-RevId: 585541595 --- .../stablehlo/ops/stablehlo_op_quant_spec.cc | 11 ++- .../stablehlo/tests/quantize_same_scale.mlir | 81 +++++++++++++++++++ 2 files changed, 86 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index 39a2a96313079f..1a20e3d6d995f8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -69,12 +69,11 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { auto scale_spec = std::make_unique(); - // TODO - b/307619822: Add below ops to the spec with unit tests. - // mlir::stablehlo::GatherOp, mlir::stablehlo::SliceOp, - // mlir::stablehlo::BroadcastInDimOp - if (llvm::isa(op)) { + if (llvm::isa(op)) { scale_spec->has_same_scale_requirement = true; } return scale_spec; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir index ff294bb4eb7031..1c8e53134e2beb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir @@ -178,3 +178,84 @@ func.func @composite_and_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>) - %9 = "quantfork.dcast"(%8) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> return %9 : tensor<1x3xf32> } + +// ----- + +// CHECK-LABEL: composite_and_broadcast_in_dim +func.func @composite_and_broadcast_in_dim() -> tensor<2x3x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[BROADCAST:.*]] = "stablehlo.broadcast_in_dim"(%[[CALL]]) + // CHECK-SAME: (tensor<1x3x!quant.uniform>) -> tensor<2x3x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[BROADCAST]]) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = "stablehlo.broadcast_in_dim"(%2) { + broadcast_dimensions = dense<[2, 1]>: tensor<2xi64> + } : (tensor<1x3xf32>) -> tensor<2x3x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x3x2xf32>) -> tensor<2x3x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + return %5 : tensor<2x3x2xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_gather +// CHECK: %[[ARG0:.*]]: tensor<2x3x2xi64> +func.func @composite_and_gather(%arg0: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[CALL]], %[[ARG0]]) + // CHECK-SAME: (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi64>) -> tensor<2x3x2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[GATHER]]) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<3x4x2>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<3x4x2xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> + %3 = "stablehlo.gather"(%2, %arg0) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + return %5 : tensor<2x3x2x2xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_slice +func.func @composite_and_slice() -> tensor<2x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<3x4x!quant.uniform> + // CHECK: %[[SLICE:.*]] = "stablehlo.slice"(%[[CALL]]) + // CHECK-SAME: (tensor<3x4x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[SLICE]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<3x4>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<3x4xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> + %3 = "stablehlo.slice"(%2) { + start_indices = dense<[1, 2]> : tensor<2xi64>, + limit_indices = dense<[3, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : (tensor<3x4xf32>) -> tensor<2x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + return %5 : tensor<2x2xf32> +} From 491f1b269f7b074ef5efc90b381a071b07ad6092 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Mon, 27 Nov 2023 00:47:51 -0800 Subject: [PATCH 079/381] Allow gather and slice for MHLO uq to int lowering PiperOrigin-RevId: 585559331 --- .../bridge/convert_mhlo_quant_to_int.cc | 7 ++-- .../bridge/convert-mhlo-quant-to-int.mlir | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index be20c98d6c0f62..2c518b9b8fbc49 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -1269,9 +1269,10 @@ class ConvertGenericOp : public ConversionPattern { Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // This pattern only handle selected ops. - if (!isa(op)) { + if (!isa(op)) { return failure(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 9885d3b8f8d6fa..06d09a2a4341a0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -1912,3 +1912,45 @@ func.func @transpose( ) -> tensor<1x3x!quant.uniform> return %0 : tensor<1x3x!quant.uniform> } + +// ----- + +// CHECK-LABEL: func @gather +func.func @gather( + %arg0: tensor<3x4x2x!quant.uniform>, + %arg1: tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> { + // CHECK: mhlo.gather + // CHECK-SAME: (tensor<3x4x2xi8>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi8> + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : ( + tensor<3x4x2x!quant.uniform>, + tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> + return %0 : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @slice +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: mhlo.slice + // CHECK-SAME: (tensor<3x4xi8>) -> tensor<2x2xi8> + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 2]> : tensor<2xi64>, + limit_indices = dense<[3, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} From e0b7d2d0c0c29e6d68c66b0b8eab99a0ed063306 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 01:02:18 -0800 Subject: [PATCH 080/381] Update GraphDef version to 1693. PiperOrigin-RevId: 585563003 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d55e052e604047..8a73ca7c587313 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1692 // Updated: 2023/11/26 +#define TF_GRAPH_DEF_VERSION 1693 // Updated: 2023/11/27 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From bac823c67bf1982ea3bb8450aa173a0a94940f71 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 01:02:19 -0800 Subject: [PATCH 081/381] compat: Update forward compatibility horizon to 2023-11-27 PiperOrigin-RevId: 585563008 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 9c937609ef9c3b..2941ade58de207 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 27) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 499b55e5ca9653c4d904d289e4544908f5f6e138 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 27 Nov 2023 01:25:14 -0800 Subject: [PATCH 082/381] Make cudnn_fused_conv_rewriter_test work in OSS. PiperOrigin-RevId: 585568014 --- third_party/xla/xla/service/gpu/BUILD | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index d39a1181e080db..e455c4a3a931e0 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3755,19 +3755,22 @@ cc_library( ]), ) -xla_cc_test( +xla_test( name = "cudnn_fused_conv_rewriter_test", srcs = ["cudnn_fused_conv_rewriter_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - shard_count = 10, - tags = [ + backend_tags = { + "gpu": [ + "requires-gpu-nvidia", + "requires-gpu-sm80-only", + "noasan", + "nomsan", + ], + }, + backends = [ "gpu", - "no_oss", - "noasan", - "nomsan", - # This test runs some fusions that are only supported on Ampere+. - "requires-gpu-sm80", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + shard_count = 10, deps = [ ":backend_configs_cc", ":cublas_cudnn", From a6ef29d96abae7df50eb7ff92e970a9d46fb17a5 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 27 Nov 2023 02:00:12 -0800 Subject: [PATCH 083/381] NFC: Reduce number of function arguments in reduction.cc. There's still room for improvement here. Ideally I think we probably want something at the level of individual reductions. Maybe we can move the ReductionCodegenState class to reduction.cc and move all functions that currently take it as an argument there. PiperOrigin-RevId: 585575236 --- .../xla/xla/service/gpu/fusions/reduction.cc | 608 ++++++++++-------- .../xla/service/gpu/kernel_mapping_scheme.h | 23 +- 2 files changed, 344 insertions(+), 287 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 4be05dec779e25..f053aa46667cc7 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" @@ -71,7 +70,6 @@ limitations under the License. #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -89,16 +87,6 @@ using ReductionOutputMap = using ExtraOutputGensMap = ConstHloInstructionMap; -// For a row reduction, returns the number of rows we can process in parallel -// per warp. -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; -} - int GetNumOutputs(const Shape& shape) { if (shape.IsTuple()) { return shape.tuple_shapes_size(); @@ -106,6 +94,123 @@ int GetNumOutputs(const Shape& shape) { return 1; } +llvm::Type* GetIndexType(const HloFusionInstruction& fusion, + const TilingScheme& tiling_scheme, + llvm::IRBuilder<>* builder) { + return GetIndexTypeForKernel(&fusion, + tiling_scheme.GetNumThreadsPerBlockPhysical() * + tiling_scheme.GetNumberOfBlocksPhysical(), + builder); +} + +class ReductionEmitter { + public: + ReductionEmitter(HloFusionAnalysis& analysis, + IrEmitterContext& ir_emitter_context, + ElementalIrEmitter& elemental_emitter, + mlir::lmhlo::FusionOp fusion_op, + const HloFusionInstruction& fusion, + KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) + : analysis_(analysis), + ir_emitter_context_(ir_emitter_context), + elemental_emitter_(elemental_emitter), + fusion_op_(fusion_op), + fusion_(fusion), + kernel_cache_(kernel_cache), + builder_(builder), + index_ty_(GetIndexType( + fusion, analysis.GetReductionCodegenInfo()->GetTilingScheme(), + builder)) {} + + StatusOr Emit(); + + private: + StatusOr> BuildKernelThunkForFusion( + const LaunchDimensions& launch_dimensions, + absl::string_view discriminator, + std::function, + std::vector)> + kernel_builder_fn); + + StatusOr> BuildFusedInitializerThunk( + const HloInstruction* fusion_root, mlir::Value dest, + BufferAllocation::Slice dest_slice, int output_index); + + Status EmitIRForReduction( + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, + const Shape& input_shape); + + void EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx); + + void EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx); + + void EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp); + + void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, + const HloInstruction* root, int partial_result_idx, + absl::Span values); + + ReductionCodegenState GenerateReductionCodegenState( + absl::Span reduce_instr_index_group, + FusedIrEmitter& fused_emitter); + + llvm_ir::IrArray::Index GetOutputIndexForReduction( + int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, + const HloReduceInstruction* reduction, const HloInstruction* root, + int output_idx); + + void GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const llvm_ir::IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays); + + void MaybeEmitFenceForAMDGPU(); + void EmitSyncThreads(); + // For a row reduction, returns the number of rows we can process in parallel + // per warp. + int RowReductionGetRowsPerWarp() const { + int reduced_dimension_size = ReducedDimensionSize(); + if (WarpSize() % reduced_dimension_size != 0 || + reduced_dimension_size >= WarpSize()) { + return 1; + } + return WarpSize() / reduced_dimension_size; + } + + int ReducedDimensionSize() const { + return analysis_.GetReductionCodegenInfo() + ->GetTilingScheme() + .GetDimsInElems()[2]; + } + + HloFusionAnalysis& analysis_; + IrEmitterContext& ir_emitter_context_; + ElementalIrEmitter& elemental_emitter_; + mlir::lmhlo::FusionOp fusion_op_; + const HloFusionInstruction& fusion_; + KernelReuseCache& kernel_cache_; + llvm::IRBuilder<>* builder_; + llvm::Type* index_ty_; +}; + // Allocates a shared tile of given dimensions, applying scaling specified in // tilng_scheme as a major-most dimension to avoid collisions. llvm::GlobalVariable* AllocateShared( @@ -125,16 +230,16 @@ llvm::GlobalVariable* AllocateShared( // Creates accumulator alloca's, populates them with initial values, generates // __shared__ caches and returns the populated object. -ReductionCodegenState GenerateReductionCodegenState( - llvm::IRBuilder<>* builder, const HloFusionInstruction* fusion, - const ReductionCodegenInfo& reduction_info, +ReductionCodegenState ReductionEmitter::GenerateReductionCodegenState( absl::Span reduce_instr_index_group, FusedIrEmitter& fused_emitter) { - ReductionCodegenState reduction_codegen_state(reduction_info); - VLOG(10) << "Emit prologue for reduction: " << fusion->ToString(); + const ReductionCodegenInfo& reduction_info = + *analysis_.GetReductionCodegenInfo(); + ReductionCodegenState reduction_codegen_state; + VLOG(10) << "Emit prologue for reduction: " << fusion_.ToString(); for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - int num_partial_results = reduction_codegen_state.GetNumPartialResults(); + int num_partial_results = reduction_info.GetNumPartialResults(); for (int op_result_idx = 0; op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { Shape result_shape = reduce_hlo->shape().IsTuple() @@ -142,51 +247,50 @@ ReductionCodegenState GenerateReductionCodegenState( : reduce_hlo->shape(); llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - result_shape.element_type(), builder->GetInsertBlock()->getModule()); + result_shape.element_type(), builder_->GetInsertBlock()->getModule()); llvm::AllocaInst* reduction_input_address = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "reduction_input_address", builder); + element_type, "reduction_input_address", builder_); llvm::AllocaInst* partial_result_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( element_type, - /*element_count=*/builder->getInt32(num_partial_results), - "partial_reduction_result", builder); + /*element_count=*/builder_->getInt32(num_partial_results), + "partial_reduction_result", builder_); const HloInstruction* init_value = reduce_hlo->init_values()[op_result_idx]; // Initialize the partial result with the initial value of the reduction. llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( - *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) + *init_value))(llvm_ir::IrArray::Index(builder_->getInt32Ty())) .value(); for (int i = 0; i < num_partial_results; ++i) { - builder->CreateStore( - init_ir_value, builder->CreateInBoundsGEP( - partial_result_address->getAllocatedType(), - partial_result_address, {builder->getInt32(i)})); + builder_->CreateStore( + init_ir_value, + builder_->CreateInBoundsGEP( + partial_result_address->getAllocatedType(), + partial_result_address, {builder_->getInt32(i)})); } - const TilingScheme& tiling_scheme = - reduction_codegen_state.GetTilingScheme(); - int64_t num_threads_x = - tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { - if (reduction_codegen_state.IsRowReduction()) { + if (reduction_info.IsRowReduction()) { // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > - 1) { + if (RowReductionGetRowsPerWarp() > 1) { return nullptr; } // Allocate __shared__ // cache[num_partial_results][num_warps][scaling_factor]. CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(builder, tiling_scheme, element_type, + return AllocateShared(builder_, tiling_scheme, element_type, {num_partial_results, num_warps}, "shared_cache"); } else { + int64_t num_threads_x = + tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); // Allocate __shared__ // cache[num_threads][num_threads + 1], where // num_threads == num_threads_x == num_threads_y. The "+1" is used to @@ -196,7 +300,7 @@ ReductionCodegenState GenerateReductionCodegenState( // don't need that much cache: Only one result is live at a time.) CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); - return AllocateShared(builder, tiling_scheme, element_type, + return AllocateShared(builder_, tiling_scheme, element_type, {num_threads_x, num_threads_x + 1}, "shared_cache"); } @@ -214,22 +318,20 @@ ReductionCodegenState GenerateReductionCodegenState( return reduction_codegen_state; } -void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - auto* module = builder->GetInsertBlock()->getModule(); +void ReductionEmitter::MaybeEmitFenceForAMDGPU() { + auto* module = builder_->GetInsertBlock()->getModule(); if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( + ir_emitter_context_.rocm_compute_capability().gcn_arch_name().substr( 0, 6) == "gfx90a") { - builder->CreateFence( + builder_->CreateFence( llvm::AtomicOrdering::SequentiallyConsistent, - builder->getContext().getOrInsertSyncScopeID("workgroup")); + builder_->getContext().getOrInsertSyncScopeID("workgroup")); } } -void EmitSyncThreads(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); +void ReductionEmitter::EmitSyncThreads() { + MaybeEmitFenceForAMDGPU(); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_); } // Builds a thunk that calls a new or reused kernel for a fusion operation. @@ -253,31 +355,30 @@ void EmitSyncThreads(llvm::IRBuilder<>* builder, // ...)); // AddThunkToThunkSequence(std::move(thunk)) // ``` -StatusOr> BuildKernelThunkForFusion( - IrEmitterContext& ir_emitter_context, KernelReuseCache& kernel_cache, - const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation, +StatusOr> ReductionEmitter::BuildKernelThunkForFusion( const LaunchDimensions& launch_dimensions, absl::string_view discriminator, std::function, std::vector)> - kernel_builder_fn, - llvm::IRBuilder<>* builder) { - std::string suggested_kernel_name = std::string(fusion->name()); + kernel_builder_fn) { + const HloComputation* fused_computation = + fusion_.fused_instructions_computation(); + std::string suggested_kernel_name = std::string(fusion_.name()); - TF_ASSIGN_OR_RETURN(auto kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), fusion_op)); + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + ir_emitter_context_.emit_ir_from_hlo() + ? KernelArguments::Create(ir_emitter_context_.buffer_assignment(), + &fusion_) + : KernelArguments::Create(ir_emitter_context_.allocations(), + fusion_op_)); auto kernel_builder_status = OkStatus(); - auto [entry, cached] = kernel_cache.Get( + auto [entry, cached] = kernel_cache_.Get( fused_computation, kernel_arguments.args(), discriminator, [&]() -> KernelReuseCache::Entry { auto [kernel, input_arrays, output_arrays] = BuildKernelPrototype( - ir_emitter_context, suggested_kernel_name, kernel_arguments.args(), - fusion->operand_count(), launch_dimensions, builder); + ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), + fusion_.operand_count(), launch_dimensions, builder_); kernel_builder_status = kernel_builder_fn(input_arrays, output_arrays); return {kernel->getName().str(), launch_dimensions}; }); @@ -287,15 +388,15 @@ StatusOr> BuildKernelThunkForFusion( << entry.kernel_name; } - if (ir_emitter_context.emit_ir_from_hlo()) { + if (ir_emitter_context_.emit_ir_from_hlo()) { return std::make_unique( - fusion, entry.kernel_name, kernel_arguments.args(), launch_dimensions, + &fusion_, entry.kernel_name, kernel_arguments.args(), launch_dimensions, // Shared memory is allocated statically. /*shmem_bytes=*/0); } return std::make_unique( - fusion_op, entry.kernel_name, kernel_arguments.args(), launch_dimensions, + fusion_op_, entry.kernel_name, kernel_arguments.args(), launch_dimensions, // Shared memory is allocated statically. /*shmem_bytes=*/0); } @@ -340,13 +441,9 @@ Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, return OkStatus(); } -StatusOr> BuildFusedInitializerThunk( - IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, - mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation, +StatusOr> ReductionEmitter::BuildFusedInitializerThunk( const HloInstruction* fusion_root, mlir::Value dest, - BufferAllocation::Slice dest_slice, ElementalIrEmitter& elemental_emitter, - KernelReuseCache& kernel_cache, int output_index, - llvm::IRBuilder<>* builder) { + BufferAllocation::Slice dest_slice, int output_index) { const HloReduceInstruction* reduce = DynCast(fusion_root); TF_RET_CHECK(reduce); @@ -354,8 +451,8 @@ StatusOr> BuildFusedInitializerThunk( const HloInstruction* init_value = reduce->init_values()[0]; TF_ASSIGN_OR_RETURN( std::optional> constant_init_thunk, - BuildConstantInitializerThunk(ir_emitter_context, fusion_op, fusion_root, - init_value, dest, dest_slice)); + BuildConstantInitializerThunk(ir_emitter_context_, fusion_op_, + fusion_root, init_value, dest, dest_slice)); if (constant_init_thunk) { return *std::move(constant_init_thunk); } @@ -364,15 +461,18 @@ StatusOr> BuildFusedInitializerThunk( TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, CalculateLaunchDimensions( - dest_shape, ir_emitter_context.gpu_device_info())); + dest_shape, ir_emitter_context_.gpu_device_info())); + const HloComputation* fused_computation = + fusion_.fused_instructions_computation(); auto builder_fn = [&](std::vector inputs, std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter); + FusedIrEmitter fused_emitter(elemental_emitter_); for (int i = 0; i < fused_computation->num_parameters(); i++) { fused_emitter.BindGenerator( *fused_computation->parameter_instruction(i), - [builder, input = inputs[i]](llvm_ir::IrArray::Index index) { + [builder = builder_, + input = inputs[i]](llvm_ir::IrArray::Index index) { return input.EmitReadArrayElement(index, builder); }); } @@ -386,16 +486,15 @@ StatusOr> BuildFusedInitializerThunk( TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(*instr->operand(1))); TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]}, - launch_dimensions, builder) - .EmitLoop(fusion.name())); + launch_dimensions, builder_) + .EmitLoop(fusion_.name())); return OkStatus(); }; - return BuildKernelThunkForFusion( - ir_emitter_context, kernel_cache, &fusion, fusion_op, fused_computation, - launch_dimensions, - /*discriminator=*/ - absl::StrCat("init_", output_index), builder_fn, builder); + return BuildKernelThunkForFusion(launch_dimensions, + /*discriminator=*/ + absl::StrCat("init_", output_index), + builder_fn); } // Gets the output offset as calculated from thread_id.x (to be applied to the @@ -418,8 +517,7 @@ static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, // // Multiple partial_result_address inputs happen when doing variadic // reduction: each one should get the output value. -void EmitFullWarpShuffleDownLoopForReduce( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, +void ReductionEmitter::EmitFullWarpShuffleDownLoopForReduce( const HloComputation* reducer, absl::Span partial_result_addresses, int threads_per_block, int num_results_per_warp) { @@ -441,63 +539,62 @@ void EmitFullWarpShuffleDownLoopForReduce( partial_result_addresses) { int bit_width = llvm_ir::GetSizeInBits(element_type); llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", builder); + element_type, "result_from_other_lane", builder_); reduction_params.push_back(result_from_other_lane); // Bitcast cannot be applied to aggregate types (even packed ones), so // we bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffled_value_type = element_type->isStructTy() - ? builder->getIntNTy(bit_width) + ? builder_->getIntNTy(bit_width) : element_type; auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { return ptr; }; - llvm::Value* partial_result = builder->CreateLoad( + llvm::Value* partial_result = builder_->CreateLoad( shuffled_value_type, convert_pointer_for_shuffle(partial_result_address), "partial_reduction_result"); - builder->CreateStore( - EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), - builder), + builder_->CreateStore( + EmitFullWarpShuffleDown(partial_result, builder_->getInt32(distance), + builder_), convert_pointer_for_shuffle(result_from_other_lane)); } StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, + CallNestedComputationWithScalarAddrs(builder_, ir_emitter_context_, *reducer, reduction_params); TF_CHECK_OK(returned_scalars.status()); for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); + builder_->CreateStore(/*Val=*/returned_scalars->at(i), + /*Ptr=*/partial_result_addresses[i].first); } } } -llvm_ir::IrArray::Index GetOutputIndexForReduction( - llvm::IRBuilder<>* builder, int partial_result_idx, llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, +llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( + int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, const HloInstruction* root, int output_idx) { auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); + return llvm::ConstantInt::get(index_ty_, c); }; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; llvm_ir::IrArray::Index start_offset = [&] { llvm::Value* x_loc = thread_id_info.thread_id_x; llvm::Value* y_loc = thread_id_info.thread_id_y; - if (!reduction_codegen_state.IsRowReduction()) { + if (!reduction_info.IsRowReduction()) { std::swap(x_loc, y_loc); } llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, x_loc, index_ty, builder); + GetStartOffsetX(tiling_scheme, x_loc, index_ty_, builder_); return tiling_kernel_info.tile_origin - .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); + .AddOffsetToDim(y_loc, TilingScheme::DimY, builder_) + .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder_); }(); const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); @@ -509,8 +606,8 @@ llvm_ir::IrArray::Index GetOutputIndexForReduction( // the input shape with the dimensions being reduced moved. llvm::Value* untransposed_output_linear_address = [&] { const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( - constant(partial_result_idx), TilingScheme::DimX, builder); - if (reduction_codegen_state.IsRowReduction()) { + constant(partial_result_idx), TilingScheme::DimX, builder_); + if (reduction_info.IsRowReduction()) { // For row-reduction, y-coordinate determines which row we write into. return index[TilingScheme::DimY]; } @@ -519,8 +616,8 @@ llvm_ir::IrArray::Index GetOutputIndexForReduction( llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); llvm::Value* x_block_offset = - builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); + builder_->CreateMul(index[TilingScheme::DimZ], x_dim_size); + return builder_->CreateAdd(x_block_offset, index[TilingScheme::DimX]); }(); // A reduction is allowed to transpose its output. For example, suppose @@ -533,7 +630,7 @@ llvm_ir::IrArray::Index GetOutputIndexForReduction( // the correct output element. llvm_ir::IrArray::Index element_index( /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder); + reduction_kept_element_shape, builder_); const Shape& output_shape = !reduction->shape().IsTuple() ? reduction->shape() : reduction->shape().tuple_shapes(output_idx); @@ -546,7 +643,7 @@ llvm_ir::IrArray::Index GetOutputIndexForReduction( ShapeUtil::EqualIgnoringElementType(output_shape, root->shape())) ? output_index : output_index.SourceIndexOfBitcast(output_shape, root->shape(), - builder); + builder_); } llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, @@ -558,41 +655,36 @@ llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, name); } -void WriteReductionOutput(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, - const HloInstruction* root, int partial_result_idx, - const absl::Span values, - ElementalIrEmitter& elemental_emitter) { +void ReductionEmitter::WriteReductionOutput( + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx, const absl::Span values) { + const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); const HloComputation* reducer = reduction->to_apply(); for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { auto [output_ptr, type] = typed_ptr; llvm_ir::IrArray::Index output_index = GetOutputIndexForReduction( - builder, partial_result_idx, index_ty, reduction_codegen_state, - tiling_kernel_info, reduction, root, oidx); + partial_result_idx, tiling_kernel_info, reduction, root, oidx); llvm::Value* output_address = output_arrays.at(root)[oidx].EmitArrayElementAddress( - output_index, builder, "output_element_address"); - if (reduction_codegen_state.IsRaceFree()) { - FusedIrEmitter fused_emitter(elemental_emitter); - llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output"); + output_index, builder_, "output_element_address"); + if (reduction_info.IsRaceFree()) { + FusedIrEmitter fused_emitter(elemental_emitter_); + llvm::Value* loaded = builder_->CreateLoad(type, output_ptr, "output"); fused_emitter.BindGenerator( *reduction, [&](const llvm_ir::IrArray::Index& index) { return loaded; }); llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); llvm::Value* generated = *gen(output_index); - builder->CreateStore(generated, output_address); + builder_->CreateStore(generated, output_address); } else { CHECK_EQ(values.size(), 1); CHECK_EQ(reduction, root) << "output fusion is not allowed for racing reductions"; TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder, ir_emitter_context, *reducer, output_address, output_ptr, + builder_, ir_emitter_context_, *reducer, output_address, output_ptr, type)); } } @@ -600,59 +692,56 @@ void WriteReductionOutput(llvm::IRBuilder<>* builder, // `current_output`: the value the tile has calculated. // `output_address`: address where the output value has to be written. -void EmitReductionOutputForRowReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, +void ReductionEmitter::EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, const ReductionOutputMap& output_arrays, const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx, ElementalIrEmitter& elemental_emitter) { + int partial_result_idx) { const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); + return llvm::ConstantInt::get(index_ty_, c); }; auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); + return builder_->CreateICmpEQ(value, constant(0)); }; int num_outputs = reducer->num_parameters() / 2; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); absl::InlinedVector current_outputs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { const ReductionCodegenState::ReductionCalculationState& state = reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); current_outputs.push_back( - {builder->CreateInBoundsGEP( + {builder_->CreateInBoundsGEP( state.partial_result_address->getAllocatedType(), state.partial_result_address, {constant(partial_result_idx)}, "current_output"), state.partial_result_address->getAllocatedType()}); } - int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; - int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); + const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + int num_rows_per_warp = RowReductionGetRowsPerWarp(); EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(current_outputs), + reducer, absl::MakeSpan(current_outputs), tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); - KernelSupportLibrary ksl(builder); + KernelSupportLibrary ksl(builder_); llvm::Value* warp_id = - builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + builder_->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); auto emit_write_output = [&](llvm::Value* write_condition, const absl::Span values) { ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(builder, ir_emitter_context, index_ty, - reduction_codegen_state, tiling_kernel_info, - output_arrays, reduction, root, partial_result_idx, - values, elemental_emitter); + WriteReductionOutput(tiling_kernel_info, output_arrays, reduction, root, + partial_result_idx, values); }); }; if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); + llvm::Value* is_writing_thread = is_zero(builder_->CreateAnd( + thread_id_info.thread_id_x, constant(ReducedDimensionSize() - 1))); emit_write_output(is_writing_thread, current_outputs); return; } @@ -662,23 +751,24 @@ void EmitReductionOutputForRowReduction( const ReductionCodegenState::ReductionCalculationState& state = reduction_codegen_state.GetCalculationStateFor(reduction, oidx); llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, {constant(partial_result_idx), warp_id}); - builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - shmem_output_addr); + builder_, state.shared_cache, + {constant(partial_result_idx), warp_id}); + builder_->CreateStore(builder_->CreateLoad(current_outputs[oidx].second, + current_outputs[oidx].first), + shmem_output_addr); } }); // TODO(cheshire): Don't we want to sync it once for everything in the // output? Not once per each? - EmitSyncThreads(builder, ir_emitter_context); + EmitSyncThreads(); ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { absl::InlinedVector selected_values; for (int oidx = 0; oidx < num_outputs; oidx++) { const ReductionCodegenState::ReductionCalculationState& state = reduction_codegen_state.GetCalculationStateFor(reduction, oidx); llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, + builder_, state.shared_cache, {constant(partial_result_idx), thread_id_info.lane_id}); llvm::Type* element_type = @@ -686,18 +776,18 @@ void EmitReductionOutputForRowReduction( // Ensure initial value address is in generic, not scratch. llvm::Value* initial_value_addr = - CastSharedToGlobal(builder, + CastSharedToGlobal(builder_, llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", builder), + element_type, "initial_value_addr", builder_), element_type, /*name=*/""); - builder->CreateStore(state.initial_value, initial_value_addr); + builder_->CreateStore(state.initial_value, initial_value_addr); - llvm::Value* warp_exists = builder->CreateICmpULT( + llvm::Value* warp_exists = builder_->CreateICmpULT( thread_id_info.thread_id_x, constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / WarpSize())); - llvm::Value* selected_value = builder->CreateSelect( + llvm::Value* selected_value = builder_->CreateSelect( warp_exists, block_accum_addr, initial_value_addr); selected_values.push_back({selected_value, element_type}); @@ -710,7 +800,7 @@ void EmitReductionOutputForRowReduction( // also unnecessary and should be removed. if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(selected_values), + reducer, absl::MakeSpan(selected_values), tiling_scheme.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); } @@ -719,31 +809,31 @@ void EmitReductionOutputForRowReduction( } // Same arguments as EmitReductionOutputForRowReduction. -void EmitReductionOutputForColumnReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, +void ReductionEmitter::EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, const ReductionOutputMap& output_arrays, const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx, ElementalIrEmitter& elemental_emitter) { - KernelSupportLibrary ksl(builder); + int partial_result_idx) { + KernelSupportLibrary ksl(builder_); const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); + return llvm::ConstantInt::get(index_ty_, c); }; auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); + return builder_->CreateICmpEQ(value, constant(0)); }; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); int num_outputs = reducer->num_parameters() / 2; // Wait for reads from shmem in the last iteration to complete. (If this is // slow, we could "double-buffer" by having two shmem buffers and switching // between them.) if (partial_result_idx > 0) { - EmitSyncThreads(builder, ir_emitter_context); + EmitSyncThreads(); } // Store the transpose in shared memory. @@ -753,20 +843,20 @@ void EmitReductionOutputForColumnReduction( llvm::GlobalVariable* shared_cache = state.shared_cache; llvm::AddrSpaceCastInst* shmem_output_addr = llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, shared_cache, + builder_, shared_cache, {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, "shmem_output_address")); - llvm::Value* current_output = builder->CreateInBoundsGEP( + llvm::Value* current_output = builder_->CreateInBoundsGEP( state.partial_result_address->getAllocatedType(), state.partial_result_address, {constant(partial_result_idx)}, "current_output"); - llvm::Value* current_output_value = builder->CreateLoad( + llvm::Value* current_output_value = builder_->CreateLoad( state.partial_result_address->getAllocatedType(), current_output); - builder->CreateStore(current_output_value, shmem_output_addr); + builder_->CreateStore(current_output_value, shmem_output_addr); } - EmitSyncThreads(builder, ir_emitter_context); + EmitSyncThreads(); // Get transposed element from shared memory. absl::InlinedVector shmem_transposed_addrs; @@ -775,7 +865,7 @@ void EmitReductionOutputForColumnReduction( reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); llvm::AddrSpaceCastInst* shmem_transposed_addr = llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, + builder_, state.shared_cache, {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, "shmem_transposed_addr")); shmem_transposed_addrs.push_back( @@ -784,34 +874,32 @@ void EmitReductionOutputForColumnReduction( ->getResultElementType()}); } - EmitFullWarpShuffleDownLoopForReduce(builder, ir_emitter_context, reducer, + EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(shmem_transposed_addrs), tiling_scheme.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); // Some warps in the block are completely outside of the bound of the // tensor, so they should not write any output at all. - llvm::Value* has_output = builder->CreateAnd( - builder->CreateICmpULT( - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty, - builder), + llvm::Value* has_output = builder_->CreateAnd( + builder_->CreateICmpULT( + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty_, + builder_), tiling_kernel_info.output_tile_bounds[1]), - builder->CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); + builder_->CreateICmpULT(thread_id_info.thread_id_x, + tiling_kernel_info.output_tile_bounds[0])); ksl.If("reduction_write_output", - builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput( - builder, ir_emitter_context, index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, root, - partial_result_idx, shmem_transposed_addrs, elemental_emitter); + builder_->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + WriteReductionOutput(tiling_kernel_info, output_arrays, reduction, + root, partial_result_idx, + shmem_transposed_addrs); }); } // Generate a single element of the tile (update the accumulator state) for a // given reducer of index `i`. -void GenerateElementForReducer( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, +void ReductionEmitter::GenerateElementForReducer( const HloReduceInstruction* reduction, llvm::Value* partial_result_index, const ReductionCodegenState& codegen_state, const llvm_ir::IrArray::Index& index_without_linear, @@ -831,8 +919,8 @@ void GenerateElementForReducer( state.partial_result_address; llvm::Value* const input_ir_value = *state.input_gen( num_partial_results > 1 ? index_without_linear : input_index); - builder->CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = builder->CreateInBoundsGEP( + builder_->CreateStore(input_ir_value, input_address); + llvm::Value* partial_result_address = builder_->CreateInBoundsGEP( partial_reduction_result_address->getAllocatedType(), partial_reduction_result_address, {partial_result_index}); reduction_accumulators.push_back(partial_result_address); @@ -855,23 +943,21 @@ void GenerateElementForReducer( // those pointers, and we have returned values on the stack (as well // as pointers to them). StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, + CallNestedComputationWithScalarAddrs(builder_, ir_emitter_context_, *reducer, reduction_params); TF_CHECK_OK(returned_scalars.status()); for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); + builder_->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); } } // Emits code for reductions in the output_instructions. -Status EmitIRForReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const HloFusionInstruction* fusion, +Status ReductionEmitter::EmitIRForReduction( absl::Span instr_index_group, FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenInfo& reduction_info, const Shape& input_shape, - ElementalIrEmitter& elemental_emitter) { + const Shape& input_shape) { + const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); std::vector roots; std::vector heroes; ExtraOutputGensMap extra_output_gens; @@ -890,42 +976,36 @@ Status EmitIRForReduction( CHECK(!heroes.empty()) << " expect at least one reduce instructions."; const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); - llvm::Type* index_ty = - GetIndexTypeForKernel(fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumberOfBlocksPhysical(), - builder); - ReductionCodegenState codegen_state = GenerateReductionCodegenState( - builder, fusion, reduction_info, heroes, fused_emitter); + ReductionCodegenState codegen_state = + GenerateReductionCodegenState(heroes, fused_emitter); EmitTileElementFunction emit_reduction_element = [&](const TilingThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc) { llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( - index, input_shape, builder, - codegen_state.GetTilingScheme().GetDimsInElems()); + index, input_shape, builder_, + reduction_info.GetTilingScheme().GetDimsInElems()); llvm::Value* partial_result_index = - codegen_state.IsRowReduction() - ? builder->getInt32(0) - : builder->CreateSub( + reduction_info.IsRowReduction() + ? builder_->getInt32(0) + : builder_->CreateSub( x_loc, GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, - index_ty, builder)); + index_ty_, builder_)); // Clear the linear index field of the llvm_ir::IrArray::Index to enable // the use of GetElementPointer with array types. This enables the // vectorization of the computation for different partial results. Use // this index if 'num_partial_results > 1'. - int num_partial_results = codegen_state.GetNumPartialResults(); + int num_partial_results = reduction_info.GetNumPartialResults(); llvm_ir::IrArray::Index index_without_linear{ input_index.multidim(), input_shape, input_index.GetType()}; // Emit code to generate the input and perform the reduction computation // for each reduction instruction. for (const HloReduceInstruction* reduce : heroes) { - GenerateElementForReducer(builder, ir_emitter_context, reduce, - partial_result_index, codegen_state, + GenerateElementForReducer(reduce, partial_result_index, codegen_state, index_without_linear, input_index, num_partial_results, result_ir_arrays); } @@ -933,36 +1013,34 @@ Status EmitIRForReduction( // Emit code to generate the output for the non-reduction instructions // in the fusion, if any. TF_CHECK_OK(EmitExtraOutputsForReduce( - builder, input_shape, result_ir_arrays, input_index, reduction_info, - extra_output_gens)); + builder_, input_shape, result_ir_arrays, input_index, + reduction_info, extra_output_gens)); }; TF_ASSIGN_OR_RETURN( TilingKernelInfo tiling_kernel_info, - EmitTilingKernel(builder, tiling_scheme, index_ty, + EmitTilingKernel(builder_, tiling_scheme, index_ty_, [&](const TilingThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index, std::array tile_dimensions) { - EmitTile(builder, codegen_state.GetTilingScheme(), + EmitTile(builder_, reduction_info.GetTilingScheme(), index, thread_id_info, tile_dimensions, emit_reduction_element); })); - KernelSupportLibrary ksl(builder); + KernelSupportLibrary ksl(builder_); for (auto [reduce, root] : llvm::zip(heroes, roots)) { for (int partial_result_idx = 0; partial_result_idx < reduction_info.GetNumPartialResults(); ++partial_result_idx) { - if (codegen_state.IsRowReduction()) { - EmitReductionOutputForRowReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, root, partial_result_idx, - elemental_emitter); + if (reduction_info.IsRowReduction()) { + EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, + result_ir_arrays, reduce, root, + partial_result_idx); } else { - EmitReductionOutputForColumnReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, root, partial_result_idx, - elemental_emitter); + EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, + result_ir_arrays, reduce, root, + partial_result_idx); } } } @@ -970,20 +1048,15 @@ Status EmitIRForReduction( return OkStatus(); } -} // namespace - -StatusOr ReductionFusion::Emit( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { +StatusOr ReductionEmitter::Emit() { auto* reduction_codegen_info = analysis_.GetReductionCodegenInfo(); TF_ASSIGN_OR_RETURN(auto launch_dimensions, analysis_.GetLaunchDimensions()); FusionEmissionResult result; - VLOG(3) << "Launch dimensions of " << fusion.name() << ": " + VLOG(3) << "Launch dimensions of " << fusion_.name() << ": " << launch_dimensions.ToString(); const HloComputation* fused_computation = - fusion.fused_instructions_computation(); + fusion_.fused_instructions_computation(); if (!reduction_codegen_info->IsRaceFree()) { // We need to get the dest slice by traversing the slice assigned to // fusion, because instructions inside fusion don't have buffer assignment. @@ -1001,17 +1074,17 @@ StatusOr ReductionFusion::Emit( // Therefore we can get the ordered slices by calling ForEachSubshape on the // result shape. std::vector slices; - if (ir_emitter_context.emit_ir_from_hlo()) { + if (ir_emitter_context_.emit_ir_from_hlo()) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion.shape(), [&](const Shape& subshape, ShapeIndex index) { - if (!ShapeUtil::IsLeafIndex(fusion.shape(), index)) { + fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { + if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { return OkStatus(); } TF_ASSIGN_OR_RETURN( BufferAllocation::Slice slice, - ir_emitter_context.buffer_assignment().GetUniqueSlice(&fusion, - index)); + ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, + index)); slices.push_back(slice); return OkStatus(); })); @@ -1022,42 +1095,40 @@ StatusOr ReductionFusion::Emit( for (int i = 0; i < fusion_roots.size(); ++i) { const HloInstruction* fusion_root = fusion_roots[i]; - mlir::Value dest = ir_emitter_context.emit_ir_from_hlo() + mlir::Value dest = ir_emitter_context_.emit_ir_from_hlo() ? nullptr - : fusion_op.getOutputBuffers()[i]; + : fusion_op_.getOutputBuffers()[i]; BufferAllocation::Slice dest_slice; - if (ir_emitter_context.emit_ir_from_hlo()) { + if (ir_emitter_context_.emit_ir_from_hlo()) { dest_slice = slices[i]; } else { TF_ASSIGN_OR_RETURN( dest_slice, - GetAllocationSlice(dest, ir_emitter_context.allocations())); + GetAllocationSlice(dest, ir_emitter_context_.allocations())); } if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { TF_ASSIGN_OR_RETURN( result.thunks.emplace_back(), - BuildFusedInitializerThunk(ir_emitter_context, fusion, fusion_op, - fused_computation, fusion_root, dest, - dest_slice, elemental_emitter, - kernel_cache, i, builder)); + BuildFusedInitializerThunk(fusion_root, dest, dest_slice, i)); } } } auto builder_fn = [&, this](std::vector inputs, std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter); + FusedIrEmitter fused_emitter(elemental_emitter_); for (int i = 0; i < fused_computation->num_parameters(); i++) { HloInstruction* fused_operand = fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator(*fused_operand, - [builder, input = inputs[i], fused_operand]( - const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement( - index, builder, fused_operand->name()); - }); + fused_emitter.BindGenerator( + *fused_operand, + [builder = builder_, input = inputs[i], + fused_operand](const llvm_ir::IrArray::Index& index) { + return input.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); } // Get outputs. @@ -1071,7 +1142,7 @@ StatusOr ReductionFusion::Emit( ir_arrays_idx += get_num_results; } - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); + KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll); // Use raw block_id_y to select the i-th parallel reduction to run. Using // block_id_y instead of block_id_x simplifies the index calculation @@ -1083,17 +1154,15 @@ StatusOr ReductionFusion::Emit( reduction_codegen_info->GetReduceOperandShape(); llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder); + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_); llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), llvm::cast(raw_block_id_y)); for (int i = 0; i < instr_index_groups.size(); ++i) { TF_RETURN_IF_ERROR(ksl.IfWithStatus( absl::StrCat("reduce-group-", i), - builder->CreateICmpEQ(raw_block_id_y, builder->getInt32(i)), [&] { - return EmitIRForReduction(builder, ir_emitter_context, &fusion, - instr_index_groups[i], fused_emitter, - result_ir_arrays, *reduction_codegen_info, - reduce_operand_shape, elemental_emitter); + builder_->CreateICmpEQ(raw_block_id_y, builder_->getInt32(i)), [&] { + return EmitIRForReduction(instr_index_groups[i], fused_emitter, + result_ir_arrays, reduce_operand_shape); })); } @@ -1102,11 +1171,20 @@ StatusOr ReductionFusion::Emit( TF_ASSIGN_OR_RETURN( result.thunks.emplace_back(), - BuildKernelThunkForFusion(ir_emitter_context, kernel_cache, &fusion, - fusion_op, fused_computation, launch_dimensions, - "", builder_fn, builder)); + BuildKernelThunkForFusion(launch_dimensions, "", builder_fn)); return result; } +} // namespace + +StatusOr ReductionFusion::Emit( + IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, + mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, + KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { + return ReductionEmitter(analysis_, ir_emitter_context, elemental_emitter, + fusion_op, fusion, kernel_cache, builder) + .Emit(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index f7b51c42c6beaf..3120c9a6680003 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -193,11 +193,10 @@ class ReductionCodegenInfo { } int GetNumPartialResults() const { return num_partial_results_; } + bool IsRowReduction() const { return is_row_reduction_; } bool IsRaceFree() const { return is_race_free_; } private: - friend class ReductionCodegenState; - TilingScheme tiling_scheme_; int num_partial_results_; bool is_row_reduction_; @@ -216,24 +215,6 @@ class ReductionCodegenState { llvm_ir::ElementGenerator input_gen; }; - explicit ReductionCodegenState( - const ReductionCodegenInfo& reduction_codegen_info) - : reduction_codegen_info_(reduction_codegen_info) {} - - const TilingScheme& GetTilingScheme() const { - return reduction_codegen_info_.tiling_scheme_; - } - - int GetNumPartialResults() const { - return reduction_codegen_info_.num_partial_results_; - } - - bool IsRowReduction() const { - return reduction_codegen_info_.is_row_reduction_; - } - - bool IsRaceFree() const { return reduction_codegen_info_.IsRaceFree(); } - const ReductionCalculationState& GetCalculationStateFor( const HloInstruction* instruction, int operand_idx) const { const ReductionOpState& op_state = state_.at(instruction); @@ -250,8 +231,6 @@ class ReductionCodegenState { } private: - ReductionCodegenInfo reduction_codegen_info_; - // One state per reduction operand. using ReductionOpState = absl::InlinedVector; From 6b1d42257f01c4ad4960ea3b118ec19e2ad3b5e6 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 27 Nov 2023 03:15:53 -0800 Subject: [PATCH 084/381] Enable tests that require v100 in OSS. PiperOrigin-RevId: 585594284 --- third_party/xla/xla/service/gpu/BUILD | 40 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index e455c4a3a931e0..fb21f0c31f3e2c 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -577,7 +577,7 @@ xla_test( ], tags = [ "large", - "no_oss", + "no_oss", # requires-mem:16g tag doesn't work in open source "nomac", "requires-mem:16g", ], @@ -1564,21 +1564,25 @@ cc_library( ]), ) -xla_cc_test( +xla_test( name = "gemm_algorithm_picker_test", - srcs = ["gemm_algorithm_picker_test.cc"], - tags = [ + srcs = if_gpu_is_configured(["gemm_algorithm_picker_test.cc"]), + backend_tags = { + "gpu": [ + "requires-gpu-nvidia", + "noasan", + "nomsan", + "requires-gpu-sm70-only", + ], + }, + backends = [ "gpu", - "no_oss", - "noasan", - "nomsan", - "requires-gpu-sm70", + "gpu_v100", ], deps = [ ":backend_configs_cc", ":gemm_algorithm_picker", ":gemm_rewriter", - "//xla/service:gpu_plugin", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", @@ -1774,20 +1778,24 @@ cc_library( ]), ) -xla_cc_test( +xla_test( name = "conv_algorithm_picker_test", srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), - tags = [ + backend_tags = { + "gpu": [ + "requires-gpu-nvidia", + "noasan", + "nomsan", + "requires-gpu-sm70-only", + ], + }, + backends = [ "gpu", - "no_oss", - "noasan", - "nomsan", - "requires-gpu-sm70", + "gpu_v100", ], deps = [ ":conv_algorithm_picker", ":gpu_conv_rewriter", - "//xla/service:gpu_plugin", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:tuple_simplifier", From 116db986b33247628ddc299d4959273b86466023 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 27 Nov 2023 03:16:06 -0800 Subject: [PATCH 085/381] Move ReductionCodegenState to reduction.cc and rename to ReductionGroupEmitter. This class doesn't need to be public. Also, it encapsulates part of the state that's needed to generate groups of reductions, so we might as well move the functions there. PiperOrigin-RevId: 585594330 --- .../xla/xla/service/gpu/fusions/reduction.cc | 422 ++++++++++-------- .../xla/service/gpu/kernel_mapping_scheme.h | 33 -- 2 files changed, 229 insertions(+), 226 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index f053aa46667cc7..5ccc9f88d509e2 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -125,6 +125,8 @@ class ReductionEmitter { StatusOr Emit(); private: + friend class ReductionGroupEmitter; + StatusOr> BuildKernelThunkForFusion( const LaunchDimensions& launch_dimensions, absl::string_view discriminator, @@ -141,49 +143,9 @@ class ReductionEmitter { FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, const Shape& input_shape); - void EmitReductionOutputForRowReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx); - - void EmitReductionOutputForColumnReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx); - - void EmitFullWarpShuffleDownLoopForReduce( - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp); - - void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, - const HloInstruction* root, int partial_result_idx, - absl::Span values); - - ReductionCodegenState GenerateReductionCodegenState( - absl::Span reduce_instr_index_group, - FusedIrEmitter& fused_emitter); - - llvm_ir::IrArray::Index GetOutputIndexForReduction( - int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int output_idx); - - void GenerateElementForReducer( - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays); - void MaybeEmitFenceForAMDGPU(); void EmitSyncThreads(); + // For a row reduction, returns the number of rows we can process in parallel // per warp. int RowReductionGetRowsPerWarp() const { @@ -211,6 +173,84 @@ class ReductionEmitter { llvm::Type* index_ty_; }; +class ReductionGroupEmitter { + public: + struct ReductionCalculationState { + llvm::GlobalVariable* shared_cache; + llvm::Value* initial_value; + llvm::AllocaInst* partial_result_address; + llvm::AllocaInst* input_address; + llvm_ir::ElementGenerator input_gen; + }; + + ReductionGroupEmitter( + ReductionEmitter& reduction_emitter, + absl::Span reduce_instr_index_group, + const ReductionOutputMap& result_ir_arrays, + FusedIrEmitter& fused_emitter); + + const ReductionCalculationState& GetCalculationStateFor( + const HloInstruction* instruction, int operand_idx) const { + const ReductionOpState& op_state = state_.at(instruction); + CHECK_LT(operand_idx, op_state.size()); + return op_state[operand_idx]; + } + + void SetCalculationStateFor( + const ReductionCalculationState& calculation_state, + const HloInstruction* instruction, int operand_idx) { + ReductionOpState& op_state = state_[instruction]; + CHECK_EQ(operand_idx, op_state.size()); + op_state.push_back(calculation_state); + } + + void EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx) const; + + void EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const HloReduceInstruction* reduction, const HloInstruction* root, + int partial_result_idx) const; + + void EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp) const; + + void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info, + const HloReduceInstruction* reduction, + const HloInstruction* root, int partial_result_idx, + absl::Span values) const; + + llvm_ir::IrArray::Index GetOutputIndexForReduction( + int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, + const HloReduceInstruction* reduction, const HloInstruction* root, + int output_idx) const; + + void GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const llvm_ir::IrArray::Index& index_without_linear, + const llvm_ir::IrArray::Index& input_index, + int num_partial_results) const; + + Status EmitExtraOutputsForReduce( + const Shape& reduction_operand_shape, + const llvm_ir::IrArray::Index& index, + const ExtraOutputGensMap& extra_output_gens) const; + + private: + ReductionEmitter& reduction_emitter_; + const ReductionOutputMap& result_ir_arrays_; + + // One state per reduction operand. + using ReductionOpState = absl::InlinedVector; + + // HloInstruction -> operand_idx -> cache + absl::flat_hash_map state_; +}; + // Allocates a shared tile of given dimensions, applying scaling specified in // tilng_scheme as a major-most dimension to avoid collisions. llvm::GlobalVariable* AllocateShared( @@ -230,14 +270,18 @@ llvm::GlobalVariable* AllocateShared( // Creates accumulator alloca's, populates them with initial values, generates // __shared__ caches and returns the populated object. -ReductionCodegenState ReductionEmitter::GenerateReductionCodegenState( +ReductionGroupEmitter::ReductionGroupEmitter( + ReductionEmitter& reduction_emitter, absl::Span reduce_instr_index_group, - FusedIrEmitter& fused_emitter) { + const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter) + : reduction_emitter_(reduction_emitter), + result_ir_arrays_(result_ir_arrays) { const ReductionCodegenInfo& reduction_info = - *analysis_.GetReductionCodegenInfo(); - ReductionCodegenState reduction_codegen_state; - VLOG(10) << "Emit prologue for reduction: " << fusion_.ToString(); + *reduction_emitter_.analysis_.GetReductionCodegenInfo(); + VLOG(10) << "Emit prologue for reduction: " + << reduction_emitter_.fusion_.ToString(); + auto* builder = reduction_emitter_.builder_; for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { int num_partial_results = reduction_info.GetNumPartialResults(); for (int op_result_idx = 0; @@ -247,45 +291,44 @@ ReductionCodegenState ReductionEmitter::GenerateReductionCodegenState( : reduce_hlo->shape(); llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - result_shape.element_type(), builder_->GetInsertBlock()->getModule()); + result_shape.element_type(), builder->GetInsertBlock()->getModule()); llvm::AllocaInst* reduction_input_address = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "reduction_input_address", builder_); + element_type, "reduction_input_address", builder); llvm::AllocaInst* partial_result_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( element_type, - /*element_count=*/builder_->getInt32(num_partial_results), - "partial_reduction_result", builder_); + /*element_count=*/builder->getInt32(num_partial_results), + "partial_reduction_result", builder); const HloInstruction* init_value = reduce_hlo->init_values()[op_result_idx]; // Initialize the partial result with the initial value of the reduction. llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( - *init_value))(llvm_ir::IrArray::Index(builder_->getInt32Ty())) + *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) .value(); for (int i = 0; i < num_partial_results; ++i) { - builder_->CreateStore( - init_ir_value, - builder_->CreateInBoundsGEP( - partial_result_address->getAllocatedType(), - partial_result_address, {builder_->getInt32(i)})); + builder->CreateStore( + init_ir_value, builder->CreateInBoundsGEP( + partial_result_address->getAllocatedType(), + partial_result_address, {builder->getInt32(i)})); } const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { if (reduction_info.IsRowReduction()) { // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp() > 1) { + if (reduction_emitter_.RowReductionGetRowsPerWarp() > 1) { return nullptr; } // Allocate __shared__ // cache[num_partial_results][num_warps][scaling_factor]. CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(builder_, tiling_scheme, element_type, + return AllocateShared(builder, tiling_scheme, element_type, {num_partial_results, num_warps}, "shared_cache"); } else { @@ -300,7 +343,7 @@ ReductionCodegenState ReductionEmitter::GenerateReductionCodegenState( // don't need that much cache: Only one result is live at a time.) CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); - return AllocateShared(builder_, tiling_scheme, element_type, + return AllocateShared(builder, tiling_scheme, element_type, {num_threads_x, num_threads_x + 1}, "shared_cache"); } @@ -308,14 +351,12 @@ ReductionCodegenState ReductionEmitter::GenerateReductionCodegenState( llvm_ir::ElementGenerator input_gen = *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - reduction_codegen_state.SetCalculationStateFor( + SetCalculationStateFor( {shared_cache, init_ir_value, partial_result_address, reduction_input_address, input_gen}, reduce_hlo, op_result_idx); } } - - return reduction_codegen_state; } void ReductionEmitter::MaybeEmitFenceForAMDGPU() { @@ -401,16 +442,14 @@ StatusOr> ReductionEmitter::BuildKernelThunkForFusion( /*shmem_bytes=*/0); } -Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, - const Shape& reduction_operand_shape, - const ReductionOutputMap& result_ir_arrays, - const llvm_ir::IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - const ExtraOutputGensMap& extra_output_gens) { +Status ReductionGroupEmitter::EmitExtraOutputsForReduce( + const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index, + const ExtraOutputGensMap& extra_output_gens) const { if (extra_output_gens.empty()) { return OkStatus(); } + auto* builder = reduction_emitter_.builder_; // Compute all extra output values before writing them. This avoids // overwriting aliased input/output buffers before all reads occurred. std::vector> @@ -432,11 +471,12 @@ Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, } for (const auto& [instr, generator] : extra_output_ir_values) { - absl::Span result_ir = result_ir_arrays.at(instr); + absl::Span result_ir = result_ir_arrays_.at(instr); CHECK_EQ(result_ir.size(), 1); result_ir[0].EmitWriteArrayElement( get_index(instr), generator, builder, /*use_linear_index=*/ - reduction_info.GetNumPartialResults() == 1); + reduction_emitter_.analysis_.GetReductionCodegenInfo() + ->GetNumPartialResults() == 1); } return OkStatus(); } @@ -517,17 +557,17 @@ static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, // // Multiple partial_result_address inputs happen when doing variadic // reduction: each one should get the output value. -void ReductionEmitter::EmitFullWarpShuffleDownLoopForReduce( +void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( const HloComputation* reducer, absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp) { + int threads_per_block, int num_results_per_warp) const { // This only works when the block size is a multiple of 32 threads. - // We check this here as a mistake in the number of threads per // block is very hard to detect. CHECK_EQ(threads_per_block % 32, 0); CHECK_EQ(WarpSize() % num_results_per_warp, 0); + auto* builder = reduction_emitter_.builder_; for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { absl::InlinedVector reduction_params; @@ -539,48 +579,51 @@ void ReductionEmitter::EmitFullWarpShuffleDownLoopForReduce( partial_result_addresses) { int bit_width = llvm_ir::GetSizeInBits(element_type); llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", builder_); + element_type, "result_from_other_lane", builder); reduction_params.push_back(result_from_other_lane); // Bitcast cannot be applied to aggregate types (even packed ones), so // we bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffled_value_type = element_type->isStructTy() - ? builder_->getIntNTy(bit_width) + ? builder->getIntNTy(bit_width) : element_type; auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { return ptr; }; - llvm::Value* partial_result = builder_->CreateLoad( + llvm::Value* partial_result = builder->CreateLoad( shuffled_value_type, convert_pointer_for_shuffle(partial_result_address), "partial_reduction_result"); - builder_->CreateStore( - EmitFullWarpShuffleDown(partial_result, builder_->getInt32(distance), - builder_), + builder->CreateStore( + EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), + builder), convert_pointer_for_shuffle(result_from_other_lane)); } StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder_, ir_emitter_context_, - *reducer, reduction_params); + CallNestedComputationWithScalarAddrs( + builder, reduction_emitter_.ir_emitter_context_, *reducer, + reduction_params); TF_CHECK_OK(returned_scalars.status()); for (int i = 0; i < returned_scalars->size(); i++) { - builder_->CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); + builder->CreateStore(/*Val=*/returned_scalars->at(i), + /*Ptr=*/partial_result_addresses[i].first); } } } -llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( +llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction( int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, const HloInstruction* root, - int output_idx) { + int output_idx) const { auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty_, c); + return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); }; - const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + auto* builder = reduction_emitter_.builder_; + const auto& reduction_info = + *reduction_emitter_.analysis_.GetReductionCodegenInfo(); const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; @@ -590,11 +633,11 @@ llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( if (!reduction_info.IsRowReduction()) { std::swap(x_loc, y_loc); } - llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, x_loc, index_ty_, builder_); + llvm::Value* start_offset_x = GetStartOffsetX( + tiling_scheme, x_loc, reduction_emitter_.index_ty_, builder); return tiling_kernel_info.tile_origin - .AddOffsetToDim(y_loc, TilingScheme::DimY, builder_) - .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder_); + .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); }(); const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); @@ -606,7 +649,7 @@ llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( // the input shape with the dimensions being reduced moved. llvm::Value* untransposed_output_linear_address = [&] { const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( - constant(partial_result_idx), TilingScheme::DimX, builder_); + constant(partial_result_idx), TilingScheme::DimX, builder); if (reduction_info.IsRowReduction()) { // For row-reduction, y-coordinate determines which row we write into. return index[TilingScheme::DimY]; @@ -616,8 +659,8 @@ llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); llvm::Value* x_block_offset = - builder_->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder_->CreateAdd(x_block_offset, index[TilingScheme::DimX]); + builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); + return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); }(); // A reduction is allowed to transpose its output. For example, suppose @@ -630,7 +673,7 @@ llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( // the correct output element. llvm_ir::IrArray::Index element_index( /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder_); + reduction_kept_element_shape, builder); const Shape& output_shape = !reduction->shape().IsTuple() ? reduction->shape() : reduction->shape().tuple_shapes(output_idx); @@ -643,7 +686,7 @@ llvm_ir::IrArray::Index ReductionEmitter::GetOutputIndexForReduction( ShapeUtil::EqualIgnoringElementType(output_shape, root->shape())) ? output_index : output_index.SourceIndexOfBitcast(output_shape, root->shape(), - builder_); + builder); } llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, @@ -655,12 +698,13 @@ llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, name); } -void ReductionEmitter::WriteReductionOutput( +void ReductionGroupEmitter::WriteReductionOutput( const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx, const absl::Span values) { - const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + int partial_result_idx, const absl::Span values) const { + auto* builder = reduction_emitter_.builder_; + const auto& reduction_info = + *reduction_emitter_.analysis_.GetReductionCodegenInfo(); const HloComputation* reducer = reduction->to_apply(); for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { auto [output_ptr, type] = typed_ptr; @@ -668,107 +712,106 @@ void ReductionEmitter::WriteReductionOutput( partial_result_idx, tiling_kernel_info, reduction, root, oidx); llvm::Value* output_address = - output_arrays.at(root)[oidx].EmitArrayElementAddress( - output_index, builder_, "output_element_address"); + result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress( + output_index, builder, "output_element_address"); if (reduction_info.IsRaceFree()) { - FusedIrEmitter fused_emitter(elemental_emitter_); - llvm::Value* loaded = builder_->CreateLoad(type, output_ptr, "output"); + FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_); + llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output"); fused_emitter.BindGenerator( *reduction, [&](const llvm_ir::IrArray::Index& index) { return loaded; }); llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); llvm::Value* generated = *gen(output_index); - builder_->CreateStore(generated, output_address); + builder->CreateStore(generated, output_address); } else { CHECK_EQ(values.size(), 1); CHECK_EQ(reduction, root) << "output fusion is not allowed for racing reductions"; TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder_, ir_emitter_context_, *reducer, output_address, output_ptr, - type)); + builder, reduction_emitter_.ir_emitter_context_, *reducer, + output_address, output_ptr, type)); } } } // `current_output`: the value the tile has calculated. // `output_address`: address where the output value has to be written. -void ReductionEmitter::EmitReductionOutputForRowReduction( +void ReductionGroupEmitter::EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - const ReductionOutputMap& output_arrays, const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) { + int partial_result_idx) const { const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty_, c); + return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); }; + + auto* builder = reduction_emitter_.builder_; auto is_zero = [&](llvm::Value* value) { - return builder_->CreateICmpEQ(value, constant(0)); + return builder->CreateICmpEQ(value, constant(0)); }; int num_outputs = reducer->num_parameters() / 2; absl::InlinedVector current_outputs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + const ReductionGroupEmitter::ReductionCalculationState& state = + GetCalculationStateFor(reduction, output_idx); current_outputs.push_back( - {builder_->CreateInBoundsGEP( + {builder->CreateInBoundsGEP( state.partial_result_address->getAllocatedType(), state.partial_result_address, {constant(partial_result_idx)}, "current_output"), state.partial_result_address->getAllocatedType()}); } - const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + const auto& reduction_info = + *reduction_emitter_.analysis_.GetReductionCodegenInfo(); const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - int num_rows_per_warp = RowReductionGetRowsPerWarp(); + int num_rows_per_warp = reduction_emitter_.RowReductionGetRowsPerWarp(); EmitFullWarpShuffleDownLoopForReduce( reducer, absl::MakeSpan(current_outputs), tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); - KernelSupportLibrary ksl(builder_); + KernelSupportLibrary ksl(builder); llvm::Value* warp_id = - builder_->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); auto emit_write_output = [&](llvm::Value* write_condition, const absl::Span values) { ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(tiling_kernel_info, output_arrays, reduction, root, + WriteReductionOutput(tiling_kernel_info, reduction, root, partial_result_idx, values); }); }; if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder_->CreateAnd( - thread_id_info.thread_id_x, constant(ReducedDimensionSize() - 1))); + llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( + thread_id_info.thread_id_x, + constant(reduction_emitter_.ReducedDimensionSize() - 1))); emit_write_output(is_writing_thread, current_outputs); return; } ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + const auto& state = GetCalculationStateFor(reduction, oidx); llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - builder_, state.shared_cache, - {constant(partial_result_idx), warp_id}); - builder_->CreateStore(builder_->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - shmem_output_addr); + builder, state.shared_cache, {constant(partial_result_idx), warp_id}); + builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, + current_outputs[oidx].first), + shmem_output_addr); } }); // TODO(cheshire): Don't we want to sync it once for everything in the // output? Not once per each? - EmitSyncThreads(); + reduction_emitter_.EmitSyncThreads(); ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { absl::InlinedVector selected_values; for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + const auto& state = GetCalculationStateFor(reduction, oidx); llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - builder_, state.shared_cache, + builder, state.shared_cache, {constant(partial_result_idx), thread_id_info.lane_id}); llvm::Type* element_type = @@ -776,18 +819,18 @@ void ReductionEmitter::EmitReductionOutputForRowReduction( // Ensure initial value address is in generic, not scratch. llvm::Value* initial_value_addr = - CastSharedToGlobal(builder_, + CastSharedToGlobal(builder, llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", builder_), + element_type, "initial_value_addr", builder), element_type, /*name=*/""); - builder_->CreateStore(state.initial_value, initial_value_addr); + builder->CreateStore(state.initial_value, initial_value_addr); - llvm::Value* warp_exists = builder_->CreateICmpULT( + llvm::Value* warp_exists = builder->CreateICmpULT( thread_id_info.thread_id_x, constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / WarpSize())); - llvm::Value* selected_value = builder_->CreateSelect( + llvm::Value* selected_value = builder->CreateSelect( warp_exists, block_accum_addr, initial_value_addr); selected_values.push_back({selected_value, element_type}); @@ -809,23 +852,23 @@ void ReductionEmitter::EmitReductionOutputForRowReduction( } // Same arguments as EmitReductionOutputForRowReduction. -void ReductionEmitter::EmitReductionOutputForColumnReduction( +void ReductionGroupEmitter::EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - const ReductionOutputMap& output_arrays, const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) { - KernelSupportLibrary ksl(builder_); + int partial_result_idx) const { + auto* builder = reduction_emitter_.builder_; + KernelSupportLibrary ksl(builder); const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty_, c); + return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); }; auto is_zero = [&](llvm::Value* value) { - return builder_->CreateICmpEQ(value, constant(0)); + return builder->CreateICmpEQ(value, constant(0)); }; - const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); + const auto& reduction_info = + *reduction_emitter_.analysis_.GetReductionCodegenInfo(); const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); int num_outputs = reducer->num_parameters() / 2; @@ -833,39 +876,37 @@ void ReductionEmitter::EmitReductionOutputForColumnReduction( // slow, we could "double-buffer" by having two shmem buffers and switching // between them.) if (partial_result_idx > 0) { - EmitSyncThreads(); + reduction_emitter_.EmitSyncThreads(); } // Store the transpose in shared memory. for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + const auto& state = GetCalculationStateFor(reduction, output_idx); llvm::GlobalVariable* shared_cache = state.shared_cache; llvm::AddrSpaceCastInst* shmem_output_addr = llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder_, shared_cache, + builder, shared_cache, {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, "shmem_output_address")); - llvm::Value* current_output = builder_->CreateInBoundsGEP( + llvm::Value* current_output = builder->CreateInBoundsGEP( state.partial_result_address->getAllocatedType(), state.partial_result_address, {constant(partial_result_idx)}, "current_output"); - llvm::Value* current_output_value = builder_->CreateLoad( + llvm::Value* current_output_value = builder->CreateLoad( state.partial_result_address->getAllocatedType(), current_output); - builder_->CreateStore(current_output_value, shmem_output_addr); + builder->CreateStore(current_output_value, shmem_output_addr); } - EmitSyncThreads(); + reduction_emitter_.EmitSyncThreads(); // Get transposed element from shared memory. absl::InlinedVector shmem_transposed_addrs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + const auto& state = GetCalculationStateFor(reduction, output_idx); llvm::AddrSpaceCastInst* shmem_transposed_addr = llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder_, state.shared_cache, + builder, state.shared_cache, {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, "shmem_transposed_addr")); shmem_transposed_addrs.push_back( @@ -881,46 +922,43 @@ void ReductionEmitter::EmitReductionOutputForColumnReduction( // Some warps in the block are completely outside of the bound of the // tensor, so they should not write any output at all. - llvm::Value* has_output = builder_->CreateAnd( - builder_->CreateICmpULT( - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty_, - builder_), + llvm::Value* has_output = builder->CreateAnd( + builder->CreateICmpULT( + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, + reduction_emitter_.index_ty_, builder), tiling_kernel_info.output_tile_bounds[1]), - builder_->CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); + builder->CreateICmpULT(thread_id_info.thread_id_x, + tiling_kernel_info.output_tile_bounds[0])); ksl.If("reduction_write_output", - builder_->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(tiling_kernel_info, output_arrays, reduction, - root, partial_result_idx, - shmem_transposed_addrs); + builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + WriteReductionOutput(tiling_kernel_info, reduction, root, + partial_result_idx, shmem_transposed_addrs); }); } // Generate a single element of the tile (update the accumulator state) for a // given reducer of index `i`. -void ReductionEmitter::GenerateElementForReducer( +void ReductionGroupEmitter::GenerateElementForReducer( const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays) { + const llvm_ir::IrArray::Index& input_index, int num_partial_results) const { HloComputation* reducer = reduction->to_apply(); + auto* builder = reduction_emitter_.builder_; CHECK_EQ(reducer->num_parameters() % 2, 0); absl::InlinedVector reduction_accumulators; absl::InlinedVector reduction_input_value; for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - codegen_state.GetCalculationStateFor(reduction, red_idx); + const auto& state = GetCalculationStateFor(reduction, red_idx); llvm::AllocaInst* input_address = state.input_address; llvm::AllocaInst* partial_reduction_result_address = state.partial_result_address; llvm::Value* const input_ir_value = *state.input_gen( num_partial_results > 1 ? index_without_linear : input_index); - builder_->CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = builder_->CreateInBoundsGEP( + builder->CreateStore(input_ir_value, input_address); + llvm::Value* partial_result_address = builder->CreateInBoundsGEP( partial_reduction_result_address->getAllocatedType(), partial_reduction_result_address, {partial_result_index}); reduction_accumulators.push_back(partial_result_address); @@ -943,12 +981,13 @@ void ReductionEmitter::GenerateElementForReducer( // those pointers, and we have returned values on the stack (as well // as pointers to them). StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder_, ir_emitter_context_, - *reducer, reduction_params); + CallNestedComputationWithScalarAddrs( + builder, reduction_emitter_.ir_emitter_context_, *reducer, + reduction_params); TF_CHECK_OK(returned_scalars.status()); for (int i = 0; i < returned_scalars->size(); i++) { - builder_->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); + builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); } } @@ -976,8 +1015,8 @@ Status ReductionEmitter::EmitIRForReduction( CHECK(!heroes.empty()) << " expect at least one reduce instructions."; const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); - ReductionCodegenState codegen_state = - GenerateReductionCodegenState(heroes, fused_emitter); + ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays, + fused_emitter); EmitTileElementFunction emit_reduction_element = [&](const TilingThreadIdInfo& thread_id_info, @@ -1005,16 +1044,15 @@ Status ReductionEmitter::EmitIRForReduction( // Emit code to generate the input and perform the reduction computation // for each reduction instruction. for (const HloReduceInstruction* reduce : heroes) { - GenerateElementForReducer(reduce, partial_result_index, codegen_state, - index_without_linear, input_index, - num_partial_results, result_ir_arrays); + group_emitter.GenerateElementForReducer( + reduce, partial_result_index, index_without_linear, input_index, + num_partial_results); } // Emit code to generate the output for the non-reduction instructions // in the fusion, if any. - TF_CHECK_OK(EmitExtraOutputsForReduce( - builder_, input_shape, result_ir_arrays, input_index, - reduction_info, extra_output_gens)); + TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( + input_shape, input_index, extra_output_gens)); }; TF_ASSIGN_OR_RETURN( @@ -1034,13 +1072,11 @@ Status ReductionEmitter::EmitIRForReduction( partial_result_idx < reduction_info.GetNumPartialResults(); ++partial_result_idx) { if (reduction_info.IsRowReduction()) { - EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, - result_ir_arrays, reduce, root, - partial_result_idx); + group_emitter.EmitReductionOutputForRowReduction( + tiling_kernel_info, reduce, root, partial_result_idx); } else { - EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, - result_ir_arrays, reduce, root, - partial_result_idx); + group_emitter.EmitReductionOutputForColumnReduction( + tiling_kernel_info, reduce, root, partial_result_idx); } } } diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index 3120c9a6680003..92ad74b634c03e 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -205,39 +205,6 @@ class ReductionCodegenInfo { const HloInstruction* first_reduce_; }; -class ReductionCodegenState { - public: - struct ReductionCalculationState { - llvm::GlobalVariable* shared_cache; - llvm::Value* initial_value; - llvm::AllocaInst* partial_result_address; - llvm::AllocaInst* input_address; - llvm_ir::ElementGenerator input_gen; - }; - - const ReductionCalculationState& GetCalculationStateFor( - const HloInstruction* instruction, int operand_idx) const { - const ReductionOpState& op_state = state_.at(instruction); - CHECK_LT(operand_idx, op_state.size()); - return op_state[operand_idx]; - } - - void SetCalculationStateFor( - const ReductionCalculationState& calculation_state, - const HloInstruction* instruction, int operand_idx) { - ReductionOpState& op_state = state_[instruction]; - CHECK_EQ(operand_idx, op_state.size()); - op_state.push_back(calculation_state); - } - - private: - // One state per reduction operand. - using ReductionOpState = absl::InlinedVector; - - // HloInstruction -> operand_idx -> cache - absl::flat_hash_map state_; -}; - } // end namespace gpu } // end namespace xla From 9542da9a805631e507f216f65e682a52c1d3b8ab Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 27 Nov 2023 03:33:23 -0800 Subject: [PATCH 086/381] [XLA:GPU] Add can_fuse cache. Avoid unnecessary calls to can_fuse_ to save a significant amount of compile time. PiperOrigin-RevId: 585598642 --- .../xla/xla/service/gpu/priority_fusion.cc | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 491170ff67df00..06136e64333f9c 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -185,6 +185,9 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } + HloInstructionAdaptor fusion_adaptor(*fusion); + can_fuse_cache_.erase(fusion_adaptor); + fusion_analysis_cache_.Invalidate(*fusion); fusion_analysis_cache_.Invalidate(*original_producer); @@ -219,6 +222,9 @@ class GpuPriorityFusionQueue : public FusionQueue { continue; } producer_user_count_[operand] = operand->user_count(); + + HloInstructionAdaptor operand_adaptor(*operand); + can_fuse_cache_[operand_adaptor].erase(fusion_adaptor); to_update_priority_.insert(operand); } to_update_priority_.insert(fusion); @@ -314,14 +320,44 @@ class GpuPriorityFusionQueue : public FusionQueue { run_times.time_fused); } - FusionDecision CanFuseWithAllUsers(HloInstruction* producer) const { + FusionDecision CanFuseCached(HloInstruction* producer, + HloInstruction* consumer) { + HloInstructionAdaptor producer_adaptor(*producer); + HloInstructionAdaptor consumer_adaptor(*consumer); + + { + absl::MutexLock lock(&can_fuse_cache_mutex_); + auto& producer_cache = can_fuse_cache_[producer_adaptor]; + + auto it = producer_cache.find(consumer_adaptor); + if (it != producer_cache.end()) { + return it->second; + } + } + + auto fusion_decision = + can_fuse_(consumer, consumer->operand_index(producer)); + + // The lock is required, because writing to a flat_hash_map is not + // thread-safe even for different keys. We never call this computation + // concurrently for the same producer, so it's guaranteed that we don't + // override any value. + { + absl::MutexLock lock(&can_fuse_cache_mutex_); + can_fuse_cache_[producer_adaptor][consumer_adaptor] = fusion_decision; + } + + return fusion_decision; + } + + FusionDecision CanFuseWithAllUsers(HloInstruction* producer) { if (producer->users().size() == 0) { return "No users to fuse"; } FusionDecision result; for (const auto& user : producer->users()) { - if (auto fusion_decision = can_fuse_(user, user->operand_index(producer)); + if (auto fusion_decision = CanFuseCached(producer, user); !fusion_decision) { VLOG(10) << "Cannot fuse " << producer->name() << " with " << user->name() << ", because: " << fusion_decision.Explain(); @@ -376,6 +412,14 @@ class GpuPriorityFusionQueue : public FusionQueue { tsl::thread::ThreadPool* thread_pool_; HloFusionAnalysisCache& fusion_analysis_cache_; + + // Caches result of can_fuse for a (producer, consumer) pair. A cache entry is + // invalidated if producer or consumer is modified. + absl::flat_hash_map< + HloInstructionAdaptor, + absl::flat_hash_map> + can_fuse_cache_; + absl::Mutex can_fuse_cache_mutex_; }; } // namespace From d9b8e78ad1f70f71861884f302d3ffebd04b38aa Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Mon, 27 Nov 2023 04:22:02 -0800 Subject: [PATCH 087/381] Integrate LLVM at llvm/llvm-project@5e5a22caf88a Updates LLVM usage to match [5e5a22caf88a](https://github.com/llvm/llvm-project/commit/5e5a22caf88a) PiperOrigin-RevId: 585609371 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index d7f3a809369277..ad08dfe0605b8e 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "af7a1453526a88a0e242baf156244aa4ae42ae4b" - LLVM_SHA256 = "f9f75e4823c2f09a8141ab4db40ee2c79aef96017782a9338e26621ee547d3d5" + LLVM_COMMIT = "5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372" + LLVM_SHA256 = "9d9ae8ae30f6262ca0823493893398ea2ab6fbd49027e338e06ac7c25bb8caf4" tf_http_archive( name = name, From b144c46f2252cc3655b0992620d3428eb61ee96d Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 27 Nov 2023 04:44:40 -0800 Subject: [PATCH 088/381] [XLA:GPU] (NFC) Remove producer_user_count_. It's not longer used. The heuristic that used the user count was removed in cl/573215172. PiperOrigin-RevId: 585613833 --- third_party/xla/xla/service/gpu/priority_fusion.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 06136e64333f9c..79e6e20235c3f5 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -110,7 +110,6 @@ class GpuPriorityFusionQueue : public FusionQueue { std::make_pair(priority, instruction->unique_id()), instruction); CHECK(emplace_result.second); reverse_map_.emplace(instruction, emplace_result.first); - producer_user_count_[instruction] = instruction->user_count(); } } @@ -221,7 +220,6 @@ class GpuPriorityFusionQueue : public FusionQueue { if (!operand->IsFusible()) { continue; } - producer_user_count_[operand] = operand->user_count(); HloInstructionAdaptor operand_adaptor(*operand); can_fuse_cache_[operand_adaptor].erase(fusion_adaptor); @@ -270,7 +268,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // Removes data for the instruction. void RemoveInstruction(HloInstruction* instruction) override { to_update_priority_.erase(instruction); - producer_user_count_.erase(instruction); fusion_analysis_cache_.Invalidate(*instruction); auto reverse_it = reverse_map_.find(instruction); @@ -393,10 +390,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // and the producer is given as the consumer's operand index. CanFuseCallback can_fuse_; - // The user counts of producers, used to determine whether we update their - // priorities when fusion happens. - absl::flat_hash_map producer_user_count_; - // The set of producers whose priorities need to be updated. Their // priorities are changed because their neighbors got fused, but we delay // the priority updates until current_consumers_ becomes empty. This is to From c929b154111be1b0ad766f08bf3c732381ae5be6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 27 Nov 2023 05:09:27 -0800 Subject: [PATCH 089/381] [XLA:GPU][NFC] Clean up tile analysis tests to use `TF_ASSERT_OK_AND_ASSIGN`. Previously, the tests assigned the wrapped map to a variable before calling `ASSERT_IS_OK` on the resulting variable, which deviates from the recommended pattern. PiperOrigin-RevId: 585618767 --- .../service/gpu/model/tile_analysis_test.cc | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 917bea1851f3dc..7010fc24703dbb 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -69,17 +69,17 @@ class TileAnalysisTest : public HloTestBase { }; TEST_F(TileAnalysisTest, ElementwiseOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[10, 20] parameter(0) p1 = f32[10, 20] parameter(1) ROOT add0 = f32[10, 20] add(p0, p1) } - )"); - ASSERT_IS_OK(input_indexing_or); + )")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))), @@ -89,22 +89,23 @@ TEST_F(TileAnalysisTest, ElementwiseOp) { } TEST_F(TileAnalysisTest, BroadcastOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[20] parameter(0) ROOT bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m f { p0 = f32[100] parameter(0) @@ -116,10 +117,9 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { p1 = f32[100] parameter(1) ROOT fusion = f32[100] fusion(p0, p1), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); + )")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, UnorderedElementsAre( MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( "(d0) -> (d0)", std::vector{}))), @@ -128,7 +128,8 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { } TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m f { p0 = f32[1000, 1000] parameter(0) @@ -139,10 +140,9 @@ TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { p0 = f32[1000,1000] parameter(0) ROOT fusion = f32[1000,1000] fusion(p0), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); + )")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, UnorderedElementsAre( @@ -151,7 +151,8 @@ TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { } TEST_F(TileAnalysisTest, FusionExponentialDuplication) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4] parameter(0) @@ -163,10 +164,9 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { slice2.0 = f32[2] slice(add1), slice={[0:2]} slice2.1 = f32[2] slice(add1), slice={[1:3]} ROOT add2 = f32[2] add(slice2.0, slice2.1) - })"); - ASSERT_IS_OK(input_indexing_or); + })")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre( MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( "(d0) -> (d0)", std::vector{}))), @@ -175,7 +175,8 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { } TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -195,9 +196,8 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { p0_init = f32[] constant(-inf) ROOT fusion = f32[10] fusion(p0, p0_init), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", @@ -205,7 +205,8 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { } TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -225,16 +226,16 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { p0_init = f32[] constant(-inf) ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap("(d0, d1)[s0] -> (d0, s0)", std::vector{20}))))); } TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m f { p0 = f32[20, 10, 50] parameter(0) @@ -257,17 +258,17 @@ TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { p0 = f32[20, 10, 50] parameter(0) ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); + )")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d2, d0, d1)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -287,9 +288,8 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { p0_init = f32[] constant(-inf) ROOT fusion = f32[32] fusion(p0, p0_init), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", @@ -297,7 +297,8 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { } TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m f { p0 = f32[150, 64, 1024] parameter(0) @@ -310,10 +311,9 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { p0 = f32[150, 64, 1024] parameter(0) ROOT fusion = f32[7, 9, 24] fusion(p0), kind=kLoop, calls=f } - )"); - ASSERT_IS_OK(input_indexing_or); + )")); EXPECT_THAT( - input_indexing_or->operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", @@ -321,7 +321,8 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { } TEST_F(TileAnalysisTest, ReduceOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -334,9 +335,8 @@ TEST_F(TileAnalysisTest, ReduceOp) { ROOT reduce = f32[150, 10] reduce(p0, p0_init), dimensions={3, 1}, to_apply=max } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", @@ -393,15 +393,15 @@ TEST_F(TileAnalysisTest, VariadicReduceOp) { } TEST_F(TileAnalysisTest, ReverseOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[1, 17, 9, 9] parameter(0) ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)", @@ -409,16 +409,16 @@ TEST_F(TileAnalysisTest, ReverseOp) { } TEST_F(TileAnalysisTest, SliceOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[10, 20, 50] parameter(0) ROOT slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0), slice={[5:10:1], [3:20:7], [0:50:2]} } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", @@ -426,16 +426,16 @@ TEST_F(TileAnalysisTest, SliceOp) { } TEST_F(TileAnalysisTest, TransposeOp) { - auto input_indexing_or = GetIndexingMapsForEntryComputation(R"( + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { p0 = f16[1, 8, 1536, 512] parameter(0) ROOT transpose = f16[1, 8, 512, 1536]{2, 3, 1, 0} transpose(p0), dimensions={0, 1, 3, 2} } - )"); - ASSERT_IS_OK(input_indexing_or); - EXPECT_THAT(input_indexing_or->operand_indexing_maps, + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, d1, d3, d2)", @@ -443,7 +443,7 @@ TEST_F(TileAnalysisTest, TransposeOp) { } TEST_F(TileAnalysisTest, DotOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing_or, + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { @@ -455,7 +455,7 @@ TEST_F(TileAnalysisTest, DotOp) { } )")); EXPECT_THAT( - input_indexing_or.operand_indexing_maps, + input_indexing.operand_indexing_maps, ElementsAre( MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " From f7348e30271727dcda9672fe976467ae179c6d9a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 05:11:12 -0800 Subject: [PATCH 090/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/58f2ec4dc891dc0bc0815a8c2d1caf196bfc13d5. PiperOrigin-RevId: 585619228 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index ada8843278851c..e501613187629a 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4347953799d066962cb1897814de77c8e195499d" - TFRT_SHA256 = "26af1f500eab6aa22f47e05a36253faeee8786208d18e4f0ee385f9ac04f21bf" + TFRT_COMMIT = "58f2ec4dc891dc0bc0815a8c2d1caf196bfc13d5" + TFRT_SHA256 = "7525f0bb63fc3c0cf2df7ce1b09949510b42ffa669cc38ed83f6948a078d6633" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index ada8843278851c..e501613187629a 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4347953799d066962cb1897814de77c8e195499d" - TFRT_SHA256 = "26af1f500eab6aa22f47e05a36253faeee8786208d18e4f0ee385f9ac04f21bf" + TFRT_COMMIT = "58f2ec4dc891dc0bc0815a8c2d1caf196bfc13d5" + TFRT_SHA256 = "7525f0bb63fc3c0cf2df7ce1b09949510b42ffa669cc38ed83f6948a078d6633" tf_http_archive( name = "tf_runtime", From ff22936fb42e5e1f73951da1722494c2b0832954 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 27 Nov 2023 06:52:08 -0800 Subject: [PATCH 091/381] Add lint changes to ensure TPU and Non-TPU Bridge passes stay the same PiperOrigin-RevId: 585638550 --- tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc | 4 ++++ .../mlir/tf2xla/internal/clustering_bridge_passes.cc | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 24de1be6fe97dc..e0cb8685552e3f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -142,6 +142,10 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, } void CreateClusteringPipeline(OpPassManager &pm, llvm::StringRef module_name) { + // Since the internal bridge clustering passes are shared among TF1/TF2 + // TF2-only passes should go here. However, this should be very rare and + // new passes generally should go into the internal + // AddBridgeClusteringPipelinePasses. pm.addPass(mlir::TFTPU::CreateTPUValidateInputsPass()); pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index 2628d9f17b59cb..e59aaf81f7d0ac 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -34,6 +34,8 @@ namespace internal { using mlir::OpPassManager; using mlir::func::FuncOp; +// LINT.IfChange(tpu_bridge_passes) + // Adds Bridge clustering pipeline passes to the given pass_manager. Does not // run them. void AddBridgeClusteringPipelinePasses(OpPassManager& pm, @@ -169,9 +171,11 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } +// LINT.ThenChange(:non_tpu_bridge_passes) void NoCanonicalization(OpPassManager& pm) {} +// LINT.IfChange(non_tpu_bridge_passes) void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // The following ops must be preserved regardless of reachability. Ideally, // all graphs should have control dependencies to enforce this. @@ -247,6 +251,7 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } +// LINT.ThenChange(:tpu_bridge_passes) }; // namespace internal }; // namespace tf2xla From d1ad87b61c28834a48d84d709f21aeedbd9c6521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Mon, 27 Nov 2023 07:47:05 -0800 Subject: [PATCH 092/381] [XLA:GPU][NFC] Transform dimension propagation to the functional paradigm This intends to make it easier to implement new fusion strategies, such as "Separate multiple uses of nodes within one scope when they are incompatible in Triton GEMM fusion". "Renames": AnalyzeForFusion -> GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible RequireSupportedInstruction -> GetPropagatedDimOrdersAndRequirements HandleInstruction -> GetPropagatedDimOrders RequireSupportedDimOrder -> GetRequirementsIfSupportedOrder RequireSupportedDimOrders -> GetRequirementsIfSupportedOrders DimOrderUpdates -> DimOrdersAndReqs Notable logic changes: I split out the splittable_dimension_major_part_size from DotProperties to DotRequirements, because it's not really a property of the dot, but rather a requirement which can be imposed by the instructions of the fusion. I explicitly return an error if a dimension split would be needed for Softmax in GetRequirementsIfSupportedOrder. I don't check IsSupportedSplittableDimensionMajorPartSize in GetRequirementsIfSupportedOrder anymore, I just check that in CombineDimOrdersAndReqs after the propagation is done. PiperOrigin-RevId: 585650283 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gemm_rewriter_triton.cc | 727 ++++++++++-------- 2 files changed, 428 insertions(+), 300 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index fb21f0c31f3e2c..dd6e91db77af96 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1298,6 +1298,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 0eb2cc6cb0b542..5f4a28605c6b89 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" @@ -106,14 +108,14 @@ bool IsDistributiveOverAddition(const HloInstruction& hlo) { namespace { -FusionDecision RequireTritonFusibleConvert( - const HloInstruction* input, se::GpuComputeCapability gpu_version) { +FusionDecision IsConversionWorthFusing(const HloInstruction& input, + se::GpuComputeCapability gpu_version) { // TODO(b/266862494): Can pick up almost any // convert, but if it's reducing the data volume it should rather be fused // to the output of the producer kernel. However not all operations support // output fusion - then it should be fused here anyway! - if (ShapeUtil::ByteSizeOf(input->operand(0)->shape()) > - ShapeUtil::ByteSizeOf(input->shape())) { + if (ShapeUtil::ByteSizeOf(input.operand(0)->shape()) > + ShapeUtil::ByteSizeOf(input.shape())) { return "Narrowing conversion."; } return FusionDecision{}; @@ -248,11 +250,89 @@ using Fragment = DimensionOrder::Fragment; using Fragments = DimensionOrder::Fragments; using FragmentOrders = DimensionOrder::FragmentOrders; using DimOrderMap = absl::flat_hash_map; +using DimOrderMapOrError = std::variant; -struct DimOrderUpdates { - DimOrderMap map; - int64_t splittable_dimension_major_part_size = 0; +// This represents an invalid dimension index. +constexpr int kNoDimensionIndex = -1; +struct DotProperties { + // Index of dot dimension that can be split. + // Currently typically LHS non-contracting one. + const int splittable_dimension_index; +}; +struct SoftmaxProperties { + const int softmax_reduction_dimension; + const int softmax_batch_dimension; +}; +// HeroProperties depend only on the hero op and they don't change as we +// change the fusion. +using HeroProperties = std::variant; + +// A special value for splittable_dimension_major_part_size. +constexpr int kNoSplitRequirement = 1; +struct DotRequirements { + explicit DotRequirements(int64_t splittable_dimension_major_part_size) + : splittable_dimension_major_part_size( + splittable_dimension_major_part_size) { + CHECK_GE(splittable_dimension_major_part_size, 1); + } + // If not kNoSplitRequirement, then the major part size of the splittable + // dimension must be the given value. + int64_t splittable_dimension_major_part_size; }; +struct SoftmaxRequirements {}; +// Requirements can change depending on what we fuse. +using Requirements = std::variant; +using RequirementsOrError = std::variant; + +// The dimension orders and requirements resulting from propagating the +// dimension orders through an HLO. +struct DimOrdersAndReqs { + DimOrderMap dim_orders; + Requirements requirements; +}; +using DimOrdersAndReqsOrError = std::variant; + +using Int64OrError = std::variant; +Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { + if (a == b || b == kNoSplitRequirement) { + return a; + } + if (a == kNoSplitRequirement) { + return b; + } + return FusionDecision("Conflicting splits of splittable dimension"); +} + +RequirementsOrError CombineDotRequirements(DotRequirements a, + DotRequirements b) { + Int64OrError combined_size_req = + CombineSplitDimMajorPartSizeReqs(a.splittable_dimension_major_part_size, + b.splittable_dimension_major_part_size); + if (std::holds_alternative(combined_size_req)) { + return std::get(combined_size_req); + } + return DotRequirements(std::get(combined_size_req)); +} + +RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a, + SoftmaxRequirements b) { + // SoftmaxRequirements is an empty class for now. + return a; +} + +RequirementsOrError CombineRequirements(Requirements a, + RequirementsOrError b_or_error) { + if (std::holds_alternative(b_or_error)) { + return b_or_error; + } + const Requirements& b = std::get(b_or_error); + if (std::holds_alternative(b)) { + return CombineDotRequirements(std::get(a), + std::get(b)); + } + return CombineSoftmaxRequirements(std::get(a), + std::get(b)); +} TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { @@ -316,33 +396,12 @@ bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { enum class TransformDirection { kInputToOutput, kOutputToInput }; -using DimOrderUpdatesOrError = std::variant; using OldToNewHloMap = absl::flat_hash_map; class FusionContext { - struct DotProperties { - int splittable_dimension; - int64_t splittable_dimension_supported_major_part_size; - }; - struct SoftmaxProperties { - int softmax_reduction_dimension; - int softmax_batch_dimension; - }; - - explicit FusionContext(DotProperties properties) : properties_(properties) {} - - explicit FusionContext(SoftmaxProperties properties) - : properties_(properties) {} - - DimOrderUpdatesOrError HandleElementwise(const HloInstruction* hlo, - const DimOrderMap& dim_orders) const; - DimOrderUpdatesOrError HandleBitcast(const HloInstruction* hlo, - const DimOrderMap& dim_orders, - TransformDirection direction) const; - DimOrderUpdatesOrError HandleDimensionAlteringOp( - const HloInstruction* hlo, const DimOrderMap& dim_orders, - TransformDirection direction) const; + FusionContext(HeroProperties properties, Requirements requirements) + : properties_(properties), requirements_(requirements) {} public: // Create fusion context from a dot operand according to @@ -353,41 +412,52 @@ class FusionContext { // Create fusion context from dot's output. static FusionContext FromDotOutput( const HloInstruction& dot, int split_k, - int64_t splittable_dimension_supported_major_part_size); + int64_t splittable_dimension_major_part_size); static FusionContext FromSoftmaxRoot(const HloInstruction&); - DimOrderUpdatesOrError HandleInstruction(const HloInstruction* hlo, - const DimOrderMap& dim_orders, - TransformDirection direction) const; - - // Tells if the dimension order is supported by the triton emitters. - // Only the dimension indicated by SplittableDimensionIndex() can be split - // physically once by other dimensions. Other ones can be only split - // logically. All subdimensions within a dimension have to be ordered. - // Return major part of splittable dimension in split_dim_major_part if a - // supported split is detected. - FusionDecision RequireSupportedDimOrder(const DimensionOrder& order, - int64_t& split_dim_major_part) const; - // Apply RequireSupportedDimOrder() to all known dimension orders - // around `hlo`. - FusionDecision RequireSupportedDimOrders(const HloInstruction& hlo, - DimOrderUpdates& updates) const; - // Try to calculate transformations of dimensions defined by the - // instruction, then check that the resulting dimension orders are supported. - DimOrderUpdatesOrError RequireSupportedInstruction( + // If possible, propagates `src_dim_order` (describing one side of `hlo`) to + // the other side and returns those dim orders. + static DimOrderMapOrError GetPropagatedDimOrders( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties); + + // If the dimension order is supported by the triton emitters, this returns + // which requirements does this order impose on the fusion. + // + // All subdimensions within a dimension have to be ordered. + static RequirementsOrError GetRequirementsIfSupportedOrder( + const DimensionOrder& order, const HeroProperties& properties); + // Apply GetRequirementsIfSupportedOrder() to all known + // dimension orders around `hlo` and combine the result. + static RequirementsOrError GetRequirementsIfSupportedOrders( const HloInstruction& hlo, const DimOrderMap& dim_orders, - TransformDirection direction) const; - // Checks if the instruction is possible and profitable to fuse. - // If so tries to transform dim_order describing one side of `hlo` into - // description(s) of its other side if it is supported. - DimOrderUpdatesOrError AnalyzeForFusion( + const HeroProperties& properties); + // If fusing the instruction is possible then it propagates + // the `src_dim_order` (describing one side of `hlo`) to the other side and + // returns those dim orders and the requirements that they impose on the + // fusion. + static DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( + const HloInstruction& hlo, const DimensionOrder& src_dim_order, + TransformDirection direction, const HeroProperties& properties); + // If fusing the instruction is possible *and profitable* then it propagates + // the `src_dim_order` (describing one side of `hlo`) to the other side and + // returns those dim orders and the requirements that they impose on the + // fusion. + // + // `src_operand_index` must be set iff `transform_direction` is + // kInputToOutput. + static DimOrdersAndReqsOrError + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( const HloInstruction& hlo, TransformDirection transform_direction, - OldToNewHloMap& old_to_new_map, - se::GpuComputeCapability gpu_version) const; - // Add dimension orders from `updates` to `dim_orders_` and update the - // splittable dimension ratio if all of them are compatible. - bool MergeUpdates(const DimOrderUpdates& updates); + const std::optional& src_operand_index, + const DimensionOrder& src_dim_order, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties); + + // Add dimension orders from `update` to `dim_orders_` and update + // `requirements_` if all of them are compatible. + bool CombineDimOrdersAndReqs(const DimOrdersAndReqs& update); // Fuse an instruction with all its fusible inputs. // If an input is not fusible stop there and make a parameter of the new // fusion, otherwise put it onto stack and check its own inputs first. @@ -403,44 +473,27 @@ class FusionContext { const HloInstruction& origin, ConstHloInstructionSet& parameters, ConstHloInstructionMap& iter_specs); - // Index of dot dimension that can be split. - // Currently typically LHS non-contracting one. - int64_t SplittableDimensionIndex() const { - CHECK(std::holds_alternative(properties_)); - return std::get(properties_).splittable_dimension; - } - // Tells whether `size` major part of a dimension can be physically split. - bool IsSupportedSplittableDimensionMajorPartSize(const int64_t size) const { - CHECK_NE(size, 0); - CHECK(std::holds_alternative(properties_)); - // 0 means no specific size requirement. - return std::get(properties_) - .splittable_dimension_supported_major_part_size == 0 || - std::get(properties_) - .splittable_dimension_supported_major_part_size == size; - } - int SplittableDimensionMajorPartSize() const { - CHECK(std::holds_alternative(properties_)); - return std::get(properties_) - .splittable_dimension_supported_major_part_size; - } - const DimOrderMap& DimOrders() const { return dim_orders_; } - - private: - DimOrderUpdatesOrError AnalyzeForFusionImpl( - const HloInstruction& hlo, TransformDirection transform_direction, - OldToNewHloMap& old_to_new_map, const DimOrderMap& dim_orders, - se::GpuComputeCapability gpu_version) const; - bool SetSplittableDimensionMajorPartSize(const int64_t size) { - if (IsSupportedSplittableDimensionMajorPartSize(size)) { - std::get(properties_) - .splittable_dimension_supported_major_part_size = size; - return true; - } - return false; + int64_t splittable_dimension_major_part_size() const { + CHECK(std::holds_alternative(requirements_)); + return std::get(requirements_) + .splittable_dimension_major_part_size; } + const HeroProperties& hero_properties() const { return properties_; } + const DimOrderMap& dim_orders() const { return dim_orders_; } - std::variant properties_; + private: + static DimOrderMap GetPropagatedDimOrdersForElementwise( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order); + static DimOrderMapOrError GetPropagatedDimOrdersForBitcast( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties); + static DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties); + + const HeroProperties properties_; + Requirements requirements_; DimOrderMap dim_orders_; }; @@ -449,12 +502,12 @@ FusionContext FusionContext::FromDotOperand(const HloInstruction& dot, const int split_k) { // There can be either none or one split-K batch dimension. const int num_split_k_batch_dims = split_k > 1; - int split_k_dimension_index = -1; + int split_k_dimension_index = kNoDimensionIndex; if (split_k > 1) { split_k_dimension_index = ContractingDimensionIndex(dot, operand_number) - 1; } - int splittable_dimension_index = -1; + int splittable_dimension_index = kNoDimensionIndex; // LHS non-contracting dimension can be split if non-splitK batch is absent. if (operand_number == 0 && dot.dot_dimension_numbers().lhs_batch_dimensions_size() - @@ -463,9 +516,8 @@ FusionContext FusionContext::FromDotOperand(const HloInstruction& dot, splittable_dimension_index = NonContractingDimensionIndex(dot, operand_number); } - FusionContext context(FusionContext::DotProperties{ - splittable_dimension_index, - /*splittable_dimension_supported_major_size=*/0}); + FusionContext context(DotProperties{splittable_dimension_index}, + DotRequirements(kNoSplitRequirement)); context.dim_orders_[dot.operand(operand_number)] = DimensionOrder::FromDotOperandOrOutput(*dot.operand(operand_number), split_k_dimension_index); @@ -474,34 +526,35 @@ FusionContext FusionContext::FromDotOperand(const HloInstruction& dot, FusionContext FusionContext::FromDotOutput( const HloInstruction& dot, const int split_k, - const int64_t splittable_dimension_supported_major_part_size) { + const int64_t splittable_dimension_major_part_size) { // Allow non-contracting dimension originating from LHS to split if // this dimension is split at the output at the same ratio as // at the input. - int splittable_dimension_index = -1; - if (splittable_dimension_supported_major_part_size > 1) { + int splittable_dimension_index = kNoDimensionIndex; + if (splittable_dimension_major_part_size > 1) { // Split-K dimension is the first one in the output if present; // LHS non-contracting follows (batch is absent in this case). splittable_dimension_index = (split_k > 1) ? 1 : 0; } - FusionContext context(FusionContext::DotProperties{ - splittable_dimension_index, - splittable_dimension_supported_major_part_size}); + FusionContext context(DotProperties{splittable_dimension_index}, + DotRequirements(splittable_dimension_major_part_size)); context.dim_orders_[&dot] = DimensionOrder::FromDotOperandOrOutput(dot); return context; } FusionContext FusionContext::FromSoftmaxRoot(const HloInstruction& root) { - FusionContext context(FusionContext::SoftmaxProperties{ - DimensionOrder::kSoftmaxReductionDimension, - DimensionOrder::kSoftmaxBatchDimension}); + FusionContext context( + SoftmaxProperties{DimensionOrder::kSoftmaxReductionDimension, + DimensionOrder::kSoftmaxBatchDimension}, + SoftmaxRequirements{}); context.dim_orders_[&root] = DimensionOrder::FromSoftmaxRoot(root); return context; } -FusionDecision FusionContext::RequireSupportedDimOrder( - const DimensionOrder& order, int64_t& split_dim_major_part) const { +/*static*/ RequirementsOrError FusionContext::GetRequirementsIfSupportedOrder( + const DimensionOrder& order, const HeroProperties& properties) { VLOG(8) << order.ToString(); + int64_t split_dim_major_part = kNoSplitRequirement; const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder(); for (const auto& [dim_index, dim_fragments] : order.DimFragmentsOrders()) { CHECK(!dim_fragments.empty()); @@ -536,10 +589,17 @@ FusionDecision FusionContext::RequireSupportedDimOrder( ++group_counter; if (group_counter > 1) { - if (dim_index == SplittableDimensionIndex() && - IsSupportedSplittableDimensionMajorPartSize(grouped_size)) { + if (!std::holds_alternative(properties)) { + return "Splitting a dimension is not supported for Softmax."; + } + // Only the dimension indicated by `splittable_dimension_index` (if any) + // can be split physically once by other dimensions. Other ones can be + // only split logically. + const int splittable_dimension_index = + std::get(properties).splittable_dimension_index; + if (dim_index == splittable_dimension_index) { if (group_counter == 2) { - if (split_dim_major_part != 0 && + if (split_dim_major_part != kNoSplitRequirement && split_dim_major_part != grouped_size) { return "Conflicting splits of splittable dimension"; } @@ -556,68 +616,110 @@ FusionDecision FusionContext::RequireSupportedDimOrder( ++fragment_it; } } - return FusionDecision{}; + + if (std::holds_alternative(properties)) { + return DotRequirements(split_dim_major_part); + } + return SoftmaxRequirements{}; } -FusionDecision FusionContext::RequireSupportedDimOrders( - const HloInstruction& hlo, DimOrderUpdates& updates) const { - auto check_if_present = [&](const HloInstruction* instr) { - if (auto it = updates.map.find(instr); it != updates.map.end()) { - return RequireSupportedDimOrder( - it->second, updates.splittable_dimension_major_part_size); +/*static*/ RequirementsOrError FusionContext::GetRequirementsIfSupportedOrders( + const HloInstruction& hlo, const DimOrderMap& dim_orders, + const HeroProperties& properties) { + const Requirements empty_requirements = + std::holds_alternative(properties) + ? Requirements(DotRequirements(kNoSplitRequirement)) + : Requirements(SoftmaxRequirements{}); + auto get_requirements = + [&](const HloInstruction& instr) -> RequirementsOrError { + if (auto it = dim_orders.find(&instr); it != dim_orders.end()) { + return GetRequirementsIfSupportedOrder(it->second, properties); } - return FusionDecision{}; + return empty_requirements; }; + + Requirements requirements = empty_requirements; for (const HloInstruction* operand : hlo.operands()) { - if (auto result = check_if_present(operand); !result) { - return result; + RequirementsOrError requirements_or_error = + CombineRequirements(requirements, get_requirements(*operand)); + if (std::holds_alternative(requirements_or_error)) { + return requirements_or_error; } + requirements = std::get(requirements_or_error); } - return check_if_present(&hlo); + + return CombineRequirements(requirements, get_requirements(hlo)); } -DimOrderUpdatesOrError FusionContext::HandleElementwise( - const HloInstruction* hlo, const DimOrderMap& dim_orders) const { - // The output and all the input dimension orders of `hlo` have to be the same. - const HloInstruction* src = nullptr; - const DimensionOrder* src_dim_order; - // Try using the output as a reference if it's already described, otherwise - // scan through all operands. - if (auto it = dim_orders.find(hlo); it != dim_orders.cend()) { - src = it->first; - src_dim_order = &it->second; - } else { - for (const HloInstruction* operand : hlo->operands()) { - if (auto it = dim_orders.find(operand); it != dim_orders.cend()) { - src = it->first; - src_dim_order = &it->second; - break; - } +/*static*/ DimOrderMap FusionContext::GetPropagatedDimOrdersForElementwise( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order) { + if (direction == TransformDirection::kOutputToInput) { + DimOrderMap map; + for (const HloInstruction* operand : hlo.operands()) { + map.insert({operand, src_dim_order}); } - CHECK_NE(src, nullptr); + return map; + } + + DimOrderMap map; + map.insert({&hlo, src_dim_order}); + // TODO(tdanyluk): For now, the "input to output" direction of this function + // also returns the dim orders for the operands, not just the output. This is + // needed to propagate the dim order of one input to the other(s) when fusing + // elementwise ops to the output. Perhaps we can separate the "input to + // output" and "output to input" directions of that in a later CL. + for (const HloInstruction* operand : hlo.operands()) { + map.insert({operand, src_dim_order}); + } + return map; +} + +const HloInstruction& GetSourceHlo(const HloInstruction& hlo, + TransformDirection direction) { + CHECK_GE(hlo.operand_count(), 1); + + if (direction == TransformDirection::kOutputToInput) { + return hlo; + } + return *hlo.operand(0); +} + +using ConstInstructionVector = absl::InlinedVector; +ConstInstructionVector GetDestHlos(const HloInstruction& hlo, + TransformDirection direction) { + if (direction == TransformDirection::kInputToOutput) { + return {&hlo}; } - DimOrderUpdates result; - result.map.insert({hlo, DimensionOrder(*src_dim_order)}); - for (const HloInstruction* operand : hlo->operands()) { - result.map.insert({operand, DimensionOrder(dim_orders.at(src))}); + ConstInstructionVector hlos; + hlos.reserve(hlo.operands().size()); + for (const HloInstruction* operand : hlo.operands()) { + hlos.push_back(operand); + } + return hlos; +} + +const HloInstruction& GetDestHlo(const HloInstruction& hlo, + TransformDirection direction) { + CHECK_EQ(hlo.operand_count(), 1); + + if (direction == TransformDirection::kInputToOutput) { + return hlo; } - return result; + + return *hlo.operand(0); } -DimOrderUpdatesOrError FusionContext::HandleBitcast( - const HloInstruction* hlo, const DimOrderMap& dim_orders, - const TransformDirection direction) const { - const HloInstruction* src = - (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0); - const HloInstruction* dst = - (direction == TransformDirection::kOutputToInput) ? hlo->operand(0) : hlo; - const Shape& dst_shape = dst->shape(); - const Fragments& src_fragments_order = - dim_orders.at(src).TensorFragmentsOrder(); - DimOrderUpdates result; +/*static*/ DimOrderMapOrError FusionContext::GetPropagatedDimOrdersForBitcast( + const HloInstruction& hlo, const TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties) { + const HloInstruction& dst = GetDestHlo(hlo, direction); + const Shape& dst_shape = dst.shape(); + const Fragments& src_fragments_order = src_dim_order.TensorFragmentsOrder(); + DimOrderMap dst_dim_orders; DimensionOrder& dst_dim_order = - result.map.insert({dst, DimensionOrder()}).first->second; + dst_dim_orders.insert({&dst, DimensionOrder()}).first->second; Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); // Size of not yet assigned part of current target dimension. int64_t dst_remaining_size = 1; @@ -634,9 +736,9 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast( dst_fragments_order.push_back(fragment); src_to_dst[&*src_dim].push_back(dst_fragments_order.size() - 1); }; - if (std::holds_alternative(properties_) && + if (std::holds_alternative(properties) && src_dim->dst_dim_number() == - std::get(properties_).softmax_batch_dimension) { + std::get(properties).softmax_batch_dimension) { // Special handling for softmax batch dimension: allow arbitrary reshapes // on it because it's guaranteed by the construction of the fusion to have // no physical alterations like transposes. @@ -731,7 +833,7 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast( FragmentOrders& dst_dim_fragment_orders = dst_dim_order.DimFragmentsOrders(); for (const auto& [dim_index, dim_sequence] : - dim_orders.at(src).DimFragmentsOrders()) { + src_dim_order.DimFragmentsOrders()) { std::vector& dst = dst_dim_fragment_orders[dim_index]; dst.reserve(dim_sequence.size()); for (const int src : dim_sequence) { @@ -741,15 +843,16 @@ DimOrderUpdatesOrError FusionContext::HandleBitcast( } } - return result; + return dst_dim_orders; } // Handle copy, transpose, broadcast or reduce. // Common between them is that they alter the tensor dimensions or their order // and the way to handle layouts. -DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( - const HloInstruction* hlo, const DimOrderMap& dim_orders, - const TransformDirection direction) const { +/*static*/ DimOrderMapOrError +FusionContext::GetPropagatedDimOrdersForDimAlteringOp( + const HloInstruction& hlo, const TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties) { // Temporary storage for new fragments local to this function. // Please keep this as the first local variable of this function, with type // std::list to make sure that all pointers to elements of this remain valid @@ -757,13 +860,12 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( // unnecessarily big for a typical size of 1. std::list new_fragments; - const HloInstruction* src = - (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0); + const HloInstruction& src = GetSourceHlo(hlo, direction); // Note: copying instead of using a const reference because // some operations (slice) will modify fragment properties in-place. - Fragments src_fragments_order = dim_orders.at(src).TensorFragmentsOrder(); - if (hlo->opcode() == HloOpcode::kSlice && - ShapeUtil::IsEffectiveScalar(hlo->shape())) { + Fragments src_fragments_order = src_dim_order.TensorFragmentsOrder(); + if (hlo.opcode() == HloOpcode::kSlice && + ShapeUtil::IsEffectiveScalar(hlo.shape())) { return FusionDecision("Slice to scalar is not implemented yet."); } // Every HLO dimension can correspond to a group of subdimensions in @@ -772,10 +874,10 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( // Group subdimensions by iterating over them in the same order as over // full dimensions and matching by total size. std::vector> src_physical; - src_physical.reserve(src->shape().rank()); + src_physical.reserve(src.shape().rank()); auto src_fragment_it = src_fragments_order.begin(); - for (int64_t dim_index : src->shape().layout().minor_to_major()) { - const int64_t dim_size = src->shape().dimensions(dim_index); + for (int64_t dim_index : src.shape().layout().minor_to_major()) { + const int64_t dim_size = src.shape().dimensions(dim_index); int64_t subdim_size_accumulator = 1; std::vector subdim_group; do { @@ -792,21 +894,17 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( std::vector> src_logical; src_logical.resize(src_physical.size()); for (int i = 0; i < src_physical.size(); ++i) { - src_logical[src->shape().layout().minor_to_major(i)] = src_physical[i]; + src_logical[src.shape().layout().minor_to_major(i)] = src_physical[i]; } - HloInstruction::InstructionVector output; - output.push_back(const_cast(hlo)); - DimOrderUpdates result; - for (const HloInstruction* dst : - (direction == TransformDirection::kInputToOutput) ? output - : hlo->operands()) { + DimOrderMap dst_dim_orders; + for (const HloInstruction* dst : GetDestHlos(hlo, direction)) { DimensionOrder& dst_dim_order = - result.map.insert({dst, DimensionOrder()}).first->second; + dst_dim_orders.insert({dst, DimensionOrder()}).first->second; // Source logical -> destination logical. std::vector> dst_logical; - if (hlo->opcode() == HloOpcode::kTranspose) { - const auto* transpose = Cast(hlo); + if (hlo.opcode() == HloOpcode::kTranspose) { + const auto* transpose = Cast(&hlo); std::vector permutation(transpose->dimensions().cbegin(), transpose->dimensions().cend()); if (direction == TransformDirection::kInputToOutput) { @@ -816,18 +914,18 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( for (int i = 0; i < permutation.size(); ++i) { dst_logical[permutation[i]] = src_logical[i]; } - } else if (hlo->opcode() == HloOpcode::kBroadcast) { - const auto* broadcast = Cast(hlo); + } else if (hlo.opcode() == HloOpcode::kBroadcast) { + const auto* broadcast = Cast(&hlo); dst_logical.resize(broadcast->dimensions().size()); for (int i = 0; i < broadcast->dimensions().size(); ++i) { dst_logical[i] = src_logical[broadcast->dimensions()[i]]; } - } else if (hlo->opcode() == HloOpcode::kReduce) { + } else if (hlo.opcode() == HloOpcode::kReduce) { // Operand 1 (the neutral value) has to be a scalar. - if (dst != hlo && hlo->operand_index(dst) == 1) { + if (dst != &hlo && hlo.operand_index(dst) == 1) { continue; } - const auto* reduce = Cast(hlo); + const auto* reduce = Cast(&hlo); dst_logical.resize(src_logical.size() + reduce->dimensions().size()); if (reduce->dimensions().size() != 1) { return FusionDecision("Unsupported reduction."); @@ -838,23 +936,23 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( // softmax fusions with known patterns for now. Generally a reduction // should create a new tiled dimension. dst_logical[i] = {&new_fragments.emplace_back( - std::get(properties_) + std::get(properties) .softmax_reduction_dimension, reduce->operand(0)->shape().dimensions(i))}; } else { dst_logical[i] = src_logical[i]; } } - } else if (hlo->opcode() == HloOpcode::kCopy) { + } else if (hlo.opcode() == HloOpcode::kCopy) { // Copy preserves the logical shape, just permutes the layout. - CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape())); + CHECK(ShapeUtil::SameDimensions(src.shape(), dst->shape())); dst_logical = src_logical; - } else if (hlo->opcode() == HloOpcode::kPad) { + } else if (hlo.opcode() == HloOpcode::kPad) { // Operand 1 (the padding value) has to be a scalar. - if (dst != hlo && hlo->operand_index(dst) == 1) { + if (dst != &hlo && hlo.operand_index(dst) == 1) { continue; } - const auto* pad = Cast(hlo); + const auto* pad = Cast(&hlo); dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { // This only handles the padding added by @@ -883,8 +981,8 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( dst_logical[i] = {&new_fragments.back()}; } } - } else if (hlo->opcode() == HloOpcode::kSlice) { - const auto slice = Cast(hlo); + } else if (hlo.opcode() == HloOpcode::kSlice) { + const auto slice = Cast(&hlo); dst_logical.resize(src_logical.size()); for (int dim = 0; dim < src_logical.size(); ++dim) { dst_logical[dim] = src_logical[dim]; @@ -923,11 +1021,11 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( } } for (const auto& [dim_index, dim_sequence] : - dim_orders.at(src).DimFragmentsOrders()) { + src_dim_order.DimFragmentsOrders()) { for (const int fragment_number : dim_sequence) { const auto it = src_to_dst.find(&src_fragments_order[fragment_number]); if (it == src_to_dst.cend()) { - if (hlo->opcode() == HloOpcode::kBroadcast && + if (hlo.opcode() == HloOpcode::kBroadcast && src_fragments_order[fragment_number].full_size() > 1 && dim_numbers_present_in_dst.contains(dim_index)) { return FusionDecision("Unsupported broadcast"); @@ -938,55 +1036,63 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp( } } } - return result; + return dst_dim_orders; } // Infers DimensionOrders of all unknown sides (output, operands) // of `hlo` from the known ones. -DimOrderUpdatesOrError FusionContext::HandleInstruction( - const HloInstruction* hlo, const DimOrderMap& dim_orders, - const TransformDirection direction) const { - VLOG(7) << "Analyzing " << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter || - hlo_query::IsScalarConstant(hlo)) { - return DimOrderUpdates{}; - } else if (hlo->opcode() == HloOpcode::kTranspose || - hlo->opcode() == HloOpcode::kCopy) { - return HandleDimensionAlteringOp(hlo, dim_orders, direction); - } else if (hlo->opcode() == HloOpcode::kBroadcast) { +/*static*/ DimOrderMapOrError FusionContext::GetPropagatedDimOrders( + const HloInstruction& hlo, const TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties) { + VLOG(7) << "Analyzing " << hlo.ToString(); + if (hlo.opcode() == HloOpcode::kParameter || + hlo_query::IsScalarConstant(&hlo)) { + CHECK(direction == TransformDirection::kOutputToInput); + return DimOrderMap{}; + } else if (hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kCopy) { + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kBroadcast) { if (direction != TransformDirection::kOutputToInput) { return "Unsupported broadcast direction."; } - return HandleDimensionAlteringOp(hlo, dim_orders, direction); - } else if (hlo->opcode() == HloOpcode::kReduce) { - if (!std::holds_alternative(properties_)) { + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kReduce) { + if (!std::holds_alternative(properties)) { return "Reductions are not supported in GEMM fusions yet."; } if (direction != TransformDirection::kOutputToInput) { return "Unsupported direction of reduction."; } - return HandleDimensionAlteringOp(hlo, dim_orders, direction); - } else if (hlo->opcode() == HloOpcode::kPad) { + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kPad) { if (direction != TransformDirection::kOutputToInput) { return "Unsupported pad direction."; } - return HandleDimensionAlteringOp(hlo, dim_orders, direction); - } else if (hlo->operand_count() > 0 && + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.operand_count() > 0 && IsTritonSupportedElementwise( - hlo->opcode(), hlo->operand(0)->shape().element_type())) { - return HandleElementwise(hlo, dim_orders); - } else if (hlo->opcode() == HloOpcode::kBitcast) { - return HandleBitcast(hlo, dim_orders, direction); - } else if (hlo->opcode() == HloOpcode::kSlice) { + hlo.opcode(), hlo.operand(0)->shape().element_type())) { + return GetPropagatedDimOrdersForElementwise(hlo, direction, src_dim_order); + } else if (hlo.opcode() == HloOpcode::kBitcast) { + return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kSlice) { if (direction != TransformDirection::kOutputToInput) { return "Unsupported slice direction."; } - return HandleDimensionAlteringOp(hlo, dim_orders, direction); - } else if (hlo->opcode() == HloOpcode::kReshape) { - if (!ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape())) { + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kReshape) { + if (!ShapeUtil::ReshapeIsBitcast(hlo.operand(0)->shape(), hlo.shape())) { return "Non-bitcast reshape."; } - return HandleBitcast(hlo, dim_orders, direction); + return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, + properties); } return "Unimplemented instruction."; } @@ -1033,34 +1139,36 @@ bool IsOutputWorthFusing(const HloInstruction& hlo) { InputMinusOutputBytes(hlo) >= -kIoToleranceBytes; } -DimOrderUpdatesOrError FusionContext::RequireSupportedInstruction( - const HloInstruction& hlo, const DimOrderMap& dim_orders, - const TransformDirection transform_direction) const { - auto result = HandleInstruction(&hlo, dim_orders, transform_direction); - if (!std::holds_alternative(result)) { - return std::get(result); - } - - if (FusionDecision supported = - RequireSupportedDimOrders(hlo, std::get(result)); - !supported) { - return supported; - } - return std::get(result); +/*static*/ DimOrdersAndReqsOrError +FusionContext::GetPropagatedDimOrdersAndRequirements( + const HloInstruction& hlo, const DimensionOrder& src_dim_order, + TransformDirection direction, const HeroProperties& properties) { + DimOrderMapOrError propagated_dim_orders_or_error = + GetPropagatedDimOrders(hlo, direction, src_dim_order, properties); + if (std::holds_alternative(propagated_dim_orders_or_error)) { + return std::get(propagated_dim_orders_or_error); + } + DimOrderMap propagated_dim_orders = + std::move(std::get(propagated_dim_orders_or_error)); + RequirementsOrError requirements_or_error = + GetRequirementsIfSupportedOrders(hlo, propagated_dim_orders, properties); + if (std::holds_alternative(requirements_or_error)) { + return std::get(requirements_or_error); + } + return DimOrdersAndReqs{propagated_dim_orders, + std::get(requirements_or_error)}; } -DimOrderUpdatesOrError FusionContext::AnalyzeForFusion( - const HloInstruction& hlo, const TransformDirection transform_direction, - OldToNewHloMap& old_to_new_map, - const se::GpuComputeCapability gpu_version) const { - return AnalyzeForFusionImpl(hlo, transform_direction, old_to_new_map, - dim_orders_, gpu_version); -} +/*static*/ DimOrdersAndReqsOrError +FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + const HloInstruction& hlo, TransformDirection transform_direction, + const std::optional& src_operand_index, + const DimensionOrder& src_dim_order, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties) { + CHECK_EQ(transform_direction == TransformDirection::kInputToOutput, + src_operand_index.has_value()); -DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( - const HloInstruction& hlo, const TransformDirection transform_direction, - OldToNewHloMap& old_to_new_map, const DimOrderMap& dim_orders, - const se::GpuComputeCapability gpu_version) const { if (hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { return "Unsupported instruction."; @@ -1080,11 +1188,14 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( if (!IsTritonSupportedDataType(hlo.shape().element_type(), gpu_version)) { return "Unsupported output data type."; } - DimOrderUpdatesOrError result = - RequireSupportedInstruction(hlo, dim_orders, transform_direction); - if (!std::holds_alternative(result)) { - return result; + DimOrdersAndReqsOrError result_or_error = + GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, + transform_direction, properties); + if (!std::holds_alternative(result_or_error)) { + return result_or_error; } + DimOrdersAndReqs dim_orders_and_requirements = + std::move(std::get(result_or_error)); int fusion_level = hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); if (!std::get(gpu_version) @@ -1094,8 +1205,7 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( if (transform_direction == TransformDirection::kOutputToInput) { if (fusion_level < 2) { if (hlo.opcode() == HloOpcode::kConvert) { - if (FusionDecision decision = - RequireTritonFusibleConvert(&hlo, gpu_version); + if (FusionDecision decision = IsConversionWorthFusing(hlo, gpu_version); !decision) { return decision; } @@ -1113,9 +1223,13 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( if (operand->opcode() == HloOpcode::kBroadcast && (operand->operand(0)->opcode() == HloOpcode::kParameter || operand->operand(0)->opcode() == HloOpcode::kConstant) && - std::holds_alternative(AnalyzeForFusionImpl( - *operand, transform_direction, old_to_new_map, - std::get(result).map, gpu_version))) { + std::holds_alternative( + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *operand, TransformDirection::kOutputToInput, + /*src_operand_index=*/std::nullopt, + /*src_dim_order=*/ + dim_orders_and_requirements.dim_orders.at(operand), + gpu_version, properties))) { accepted = true; break; } @@ -1129,9 +1243,10 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( if (fusion_level < 2) { return "Skipping fusing outputs at low fusion levels."; } - for (const HloInstruction* operand : hlo.operands()) { - // Skip already fused operands. - if (old_to_new_map.contains(operand)) { + for (int i = 0; i < hlo.operand_count(); ++i) { + const HloInstruction* operand = hlo.operand(i); + // Skip source operand. + if (i == *src_operand_index) { continue; } // Currently only broadcasts of scalar constants or parameters @@ -1147,7 +1262,7 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl( return "Not obviously profitable to fuse as output."; } } - return std::get(result); + return dim_orders_and_requirements; } // Gets the fused HLO corresponding to `hlo` or adds a new parameter if not @@ -1223,21 +1338,24 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { return hlo.operand_count() - 1; } -bool FusionContext::MergeUpdates(const DimOrderUpdates& updates) { +bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { // First check that all updates to insert are compatible to avoid // incomplete merges. - for (const auto& [key, value] : updates.map) { + for (const auto& [key, value] : update.dim_orders) { auto it = dim_orders_.find(key); if (it != dim_orders_.cend() && !it->second.IsPhysicallyEquivalent(value)) { return false; } } - if (updates.splittable_dimension_major_part_size > 1 && - !SetSplittableDimensionMajorPartSize( - updates.splittable_dimension_major_part_size)) { + + RequirementsOrError requirements_or_error = + CombineRequirements(requirements_, update.requirements); + if (std::holds_alternative(requirements_or_error)) { return false; } - dim_orders_.insert(updates.map.begin(), updates.map.end()); + + requirements_ = std::move(std::get(requirements_or_error)); + dim_orders_.insert(update.dim_orders.begin(), update.dim_orders.end()); return true; } @@ -1270,10 +1388,13 @@ void FusionContext::TryToFuseWithInputsRecursively( continue; } num_requeued = 0; - const DimOrderUpdatesOrError result = AnalyzeForFusion( - *hlo, TransformDirection::kOutputToInput, old_to_new_map, gpu_version); - if (!std::holds_alternative(result) || - !MergeUpdates(std::get(result))) { + const DimOrdersAndReqsOrError result = + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *hlo, TransformDirection::kOutputToInput, + /*src_operand_index=*/std::nullopt, dim_orders_.at(hlo), + gpu_version, properties_); + if (!std::holds_alternative(result) || + !CombineDimOrdersAndReqs(std::get(result))) { continue; } if (hlo->opcode() != HloOpcode::kParameter) { @@ -1369,7 +1490,7 @@ StatusOr FuseDot(HloInstruction& dot, // These describe _outputs_ of corresponding HLOs. auto context = FusionContext::FromDotOutput( - dot, /*split_k=*/1, lhs_context.SplittableDimensionMajorPartSize()); + dot, /*split_k=*/1, lhs_context.splittable_dimension_major_part_size()); HloInstruction* fusion_output = ˙ bool output_changed = true; while (output_changed) { @@ -1381,13 +1502,17 @@ StatusOr FuseDot(HloInstruction& dot, if (!IsDistributiveOverAddition(*user)) { break; } - auto result = - context.AnalyzeForFusion(*user, TransformDirection::kInputToOutput, - output_old_to_new_map, gpu_version); - if (!std::holds_alternative(result)) { + DimOrdersAndReqsOrError result = + FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *user, TransformDirection::kInputToOutput, + user->operand_index(fusion_output), + context.dim_orders().at(fusion_output), gpu_version, + context.hero_properties()); + if (!std::holds_alternative(result)) { continue; } - TF_RET_CHECK(context.MergeUpdates(std::get(result))); + TF_RET_CHECK( + context.CombineDimOrdersAndReqs(std::get(result))); for (HloInstruction* operand : user->operands()) { if (!output_old_to_new_map.contains(operand)) { context.TryToFuseWithInputsRecursively(*operand, gpu_version, @@ -1509,12 +1634,11 @@ Status FusionContext::PropagateDimensionOrdersToParameters( TF_RET_CHECK(parameters.insert(hlo).second); VLOG(5) << hlo->ToString(); } - auto result = - HandleInstruction(hlo, dim_orders_, TransformDirection::kOutputToInput); - TF_RET_CHECK(std::holds_alternative(result)); - TF_RET_CHECK( - RequireSupportedDimOrders(*hlo, std::get(result))); - TF_RET_CHECK(MergeUpdates(std::get(result))); + DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( + *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, + properties_); + TF_RET_CHECK(std::holds_alternative(result)); + TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); iter_specs[hlo] = DimensionOrderToTensorIterationSpec(dim_orders_.at(hlo)); for (const HloInstruction* operand : hlo->operands()) { if (!visited.insert(operand).second) { @@ -1531,7 +1655,7 @@ Status FusionContext::PropagateDimensionOrdersToParameters( return OkStatus(); } -} // anonymous namespace +} // namespace // Data types that are supported by the Triton emitters. bool IsTritonSupportedDataType(PrimitiveType type, @@ -1643,14 +1767,15 @@ Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, const int split_k) { - int64_t lhs_nc_split_major_part_size = -1; + int64_t lhs_nc_split_major_part_size = kNoSplitRequirement; for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); auto context = FusionContext::FromDotOperand(dot, operand_number, split_k); TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); - if (scope == Scope::LHS && context.SplittableDimensionMajorPartSize() > 1) { - lhs_nc_split_major_part_size = context.SplittableDimensionMajorPartSize(); + if (scope == Scope::LHS) { + lhs_nc_split_major_part_size = + context.splittable_dimension_major_part_size(); } } @@ -1661,17 +1786,19 @@ Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, // Propagate dimension order from dot to root. while (!output->IsRoot()) { TF_RET_CHECK(output->user_count() == 1); + const HloInstruction* input = output; output = output->users()[0]; - auto result = context.HandleInstruction(output, context.DimOrders(), - TransformDirection::kInputToOutput); - TF_RET_CHECK(std::holds_alternative(result)); - TF_RET_CHECK(context.RequireSupportedDimOrders( - *output, std::get(result))); - TF_RET_CHECK(context.MergeUpdates(std::get(result))); + DimOrdersAndReqsOrError result = + context.GetPropagatedDimOrdersAndRequirements( + *output, context.dim_orders().at(input), + TransformDirection::kInputToOutput, context.hero_properties()); + TF_RET_CHECK(std::holds_alternative(result)); + TF_RET_CHECK( + context.CombineDimOrdersAndReqs(std::get(result))); } TF_RET_CHECK(iter_specs_[Scope::OUTPUT] .insert({output, DimensionOrderToTensorIterationSpec( - context.DimOrders().at(output))}) + context.dim_orders().at(output))}) .second); if (output != &dot) { // Propagate back to parameters of the output fusion. From b6aff0cf370ab97c6538d3982709b2a0ef46aecb Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Mon, 27 Nov 2023 07:52:41 -0800 Subject: [PATCH 093/381] Add define to force using latest ops in XNNPack delegate PiperOrigin-RevId: 585651614 --- tensorflow/lite/delegates/xnnpack/BUILD | 9 +++++++++ tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index affa0600b1ed6d..ad5755f6c990a3 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -27,6 +27,12 @@ config_setting( define_values = {"xnnpack_force_float_precision": "fp16"}, ) +# Force XNNPACK to use all operators in the delegate. +config_setting( + name = "xnnpack_use_latest_ops_explicit", + define_values = {"xnnpack_use_latest_ops": "true"}, +) + # Enable offloading of quantized 8-bit signed operators to XNNPACK delegate config_setting( name = "tflite_with_xnnpack_qs8_explicit_true", @@ -214,6 +220,9 @@ cc_library( copts = tflite_copts() + select({ ":xnnpack_force_float_precision_explicit_fp16": ["-DXNNPACK_DELEGATE_FORCE_PRECISION_FP16=1"], "//conditions:default": [], + }) + select({ + ":xnnpack_use_latest_ops_explicit": ["-DXNNPACK_DELEGATE_USE_LATEST_OPS=1"], + "//conditions:default": [], }), linkstatic = True, deps = [ diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index ca7bdfc88804ed..3e31df1398148b 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -532,8 +532,12 @@ class Delegate { } bool enable_latest_operators() const { +#ifdef XNNPACK_DELEGATE_USE_LATEST_OPS + return true; +#else return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS) != 0; +#endif } bool support_variable_ops() const { From 9692c9a5224c558fd88276234528c53e2c959c23 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 27 Nov 2023 08:30:40 -0800 Subject: [PATCH 094/381] Remove additional empty line --- tensorflow/python/ops/image_ops_impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index acc16d5bf2b69c..265293d1536775 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2027,7 +2027,6 @@ def random_brightness(image, max_delta, seed=None): with `tf.image.random_*` ops, `tf.image.stateless_random_*` ops guarantee the same results given the same seed independent of how many times the function is called, and independent of global seed settings (e.g. tf.random.set_seed). - Args: image: An image or images to adjust. From 417a92f664750c16ef49931b9e3e06527b0f65bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Mon, 27 Nov 2023 09:03:01 -0800 Subject: [PATCH 095/381] [XLA:GPU] Fix for "Transform dimension propagation to the functional paradigm " I think that now it's perhaps possible that FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible succeeds, but context.CombineDimOrdersAndReqs fails because of a "splittable_dimension_major_part_size" requirement. Also continue is the same as break in this specific context, so I changed it to break. PiperOrigin-RevId: 585669068 --- third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 5f4a28605c6b89..573ea2cdc4d999 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -1508,11 +1508,10 @@ StatusOr FuseDot(HloInstruction& dot, user->operand_index(fusion_output), context.dim_orders().at(fusion_output), gpu_version, context.hero_properties()); - if (!std::holds_alternative(result)) { - continue; + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs(std::get(result))) { + break; } - TF_RET_CHECK( - context.CombineDimOrdersAndReqs(std::get(result))); for (HloInstruction* operand : user->operands()) { if (!output_old_to_new_map.contains(operand)) { context.TryToFuseWithInputsRecursively(*operand, gpu_version, From 2493906ddc8d3bc44cbd7582ea538776e92d115e Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Mon, 27 Nov 2023 09:04:49 -0800 Subject: [PATCH 096/381] [XLA] Collective pipeliner goes through small reduces as they do not increase memory pressure much. PiperOrigin-RevId: 585669608 --- third_party/xla/xla/service/collective_pipeliner.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 3d5ddf374d32b6..6d5e2f5f2a63ec 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -303,8 +303,9 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, return false; } if (i->opcode() == HloOpcode::kReduce && - ShapeUtil::ElementsIn(i->shape()) == - ShapeUtil::ElementsIn(instr->operand(0)->shape())) { + (ShapeUtil::ElementsIn(i->shape()) == + ShapeUtil::ElementsIn(instr->operand(0)->shape()) || + ShapeUtil::ElementsIn(instr->operand(0)->shape()) < 1024)) { return true; } return HloPredicateIsOp Date: Mon, 27 Nov 2023 09:30:08 -0800 Subject: [PATCH 097/381] #tf-data Cap the pipeline processing time to avoid integer overflows. PiperOrigin-RevId: 585675706 --- tensorflow/core/framework/model.cc | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 7524050825fca1..449520172bbd15 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -2376,11 +2377,19 @@ void Model::Optimize(AutotuneAlgorithm algorithm, pipeline_processing_usec = 0; break; } - int64_t root_total_time_usec = root_timing->total_time_nsec * - root_timing->pipeline_ratio / - EnvTime::kMicrosToNanos; - pipeline_processing_usec = - std::max(pipeline_processing_usec, root_total_time_usec); + + double root_total_time_usec = root_timing->total_time_nsec * + root_timing->pipeline_ratio / + EnvTime::kMicrosToNanos; + + // Cap the computed value to ensure that there is no integer overflow. + if (root_total_time_usec < std::numeric_limits::max()) { + pipeline_processing_usec = + std::max(pipeline_processing_usec, + static_cast(root_total_time_usec)); + } else { + pipeline_processing_usec = std::numeric_limits::max(); + } } // Only updates the pipeline processing time when it is greater than 0. // If it is zero, we assume the pipeline processing time is the same From 907196b966ba53551b292f8c3c3d0bb0539825ba Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 27 Nov 2023 09:53:58 -0800 Subject: [PATCH 098/381] Don't produce layout info for TPU Embedding Ops in the old bridge to match behavior of the MLIR Bridge. XLA produces the information. PiperOrigin-RevId: 585681525 --- tensorflow/core/tpu/kernels/tpu_embedding_ops.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc index 52b83ae6be78b2..97fe019201e4cb 100644 --- a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc @@ -25,7 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/xla_builder.h" +#include "xla/layout_util.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -252,7 +254,11 @@ class SendTPUEmbeddingGradientsOp : public XlaOpKernel { auto builder = ctx->builder(); gradient_shapes.reserve(gradients.size()); for (xla::XlaOp op : gradients) { - gradient_shapes.push_back(builder->GetShape(op).value()); + // Gradient layout information is added by XLA, so we can just create + // default layout information. + xla::Shape gradient_shape = builder->GetShape(op).value(); + xla::LayoutUtil::SetToDefaultLayout(&gradient_shape); + gradient_shapes.push_back(gradient_shape); } std::vector learning_rates; From 2c5b1a264ecdb0d67f01504b0af6cc13191b7b91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 10:19:28 -0800 Subject: [PATCH 099/381] Re-enable layering_check for target. PiperOrigin-RevId: 585689170 --- tensorflow/core/common_runtime/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index d0349fd275813a..69822584487fcd 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -384,13 +384,13 @@ cc_library( srcs = ["buf_rendezvous.cc"], hdrs = ["buf_rendezvous.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device", ":device_mgr", ":process_util", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) From 2b4a3082dccbca2417a1ff55120333505b271da8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 10:31:30 -0800 Subject: [PATCH 100/381] Support 8-bit numbers for XlaRngBitGenerator PiperOrigin-RevId: 585692766 --- .../compiler/mlir/tensorflow/ir/tf_generated_ops.td | 2 +- tensorflow/compiler/tests/xla_ops_test.py | 12 ++++++------ .../tf2xla/kernels/stateless_random_ops_v2.cc | 2 +- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index c35ff30bff0a9f..e870c838538d0a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -22008,7 +22008,7 @@ a u64[2] and for PHILOX a u64[3].}]>:$initial_state, let results = (outs TF_Uint64Tensor:$output_key, - TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + TensorOf<[TF_Int32, TF_Int64, TF_Int8, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<2>; diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 47cc309b45452b..46f192648ecaa6 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -1270,13 +1270,13 @@ def assert_output_shapes(output, expected_shape): ): reduce_with_shapes((None, 4, 5), (3, None, 5), (13, 4, 5)) - @parameterized.parameters( - random_ops_util.Algorithm.THREEFRY, - random_ops_util.Algorithm.PHILOX, - random_ops_util.Algorithm.AUTO_SELECT, + @parameterized.product( + algorithm=[random_ops_util.Algorithm.THREEFRY, + random_ops_util.Algorithm.PHILOX, + random_ops_util.Algorithm.AUTO_SELECT], + dtype=[np.uint8, np.uint64], ) - def testRngBitGenerator(self, algorithm): - dtype = np.uint64 + def testRngBitGenerator(self, algorithm, dtype): initial_state = array_ops.placeholder(np.uint64, shape=(2,)) shape = (2, 3) res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index 098ecf39792e21..cc0cdfc2036fa7 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -428,7 +428,7 @@ REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp); REGISTER_XLA_OP(Name("XlaRngBitGenerator") .CompileTimeConstantInput("algorithm") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}), + .TypeConstraint("dtype", {DT_UINT8, DT_UINT32, DT_UINT64}), MlirXlaOpKernel); } // namespace diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 480dc474410359..edb2a40f4d332b 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -816,7 +816,7 @@ REGISTER_OP("XlaRngBitGenerator") .Input("shape: Tshape") .Output("output_key: uint64") .Output("output: dtype") - .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64") + .Attr("dtype: {uint8, int8, int32, int64, uint32, uint64} = DT_UINT64") .Attr("Tshape: {int32, int64} = DT_INT32") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle algorithm; From 8639f55d944ef71ff272c2860e144a5a1706396c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 11:34:50 -0800 Subject: [PATCH 101/381] Adds back the solver parameter string. PiperOrigin-RevId: 585712342 --- .../xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index e86ca08613a246..0c71a705118e5a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -217,6 +217,7 @@ AutoShardingSolverResult CallORToolsSolver( "share_binary_clauses:false,random_seed:1,interleave_" "search:true,num_workers:", num_workers); + solver->SetSolverSpecificParametersAsString(solver_parameter_str); } #endif // Create variables From 131efc28a7f07bc27103a453a8fd88d57329e07f Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Mon, 27 Nov 2023 11:36:13 -0800 Subject: [PATCH 102/381] Add presubmit check to enforce single source targets. PiperOrigin-RevId: 585712712 --- tensorflow/BUILD | 8 ++++++++ tensorflow/build_cleaner_spec.textproto | 14 ++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 tensorflow/build_cleaner_spec.textproto diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 58b26fb5e79a61..289f37ef902c63 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -41,8 +41,11 @@ load( ) # copybara:uncomment_begin +# # buildifier: disable=out-of-order-load +# load("//devtools/build_cleaner/skylark:action_config_test.bzl", "action_config_test") # load("//devtools/copybara/rules:copybara.bzl", "copybara_config_test") # load("//tools/build_defs/license:license.bzl", "license") +# # buildifier: enable=out-of-order-load # copybara:uncomment_end # copybara:comment_begin(oss-only) @@ -183,6 +186,11 @@ package( # ], # deps = [":copybara_config"], # ) +# +# action_config_test( +# name = "build_cleaner_spec_test", +# src = "build_cleaner_spec.textproto", +# ) # copybara:uncomment_end licenses(["notice"]) diff --git a/tensorflow/build_cleaner_spec.textproto b/tensorflow/build_cleaner_spec.textproto new file mode 100644 index 00000000000000..4c593a0ec4fd4c --- /dev/null +++ b/tensorflow/build_cleaner_spec.textproto @@ -0,0 +1,14 @@ +# proto-file: devtools/build_cleaner/proto/actions.proto +# proto-message: ActionSpecs + +# Rules (except for the allowlist) should not have more than one source file. +action_spec { + action: CHECK_FILE_COUNT + file_count_params { + rule_selector { + rule_kind_regex: "^(?!filegroup|genrule|_policy_filegroup).*$" + generator_function_regex: "^(?!boq_header)$" + } + max_source_count: 1 + } +} \ No newline at end of file From a068df0f5b032066ff439aa9ff8708bac85eacdb Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 27 Nov 2023 11:44:23 -0800 Subject: [PATCH 103/381] Create a `pywrap_quantization` pybind11 module and add a function for static-range PTQ. PiperOrigin-RevId: 585714763 --- .../mlir/quantization/stablehlo/python/BUILD | 32 ++- .../integration_test/quantize_model_test.py | 6 + .../stablehlo/python/pywrap_quantization.cc | 201 ++++++++++++++++++ .../stablehlo/python/pywrap_quantization.pyi | 33 +++ .../stablehlo/python/quantization.py | 59 ++++- .../mlir/quantization/tensorflow/python/BUILD | 4 +- 6 files changed, 327 insertions(+), 8 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index ebe8d44d60aa12..996af258bac493 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -1,5 +1,9 @@ load("//tensorflow:pytype.default.bzl", "pytype_strict_library") -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") +load( + "//tensorflow:tensorflow.default.bzl", + "tf_py_strict_test", + "tf_python_pybind_extension", +) load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") package_group( @@ -26,8 +30,14 @@ pytype_strict_library( name = "quantization", srcs = ["quantization.py"], deps = [ + ":pywrap_quantization", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:save_model", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/saved_model:loader", ], ) @@ -69,3 +79,23 @@ tf_py_strict_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_python_pybind_extension( + name = "pywrap_quantization", + srcs = ["pywrap_quantization.cc"], + pytype_srcs = ["pywrap_quantization.pyi"], + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:env", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 6071ded7ed39c5..9a796dedfe8e54 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -108,6 +108,9 @@ def data_gen() -> repr_dataset.RepresentativeDataset: tfrecord_file_path=dataset_path ) }, + calibration_options=quant_opts_pb2.CalibrationOptions( + calibration_method=quant_opts_pb2.CalibrationOptions.CALIBRATION_METHOD_MIN_MAX + ), ) quantization.quantize_saved_model( self._input_saved_model_path, @@ -190,6 +193,9 @@ def data_gen() -> repr_dataset.RepresentativeDataset: tfrecord_file_path=dataset_path ) }, + calibration_options=quant_opts_pb2.CalibrationOptions( + calibration_method=quant_opts_pb2.CalibrationOptions.CALIBRATION_METHOD_MIN_MAX + ), ) quantization.quantize_saved_model( self._input_saved_model_path, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc new file mode 100644 index 00000000000000..76453b285eca66 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -0,0 +1,201 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11/stl.h" // from @pybind11 // IWYU pragma: keep +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tsl/platform/env.h" + +namespace py = pybind11; + +namespace { + +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::DebuggerOptions; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::QuantizationOptions; + +// TODO: b/307624867 - Factor out this function to a separate file. +// Creates a temporary directory and returns its path. +std::string CreateTmpDir() { + tsl::Env* const env = tsl::Env::Default(); + + std::string tmp_dir; + env->LocalTempFilename(&tmp_dir); + if (!env->RecursivelyCreateDir(tmp_dir).ok()) { + throw py::value_error( + absl::StrFormat("Failed to create tmp dir: '%s'", tmp_dir)); + } + + return tmp_dir; +} + +// TODO: b/312371048 - Factor out this function to a separate file. +// Enables debugging on `exported_model` by updating the `DumpTensor` ops. +// +// Saves the current model to `debugger_options.unquantized_dump_model_path()` +// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because +// in whole-model debugging mode the `DumpTensor` ops for the unquantized +// tensors are only inserted in the unquantized model whereas `DumpTensor` ops +// for the quantized tensors are only inserted in the quantized model. Both +// models are required to be able to dump both quantized and unquantized tensors +// and compare them offline. +ExportedModel EnableDebugging( + const ExportedModel& exported_model, + const DebuggerOptions& debugger_options, + const PyFunctionLibrary& py_function_library, + const absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& signature_def_map) { + ExportedModel debugger_enabled_exported_model = exported_model; + *debugger_enabled_exported_model.mutable_graph_def() = + py_function_library.EnableDumpTensor(exported_model.graph_def()); + if (debugger_options.debugger_type() == + DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { + // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. + // TODO: b/296916287 - Create a separate function for saving unquantized + // dump model. + py_function_library.SaveExportedModel( + debugger_options.unquantized_dump_model_path(), + debugger_enabled_exported_model, src_saved_model_path, tags, + signature_def_map); + + *debugger_enabled_exported_model.mutable_graph_def() = + py_function_library.ChangeDumpTensorFileName( + debugger_enabled_exported_model.graph_def()); + } + + return debugger_enabled_exported_model; +} + +} // namespace + +PYBIND11_MODULE(pywrap_quantization, m) { + // Supports absl::Status type conversions. + pybind11::google::ImportStatusModule(); + + m.doc() = "StableHLO Quantization APIs."; + + m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "static_range_ptq", + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const absl::flat_hash_map& function_aliases, + const PyFunctionLibrary& py_function_library, + py::object representative_dataset) -> absl::Status { + // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + + const absl::StatusOr exported_model = + QuantizePtqModelPreCalibration(src_saved_model_path, signature_keys, + tags, quantization_options, + function_aliases); + if (!exported_model.ok()) return exported_model.status(); + + const ExportedModel exported_model_ids_assigned = + py_function_library.AssignIdsToCustomAggregatorOps(*exported_model); + + const std::string precalibrated_saved_model_dir = CreateTmpDir(); + + py_function_library.SaveExportedModel( + precalibrated_saved_model_dir, exported_model_ids_assigned, + src_saved_model_path, tags, signature_def_map); + + ExportedModel calibrated_exported_model = + py_function_library.RunCalibration( + precalibrated_saved_model_dir, signature_keys, tags, + exported_model_ids_assigned, + quantization_options.calibration_options(), + quantization_options.force_graph_mode_calibration(), + representative_dataset); + + if (quantization_options.has_debugger_options()) { + calibrated_exported_model = EnableDebugging( + calibrated_exported_model, + quantization_options.debugger_options(), py_function_library, + src_saved_model_path, tags, signature_def_map); + } + + const std::string calibrated_saved_model_path = CreateTmpDir(); + + py_function_library.SaveExportedModel( + calibrated_saved_model_path, calibrated_exported_model, + src_saved_model_path, tags, signature_def_map); + + const absl::flat_hash_map + function_aliases_after_calibration( + calibrated_exported_model.function_aliases().begin(), + calibrated_exported_model.function_aliases().end()); + + const absl::StatusOr post_calibrated_exported_model = + QuantizePtqModelPostCalibration( + calibrated_saved_model_path, signature_keys, tags, + quantization_options, function_aliases_after_calibration); + if (!post_calibrated_exported_model.ok()) { + return post_calibrated_exported_model.status(); + } + + py_function_library.SaveExportedModel( + dst_saved_model_path, *post_calibrated_exported_model, + calibrated_saved_model_path, tags, signature_def_map); + + return absl::OkStatus(); + }, + R"pbdoc( + Runs static-range post-training quantization (PTQ) on a SavedModel at + `src_saved_model_path` and saves the resulting model to + `dst_saved_model_path`. + + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. + + Raises `StatusNotOk` exception if when the run was unsuccessful. + )pbdoc", + py::arg("saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("function_aliases"), py::arg("py_function_library"), + py::arg("representative_dataset")); +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi new file mode 100644 index 00000000000000..1870115a4aa847 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi @@ -0,0 +1,33 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Any + +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd + +# LINT.IfChange(static_range_ptq) +def static_range_ptq( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + function_aliases: dict[str, str], + py_function_library: py_function_lib.PyFunctionLibrary, + representative_dataset: rd.RepresentativeDatasetOrMapping, +) -> Any: ... # Status + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py index 5eefe2b94bbb7e..fab36e2005110f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -13,8 +13,36 @@ # limitations under the License. # ============================================================================== """StableHLO Quantizer.""" +from typing import Mapping + +from tensorflow.compiler.mlir.quantization.stablehlo.python import pywrap_quantization from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import loader_impl + +# Mapping of signature def key -> SignatureDef. +_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] + + +def _serialize_signature_def_map( + signature_def_map: _SignatureDefMap, +) -> dict[str, bytes]: + """Serializes SignatureDef values in `signature_def_map`. + + Args: + signature_def_map: Signature key -> SignatureDef mapping. + + Returns: + Signature def map where the values (`SignatureDef`) are serialized. + """ + signature_def_map_serialized = {} + for key, signature_def in signature_def_map.items(): + signature_def_map_serialized[key] = signature_def.SerializeToString() + + return signature_def_map_serialized # TODO: b/310594193 - Export API to pip package. @@ -44,6 +72,29 @@ def quantize_saved_model( ' single signature.' ) - # TODO: b/307624867 - Remove TF Quantizer dependency and replace it with - # StableHLO Quantizer components. - quantize_model.quantize(src_saved_model_path, dst_saved_model_path, config) + signature_def_map = save_model.get_signatures_from_saved_model( + src_saved_model_path, + list(config.signature_keys), + set(config.tags), + ) + + loader = loader_impl.SavedModelLoader(src_saved_model_path) + function_aliases = loader.get_meta_graph_def_from_tags( + config.tags + ).meta_info_def.function_aliases + + representative_dataset = rd.RepresentativeDatasetLoader( + config.representative_datasets + ).load() + + signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) + pywrap_quantization.static_range_ptq( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=config.SerializeToString(), + signature_keys=list(config.signature_keys), + signature_def_map_serialized=signature_def_map_serialized, + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=representative_dataset, + ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 202e5357bcf702..c1b0040dc78f30 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -13,6 +13,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", "//tensorflow/python:__subpackages__", ], @@ -101,7 +102,6 @@ cc_library( pytype_strict_library( name = "py_function_lib_py", srcs = ["py_function_lib.py"], - visibility = ["//visibility:private"], deps = [ ":pywrap_function_lib", ":representative_dataset", @@ -149,7 +149,6 @@ cc_library( "-use_header_modules", "-parse_headers", ], - visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -166,7 +165,6 @@ cc_library( cc_library( name = "py_function_lib", hdrs = ["py_function_lib.h"], - visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", From 49fbda812b69208ed7d99e4e9532fe1a3cd60e7c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 11:45:24 -0800 Subject: [PATCH 104/381] Temporarily disable the XLA sub-target for linear_operator_tridiag_test. Due to the shard count increase, there are now some empty shards in the target, since with XLA, some tests aren't dynamically added to the test classes, which causes NO TESTS RAN failures on 3.12. This only affects one test, as all others already require XLA to not be present. PiperOrigin-RevId: 585715034 --- tensorflow/python/kernel_tests/linalg/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 2298216f4aaed9..936e4204d3ace8 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -521,7 +521,9 @@ cuda_py_strict_test( "no_windows_gpu", "optonly", ], - xla_enable_strict_auto_jit = True, + # TODO(b/313470344): XLA temporarily disabled due to empty shards on 3.12. + xla_enable_strict_auto_jit = False, + xla_enabled = False, deps = [ "//tensorflow/python/framework:config", "//tensorflow/python/framework:test_lib", From bc4218606b7ca4f9b6ae927b76b077590d8e0299 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 27 Nov 2023 11:51:36 -0800 Subject: [PATCH 105/381] Consolidate `type_caster` implementations for serialized protobuf objects into a template class. Quantization library's `type_caster.h` library contained separate pybind11 `type_caster` implementations for each protobuf message types in use. This change consolidates similar implementations of them into a single template class `SerializedProtobufCaster`. This simplifies the `type_caster.h` library and makes the affected protobuf message types explicit. PiperOrigin-RevId: 585716588 --- .../python/pywrap_quantize_model.cc | 2 - .../tensorflow/python/type_casters.h | 190 ++++++------------ 2 files changed, 60 insertions(+), 132 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 756b1b0fe94214..8109e422e5da5d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -109,8 +109,6 @@ ExportedModel EnableDebugging( PYBIND11_MODULE(pywrap_quantize_model, m) { // Supports absl::StatusOr type conversions. pybind11::google::ImportStatusModule(); - // TODO - b/308532051: Make protobuf objects work without serialization - // overhead. pybind11_protobuf::ImportNativeProtoCasters(); m.def( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h index 7c7d1ae46b42f9..6632fd571927ee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h @@ -16,8 +16,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TYPE_CASTERS_H_ #include +#include #include +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 @@ -28,7 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tsl/platform/protobuf.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace pybind11::detail { namespace internal { @@ -41,170 +43,98 @@ std::string Serialize(const tsl::protobuf::Message& protobuf_object) { // Empty string means it failed to serialize the protobuf with an error. See // the docstring for SerializeAsString for details. if (serialized.empty()) { - throw py::value_error("Failed to serialize protobuf object."); + // Show the name of the protobuf message type to provide more information + // and easier debugging. + const std::string descriptor_name = + protobuf_object.GetDescriptor() == nullptr + ? "unknown" + : protobuf_object.GetDescriptor()->full_name(); + throw py::value_error(absl::StrFormat( + "Failed to serialize protobuf object: %s.", descriptor_name)); } return serialized; } -} // namespace internal - -// Handles `ExportedModel` (c++) <-> `bytes` (python) conversion. The `bytes` -// object in the python layer is a serialization of `ExportedModel`. +// Handles `ProtoT` (c++) <-> `bytes` (python) conversion. The `bytes` +// object in the python layer is a serialization of `ProtoT`. // -// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html for -// further details on how custom type conversions work for pybind11. -template <> -struct type_caster { +// The caller of c++ interfaces should make sure to pass valid serialized +// `ProtoT` objects as arguments. Failing to do so results in raising a +// `ValueError`. Similarly, the python implementation of a c++ virtual member +// function that return an `ProtoT` should return a valid serialized `ProtoT`. +// +// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html +template >> +struct SerializedProtobufCaster { public: - PYBIND11_TYPE_CASTER(tensorflow::quantization::ExportedModel, - const_name("ExportedModel")); + PYBIND11_TYPE_CASTER(ProtoT, const_name()); - // Loads an `ExportedModel` instance from a python `bytes` object (`src`). + // Loads an `ProtoT` instance from a python `bytes` object (`src`). bool load(handle src, const bool convert) { auto caster = make_caster(); // Make sure the user passed a valid python string. - if (!caster.load(src, convert)) { - return false; - } + if (!caster.load(src, convert)) return false; - const absl::string_view exported_model_serialized = + const absl::string_view serialized_proto = cast_op(std::move(caster)); // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(exported_model_serialized)); + return value.ParseFromString(std::string(serialized_proto)); } - // Constructs a `bytes` object after serializing `src`. - static handle cast(tensorflow::quantization::ExportedModel&& src, - return_value_policy policy, handle parent) { + // Constructs a `bytes` object by serializing `src`. + static handle cast(ProtoT&& src, return_value_policy policy, handle parent) { // release() prevents the reference count from decreasing upon the // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); + return py::bytes(Serialize(src)).release(); } - // Constructs a `bytes` object after serializing `src`. - static handle cast(const tensorflow::quantization::ExportedModel& src, - return_value_policy policy, handle parent) { + // Constructs a `bytes` object by serializing `src`. + static handle cast(const ProtoT& src, return_value_policy policy, + handle parent) { // release() prevents the reference count from decreasing upon the // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); + return py::bytes(Serialize(src)).release(); } }; -// Handles type conversion for `QuantizationOptions`. -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(tensorflow::quantization::QuantizationOptions, - const_name("QuantizationOptions")); - - // Python -> C++. Converts a serialized protobuf string and deserializes into - // an instance of `QuantizationOptions`. - bool load(handle src, const bool convert) { - auto caster = make_caster(); - // The user should have passed a valid python string. - if (!caster.load(src, convert)) { - return false; - } - - const absl::string_view quantization_opts_serialized = - cast_op(std::move(caster)); - - // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(quantization_opts_serialized)); - } +} // namespace internal - // C++ -> Python. Constructs a `bytes` object after serializing `src`. - static handle cast(const tensorflow::quantization::QuantizationOptions& src, - return_value_policy policy, handle parent) { - // release() prevents the reference count from decreasing upon the - // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); - } -}; +// The following explicit specializations of protobuf `type_caster`s for +// specific protobuf message types are there to have higher priority over those +// defined in `native_proto_caster.h` during the resolution process. This is +// because the type casters in `native_proto_caster.h`, which allow seamlessly +// exchanging protobuf messages across c++-python boundaries, potentially +// without serialization, fail in the open-source environment. +// Explicitly-specialized type casters for serialized protobufs are added on an +// on-demand basis for quantization library. +// TODO: b/308532051 - Make `native_proto_caster.h` work in the open-source +// environment. -// Handles type conversion for `CalibrationOptions`. template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(tensorflow::quantization::CalibrationOptions, - const_name("CalibrationOptions")); - - // Python -> C++. Converts a serialized protobuf string and deserializes into - // an instance of `CalibrationOptions`. - bool load(handle src, const bool convert) { - auto caster = make_caster(); - // The user should have passed a valid python string. - if (!caster.load(src, convert)) { - return false; - } - - const absl::string_view calibration_opts_serialized = - cast_op(std::move(caster)); - - // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(calibration_opts_serialized)); - } - - // C++ -> Python. Constructs a `bytes` object after serializing `src`. - static handle cast(const tensorflow::quantization::CalibrationOptions& src, - return_value_policy policy, handle parent) { - // release() prevents the reference count from decreasing upon the - // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); - } -}; +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::ExportedModel> {}; template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(tensorflow::SignatureDef, const_name("SignatureDef")); - - // Python->C++ conversion. Accepts a serialized `SignatureDef` string from the - // python side. - bool load(handle src, const bool convert) { - auto caster = make_caster(); - if (!caster.load(src, convert)) return false; - - const absl::string_view signature_def_serialized = - cast_op(std::move(caster)); - - // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(signature_def_serialized)); - } - - // C++->Python conversion. Returns a serialized `SignatureDef` string. - static handle cast(const tensorflow::SignatureDef& src, - return_value_policy policy, handle parent) { - return py::bytes(internal::Serialize(src)).release(); - } -}; +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::QuantizationOptions> {}; template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(tensorflow::GraphDef, const_name("GraphDef")); - - // Python->C++ conversion. Accepts a serialized `GraphDef` string from the - // python side. - bool load(handle src, const bool convert) { - auto caster = make_caster(); - if (!caster.load(src, convert)) return false; +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::CalibrationOptions> {}; - const absl::string_view signature_def_serialized = - cast_op(std::move(caster)); - - // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(signature_def_serialized)); - } +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; - // C++->Python conversion. Returns a serialized `GraphDef` string. - static handle cast(const tensorflow::GraphDef& src, - return_value_policy policy, handle parent) { - return py::bytes(internal::Serialize(src)).release(); - } -}; +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; } // namespace pybind11::detail From 77acca40bc72ec95a7b1da12c397ad0c77d0f74c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 11:53:22 -0800 Subject: [PATCH 106/381] [stream_executor] Add Case conditional command to CommandBuffer #6973 PiperOrigin-RevId: 585717031 --- .../xla/xla/stream_executor/command_buffer.cc | 7 + .../xla/xla/stream_executor/command_buffer.h | 9 ++ .../cuda/cuda_command_buffer_test.cc | 96 +++++++++++ .../cuda/cuda_conditional_kernels.cu.cc | 37 +++++ .../stream_executor/gpu/gpu_command_buffer.cc | 153 ++++++++++-------- .../stream_executor/gpu/gpu_command_buffer.h | 29 +++- .../rocm/hip_conditional_kernels.cu.cc | 3 + .../stream_executor_internal.h | 9 ++ 8 files changed, 274 insertions(+), 69 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 94ce1f67a4d846..669c0ce9543026 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -127,6 +128,12 @@ tsl::Status CommandBuffer::IfElse(StreamExecutor* executor, std::move(else_builder)); } +tsl::Status CommandBuffer::Case(StreamExecutor* executor, + DeviceMemory index, + std::vector branches) { + return implementation_->Case(executor, index, std::move(branches)); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2698293af7261a..0ba382256bdd0d 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/functional/any_invocable.h" #include "xla/stream_executor/device_memory.h" @@ -137,6 +138,14 @@ class CommandBuffer { tsl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, Builder then_builder, Builder else_builder); + // Adds a conditional operation that will run a command buffer constructed by + // the `branches` builder at `index`. If `index` is out of range, then it will + // run a conditional command buffer constructed by the last builder. + // + // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + tsl::Status Case(StreamExecutor* executor, DeviceMemory index, + std::vector branches); + //--------------------------------------------------------------------------// // Finalizes command buffer and makes it executable. Once command buffer is diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index cc1f823e3608db..5cf467a8ff3109 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -414,6 +414,102 @@ TEST(CudaCommandBufferTest, ConditionalIfElse) { ASSERT_EQ(dst, expected_mul); } +TEST(CudaCommandBufferTest, ConditionalCase) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + if (!CommandBuffer::SupportsConditionalCommands(platform)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + AddI32Kernel add(executor); + MulI32Kernel mul(executor); + + { // Load addition kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); + TF_ASSERT_OK(executor->GetKernel(spec, &add)); + } + + { // Load multiplication kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetMulI32CudaKernel(), "mul"); + TF_ASSERT_OK(executor->GetKernel(spec, &mul)); + } + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=2, b=3, c=0, index=0 + DeviceMemory index = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&index, 0, sizeof(int32_t)); + stream.ThenMemset32(&a, 2, byte_length); + stream.ThenMemset32(&b, 3, byte_length); + stream.ThenMemZero(&c, byte_length); + + // if (index == 0) c = a + b + CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { + return branch0_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // if (index == 1) c = a * b + CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { + return branch1_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.Case(executor, index, {branch0, branch1})); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `c` data back to host. + std::vector dst(4, 42); + stream.ThenMemcpy(dst.data(), c, byte_length); + + std::vector expected_add = {5, 5, 5, 5}; + ASSERT_EQ(dst, expected_add); + + // Set index to `1` + stream.ThenMemset32(&index, 1, sizeof(int32_t)); + + // Submit the same command buffer, but this time it should multiply inputs. + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), c, byte_length); + std::vector expected_mul = {6, 6, 6, 6}; + ASSERT_EQ(dst, expected_mul); + + // Set index to `-1` (out of bound index value). + stream.ThenMemset32(&index, -1, sizeof(int32_t)); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), c, byte_length); + ASSERT_EQ(dst, expected_mul); + + // Set index to `2` (out of bound index value). + stream.ThenMemset32(&index, 2, sizeof(int32_t)); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), c, byte_length); + ASSERT_EQ(dst, expected_mul); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index ad49ecade31c8e..c8eff5ed7dfe29 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "third_party/gpus/cuda/include/cuda.h" namespace stream_executor { @@ -43,10 +46,40 @@ __global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, } } +__global__ void SetCaseCondition( + cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1, + cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3, + cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5, + cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7, + int32_t* index, int32_t num_handles) { + // Only handles in [0, num_handles) range are valid. + // + // We can't define a device function with dynamic number of handle arguments, + // so we always pass 8 handles, but only some of them are valid. Size 8 picked + // as a reasonable (but random) upper bound for what we see in XLA uses. + std::array handles = {h0, h1, h2, h3, + h4, h5, h6, h7}; + + // If branch index is out of range activate the last valid handle. + int32_t branch_index = *index; + if (branch_index < 0 || branch_index >= num_handles) { + branch_index = num_handles - 1; + } + + for (int32_t i = 0; i < num_handles; ++i) { + if (branch_index == i) { + cudaGraphSetConditional(handles[i], 1); + } else { + cudaGraphSetConditional(handles[i], 0); + } + } +} + #else // CUDA graph conditionals are not available __global__ void SetIfCondition() {} __global__ void SetIfElseCondition() {} +__global__ void SetCaseCondition() {} #endif @@ -63,6 +96,10 @@ void* GetSetIfElseConditionKernel() { return reinterpret_cast(&cuda::SetIfElseCondition); } +void* GetSetCaseConditionKernel() { + return reinterpret_cast(&cuda::SetCaseCondition); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 09143ffbb5fcad..557c9c2a6afa7f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -129,12 +129,6 @@ GpuCommandBuffer::ScopedGpuGraphExec::~ScopedGpuGraphExec() { cmd_buffer->is_owned_graph_exec_ = restore_is_owned; } -void GpuCommandBuffer::ConditionalCommandBuffers::Add( - GpuGraphConditionalHandle handle, CommandBuffer command_buffer) { - handles.push_back(handle); - command_buffers.push_back(std::move(command_buffer)); -} - static GpuDevicePtr AsDevicePtr(const DeviceMemoryBase& mem) { return reinterpret_cast(const_cast(mem.opaque())); } @@ -319,12 +313,12 @@ GpuCommandBuffer::CreateConditionalNodes( return conditional_graphs; } -tsl::StatusOr +tsl::StatusOr> GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders) { - ConditionalCommandBuffers cond_cmd_buffers; + std::vector cmd_buffers; // Conditional command buffers always created in nested mode and with // underlying graphs owned by a conditional node. @@ -339,10 +333,10 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( TF_RETURN_IF_ERROR(builders[i](&command_buffer)); TF_RETURN_IF_ERROR(command_buffer.Finalize()); - cond_cmd_buffers.Add(handles[i], std::move(command_buffer)); + cmd_buffers.push_back(std::move(command_buffer)); } - return cond_cmd_buffers; + return cmd_buffers; } tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( @@ -360,35 +354,26 @@ tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( return tsl::OkStatus(); } -tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder) { - DCHECK(executor->implementation() == parent_); // NOLINT - - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `If`. - SetIfConditionKernel set_if_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddInProcessSymbol(gpu::GetSetIfConditionKernel(), "set_if_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_condition)); - } - - std::array builders = {std::move(then_builder)}; +tsl::Status GpuCommandBuffer::CreateConditionalCommand( + SetConditionFn set_condition, + absl::Span builders) { + // Every conditional command buffer is controlled by its own handle. + size_t num_handles = builders.size(); if (state_ == State::kCreate) { - TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(1)); + TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(num_handles)); - // Add a kernel to update conditional handle value based on a predicate. - TF_RETURN_IF_ERROR(Launch(set_if_condition, ThreadDim(), BlockDim(), - handles[0], predicate)); + // Add a kernel to update conditional handles values. + TF_RETURN_IF_ERROR(set_condition(handles)); - // Create conditional command buffer for then branch. + // Create conditional command buffer for each builder. TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(handles)); - TF_ASSIGN_OR_RETURN( - conditional_command_buffers_.emplace_back(), - CreateConditionalCommandBuffers(handles, graphs, builders)); + TF_ASSIGN_OR_RETURN(auto cmd_buffers, CreateConditionalCommandBuffers( + handles, graphs, builders)); + + // Keep track of created conditional handles and command buffers. + conditional_command_buffers_.emplace_back(std::move(handles), + std::move(cmd_buffers)); return tsl::OkStatus(); } @@ -398,14 +383,13 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, conditional_command_buffers_[update_state_.conditional_idx++]; // Sanity check that we got the correct conditional command buffers. - TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, 1)); + TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, num_handles)); - // Update a kernel that updates conditional handle based on a predicate. - TF_RETURN_IF_ERROR(Launch(set_if_condition, ThreadDim(), BlockDim(), - cond_cmd_buffers.handles[0], predicate)); + // Update a kernel that updates conditional handles values. + TF_RETURN_IF_ERROR(set_condition(cond_cmd_buffers.handles)); // Skip updating conditional nodes. - update_state_.node_idx += cond_cmd_buffers.handles.size(); + update_state_.node_idx += num_handles; return UpdateConditionalCommandBuffers( absl::MakeSpan(cond_cmd_buffers.command_buffers), builders); @@ -414,6 +398,31 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, return UnsupportedStateError(state_); } +tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, + DeviceMemory predicate, + CommandBuffer::Builder then_builder) { + DCHECK(executor->implementation() == parent_); // NOLINT + + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `If`. + SetIfConditionKernel set_if_condition(executor); + + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddInProcessSymbol(gpu::GetSetIfConditionKernel(), "set_if_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_condition)); + } + + auto set_cond_fn = [&](absl::Span handles) { + return Launch(set_if_condition, ThreadDim(), BlockDim(), handles[0], + predicate); + }; + + std::array builders = {std::move(then_builder)}; + + return CreateConditionalCommand(set_cond_fn, builders); +} + tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder, @@ -421,7 +430,7 @@ tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, DCHECK(executor->implementation() == parent_); // NOLINT // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `If`. + // every call to `IfElse`. SetIfElseConditionKernel set_if_else_condition(executor); { // Load kernels that updates condition handle value. @@ -431,45 +440,55 @@ tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_else_condition)); } + auto set_cond_fn = [&](absl::Span handles) { + return Launch(set_if_else_condition, ThreadDim(), BlockDim(), handles[0], + handles[1], predicate); + }; + std::array builders = {std::move(then_builder), std::move(else_builder)}; - if (state_ == State::kCreate) { - TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(2)); - - // Add a kernel to update conditional handle value based on a predicate. - TF_RETURN_IF_ERROR(Launch(set_if_else_condition, ThreadDim(), BlockDim(), - handles[0], handles[1], predicate)); + return CreateConditionalCommand(set_cond_fn, builders); +} - // Create conditional command buffers for then/else branches. - TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(handles)); - TF_ASSIGN_OR_RETURN( - conditional_command_buffers_.emplace_back(), - CreateConditionalCommandBuffers(handles, graphs, builders)); +tsl::Status GpuCommandBuffer::Case( + StreamExecutor* executor, DeviceMemory index, + std::vector branches) { + DCHECK(executor->implementation() == parent_); // NOLINT - return tsl::OkStatus(); + // TODO(ezhulenev): Relax this constraint, we can launch multiple back to back + // kernels to update conditional handles in batches of size 8. + if (branches.size() > 8) { + return absl::InvalidArgumentError(absl::StrCat( + "Case command supports only up to 8 branches, got: ", branches.size())); } - if (state_ == State::kUpdate) { - ConditionalCommandBuffers& cond_cmd_buffers = - conditional_command_buffers_[update_state_.conditional_idx++]; + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `Case`. + SetCaseConditionKernel set_case_condition(executor); - // Sanity check that we got the correct conditional command buffers. - TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, 2)); + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/10); + spec.AddInProcessSymbol(gpu::GetSetCaseConditionKernel(), + "set_case_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_case_condition)); + } - // Update a kernel that updates conditional handles based on a predicate. - TF_RETURN_IF_ERROR(Launch(set_if_else_condition, ThreadDim(), BlockDim(), - cond_cmd_buffers.handles[0], - cond_cmd_buffers.handles[0], predicate)); + auto set_cond_fn = [&](absl::Span handles) { + int32_t num_handles = handles.size(); - // Skip updating conditional nodes. - update_state_.node_idx += cond_cmd_buffers.handles.size(); + // Pad handles up to size 8 with a default initialized handle. + std::vector padded_handles(handles.begin(), + handles.end()); + padded_handles.resize(8); - return UpdateConditionalCommandBuffers( - absl::MakeSpan(cond_cmd_buffers.command_buffers), builders); - } + return Launch(set_case_condition, ThreadDim(), BlockDim(), + padded_handles[0], padded_handles[1], padded_handles[2], + padded_handles[3], padded_handles[4], padded_handles[5], + padded_handles[6], padded_handles[7], index, num_handles); + }; - return UnsupportedStateError(state_); + return CreateConditionalCommand(set_cond_fn, branches); } tsl::Status GpuCommandBuffer::Finalize() { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index ed048d24f6675a..146e5e3b43c019 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include #include #include "absl/container/inlined_vector.h" @@ -64,6 +66,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { CommandBuffer::Builder then_builder, CommandBuffer::Builder else_builder) override; + tsl::Status Case(StreamExecutor* executor, DeviceMemory index, + std::vector branches) override; + tsl::Status Finalize() override; tsl::Status Update() override; @@ -107,10 +112,22 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // A signature of a device kernels updating conditional handle(s). using SetIfConditionKernel = TypedKernel>; + using SetIfElseConditionKernel = TypedKernel>; + using SetCaseConditionKernel = + TypedKernel, int32_t>; + + // A callback to launch a kernel that updates conditional handles state. + using SetConditionFn = + std::function)>; + // Overwrites the `exec_` handle in a Gpu command buffer by `exec`, and // restores to the original handle when destroyed. This allows us updating // primary graph executable using nested command buffers (command buffers that @@ -128,7 +145,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // For each conditional node in the Gpu graph we keep a record of conditional // command buffers attached to a node, so we can apply updates to them. struct ConditionalCommandBuffers { - void Add(GpuGraphConditionalHandle handle, CommandBuffer command_buffer); + ConditionalCommandBuffers(std::vector handles, + std::vector command_buffers) + : handles(std::move(handles)), + command_buffers(std::move(command_buffers)) {} std::vector handles; std::vector command_buffers; @@ -140,7 +160,7 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { tsl::StatusOr> CreateConditionalNodes( absl::Span handles); - tsl::StatusOr CreateConditionalCommandBuffers( + tsl::StatusOr> CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders); @@ -149,6 +169,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { absl::Span command_buffers, absl::Span builders); + tsl::Status CreateConditionalCommand( + SetConditionFn set_condition, + absl::Span builders); + // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a // dependency between all nodes added to a command buffer. We need a // concept of a barrier at a command buffer level. @@ -229,6 +253,7 @@ inline tsl::Status GpuCommandBuffer::Launch( void* GetSetIfConditionKernel(); void* GetSetIfElseConditionKernel(); +void* GetSetCaseConditionKernel(); } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index a31990bf61b989..669ac500841e96 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -31,6 +31,9 @@ void* GetSetIfConditionKernel() { void* GetSetIfElseConditionKernel() { return reinterpret_cast(&rocm::SetCondition); } +void* GetSetCaseConditionKernel() { + return reinterpret_cast(&rocm::SetCondition); +} } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 522186c4bad7e9..d8a8fd0b9920ec 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -158,6 +158,15 @@ class CommandBufferInterface { CommandBuffer::Builder then_builder, CommandBuffer::Builder else_builder) = 0; + // Adds a conditional operation that will run a command buffer constructed by + // the `branches` builder at `index`. If `index` is out of range, then it will + // run a conditional command buffer constructed by the last builder. + // + // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + virtual tsl::Status Case(StreamExecutor* executor, + DeviceMemory index, + std::vector branches) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From 5625b42ae69507c15bfe99dd4a05e9d3bfba9c9c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 12:49:29 -0800 Subject: [PATCH 107/381] [stream_executor] Add For conditional command to CommandBuffer PiperOrigin-RevId: 585733291 --- .../xla/xla/stream_executor/command_buffer.cc | 7 ++ .../xla/xla/stream_executor/command_buffer.h | 5 + .../cuda/cuda_command_buffer_test.cc | 54 +++++++++++ .../cuda/cuda_conditional_kernels.cu.cc | 15 +++ .../xla/stream_executor/cuda/cuda_driver.cc | 14 +++ .../stream_executor/gpu/gpu_command_buffer.cc | 95 +++++++++++++++---- .../stream_executor/gpu/gpu_command_buffer.h | 30 +++++- .../xla/xla/stream_executor/gpu/gpu_driver.h | 2 +- .../rocm/hip_conditional_kernels.cu.cc | 3 + .../stream_executor_internal.h | 6 ++ 10 files changed, 207 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 669c0ce9543026..b27ac616cda856 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -134,6 +134,13 @@ tsl::Status CommandBuffer::Case(StreamExecutor* executor, return implementation_->Case(executor, index, std::move(branches)); } +tsl::Status CommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_index, + Builder body_builder) { + return implementation_->For(executor, num_iteration, loop_index, + std::move(body_builder)); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 0ba382256bdd0d..1c064460b369e8 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -146,6 +146,11 @@ class CommandBuffer { tsl::Status Case(StreamExecutor* executor, DeviceMemory index, std::vector branches); + // Adds a conditional operation that will run a command buffer constructed by + // the `body_builder` exactly `num_iteration` times. + tsl::Status For(StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_index, Builder body_builder); + //--------------------------------------------------------------------------// // Finalizes command buffer and makes it executable. Once command buffer is diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 5cf467a8ff3109..8322fa0a672263 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -510,6 +510,60 @@ TEST(CudaCommandBufferTest, ConditionalCase) { ASSERT_EQ(dst, expected_mul); } +TEST(CudaCommandBufferTest, ConditionalFor) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + if (!CommandBuffer::SupportsConditionalCommands(platform)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + AddI32Kernel add(executor); + + { // Load addition kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); + TF_ASSERT_OK(executor->GetKernel(spec, &add)); + } + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=0, loop_index=0 + DeviceMemory loop_index = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&loop_index, 0, sizeof(int32_t)); + stream.ThenMemset32(&a, 1, byte_length); + stream.ThenMemZero(&b, byte_length); + + // Loop body: b = a + b + CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, b); + }; + + int32_t num_iters = 10; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.For(executor, num_iters, loop_index, body_builder)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 42); + stream.ThenMemcpy(dst.data(), b, byte_length); + + std::vector expected = {10, 10, 10, 10}; + ASSERT_EQ(dst, expected); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index c8eff5ed7dfe29..8796c17461febb 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -75,11 +75,22 @@ __global__ void SetCaseCondition( } } +__global__ void SetForCondition(cudaGraphConditionalHandle handle, + int32_t* loop_index, int32_t num_iterations) { + if (*loop_index < num_iterations) { + cudaGraphSetConditional(handle, 1); + } else { + cudaGraphSetConditional(handle, 0); + } + *loop_index += 1; +} + #else // CUDA graph conditionals are not available __global__ void SetIfCondition() {} __global__ void SetIfElseCondition() {} __global__ void SetCaseCondition() {} +__global__ void SetForCondition() {} #endif @@ -100,6 +111,10 @@ void* GetSetCaseConditionKernel() { return reinterpret_cast(&cuda::SetCaseCondition); } +void* GetSetForConditionKernel() { + return reinterpret_cast(&cuda::SetForCondition); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 79ee95668a6f5c..e9effc95c21025 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -736,6 +736,16 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return ::tsl::OkStatus(); } +static std::string ConditionalTypeToString( + GpuDriver::GpuGraphConditionalNodeParams::Type type) { + switch (type) { + case GpuDriver::GpuGraphConditionalNodeParams::Type::kIf: + return "IF"; + case GpuDriver::GpuGraphConditionalNodeParams::Type::kWhile: + return "WHILE"; + } +} + /* static */ tsl::StatusOr GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, absl::Span deps, @@ -744,6 +754,7 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, // Add conditional node to a graph. if (auto* conditional = std::get_if(¶ms)) { VLOG(2) << "Add conditional node to a graph " << graph + << "; type: " << ConditionalTypeToString(conditional->type) << "; deps: " << deps.size(); CUgraphNodeParams cu_params; @@ -758,6 +769,9 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, case GpuDriver::GpuGraphConditionalNodeParams::Type::kIf: cu_params.conditional.type = CU_GRAPH_COND_TYPE_IF; break; + case GpuDriver::GpuGraphConditionalNodeParams::Type::kWhile: + cu_params.conditional.type = CU_GRAPH_COND_TYPE_WHILE; + break; } RETURN_IF_CUDA_RES_ERROR( diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 557c9c2a6afa7f..7afd7c11c1cd42 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -276,6 +277,14 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, // Command buffer condtitional commands API //--------------------------------------------------------------------------// +/*static*/ GpuCommandBuffer::ConditionBuilder +GpuCommandBuffer::ToConditionBuilder(CommandBuffer::Builder builder) { + return [builder = std::move(builder)](CommandBuffer* cmd_buffer, + GpuGraphConditionalHandle) { + return builder(cmd_buffer); + }; +} + tsl::StatusOr> GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { std::vector handles; @@ -288,7 +297,7 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { tsl::StatusOr> GpuCommandBuffer::CreateConditionalNodes( - absl::Span handles) { + ConditionType type, absl::Span handles) { std::vector conditional_graphs; using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; @@ -299,7 +308,7 @@ GpuCommandBuffer::CreateConditionalNodes( GpuGraphNodeHandle* node = &nodes_.emplace_back(); ConditionalParams params; - params.type = ConditionalParams::Type::kIf; + params.type = type; params.handle = handle; params.context = parent_->gpu_context(); @@ -317,7 +326,7 @@ tsl::StatusOr> GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, - absl::Span builders) { + absl::Span builders) { std::vector cmd_buffers; // Conditional command buffers always created in nested mode and with @@ -330,7 +339,8 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( nested, graphs[i], is_owned_graph); auto command_buffer = CommandBuffer::Wrap(std::move(command_buffer_impl)); - TF_RETURN_IF_ERROR(builders[i](&command_buffer)); + + TF_RETURN_IF_ERROR(builders[i](&command_buffer, handles[i])); TF_RETURN_IF_ERROR(command_buffer.Finalize()); cmd_buffers.push_back(std::move(command_buffer)); @@ -340,23 +350,24 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( } tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( + absl::Span handles, absl::Span command_buffers, - absl::Span builders) { + absl::Span builders) { for (size_t i = 0; i < command_buffers.size(); ++i) { // Use parent graph executable for conditional command buffer update. ScopedGpuGraphExec scoped_exec(Cast(&command_buffers[i]), exec_); // Update command buffer using user-provided builder callback. TF_RETURN_IF_ERROR(command_buffers[i].Update()); - TF_RETURN_IF_ERROR(builders[i](&command_buffers[i])); + TF_RETURN_IF_ERROR(builders[i](&command_buffers[i], handles[i])); TF_RETURN_IF_ERROR(command_buffers[i].Finalize()); } return tsl::OkStatus(); } tsl::Status GpuCommandBuffer::CreateConditionalCommand( - SetConditionFn set_condition, - absl::Span builders) { + ConditionType type, SetConditionFn set_condition, + absl::Span builders) { // Every conditional command buffer is controlled by its own handle. size_t num_handles = builders.size(); @@ -367,7 +378,7 @@ tsl::Status GpuCommandBuffer::CreateConditionalCommand( TF_RETURN_IF_ERROR(set_condition(handles)); // Create conditional command buffer for each builder. - TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(handles)); + TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(type, handles)); TF_ASSIGN_OR_RETURN(auto cmd_buffers, CreateConditionalCommandBuffers( handles, graphs, builders)); @@ -392,6 +403,7 @@ tsl::Status GpuCommandBuffer::CreateConditionalCommand( update_state_.node_idx += num_handles; return UpdateConditionalCommandBuffers( + cond_cmd_buffers.handles, absl::MakeSpan(cond_cmd_buffers.command_buffers), builders); } @@ -401,7 +413,7 @@ tsl::Status GpuCommandBuffer::CreateConditionalCommand( tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) { - DCHECK(executor->implementation() == parent_); // NOLINT + DCHECK(executor->implementation() == parent_); // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on // every call to `If`. @@ -418,16 +430,17 @@ tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, predicate); }; - std::array builders = {std::move(then_builder)}; + std::array builders = { + ToConditionBuilder(std::move(then_builder))}; - return CreateConditionalCommand(set_cond_fn, builders); + return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); } tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder, CommandBuffer::Builder else_builder) { - DCHECK(executor->implementation() == parent_); // NOLINT + DCHECK(executor->implementation() == parent_); // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on // every call to `IfElse`. @@ -445,16 +458,17 @@ tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, handles[1], predicate); }; - std::array builders = {std::move(then_builder), - std::move(else_builder)}; + std::array builders = { + ToConditionBuilder(std::move(then_builder)), + ToConditionBuilder(std::move(else_builder))}; - return CreateConditionalCommand(set_cond_fn, builders); + return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); } tsl::Status GpuCommandBuffer::Case( StreamExecutor* executor, DeviceMemory index, std::vector branches) { - DCHECK(executor->implementation() == parent_); // NOLINT + DCHECK(executor->implementation() == parent_); // TODO(ezhulenev): Relax this constraint, we can launch multiple back to back // kernels to update conditional handles in batches of size 8. @@ -488,7 +502,52 @@ tsl::Status GpuCommandBuffer::Case( padded_handles[6], padded_handles[7], index, num_handles); }; - return CreateConditionalCommand(set_cond_fn, branches); + // Wrap all branches into conditional command buffer builders. + absl::InlinedVector builders; + builders.reserve(branches.size()); + for (auto& branch : branches) { + builders.push_back(ToConditionBuilder(std::move(branch))); + } + + return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); +} + +tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, + int32_t num_iteration, + DeviceMemory loop_index, + CommandBuffer::Builder body_builder) { + DCHECK(executor->implementation() == parent_); + + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `For`. + SetForConditionKernel set_for_condition(executor); + + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(gpu::GetSetForConditionKernel(), + "set_for_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_for_condition)); + } + + // TODO(ezhulenev): We currently assume that `loop_index` initialized to + // zero, instead we should explicitly add a memset to clear it. + + auto set_cond_fn = [&](absl::Span handles) { + return Launch(set_for_condition, ThreadDim(), BlockDim(), handles[0], + loop_index, num_iteration); + }; + + auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { + TF_RETURN_IF_ERROR(body_builder(body)); + + // Decide if we want to continue loop iteration. + return body->Launch(set_for_condition, ThreadDim(), BlockDim(), handle, + loop_index, num_iteration); + }; + + std::array builders = {std::move(body)}; + + return CreateConditionalCommand(ConditionType::kWhile, set_cond_fn, builders); } tsl::Status GpuCommandBuffer::Finalize() { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 146e5e3b43c019..d00bb2ab69408d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" @@ -69,6 +70,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { tsl::Status Case(StreamExecutor* executor, DeviceMemory index, std::vector branches) override; + tsl::Status For(StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_index, + CommandBuffer::Builder body_builder) override; + tsl::Status Finalize() override; tsl::Status Update() override; @@ -124,10 +129,23 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { GpuGraphConditionalHandle, GpuGraphConditionalHandle, DeviceMemory, int32_t>; + using SetForConditionKernel = + TypedKernel, int32_t>; + // A callback to launch a kernel that updates conditional handles state. using SetConditionFn = std::function)>; + // An extension of `CommandBuffer::Builder` for building conditional command + // buffers tied to conditional handles. + using ConditionBuilder = + std::function; + + // Wraps a regular command buffer builder into condition builder. + static ConditionBuilder ToConditionBuilder(CommandBuffer::Builder builder); + + using ConditionType = typename GpuDriver::GpuGraphConditionalNodeParams::Type; + // Overwrites the `exec_` handle in a Gpu command buffer by `exec`, and // restores to the original handle when destroyed. This allows us updating // primary graph executable using nested command buffers (command buffers that @@ -158,20 +176,21 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { CreateConditionalHandles(size_t num_handles); tsl::StatusOr> CreateConditionalNodes( - absl::Span handles); + ConditionType type, absl::Span handles); tsl::StatusOr> CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, - absl::Span builders); + absl::Span builders); tsl::Status UpdateConditionalCommandBuffers( + absl::Span handles, absl::Span command_buffers, - absl::Span builders); + absl::Span builders); tsl::Status CreateConditionalCommand( - SetConditionFn set_condition, - absl::Span builders); + ConditionType type, SetConditionFn set_condition, + absl::Span builders); // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a // dependency between all nodes added to a command buffer. We need a @@ -254,6 +273,7 @@ inline tsl::Status GpuCommandBuffer::Launch( void* GetSetIfConditionKernel(); void* GetSetIfElseConditionKernel(); void* GetSetCaseConditionKernel(); +void* GetSetForConditionKernel(); } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 888d00f218dc96..109a64ace754fa 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -438,7 +438,7 @@ class GpuDriver { struct GpuGraphConditionalNodeParams { // Conditional node type. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g04ade961d0263336423eb216fbe514da - enum class Type { kIf }; + enum class Type { kIf, kWhile }; // A struct for returning output arguments back to the caller. struct Result { diff --git a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index 669ac500841e96..a8cbe5376ae0b9 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -34,6 +34,9 @@ void* GetSetIfElseConditionKernel() { void* GetSetCaseConditionKernel() { return reinterpret_cast(&rocm::SetCondition); } +void* GetSetForConditionKernel() { + return reinterpret_cast(&rocm::SetCondition); +} } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index d8a8fd0b9920ec..af5f14427f2b4f 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -167,6 +167,12 @@ class CommandBufferInterface { DeviceMemory index, std::vector branches) = 0; + // Adds a conditional operation that will run a command buffer constructed by + // the `body_builder` exactly `num_iteration` times. + virtual tsl::Status For(StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_index, + CommandBuffer::Builder body_builder) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From 53c0900fdbb023ef48d582be01633083d98cadf4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 13:28:45 -0800 Subject: [PATCH 108/381] [stream_executor] NFC: Do not leak internal stream executor header PiperOrigin-RevId: 585744507 --- third_party/xla/xla/backends/interpreter/compiler.cc | 2 +- third_party/xla/xla/pjrt/metrics.cc | 2 +- third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc | 2 +- third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h | 1 - .../xla/xla/service/gpu/runtime3/command_buffer_thunk.cc | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index 864d98a8269082..3b89c3b6054de1 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -49,7 +49,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/metrics.cc b/third_party/xla/xla/pjrt/metrics.cc index 4dd49508bb5bca..06aaa18121c115 100644 --- a/third_party/xla/xla/pjrt/metrics.cc +++ b/third_party/xla/xla/pjrt/metrics.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/lib/monitoring/counter.h" #include "tsl/lib/monitoring/gauge.h" diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 26039ed5fba41f..97c6e20b551ede 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 7d9dccf861af76..8477c093c6e6cb 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -35,7 +35,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_pimpl.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index 08d5cdb9ca74d4..9dca8e0be18fb7 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" From 37efaf3fe7dba6e43970021a689ac668d6cc9963 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 27 Nov 2023 13:44:02 -0800 Subject: [PATCH 109/381] [XLA:GPU] Triton GEMM: enable fusion of inputs concatenated along non-contracting dimensions. PiperOrigin-RevId: 585748647 --- third_party/xla/xla/service/gpu/BUILD | 3 + .../xla/service/gpu/autotuner_compile_util.cc | 2 +- .../xla/service/gpu/gemm_rewriter_triton.cc | 87 +++++- .../xla/service/gpu/gemm_rewriter_triton.h | 1 + .../service/gpu/gemm_rewriter_triton_test.cc | 147 +++++++++ .../xla/xla/service/gpu/ir_emitter_triton.cc | 281 ++++++++++++++---- .../xla/service/gpu/ir_emitter_triton_test.cc | 26 ++ 7 files changed, 488 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index dd6e91db77af96..a5c4e042163ad0 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -474,7 +474,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Linker", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc index fc7d93ae85b572..4350dc235423af 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc @@ -136,7 +136,7 @@ StatusOr> AutotunerCompileUtil::Compile( GenerateModuleFn extractor) { StatusOr> new_hlo_module = extractor(opts_); if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) { - // Incompatible value of split-k is an expected failure. + // Incompatible value of split-k is an example of an expected failure. return std::unique_ptr(); } else if (!new_hlo_module.status().ok()) { return new_hlo_module.status(); diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 573ea2cdc4d999..0737a0a9ce717a 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -255,6 +255,7 @@ using DimOrderMapOrError = std::variant; // This represents an invalid dimension index. constexpr int kNoDimensionIndex = -1; struct DotProperties { + const int noncontracting_dimension; // Index of dot dimension that can be split. // Currently typically LHS non-contracting one. const int splittable_dimension_index; @@ -394,6 +395,25 @@ bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { DimensionOrderToTensorIterationSpec(other); } +// Logical index of a dimension in `shape` labeled with `label` in the +// `dim_order` describing the shape. +std::optional LogicalIndexOfLabeledDimension( + const Shape& shape, const DimensionOrder& dim_order, const int label) { + auto fragment_it = dim_order.TensorFragmentsOrder().cbegin(); + for (int dim : shape.layout().minor_to_major()) { + const int64_t dim_size = shape.dimensions()[dim]; + int64_t fragments_size = 1; + while (fragments_size < dim_size) { + fragments_size *= fragment_it->full_size(); + if (fragment_it->dst_dim_number() == label) { + return dim; + } + ++fragment_it; + } + } + return std::nullopt; +} + enum class TransformDirection { kInputToOutput, kOutputToInput }; using OldToNewHloMap = @@ -516,8 +536,11 @@ FusionContext FusionContext::FromDotOperand(const HloInstruction& dot, splittable_dimension_index = NonContractingDimensionIndex(dot, operand_number); } - FusionContext context(DotProperties{splittable_dimension_index}, - DotRequirements(kNoSplitRequirement)); + FusionContext context( + DotProperties{ + static_cast(NonContractingDimensionIndex(dot, operand_number)), + splittable_dimension_index}, + DotRequirements(kNoSplitRequirement)); context.dim_orders_[dot.operand(operand_number)] = DimensionOrder::FromDotOperandOrOutput(*dot.operand(operand_number), split_k_dimension_index); @@ -536,7 +559,8 @@ FusionContext FusionContext::FromDotOutput( // LHS non-contracting follows (batch is absent in this case). splittable_dimension_index = (split_k > 1) ? 1 : 0; } - FusionContext context(DotProperties{splittable_dimension_index}, + FusionContext context(DotProperties{/*noncontracting_dimension=*/-1, + splittable_dimension_index}, DotRequirements(splittable_dimension_major_part_size)); context.dim_orders_[&dot] = DimensionOrder::FromDotOperandOrOutput(dot); return context; @@ -943,6 +967,18 @@ FusionContext::GetPropagatedDimOrdersForDimAlteringOp( dst_logical[i] = src_logical[i]; } } + } else if (hlo.opcode() == HloOpcode::kConcatenate) { + dst_logical.resize(src_logical.size()); + for (int i = 0; i < src_logical.size(); ++i) { + dst_logical[i] = src_logical[i]; + if (i == hlo.concatenate_dimension()) { + if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { + return FusionDecision("Unsupported concatenation."); + } + dst_logical[i][0]->set_size(dst->shape().dimensions(i)); + dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); + } + } } else if (hlo.opcode() == HloOpcode::kCopy) { // Copy preserves the logical shape, just permutes the layout. CHECK(ShapeUtil::SameDimensions(src.shape(), dst->shape())); @@ -1045,6 +1081,13 @@ FusionContext::GetPropagatedDimOrdersForDimAlteringOp( const HloInstruction& hlo, const TransformDirection direction, const DimensionOrder& src_dim_order, const HeroProperties& properties) { VLOG(7) << "Analyzing " << hlo.ToString(); + if (hlo.opcode() != HloOpcode::kParameter && + direction == TransformDirection::kOutputToInput && + absl::c_any_of(hlo.users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kConcatenate; + })) { + return "No fusion into concatenations"; + } if (hlo.opcode() == HloOpcode::kParameter || hlo_query::IsScalarConstant(&hlo)) { CHECK(direction == TransformDirection::kOutputToInput); @@ -1093,6 +1136,40 @@ FusionContext::GetPropagatedDimOrdersForDimAlteringOp( } return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, properties); + } else if (hlo.opcode() == HloOpcode::kConcatenate && + direction == TransformDirection::kOutputToInput) { + if (!std::holds_alternative(properties)) { + return "Concatenations for now are only supported in GEMM fusions."; + } + auto dim = LogicalIndexOfLabeledDimension( + hlo.shape(), src_dim_order, + std::get(properties).noncontracting_dimension); + if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { + return "Unsupported concatenation."; + } + if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) { + return operand->user_count() > 1; + })) { + return FusionDecision( + "Concatenation has to be the only user of its inputs."); + } + if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { + // In the current simple implementation of concatenation the size of + // each of its inputs along the concatenated dimension has to be + // divisible by the tile size used for this dimension. Concatenations + // with any operand not divisible by kMinConcatFragmentSize will not + // be fused; tiling configurations with tile size for this dimension + // larger than kMinConcatFragmentSize will not be emitted. + constexpr int kMinConcatFragmentSize = 128; + return operand->shape().dimensions(hlo.concatenate_dimension()) % + kMinConcatFragmentSize != + 0; + })) { + return FusionDecision( + "One or more operands of concatenation can not be perfectly tiled."); + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); } return "Unimplemented instruction."; } @@ -1636,7 +1713,9 @@ Status FusionContext::PropagateDimensionOrdersToParameters( DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, properties_); - TF_RET_CHECK(std::holds_alternative(result)); + if (std::holds_alternative(result)) { + LOG(FATAL) << std::get(result).Explain(); + } TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); iter_specs[hlo] = DimensionOrderToTensorIterationSpec(dim_orders_.at(hlo)); for (const HloInstruction* operand : hlo->operands()) { diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h index 2f15dadaa883a0..a95b2f02f920ea 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ #include +#include #include #include diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index ab19d7809e98e9..9a9c0bef826420 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -1408,6 +1408,153 @@ ENTRY e { /*subfragments=*/ElementsAre(7)))); } +TEST_F(GemmRewriterTritonLevel2Test, FusedConcatenationIsAnalyzedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[153,1536] parameter(0) + p1 = s8[153,128] parameter(1) + p2 = s8[153,256] parameter(2) + cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1} + cvt = bf16[153,1920] convert(cat) + p3 = bf16[16,153] parameter(3) + ROOT d = bf16[16,1920] dot(p3, cvt), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter())))); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(0), 0), + ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153, + /*slice_start=*/0, /*slice_limit=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(0), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536, + /*slice_start=*/0, /*slice_limit=*/1536, + /*subfragments=*/ElementsAre(1536)))); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(1), 0), + ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153, + /*slice_start=*/0, /*slice_limit=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(1), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, + /*slice_start=*/0, /*slice_limit=*/128, + /*subfragments=*/ElementsAre(128)))); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(2), 0), + ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153, + /*slice_start=*/0, /*slice_limit=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(2), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, + /*slice_start=*/0, /*slice_limit=*/256, + /*subfragments=*/ElementsAre(256)))); +} + +TEST_F(GemmRewriterTritonLevel2Test, IndivisibleConcatenationIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024] parameter(0) + p1 = s8[124,1001] parameter(1) + cat = s8[124,2025] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2025] convert(cat) + p2 = f16[123,124] parameter(2) + ROOT d = f16[2025,123] dot(cvt, p2), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmRewriterTritonLevel2Test, ConcatenationOfContractingIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024] parameter(0) + p1 = s8[124,1024] parameter(1) + cat = s8[124,2048] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2048] convert(cat) + p2 = f16[123,2048] parameter(2) + ROOT d = f16[124,123] dot(cvt, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmRewriterTritonLevel2Test, ConcatenationOfBatchIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024,50] parameter(0) + p1 = s8[124,1024,50] parameter(1) + cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2048,50] convert(cat) + p2 = f16[123,2048,50] parameter(2) + ROOT d = f16[2048,124,123] dot(cvt, p2), + lhs_batch_dims={1}, rhs_batch_dims={1}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmRewriterTritonLevel2Test, + TwoConcatenationsOfSameParametersAreNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[128,2] parameter(0) + p1 = s8[128,2] parameter(1) + cat0 = s8[256,2] concatenate(p0, p1), dimensions={0} + cvt0 = f16[256,2] convert(cat0) + cat1 = s8[256,2] concatenate(p1, p0), dimensions={0} + n1 = s8[256,2] negate(cat1) + cvt1 = f16[256,2] convert(n1) + a = f16[256,2] add(cvt1, cvt0) + p2 = f16[2,18] parameter(2) + ROOT d = f16[18,256] dot(p2, a), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Concatenate(), + m::Parameter())))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index ab8d1c8cf59b6c..7f5d4ff2b412f9 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -32,7 +32,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -666,8 +669,14 @@ StatusOr EmitScope( absl::flat_hash_map& values) { for (const HloInstruction* hlo : instructions) { Value result; - if (hlo->opcode() == HloOpcode::kParameter) { - // Parameter loads are handled outside EmitScope. + if (hlo->opcode() == HloOpcode::kConcatenate) { + // Parameter loads and their concatenations are handled outside EmitScope. + TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); + continue; + } else if (hlo->opcode() == HloOpcode::kParameter) { + if (hlo->users()[0]->opcode() == HloOpcode::kConcatenate) { + continue; + } TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; } else if (hlo->opcode() == HloOpcode::kConstant) { @@ -1072,6 +1081,33 @@ struct Side { std::optional batch_dim_idx; }; +// if (index < limits[0]) { +// return choices[0]; +// } else if (index < limits[1]) { +// return choices[1]; +// } else if (...) { +// ... +// } else { +// return choices.back(); +// } +Value EmitMultiSelect(ImplicitLocOpBuilder b, Value index, ValueRange limits, + ValueRange choices) { + CHECK_EQ(choices.size() - 1, limits.size()); + Value result = choices[0]; + for (int i = 0; i < choices.size() - 1; ++i) { + result = b.create( + b.create(ma::CmpIPredicate::slt, index, limits[i]), result, + choices[i + 1]); + } + return result; +} + +Status UncompilableMatmul(absl::string_view explanation) { + Status s = absl::CancelledError(explanation); + s.SetPayload(kUncompilableFusion, absl::Cord(explanation)); + return s; +} + class MatMulEmitterHelper { public: MatMulEmitterHelper(absl::string_view libdevice_path, @@ -1148,49 +1184,140 @@ class MatMulEmitterHelper { values); } - Value EmitTensorPointer(const HloInstruction* hlo, const Side& side, - Value base, Value pid_k, - std::vector& boundary_checks) { - auto pid_batch = - b_.create(launch_config_.batch_program_id_dim); - + StatusOr EmitTensorPointer(const HloInstruction* hlo, const Side& side, + ValueRange bases, Value pid_k, + std::vector& boundary_checks) { + // Parameters of MakeTensorPtrOp to be generated by this function. + Value base; std::vector bounds; std::vector strides; // Offsets from tensor origin, same for all thread blocks. std::vector tensor_offsets; - // Offsets for a given thread block, typically pid * block size. - std::vector block_offsets; std::vector block_dims; std::vector dim_order; + // Offsets for a given thread block, typically pid * block size. + // Used in a one-off AdvanceOp applied to the generated MakeTensorPtrOp. + std::vector block_offsets; + + // Concatenations of parameters are handled during generation of block + // pointers because of a limitation of implementation of block pointers + // in the Triton compiler: block pointers are not supported inside + // conditionals. + // Therefore instead of directly using a conditional to emit a concatenation + // and emitting its inputs inside the cases a single block pointer is + // emitted for all inputs, but all its properties (base, strides etc) get + // generated conditionally on the position of the current thread block + // within the concatenated dimension. + + // Index of concatenated dimension if present, -1 otherwise. + int concat_dim_idx; + // Offsets along the concatenated dimension at which operands change. + std::vector concat_boundaries; + // Block index along the concatenated dimension * block size. + Value concat_dim_pid_offset; + + if (hlo->opcode() == HloOpcode::kConcatenate) { + // For now only non-contracting dimension can be concatenated. + concat_dim_idx = (side.scope == TritonFusionAnalysis::Scope::LHS) + ? dims_.lhs_noncontracting_dim_idx + : dims_.rhs_noncontracting_dim_idx; + const DimProperties& properties = [&] { + for (const DimProperties& dim : side.tiled_dims) { + if (dim.index == concat_dim_idx) { + return dim; + } + } + LOG(FATAL) << "Missing dimension."; + }(); + CHECK_EQ(bases.size(), hlo->operand_count()); + + concat_boundaries.reserve(hlo->operand_count() - 1); + int64_t accumulated_size = 0; + for (int i = 0; i < hlo->operand_count() - 1; ++i) { + const int64_t operand_size = + analysis_.IterSpec(side.scope, hlo->operand(i), concat_dim_idx) + ->at(0) + .count; + if (operand_size % properties.block_size != 0) { + return UncompilableMatmul( + "Operand is not divisible by the block size."); + } + accumulated_size += operand_size; + concat_boundaries.push_back(Cst32(accumulated_size)); + } + + concat_dim_pid_offset = + b_.create(properties.pid, Cst32(properties.block_size)); + base = + EmitMultiSelect(b_, concat_dim_pid_offset, concat_boundaries, bases); + } else { + concat_dim_idx = -1; + base = bases[0]; + } + auto add_dim = [&](const DimProperties& properties) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis_.IterSpec(side.scope, hlo, properties.index); - if (spec == nullptr) { + if (analysis_.IterSpec(side.scope, hlo, properties.index) == nullptr) { return; } - int64_t count = spec->at(0).count; - if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && - properties.index == dims_.out_lhs_noncontracting_dim_idx && - spec->size() == 1 && dims_.lhs_noncontracting_split.has_value()) { - // Dimension of the output produced by the non-contracting LHS one - // is logically split, major part is addressed using pid_batch. - count /= *dims_.lhs_noncontracting_split; - } - if (count % (properties.block_size * properties.split_value) != 0) { - boundary_checks.push_back(bounds.size()); - } - bounds.push_back(Cst64(count)); - strides.push_back(Cst64(spec->at(0).stride)); - block_offsets.push_back( + Value pid_offset = (properties.pid == nullptr) ? Cst32(0) : b_.create(properties.pid, - Cst32(properties.block_size))); - tensor_offsets.push_back(Cst32(spec->at(0).slice_start)); + Cst32(properties.block_size)); + std::vector inputs; + if (hlo->opcode() == HloOpcode::kConcatenate) { + inputs.insert(inputs.end(), hlo->operands().cbegin(), + hlo->operands().cend()); + } else { + inputs = {hlo}; + } + std::vector specs; + std::vector input_strides; + std::vector input_offsets; + std::vector input_bounds; + specs.reserve(inputs.size()); + input_strides.reserve(inputs.size()); + input_offsets.reserve(inputs.size()); + input_bounds.reserve(inputs.size()); + for (const HloInstruction* input : inputs) { + specs.push_back( + analysis_.IterSpec(side.scope, input, properties.index)); + input_strides.push_back(Cst64(specs.back()->at(0).stride)); + input_offsets.push_back(b_.create( + pid_offset, input_offsets.empty() + ? Cst32(0) + : concat_boundaries[input_offsets.size() - 1])); + input_bounds.push_back(Cst64(specs.back()->at(0).count)); + } + strides.push_back(EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, input_strides)); + if (properties.index == concat_dim_idx) { + block_offsets.push_back( + EmitMultiSelect(b_, pid_offset, concat_boundaries, input_offsets)); + bounds.push_back( + EmitMultiSelect(b_, pid_offset, concat_boundaries, input_bounds)); + } else { + block_offsets.push_back(pid_offset); + int64_t count = specs.back()->at(0).count; + if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && + properties.index == dims_.out_lhs_noncontracting_dim_idx && + specs.back()->size() == 1 && + dims_.lhs_noncontracting_split.has_value()) { + // Dimension of the output produced by the non-contracting LHS one + // is logically split, major part is addressed using pid_batch. + count /= *dims_.lhs_noncontracting_split; + } + bounds.push_back(Cst64(count)); + if (count % (properties.block_size * properties.split_value) != 0) { + boundary_checks.push_back(bounds.size() - 1); + } + } + tensor_offsets.push_back(Cst32(specs.back()->at(0).slice_start)); block_dims.push_back(properties.block_size); dim_order.emplace(dim_order.begin(), dim_order.size()); }; + for (const DimProperties& dim : side.tiled_dims) { add_dim(dim); } @@ -1225,6 +1352,8 @@ class MatMulEmitterHelper { } } if (stride_batch != 0) { + Value pid_batch = + b_.create(launch_config_.batch_program_id_dim); Value pid_offset_batch = b_.create( b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), Cst(stride_batch)); @@ -1243,6 +1372,7 @@ class MatMulEmitterHelper { } if (block_dims.empty()) { + // Load of a scalar. return base; } auto tensor_ptr = @@ -1295,6 +1425,46 @@ LaunchDimensions GetMatMulLaunchDimensions(const TritonFusionAnalysis& analysis, return launch_config.launch_dims; } +SmallVector GetArguments(mlir::triton::FuncOp fn, + const HloInstruction& input) { + if (input.opcode() == HloOpcode::kParameter) { + return {fn.getArgument(input.parameter_number())}; + } else if (input.opcode() == HloOpcode::kConcatenate) { + SmallVector result; + for (const HloInstruction* operand : input.operands()) { + result.push_back(fn.getArgument(operand->parameter_number())); + } + return result; + } + LOG(FATAL) << "Unexpected opcode: " << input.opcode(); +} + +// Concatenations can currently only be applied directly to parameters; +// all concatenated parameters share the same block pointer. This function +// returns all inputs of a kernel: concatenations of parameters and standalone +// parameters. +ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, + const TritonFusionAnalysis::Scope scope) { + ConstHloInstructionSet result; + for (const HloInstruction* parameter : analysis.ScopeParameters(scope)) { + if (absl::c_any_of(parameter->users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kConcatenate; + })) { + // Concatenation is always the only user of its parameters by + // construction. + CHECK_EQ(parameter->users().size(), 1); + for (const HloInstruction* operand : parameter->users()[0]->operands()) { + // All operands of a concatenation have to be computation parameters. + CHECK_EQ(operand->opcode(), HloOpcode::kParameter); + } + result.insert(parameter->users()[0]); + } else { + result.insert(parameter); + } + } + return result; +} + // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, const TritonFusionAnalysis& analysis, @@ -1365,7 +1535,7 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, // Parameters are passed to the loop in non-trivial order, these maps help // finding them and their attributes. - absl::flat_hash_map iter_args_to_parameters; + absl::flat_hash_map iter_args_to_inputs; absl::flat_hash_map> iter_args_to_boundary_checks; Side lhs{ @@ -1399,11 +1569,11 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, // Load tiles of all parameters of LHS and RHS scopes and advance pointers. for (int i = 0; i < iter_args.size() - 1; ++i) { const bool is_lhs = - i < analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(); + i < ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); Side& side = is_lhs ? lhs : rhs; auto& values = is_lhs ? values_lhs : values_rhs; - const HloInstruction* param_hlo = iter_args_to_parameters[i]; + const HloInstruction* param_hlo = iter_args_to_inputs[i]; Type param_ty = TritonType(b, param_hlo->shape().element_type()); Type param_storage_ty = StorageType(b, param_ty); Value param_value = @@ -1416,8 +1586,8 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, CHECK(values.insert({param_hlo, param_value}).second); SmallVector increments; for (const DimProperties& dim : side.tiled_dims) { - const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( - side.scope, iter_args_to_parameters[i], dim.index); + const TensorIterationSpec::DimIterationSpec* spec = + analysis.IterSpec(side.scope, iter_args_to_inputs[i], dim.index); if (spec == nullptr || spec->at(0).stride == 0) { continue; } @@ -1494,21 +1664,21 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, b.create(iter_args_next); }; - // Pointers to parameters of LHS scope, then RHS, then the accumulator + // Pointers to inputs of LHS scope, then RHS, then the accumulator // that change with every loop iteration and are passed between them. - // LHS and RHS can use same HLO computation parameters, but because they use - // different pointers they have to be stored separately for each scope. SmallVector iter_args; iter_args.reserve( - analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size() + - analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size() + 1); + ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size() + + ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size() + 1); for (const Side& side : {lhs, rhs}) { - for (const HloInstruction* param : analysis.ScopeParameters(side.scope)) { - CHECK(iter_args_to_parameters.insert({iter_args.size(), param}).second); - iter_args.push_back(emitter.EmitTensorPointer( - param, side, fn.getArgument(param->parameter_number()), pid_k, - iter_args_to_boundary_checks[iter_args.size()])); + for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + CHECK(iter_args_to_inputs.insert({iter_args.size(), input}).second); + TF_ASSIGN_OR_RETURN(Value tensor_ptr, + emitter.EmitTensorPointer( + input, side, GetArguments(fn, *input), pid_k, + iter_args_to_boundary_checks[iter_args.size()])); + iter_args.push_back(tensor_ptr); } } @@ -1527,14 +1697,15 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, if (std::vector to_emit = emitter.EpiloguePostOrderTransitiveOperands(root); !to_emit.empty()) { - for (const HloInstruction* parameter : - analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT)) { + for (const HloInstruction* input : + ScopeInputs(analysis, TritonFusionAnalysis::Scope::OUTPUT)) { std::vector boundary_checks; - Value tensor_pointer = emitter.EmitTensorPointer( - parameter, out, fn.getArgument(parameter->parameter_number()), pid_k, - boundary_checks); + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer(input, out, GetArguments(fn, *input), pid_k, + boundary_checks)); CHECK(values_out - .insert({parameter, + .insert({input, EmitParameterLoad(b, tensor_pointer, boundary_checks)}) .second); } @@ -1550,10 +1721,12 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, const HloInstruction* producer = root->shape().IsTuple() ? root->operand(i) : root; std::vector boundary_checks; - Value tensor_pointer = emitter.EmitTensorPointer( - producer, out, - fn.getArgument(i + dot_instr->parent()->num_parameters()), pid_k, - boundary_checks); + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer( + producer, out, + {fn.getArgument(i + dot_instr->parent()->num_parameters())}, pid_k, + boundary_checks)); b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 534305831c2a7c..890ddb3ece7381 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -1745,6 +1745,32 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } +TEST_F(TritonGemmLevel2Test, FuseConcatenation) { + const std::string kHloText = R"( +e { + p0 = s8[153,1536] parameter(0) + p1 = s8[153,128] parameter(1) + p2 = s8[153,128] parameter(2) + cat = s8[153,1792] concatenate(p0, p1, p2), dimensions={1} + cvt = bf16[153,1792] convert(cat) + p3 = bf16[16,153] parameter(3) + ROOT d = bf16[16,1792] dot(p3, cvt), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) { constexpr absl::string_view kHloText = R"( HloModule t From bf9bc93141c1f15a56e00a36350f61de9f717997 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 13:48:08 -0800 Subject: [PATCH 110/381] Deduplicate some code by using a refactored portion of MaybeFollowInsStrategyGroup to generate strategies for elementwise ops as well. PiperOrigin-RevId: 585749608 --- .../auto_sharding/auto_sharding.cc | 130 +++++++++--------- 1 file changed, 62 insertions(+), 68 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3e5411a6a685f9..6750ec3e374ba9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -231,6 +231,59 @@ GenerateReshardingCostsAndShardingsForAllOperands( return {resharding_costs, input_shardings_optional}; } +// When computing resharding costs for inputs, this function assumes that the +// shape of the input is the same as the shape of the output (ie. the `shape` +// operand to the function) +void FollowArrayOrTokenStrategyGroup( + const StrategyGroup& src_strategy_group, const Shape& shape, + size_t instruction_id, bool have_memory_cost, + const ClusterEnvironment& cluster_env, + StableHashMap>& + pretrimmed_strategy_map, + StrategyGroup& strategy_group) { + CHECK(shape.IsArray() || shape.IsToken()); + + // Only follows the given strategy when there is no other strategy to be + // restored. + if (!pretrimmed_strategy_map.contains(src_strategy_group.node_idx)) { + strategy_group.following = &src_strategy_group; + } + strategy_group.strategies.reserve(src_strategy_group.strategies.size()); + // Creates the sharding strategies and restores trimmed strategies, if any. + std::vector& pretrimmed_strategies = + pretrimmed_strategy_map[src_strategy_group.node_idx]; + for (int64_t sid = 0; sid < src_strategy_group.strategies.size() + + pretrimmed_strategies.size(); + ++sid) { + const HloSharding* output_spec; + if (sid < src_strategy_group.strategies.size()) { + output_spec = &src_strategy_group.strategies[sid].output_sharding; + } else { + output_spec = + &pretrimmed_strategies[sid - src_strategy_group.strategies.size()] + .output_sharding; + VLOG(1) << "Adding outspec from the trimmed strategy map: " + << output_spec->ToString(); + } + std::string name = ToStringSimple(*output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = + have_memory_cost ? GetBytes(shape) / output_spec->NumTiles() : 0; + size_t num_in_nodes = strategy_group.in_nodes.size(); + std::vector> input_shardings(num_in_nodes, + *output_spec); + std::vector> resharding_costs; + for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { + resharding_costs.push_back(ReshardingCostVector( + strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); + } + + strategy_group.strategies.push_back( + ShardingStrategy({name, *output_spec, compute_cost, communication_cost, + memory_cost, resharding_costs, input_shardings})); + } +} + std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, size_t instruction_id, bool have_memory_cost, @@ -252,49 +305,12 @@ std::unique_ptr MaybeFollowInsStrategyGroup( strategy_group->childs.push_back(std::move(child_strategies)); } } else { - CHECK(shape.IsArray() || shape.IsToken()); strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); strategy_group->in_nodes.push_back(src_strategy_group); - // Only follows the given strategy when there is no other strategy to be - // restored. - if (!pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { - strategy_group->following = src_strategy_group; - } - strategy_group->strategies.reserve(src_strategy_group->strategies.size()); - // Creates the sharding strategies and restores the trimmed strategies if - // there is any. - for (int64_t sid = 0; - sid < src_strategy_group->strategies.size() + - pretrimmed_strategy_map[src_strategy_group->node_idx].size(); - ++sid) { - const HloSharding* output_spec; - if (sid < src_strategy_group->strategies.size()) { - output_spec = &src_strategy_group->strategies[sid].output_sharding; - } else { - output_spec = - &pretrimmed_strategy_map[src_strategy_group->node_idx] - [sid - - src_strategy_group->strategies.size()] - .output_sharding; - VLOG(1) << "Adding outspec from the trimmed strategy map: " - << output_spec->ToString(); - } - std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - have_memory_cost ? GetBytes(shape) / output_spec->NumTiles() : 0; - auto resharding_costs = ReshardingCostVector(src_strategy_group, shape, - *output_spec, cluster_env); - strategy_group->strategies.push_back( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - {std::move(resharding_costs)}, - {*output_spec}})); - } + FollowArrayOrTokenStrategyGroup(*src_strategy_group, shape, instruction_id, + have_memory_cost, cluster_env, + pretrimmed_strategy_map, *strategy_group); } return strategy_group; } @@ -1528,7 +1544,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - const StableHashMap>& + StableHashMap>& pretrimmed_strategy_map, int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs) { @@ -1555,37 +1571,15 @@ std::unique_ptr CreateElementwiseOperatorStrategies( continue; } - auto process_src_strategy_group = - [&](const std::vector& src_strategies) { - for (int64_t sid = 0; sid < src_strategies.size(); ++sid) { - HloSharding output_spec = src_strategies[sid].output_sharding; - std::string name = ToStringSimple(output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - GetBytes(ins->shape()) / output_spec.NumTiles(); - std::vector> resharding_costs; - std::vector> input_shardings; - for (int64_t k = 0; k < ins->operand_count(); ++k) { - resharding_costs.push_back(ReshardingCostVector( - strategy_map.at(ins->operand(k)).get(), - ins->operand(k)->shape(), output_spec, cluster_env)); - input_shardings.push_back(output_spec); - } - - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, - memory_cost, std::move(resharding_costs), input_shardings})); - } - }; StrategyGroup* src_strategy_group = strategy_map.at(ins->operand(i)).get(); CHECK(!src_strategy_group->is_tuple); - process_src_strategy_group(src_strategy_group->strategies); - if (pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { - process_src_strategy_group( - pretrimmed_strategy_map.at(src_strategy_group->node_idx)); - } + FollowArrayOrTokenStrategyGroup(*src_strategy_group, ins->shape(), + instruction_id, + /* have_memory_cost */ true, cluster_env, + pretrimmed_strategy_map, *strategy_group); } + if (ins->opcode() == HloOpcode::kAdd) { // Adjust the resharding costs for AllReduceReassociate pass. // The AllReduceReassociate pass can simplify From 6150be321640fe12c5706adc9ddb2352f9892ff3 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 27 Nov 2023 13:49:22 -0800 Subject: [PATCH 111/381] Factor out `CreateTmpDir` out to separate file from TF Quantizer and StableHLO Quantizer. PiperOrigin-RevId: 585749900 --- .../mlir/quantization/stablehlo/cc/BUILD | 30 +++++ .../mlir/quantization/stablehlo/cc/io.cc | 42 +++++++ .../mlir/quantization/stablehlo/cc/io.h | 37 ++++++ .../mlir/quantization/stablehlo/cc/io_test.cc | 113 ++++++++++++++++++ .../mlir/quantization/stablehlo/python/BUILD | 1 + .../stablehlo/python/pywrap_quantization.cc | 43 ++++--- .../mlir/quantization/tensorflow/python/BUILD | 3 +- .../python/pywrap_quantize_model.cc | 44 ++++--- 8 files changed, 265 insertions(+), 48 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index ec2bfa892e0c66..f63ffe76399657 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( "//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", @@ -22,3 +23,32 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "io", + srcs = ["io.cc"], + hdrs = ["io.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:env", + ], +) + +tf_cc_test( + name = "io_test", + srcs = ["io_test.cc"], + deps = [ + ":io", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:types", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc new file mode 100644 index 00000000000000..5760f571db830b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "tsl/platform/env.h" + +namespace stablehlo::quantization::io { + +absl::StatusOr CreateTmpDir(tsl::Env* const env) { + std::string tmp_dir; + env->LocalTempFilename(&tmp_dir); + if (!env->RecursivelyCreateDir(tmp_dir).ok()) { + return absl::InternalError( + absl::StrFormat("Failed to create tmp dir: '%s'", tmp_dir)); + } + + return tmp_dir; +} + +absl::StatusOr CreateTmpDir() { + // The overloaded function uses the default env. + return CreateTmpDir(tsl::Env::Default()); +} + +} // namespace stablehlo::quantization::io diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h new file mode 100644 index 00000000000000..11c2e3c949f92b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ + +#include + +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" + +namespace stablehlo::quantization::io { + +// Creates a temporary directory on an environment defined by the implementation +// of `tsl::Env` and returns its path. Returns an InternalError status if +// failed. +absl::StatusOr CreateTmpDir(tsl::Env* env); + +// Creates a temporary directory and returns its path. Returns an InternalError +// status if failed. The file system used will be the default environment +// returned by `tsl::Env::Default`. +absl::StatusOr CreateTmpDir(); + +} // namespace stablehlo::quantization::io + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc new file mode 100644 index 00000000000000..72f004a05edbd8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" + +#include +#include +#include + +#include +#include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/status.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/types.h" + +namespace { + +using ::stablehlo::quantization::io::CreateTmpDir; +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// A test-only derived class of `tsl::Env` which is broken. Used to cause +// failure for the `CreateTmpDir` function. Each of the overridden member +// functions implements a dummy functionality just to be able to create an +// instance of this class. +class TestEnvBrokenFileSystem : public tsl::Env { + public: + TestEnvBrokenFileSystem() = default; + + bool MatchPath(const tsl::string& path, const tsl::string& pattern) override { + return false; + } + + void SleepForMicroseconds(int64_t micros) override {} + + tsl::string GetRunfilesDir() override { return tsl::string("dummy_path"); } + + int32_t GetCurrentThreadId() override { return 0; } + + tsl::Thread* StartThread(const tsl::ThreadOptions& thread_options, + const tsl::string& name, + absl::AnyInvocable fn) override { + return nullptr; + } + + bool GetCurrentThreadName(tsl::string* name) override { return false; } + + void SchedClosure(absl::AnyInvocable closure) override {} + + void SchedClosureAfter(int64_t micros, + absl::AnyInvocable closure) override {} + + absl::Status LoadDynamicLibrary(const char* library_filename, + void** handle) override { + return tsl::OkStatus(); + } + + absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { + return tsl::OkStatus(); + } + + tsl::string FormatLibraryFileName(const tsl::string& name, + const tsl::string& version) override { + return tsl::string("dummy_path"); + } + + absl::Status GetFileSystemForFile(const std::string& fname, + tsl::FileSystem** result) override { + return absl::InternalError("Broken file system"); + } + + private: + // This is the part that essentially breaks the `CreateTmpDir` function + // because it doesn't provide any available temp dirs. + void GetLocalTempDirectories(std::vector* list) override {} +}; + +TEST(IoTest, CreateTmpDirReturnsValidTmpPath) { + absl::StatusOr tmp_dir = CreateTmpDir(); + + ASSERT_THAT(tmp_dir, IsOk()); + + auto* const env = tsl::Env::Default(); + EXPECT_THAT(env->FileExists(*tmp_dir), IsOk()); +} + +TEST(IoTest, CreateTmpDirWhenInvalidPathReturnsInternalError) { + TestEnvBrokenFileSystem test_env{}; + absl::StatusOr tmp_dir = CreateTmpDir(&test_env); + + EXPECT_THAT(tmp_dir, StatusIs(absl::StatusCode::kInternal, + HasSubstr("Failed to create tmp dir"))); +} + +} // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 996af258bac493..4ef4ad4ce1a720 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -85,6 +85,7 @@ tf_python_pybind_extension( srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index 76453b285eca66..ce66277ddd0e5d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -25,6 +25,7 @@ limitations under the License. #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" @@ -36,27 +37,13 @@ namespace py = pybind11; namespace { +using ::stablehlo::quantization::io::CreateTmpDir; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; -// TODO: b/307624867 - Factor out this function to a separate file. -// Creates a temporary directory and returns its path. -std::string CreateTmpDir() { - tsl::Env* const env = tsl::Env::Default(); - - std::string tmp_dir; - env->LocalTempFilename(&tmp_dir); - if (!env->RecursivelyCreateDir(tmp_dir).ok()) { - throw py::value_error( - absl::StrFormat("Failed to create tmp dir: '%s'", tmp_dir)); - } - - return tmp_dir; -} - // TODO: b/312371048 - Factor out this function to a separate file. // Enables debugging on `exported_model` by updating the `DumpTensor` ops. // @@ -131,15 +118,21 @@ PYBIND11_MODULE(pywrap_quantization, m) { const ExportedModel exported_model_ids_assigned = py_function_library.AssignIdsToCustomAggregatorOps(*exported_model); - const std::string precalibrated_saved_model_dir = CreateTmpDir(); + const absl::StatusOr precalibrated_saved_model_dir = + CreateTmpDir(); + if (!precalibrated_saved_model_dir.ok()) { + throw py::value_error(absl::StrFormat( + "Failed to create tmp dir for precalibrated saved model: %s", + precalibrated_saved_model_dir.status().ToString())); + } py_function_library.SaveExportedModel( - precalibrated_saved_model_dir, exported_model_ids_assigned, + *precalibrated_saved_model_dir, exported_model_ids_assigned, src_saved_model_path, tags, signature_def_map); ExportedModel calibrated_exported_model = py_function_library.RunCalibration( - precalibrated_saved_model_dir, signature_keys, tags, + *precalibrated_saved_model_dir, signature_keys, tags, exported_model_ids_assigned, quantization_options.calibration_options(), quantization_options.force_graph_mode_calibration(), @@ -152,10 +145,16 @@ PYBIND11_MODULE(pywrap_quantization, m) { src_saved_model_path, tags, signature_def_map); } - const std::string calibrated_saved_model_path = CreateTmpDir(); + const absl::StatusOr calibrated_saved_model_path = + CreateTmpDir(); + if (!calibrated_saved_model_path.ok()) { + throw py::value_error(absl::StrFormat( + "Failed to create tmp dir for calibrated saved model: %s", + calibrated_saved_model_path.status().ToString())); + } py_function_library.SaveExportedModel( - calibrated_saved_model_path, calibrated_exported_model, + *calibrated_saved_model_path, calibrated_exported_model, src_saved_model_path, tags, signature_def_map); const absl::flat_hash_map @@ -165,7 +164,7 @@ PYBIND11_MODULE(pywrap_quantization, m) { const absl::StatusOr post_calibrated_exported_model = QuantizePtqModelPostCalibration( - calibrated_saved_model_path, signature_keys, tags, + *calibrated_saved_model_path, signature_keys, tags, quantization_options, function_aliases_after_calibration); if (!post_calibrated_exported_model.ok()) { return post_calibrated_exported_model.status(); @@ -173,7 +172,7 @@ PYBIND11_MODULE(pywrap_quantization, m) { py_function_library.SaveExportedModel( dst_saved_model_path, *post_calibrated_exported_model, - calibrated_saved_model_path, tags, signature_def_map); + *calibrated_saved_model_path, tags, signature_def_map); return absl::OkStatus(); }, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index c1b0040dc78f30..b9cc102aa33b1b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -203,6 +203,7 @@ tf_python_pybind_extension( deps = [ ":py_function_lib", ":type_casters", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", @@ -211,8 +212,6 @@ tf_python_pybind_extension( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:env", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:import_status_module", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 8109e422e5da5d..bf56d8607d9058 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 @@ -30,6 +29,7 @@ limitations under the License. #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" @@ -37,11 +37,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tsl/platform/env.h" namespace { -using ::tensorflow::GraphDef; +using ::stablehlo::quantization::io::CreateTmpDir; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; @@ -53,20 +52,6 @@ using ::tensorflow::quantization::QuantizePtqModelPreCalibration; using ::tensorflow::quantization::QuantizeQatModel; using ::tensorflow::quantization::QuantizeWeightOnly; -// Creates a temporary directory and returns its path. -std::string CreateTmpDir() { - tsl::Env* const env = tsl::Env::Default(); - - std::string tmp_dir; - env->LocalTempFilename(&tmp_dir); - if (!env->RecursivelyCreateDir(tmp_dir).ok()) { - throw py::value_error( - absl::StrFormat("Failed to create tmp dir: '%s'", tmp_dir)); - } - - return tmp_dir; -} - // Enables debugging on `exported_model` by updating the `DumpTensor` ops. // // Saves the current model to `debugger_options.unquantized_dump_model_path()` @@ -283,15 +268,20 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { const ExportedModel exported_model_ids_assigned = py_function_library.AssignIdsToCustomAggregatorOps(*exported_model); - const std::string precalibrated_saved_model_dir = CreateTmpDir(); + const absl::StatusOr precalibrated_saved_model_dir = + CreateTmpDir(); + if (!precalibrated_saved_model_dir.ok()) { + throw py::value_error( + precalibrated_saved_model_dir.status().ToString()); + } py_function_library.SaveExportedModel( - precalibrated_saved_model_dir, exported_model_ids_assigned, + *precalibrated_saved_model_dir, exported_model_ids_assigned, src_saved_model_path, tags, signature_def_map); ExportedModel calibrated_exported_model = py_function_library.RunCalibration( - precalibrated_saved_model_dir, signature_keys, tags, + *precalibrated_saved_model_dir, signature_keys, tags, exported_model_ids_assigned, quantization_options.calibration_options(), quantization_options.force_graph_mode_calibration(), @@ -304,9 +294,15 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { src_saved_model_path, tags, signature_def_map); } - const std::string calibrated_saved_model_path = CreateTmpDir(); + const absl::StatusOr calibrated_saved_model_path = + CreateTmpDir(); + if (!calibrated_saved_model_path.ok()) { + throw py::value_error( + calibrated_saved_model_path.status().ToString()); + } + py_function_library.SaveExportedModel( - calibrated_saved_model_path, calibrated_exported_model, + *calibrated_saved_model_path, calibrated_exported_model, src_saved_model_path, tags, signature_def_map); const absl::flat_hash_map @@ -316,14 +312,14 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { const absl::StatusOr post_calibrated_exported_model = QuantizePtqModelPostCalibration( - calibrated_saved_model_path, signature_keys, tags, + *calibrated_saved_model_path, signature_keys, tags, quantization_options, function_aliases_after_calibration); if (!post_calibrated_exported_model.ok()) return post_calibrated_exported_model.status(); py_function_library.SaveExportedModel( dst_saved_model_path, *post_calibrated_exported_model, - calibrated_saved_model_path, tags, signature_def_map); + *calibrated_saved_model_path, tags, signature_def_map); return absl::OkStatus(); }, From dc9ba8cd2226aac23360f9f7cac105618b1cb8c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 14:13:15 -0800 Subject: [PATCH 112/381] Search for possible predecessors instead of only caching one hops. Some hlo computations were getting scheduling loops because of the alias checking. PiperOrigin-RevId: 585756812 --- .../xla/service/latency_hiding_scheduler.cc | 45 +++++++++----- .../xla/service/latency_hiding_scheduler.h | 4 ++ .../service/latency_hiding_scheduler_test.cc | 61 +++++++++++++++++++ 3 files changed, 95 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index bc2c49bb0c6805..f76d438aeed193 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -1385,6 +1385,28 @@ std::string HloEdge::ToString() const { " latency: ", Latency(), "\n"); } +bool HloScheduleGraph::IsPredecessorTransitively( + const HloGraphNode* node, const HloGraphNode* possible_predecessor) { + absl::flat_hash_set visited = {possible_predecessor}; + std::vector to_visit_queue = {node}; + while (!to_visit_queue.empty()) { + const HloGraphNode* curr = to_visit_queue.back(); + to_visit_queue.pop_back(); + if (curr == possible_predecessor) { + return true; + } + if (visited.contains(curr)) { + continue; + } + visited.insert(curr); + for (const auto& edge : curr->GetPredecessors()) { + auto user_node_it = nodes_.find(&edge.Target().GetInstr()); + to_visit_queue.push_back(user_node_it->second.get()); + } + } + return false; +} + HloScheduleGraph::HloScheduleGraph( const std::vector* post_order_instructions, HloAliasAnalysis* alias_analysis, const LatencyEstimator* latency_estimator, @@ -1413,14 +1435,8 @@ HloScheduleGraph::HloScheduleGraph( async_tracker->GetOccupiedShareableResourcesFromVector( new_node_it->second->GetResources()); } - // Cache used to detect if we already added a dependency between two nodes - // to avoid duplicates in the predecessors/successors lists. - absl::flat_hash_map> - dependencies_set; - auto add_dependency_helper = [&dependencies_set, latency_estimator, - async_tracker](HloGraphNode* from, - HloGraphNode* to) { + auto add_dependency_helper = [latency_estimator](HloGraphNode* from, + HloGraphNode* to) { // Get the latency between these two instructions for this edge. const LatencyEstimator::TimeCost latency = latency_estimator->GetLatencyBetween(*from, *to); @@ -1430,9 +1446,6 @@ HloScheduleGraph::HloScheduleGraph( to->predecessors_.push_back(HloEdge(latency, from)); ++to->indegree_; ++from->outdegree_; - if (async_tracker->IsSupportedAsyncStart(to->GetInstr())) { - dependencies_set[&to->GetInstr()].insert(&from->GetInstr()); - } }; // Add dependencies edges between each of the graph nodes. for (const HloInstruction* instr : *post_order_instructions) { @@ -1473,11 +1486,8 @@ HloScheduleGraph::HloScheduleGraph( // The instruction itself and later ones might be // identified as use.instruction. Add checks here to avoid // adding dependencies for these instructions. - // Also don't add the dependency if it has been already added. - auto dep_it = dependencies_set.find(async_start); if (use.instruction == async_start || - reachability->IsReachable(instr, use.instruction) || - dep_it->second.contains(use.instruction)) { + reachability->IsReachable(instr, use.instruction)) { continue; } auto it = nodes_.find(use.instruction); @@ -1486,6 +1496,11 @@ HloScheduleGraph::HloScheduleGraph( it = nodes_.find(async_start); CHECK(it != nodes_.end()); HloGraphNode* start_node = it->second.get(); + // If there is already a transitive link between the nodes the + // other way then skip adding this one. + if (IsPredecessorTransitively(pred_node, start_node)) { + continue; + } pred_node->successors_.push_back(HloEdge(1, start_node)); start_node->predecessors_.push_back(HloEdge(1, pred_node)); ++pred_node->outdegree_; diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 7775912d0afde9..32d229a31a12fa 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -499,6 +499,10 @@ class HloScheduleGraph { // List containing the original order (before scheduling) of the // instructions). std::vector original_order_; + // Searches through node's predecessors to see if + // possible_predecessor can be found. + bool IsPredecessorTransitively(const HloGraphNode* node, + const HloGraphNode* possible_predecessor); }; // Tracks data about HloBuffers like where the first definition is in the diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index 58d18e910ed685..4edcc1a682c8b3 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -2916,4 +2916,65 @@ TEST_F(LatencyHidingSchedulerTest, RerunWithSmallerMemoryLimit) { EXPECT_LT(PositionInVector(new_instruction_sequence, s), PositionInVector(new_instruction_sequence, cps)); } + +TEST_F(LatencyHidingSchedulerTest, MultipleAsyncDoneOperationsDoNotCreateLoop) { + absl::string_view hlo_string = R"( +HloModule multiple_async_done_scheduler_test, is_scheduled=true + +called_computation { + ROOT %param = s32[<=4096]{0:T(8)M(1024)} parameter(0) +} + +ENTRY main { + %while_body_forward_pass_input_tuple = (s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}) parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR"} + + %get-tuple-element.0 = s32[<=4096]{0:T(8)M(1024)} get-tuple-element( + (s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}) %while_body_forward_pass_input_tuple), + index=0, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR"} + + %get-tuple-element.1 = s32[<=4096]{0:T(8)M(1024)} get-tuple-element( + (s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}) %while_body_forward_pass_input_tuple), + index=1, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR"} + + %call-start.1 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) + call-start(s32[<=4096]{0:T(8)M(1024)} %get-tuple-element.1), + async_group_id=17, async_execution_thread="sparsecore", to_apply=%called_computation + + %call-done.1 = s32[<=4096]{0:T(8)M(1024)} + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.1), + async_group_id=17, async_execution_thread="sparsecore", to_apply=%called_computation + + %call-start.2 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) + call-start(s32[<=4096]{0:T(8)M(1024)} %call-done.1), + async_group_id=27, async_execution_thread="sparsecore", to_apply=%called_computation + + %call-done.2 = s32[<=4096]{0:T(8)M(1024)} + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.2), + async_group_id=27, async_execution_thread="sparsecore", to_apply=%called_computation + + %call-start.3 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) + call-start(s32[<=4096]{0:T(8)M(1024)} %get-tuple-element.0), + async_group_id=14, async_execution_thread="sparsecore", to_apply=%called_computation + + %call-done.3 = s32[<=4096]{0:T(8)M(1024)} + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.3), + async_group_id=14, async_execution_thread="sparsecore", to_apply=%called_computation + + ROOT %tuple.6 = (s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}) + tuple(s32[<=4096]{0:T(8)M(1024)} %call-done.2, s32[<=4096]{0:T(8)M(1024)} %call-done.3), + backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR"} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + // The double indirection of the buffer aliasing in the module above should + // not create a failure of scheduling by the async done checks. + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); +} } // namespace xla From 29227bd331842323bf9bcb0fc2c52166022b85de Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Mon, 27 Nov 2023 14:26:28 -0800 Subject: [PATCH 113/381] Apply pass-by-reference for non-null values in StableHLO Quantizer directory. PiperOrigin-RevId: 585760434 --- .../stablehlo/passes/bridge/verify_quant_legalization.cc | 4 ++-- .../quantization/stablehlo/passes/quantization_pattern.h | 8 ++++---- .../mlir/quantization/stablehlo/passes/quantize.cc | 8 ++++---- .../stablehlo/passes/unwrap_xla_call_module_op.cc | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc index 361d98c7775abe..2825195addea12 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc @@ -57,7 +57,7 @@ bool IsQuantType(Type type) { IsTFQintType(element_type); } -bool IsMhloUniformQuantizedOp(Operation* op) { +bool IsMhloUniformQuantizedOp(Operation& op) { return llvm::isa(op); } @@ -68,7 +68,7 @@ void VerifyQuantLegalization::runOnOperation() { // Verify all uq and qint types are lowered. if (llvm::any_of(op->getOperandTypes(), IsQuantType) || llvm::any_of(op->getResultTypes(), IsQuantType) || - IsTFUniformQuantizedOp(op) || IsMhloUniformQuantizedOp(op)) { + IsTFUniformQuantizedOp(op) || IsMhloUniformQuantizedOp(*op)) { op->emitOpError("is illegal as it is a UQ op or contains uq/qint types"); LOG(ERROR) << "Found illegal op containing uq/qint type: " << op->getName().getStringRef().str(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h index 23564308da8d6d..133bafab2978fe 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h @@ -67,8 +67,8 @@ bool IsOpQuantizableStableHlo(Operation* op); // quantized results. The concrete pattern should define the following two // functions: // -// bool AllowDynamicRangeQuantizedOperand(Operation *) const -// bool AllowDynamicRangeQuantizedResult(Operation *) const +// bool AllowDynamicRangeQuantizedOperand(Operation&) const +// bool AllowDynamicRangeQuantizedResult(Operation&) const // // Full integer quantization disallows "DynamicRangeQuantized" operands or // results. Dynamic range quantization allows "DynamicRangeQuantized" operands @@ -146,7 +146,7 @@ class StableHloQuantizationPattern : public RewritePattern { if (!IsOpQuantizableStableHlo(quantizing_op) && !static_cast(this)->IsQuantizableCustomOp( - quantizing_op, custom_map)) { + *quantizing_op, custom_map)) { return failure(); } @@ -229,7 +229,7 @@ class StableHloQuantizationPattern : public RewritePattern { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); } else if (static_cast(this) - ->AllowDynamicRangeQuantizedResult(quantizing_op, + ->AllowDynamicRangeQuantizedResult(*quantizing_op, custom_map)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 811c53125056cd..fffa0ad782c8d9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -55,22 +55,22 @@ struct StableHloQuantizationBase /*VerifierT=*/void, RootOpT>( ctx, quant_params) {} - static bool IsQuantizableCustomOp(Operation* op, + static bool IsQuantizableCustomOp(Operation& op, const CustomMap& custom_op_map) { return false; } static bool AllowDynamicRangeQuantizedOperand( - Operation* quantized_op, const CustomMap& custom_op_map) { + Operation& quantized_op, const CustomMap& custom_op_map) { return false; } - static bool AllowDynamicRangeQuantizedResult(Operation* quantized_op, + static bool AllowDynamicRangeQuantizedResult(Operation& quantized_op, const CustomMap& custom_op_map) { return false; } - static bool IsWeightOnlyOp(Operation* quantized_op, + static bool IsWeightOnlyOp(Operation& quantized_op, absl::flat_hash_set& ops_blocklist, bool weight_only_quantization, const CustomMap& custom_op_map) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc index 1c1b1249558bd4..a65694a7a7287f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc @@ -94,9 +94,9 @@ void UnwrapXlaCallModuleOp(TF::XlaCallModuleOp call_op, continue; } - Operation* new_op = builder.clone(op, arg_mapper); + Operation& new_op = *builder.clone(op, arg_mapper); for (auto [result, new_result] : - llvm::zip_equal(op.getResults(), new_op->getResults())) { + llvm::zip_equal(op.getResults(), new_op.getResults())) { new_op_mapper.map(result, new_result); } } From 2d8c4caf674435ea8e63f0d7aa7b8fd7a36f3859 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2023 14:44:15 -0800 Subject: [PATCH 114/381] Use tf2xla bridge to legalize XlaSendTPUEmbeddingGradientsOp PiperOrigin-RevId: 585765195 --- .../compiler/mlir/tf2xla/transforms/legalization_op_config.cc | 2 ++ .../mlir/tf2xla/transforms/legalization_op_config_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 6e5f8285a0a928..979ec3f97e629e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -336,6 +336,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -481,6 +482,7 @@ bool IsOpTypeAllowedTf2XlaPreferred(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index ade2b5faa73c8a..7084f98b28568e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -131,8 +131,8 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 315); - EXPECT_EQ(non_categorized_count, 422); + EXPECT_EQ(tf2xla_fallback_count, 316); + EXPECT_EQ(non_categorized_count, 421); } // Just a counter test to see which ops have duplicate lowerings. This isn't a From 50299228e5df92b486548ee1cb856e79de69ad43 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Mon, 27 Nov 2023 14:54:26 -0800 Subject: [PATCH 115/381] PR #7269: Fix an incorrect static_assert Imported from GitHub PR https://github.com/openxla/xla/pull/7269 The size of a uint8_t is 1 so a static_assert to check that it is 0 makes no sense, fix it. Also fix a couple of warnings about lack of typename Copybara import of the project: -- e72164f0a16ec5e548927ca3583d1a31edbd95d8 by Andrew Goodbody : Fix an incorrect static_assert The size of a uint8_t is 1 so a static_assert to check that it is 0 will always be false. The original intent was to have the assert only trigger if the struct was instantiated but the standard deems it ill formed if it can never be true and allows compilers to reject it. Adopt a different workaround that avoids this by allowing the possibility of an evaluation to true. Also fix a couple of warnings about lack of typename Merging this change closes #7269 PiperOrigin-RevId: 585767871 --- third_party/xla/xla/ffi/api/ffi.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index ed168f91488b5c..9982335fe12b47 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -108,9 +108,14 @@ class Error { namespace internal { +// A workaround for the fact that a static_assertion can be evaluated +// whether or not the template is instantiated +template +struct always_false : std::false_type {}; + template struct PtrType { - static_assert(sizeof(dtype) == 0, "unsupported data type"); + static_assert(always_false::value, "unsupported data type"); }; // clang-format off @@ -133,7 +138,7 @@ template <> struct PtrType { using Type = std::uint16_t; }; template struct BufferBase { - internal::PtrType::Type* data; + typename internal::PtrType::Type* data; Span dimensions; }; @@ -149,7 +154,8 @@ struct ArgDecoding> { auto* buf = reinterpret_cast(arg); // TODO(slebedev): Emit a user-friendly error instead. if (static_cast(buf->dtype) != dtype) return std::nullopt; - auto* data = static_cast::Type*>(buf->data); + auto* data = + static_cast::Type*>(buf->data); return BufferBase{data, Span(buf->dims, buf->rank)}; } From b8494a08185e0ea1021e4859d862dfaff654c16b Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:59:45 -0800 Subject: [PATCH 116/381] PR #7277: [ROCM] rocm 6.0 fixes followup Imported from GitHub PR https://github.com/openxla/xla/pull/7277 This is a follow-up PR for rocm-6.0 platform support. I have also adapted two unit tests which were failing on specific architectures / platform version. @xla-rotation: would you take a look, please? Copybara import of the project: -- 12e609064b6c8eda36866602874d9fdbfb89ef74 by Pavel Emeliyanenko : another fixes for rocm-6.0 platform -- 9c82a7af6cac49c73bd4cfd61dbc1acf23b89a92 by Pavel Emeliyanenko : added checks for the failing tests -- 853e3a59dcf4b730ed255faab5aef648d606346c by Pavel Emeliyanenko : fixing cuda compile -- 2d29cc872b35b9e2c9f895cfeb3cc02d3a2424e0 by Pavel Emeliyanenko : addressing reviewer comments Merging this change closes #7277 PiperOrigin-RevId: 585769317 --- third_party/xla/xla/service/gpu/BUILD | 5 +- .../xla/xla/service/gpu/determinism_test.cc | 24 +++++++-- .../xla/stream_executor/device_description.h | 49 ++++++++++++------- third_party/xla/xla/tests/BUILD | 3 ++ .../xla/tests/array_elementwise_ops_test.cc | 9 ++++ 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a5c4e042163ad0..1a2c311dead69c 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4893,7 +4893,10 @@ xla_cc_test( xla_cc_test( name = "determinism_test", srcs = ["determinism_test.cc"], - tags = tf_cuda_tests_tags(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), deps = [ ":autotuner_util", "//xla:literal", diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index 6a2f476bdb696e..ab362dba4ad42b 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -89,6 +89,16 @@ ENTRY e { ROOT d = f32[128,128] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; +#if TENSORFLOW_USE_ROCM + auto rocm = backend() + .default_stream_executor() + ->GetDeviceDescription() + .rocm_compute_capability(); + if (!rocm.has_hipblaslt()) { + GTEST_SKIP() << "No hipblas-lt support on this architecture!"; + } +#endif // TENSORFLOW_USE_ROCM + debug_options_.set_xla_gpu_triton_fusion_level(0); MatchOptimizedHlo(kHloText, R"(; CHECK: custom_call_target="__cublas$gemm")"); AssertDeterminism(kHloText); @@ -100,13 +110,17 @@ ENTRY e { } TEST_F(DeterminismTest, TritonDot) { - se::CudaComputeCapability compute_capability = backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - if (!compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA)) { +#if GOOGLE_CUDA + auto comp = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!comp.IsAtLeast(se::CudaComputeCapability::VOLTA)) { GTEST_SKIP() << "Triton not used on pre-Volta GPUs"; } +#elif TENSORFLOW_USE_ROCM + GTEST_SKIP() << "Triton Gemm rewriter is not yet supported on ROCM"; +#endif // TENSORFLOW_USE_ROCM constexpr absl::string_view kHloText = R"( ENTRY e { diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index ed16f66703e68d..85cac5a9ee075d 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -159,33 +159,43 @@ class RocmComputeCapability { return absl::StrJoin(kSupportedGfxVersions, ", "); } - bool has_nhwc_layout_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a"}; + bool gfx9_mi100_or_later() const { + static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", + "gfx941", "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } - bool has_bf16_dtype_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a"}; + bool gfx9_mi200_or_later() const { + static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941", + "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } + bool navi21() const { return gfx_version() == "gfx1030"; } + + bool navi31() const { return gfx_version() == "gfx1100"; } + + bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } + + bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } + bool has_fast_fp16_support() const { - static constexpr absl::string_view kList[] = {"gfx906", "gfx908", "gfx90a", - "gfx1030"}; - return absl::c_count(kList, gfx_version()) != 0; + return gfx9_mi100_or_later() || navi21() || navi31(); } - bool has_mfma_instr_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a"}; - return absl::c_count(kList, gfx_version()) != 0; - } + bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } bool has_fp16_atomics_support() const { // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). - static constexpr absl::string_view kList[] = {"gfx90a"}; - return absl::c_count(kList, gfx_version()) != 0; + return gfx9_mi200_or_later(); } + bool fence_before_barrier() const { + return gfx_version() != "gfx900" && gfx_version() != "gfx906"; + } + + bool has_hipblaslt() const { return gfx9_mi200_or_later(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_); @@ -200,10 +210,11 @@ class RocmComputeCapability { std::string gcn_arch_name_ = "gfx000"; // default to invalid arch. static constexpr absl::string_view kSupportedGfxVersions[]{ - "gfx900", // MI25 - "gfx906", // MI50 / MI60 - "gfx908", // MI100 - "gfx90a", // MI200 + "gfx900", // MI25 + "gfx906", // MI50 / MI60 + "gfx908", // MI100 + "gfx90a", // MI200 + "gfx940", "gfx941", "gfx942", "gfx1030", // Navi21 "gfx1100" // Navi31 }; @@ -362,10 +373,10 @@ class DeviceDescription { static const char *kUndefinedString; private: - DeviceDescription(); - friend class internal::DeviceDescriptionBuilder; + DeviceDescription(); + // For description of the following members, see the corresponding accessor // above. // diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 447e018675735c..1b32e39e73b18a 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -777,6 +777,9 @@ xla_test( xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), shard_count = 25, deps = [ ":client_library_test_base", diff --git a/third_party/xla/xla/tests/array_elementwise_ops_test.cc b/third_party/xla/xla/tests/array_elementwise_ops_test.cc index d932c5c540566a..5a2e0add1eebb8 100644 --- a/third_party/xla/xla/tests/array_elementwise_ops_test.cc +++ b/third_party/xla/xla/tests/array_elementwise_ops_test.cc @@ -42,6 +42,10 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/types.h" +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + namespace xla { namespace { @@ -1590,6 +1594,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION == 50700 + GTEST_SKIP() + << "This test fails on rocm-5.7.0 platform due to a compiler bug"; +#endif + SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto eps = std::numeric_limits::epsilon(); From 5dc15de1e1d6dc4727b37e47d42e1e5c58a3f031 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Mon, 27 Nov 2023 15:05:04 -0800 Subject: [PATCH 117/381] Unify metric version for CPU/GPU graphs There are currently two metric versions for CPU/GPU graphs. This unifies the naming. PiperOrigin-RevId: 585770848 --- tensorflow/compiler/tf2xla/mlir_bridge_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 0cfcc0a5a7a78a..c8a4984c356359 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -237,13 +237,13 @@ MlirOptimizationPassState GetPassStateImpl( return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kDisabledByUser: VLOG(1) << "Skipping MLIR CPU/GPU Bridge, disabled by user."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "tfxla", false, + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, "disabled_by_user"); return MlirOptimizationPassState::Disabled; default: // This case should never be hit. Added here to be consistent with OSS // implementation. - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "ftxla", false, + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, "invalid_graph"); return MlirOptimizationPassState::Disabled; } From 276df5c71019dba4e6ee747b86eca7bdbf43fdc4 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 27 Nov 2023 16:33:06 -0800 Subject: [PATCH 118/381] Add c-api support for GetCompiledMemoryStats. PiperOrigin-RevId: 585793365 --- third_party/xla/xla/pjrt/c/CHANGELOG.md | 3 +++ third_party/xla/xla/pjrt/c/pjrt_c_api.h | 25 ++++++++++++++++++- .../xla/xla/pjrt/c/pjrt_c_api_helpers.cc | 17 +++++++++++++ .../xla/xla/pjrt/c/pjrt_c_api_helpers.h | 3 +++ .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 16 ++++++++++++ .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 4 +++ third_party/xla/xla/pjrt/pjrt_c_api_client.h | 8 ++++++ 7 files changed, 75 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index fa60ae26859c31..5bd553df96126c 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,8 @@ # PJRT C API changelog +## 0.40 (Nov 27, 2023) +* Added PJRT_Executable_GetCompiledMemoryStats. + ## 0.39 (Nov 16, 2023) * Add non_donatable_input_indices and num_non_donatable_input_indices to PJRT_ExecuteOptions. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 3cc1a26bd4460e..e249630374a13b 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 39 +#define PJRT_API_MINOR 40 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -1376,6 +1376,27 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCostAnalysis_Args, properties); typedef PJRT_Error* PJRT_Executable_GetCostAnalysis( PJRT_Executable_GetCostAnalysis_Args* args); +struct PJRT_Executable_GetCompiledMemoryStats_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + + // Mirrors xla::CompiledMemoryStats. + int64_t generated_code_size_in_bytes; // out + int64_t argument_size_in_bytes; // out + int64_t output_size_in_bytes; // out + // How much argument is reused for output. + int64_t alias_size_in_bytes; // out + int64_t temp_size_in_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCompiledMemoryStats_Args, + temp_size_in_bytes); + +// Return memory stats that allow callers to estimate device memory usage +// when running this executable. +typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( + PJRT_Executable_GetCompiledMemoryStats_Args* args); + struct PJRT_Executable_OutputElementTypes_Args { size_t struct_size; void* priv; @@ -2136,6 +2157,8 @@ typedef struct { _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint); _PJRT_API_STRUCT_FIELD(PJRT_Client_TopologyDescription); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_GetCompiledMemoryStats); } PJRT_Api; enum { diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index 60af425b398346..5e960c3cd60d37 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -942,4 +942,21 @@ absl::Span DeviceDescriptions( return {args.descriptions, args.num_descriptions}; } +absl::StatusOr GetCompiledMemoryStats( + const PJRT_Api* api, PJRT_Executable* executable) { + PJRT_Executable_GetCompiledMemoryStats_Args args; + args.struct_size = PJRT_Executable_GetCompiledMemoryStats_Args_STRUCT_SIZE; + args.priv = 0; + args.executable = executable; + RETURN_STATUS_IF_PJRT_ERROR( + api->PJRT_Executable_GetCompiledMemoryStats(&args), api); + xla::CompiledMemoryStats results; + results.generated_code_size_in_bytes = args.generated_code_size_in_bytes; + results.argument_size_in_bytes = args.argument_size_in_bytes; + results.output_size_in_bytes = args.output_size_in_bytes; + results.alias_size_in_bytes = args.alias_size_in_bytes; + results.temp_size_in_bytes = args.temp_size_in_bytes; + return results; +} + } // namespace pjrt diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h index 49d5755d3d42ff..f727a00709713f 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h @@ -269,6 +269,9 @@ absl::string_view PlatformName(const PJRT_Api* api, absl::Span DeviceDescriptions( const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc); +absl::StatusOr GetCompiledMemoryStats( + const PJRT_Api* api, PJRT_Executable* executable); + } // namespace pjrt #endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_ diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index e782cf61331902..3deaf7e0841e4e 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -1480,6 +1480,22 @@ PJRT_Error* PJRT_Executable_Serialize(PJRT_Executable_Serialize_Args* args) { return nullptr; } +PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( + PJRT_Executable_GetCompiledMemoryStats_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Executable_Serialize_Args", + PJRT_Executable_Serialize_Args_STRUCT_SIZE, args->struct_size)); + PJRT_ASSIGN_OR_RETURN(auto memory_stats, + args->executable->executable->GetCompiledMemoryStats()); + args->generated_code_size_in_bytes = + memory_stats.generated_code_size_in_bytes; + args->argument_size_in_bytes = memory_stats.argument_size_in_bytes; + args->output_size_in_bytes = memory_stats.output_size_in_bytes; + args->alias_size_in_bytes = memory_stats.alias_size_in_bytes; + args->temp_size_in_bytes = memory_stats.temp_size_in_bytes; + return nullptr; +} + PJRT_Error* PJRT_Executable_DeserializeAndLoad( PJRT_Executable_DeserializeAndLoad_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 44d2f4cdd9ed12..2ce5220719452e 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -286,6 +286,8 @@ PJRT_Error* PJRT_Executable_OutputMemoryKinds( PJRT_Error* PJRT_Executable_OptimizedProgram( PJRT_Executable_OptimizedProgram_Args* args); PJRT_Error* PJRT_Executable_Serialize(PJRT_Executable_Serialize_Args* args); +PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( + PJRT_Executable_GetCompiledMemoryStats_Args* args); PJRT_Error* PJRT_LoadedExecutable_Destroy( PJRT_LoadedExecutable_Destroy_Args* args); @@ -592,6 +594,8 @@ constexpr PJRT_Api CreatePjrtApi( /*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint, /*PJRT_Client_TopologyDescription= */ pjrt::PJRT_Client_TopologyDescription, + /*PJRT_Executable_GetCompiledMemoryStats= */ + pjrt::PJRT_Executable_GetCompiledMemoryStats, }; } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index d147dee6ee2b03..24c31bd9e089bf 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -573,6 +573,10 @@ class PjRtCApiExecutable : public PjRtExecutable { StatusOr>> GetHloModules() const override; + StatusOr GetCompiledMemoryStats() const override { + return pjrt::GetCompiledMemoryStats(c_api_, executable_.get()); + } + StatusOr> GetOutputShapes() const override { LOG(FATAL) << "PjRtExecutable::GetOutputShapes() not implemented in PJRT C " "API. Please use PjRtExecutable::GetOutputElementTypes() or " @@ -638,6 +642,10 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable { return executable_->GetHloModules(); } + StatusOr GetCompiledMemoryStats() const override { + return executable_->GetCompiledMemoryStats(); + } + StatusOr> GetOutputShapes() const override { LOG(FATAL) << "PjRtLoadedExecutable::GetOutputShapes() not implemented in PJRT C " From af5b061dc4066b06a815b64b91bd76323b8ec7c5 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 27 Nov 2023 16:41:31 -0800 Subject: [PATCH 119/381] Add missing #includes that define symbols referenced by simple_*delegate* Also remove some unused #includes from the .cc files. Also use "= default" syntax for destructor. PiperOrigin-RevId: 585795461 --- tensorflow/lite/delegates/utils/BUILD | 5 ++++- tensorflow/lite/delegates/utils/simple_delegate.cc | 6 +++++- tensorflow/lite/delegates/utils/simple_delegate.h | 5 ++++- tensorflow/lite/delegates/utils/simple_delegate_test.cc | 3 ++- .../lite/delegates/utils/simple_opaque_delegate.cc | 7 ++++--- tensorflow/lite/delegates/utils/simple_opaque_delegate.h | 3 +++ .../lite/delegates/utils/simple_opaque_delegate_test.cc | 9 +++++++++ 7 files changed, 31 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/delegates/utils/BUILD b/tensorflow/lite/delegates/utils/BUILD index 924eb3cb0c02a4..9bb7079a8a9386 100644 --- a/tensorflow/lite/delegates/utils/BUILD +++ b/tensorflow/lite/delegates/utils/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,6 +19,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/lite:array", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/c:common", @@ -52,6 +53,7 @@ cc_library_with_tflite( "//tensorflow/lite/c:common", ], deps = [ + "//tensorflow/lite:array", "//tensorflow/lite:builtin_ops", "//tensorflow/lite:minimal_logging", "//tensorflow/lite:util", @@ -75,6 +77,7 @@ cc_test( data = [":c_api_test_builtin_op_models"], deps = [ ":simple_opaque_delegate", + "//tensorflow/lite:framework_stable", "//tensorflow/lite/c:c_api", "//tensorflow/lite/c:c_api_experimental", "//tensorflow/lite/c:c_api_types", diff --git a/tensorflow/lite/delegates/utils/simple_delegate.cc b/tensorflow/lite/delegates/utils/simple_delegate.cc index 7b3e9647a052a7..6b0401ba0385d9 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate.cc +++ b/tensorflow/lite/delegates/utils/simple_delegate.cc @@ -14,16 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_delegate.h" +#include +#include + #include #include #include #include +#include "tensorflow/lite/array.h" #include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" namespace tflite { diff --git a/tensorflow/lite/delegates/utils/simple_delegate.h b/tensorflow/lite/delegates/utils/simple_delegate.h index 0fb67abc5f7621..8b65fa892d4004 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate.h +++ b/tensorflow/lite/delegates/utils/simple_delegate.h @@ -29,7 +29,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ +#include + #include +#include #include "tensorflow/lite/core/c/common.h" @@ -42,7 +45,7 @@ using TfLiteDelegateUniquePtr = // Each instance represents a single part of the graph (subgraph). class SimpleDelegateKernelInterface { public: - virtual ~SimpleDelegateKernelInterface() {} + virtual ~SimpleDelegateKernelInterface() = default; // Initializes a delegated subgraph. // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace diff --git a/tensorflow/lite/delegates/utils/simple_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_delegate_test.cc index f2d9d352551601..fc589f1fb2dfc7 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate_test.cc +++ b/tensorflow/lite/delegates/utils/simple_delegate_test.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include -#include #include #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc index db22cbe891f7c3..8d0b7aba519e45 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc @@ -14,18 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include +#include +#include + #include -#include #include +#include "tensorflow/lite/array.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/util.h" namespace tflite { namespace { diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate.h b/tensorflow/lite/delegates/utils/simple_opaque_delegate.h index 817186c01a1cc3..b6f8e91946ee58 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate.h +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate.h @@ -31,7 +31,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_OPAQUE_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_OPAQUE_DELEGATE_H_ +#include + #include +#include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc index dd6155b4b52667..953f348d734b41 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc @@ -14,20 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" +#include +#include +#include + #include #include #include #include #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" namespace tflite { From f463362b5f5e02e2986bfe988d1cd93f7d36e13b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 27 Nov 2023 17:04:31 -0800 Subject: [PATCH 120/381] Disable uploading macOS Arm64 Libtensorflow build artifacts These are not yet ready. Also, this removes the resultstore config since Mac VMs do not yet have the right permissions to upload to resultstore and fixes a small typo in the `TFCI_LIB_SUFFIX` env var. PiperOrigin-RevId: 585800536 --- ci/official/envs/nightly_libtensorflow_macos_arm64 | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/official/envs/nightly_libtensorflow_macos_arm64 b/ci/official/envs/nightly_libtensorflow_macos_arm64 index 9147ba2a6c288b..be992e2b19f422 100644 --- a/ci/official/envs/nightly_libtensorflow_macos_arm64 +++ b/ci/official/envs/nightly_libtensorflow_macos_arm64 @@ -1,7 +1,7 @@ -source ci/official/envs/ci_nightly_uploads -TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +# Disable arm64 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_DOCKER_ENABLE=0 -TFCI_LIB_SUFFIX="-cpu-macos-arm64" +TFCI_LIB_SUFFIX="-cpu-darwin-arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 -TFCI_UPLOAD_WHL_GCS_URI=1 From cd785d12ddce13d98b315be0c16eead4101a9d55 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Mon, 27 Nov 2023 17:24:55 -0800 Subject: [PATCH 121/381] Apply c++ unit test best practices for StableHLO Quantizer directory. PiperOrigin-RevId: 585804196 --- .../mlir/quantization/stablehlo/BUILD | 3 - .../convert_tf_quant_to_mhlo_int_test.cc | 34 ++++++---- .../bridge/convert_tf_quant_types_test.cc | 3 +- .../passes/bridge/legalize_tf_quant_test.cc | 4 +- .../tests/stablehlo_op_quant_spec_test.cc | 42 ++++++------ .../stablehlo/uniform_quantized_types_test.cc | 68 ++++++++++--------- .../stablehlo/utils/tf_type_utils_test.cc | 14 ++-- 7 files changed, 89 insertions(+), 79 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index f8babf008f107c..4773cc2c884d1e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -364,11 +364,8 @@ cc_library( ":fill_quantization_options", ":passes", ":quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core/platform:path", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index f00e3402e04902..1987b607392379 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -62,7 +62,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -class ConvertTfQuantToMhloIntTest : public ::testing::Test { +using ::testing::Test; + +class ConvertTfQuantToMhloIntTest : public Test { protected: void SetUp() override { DialectRegistry dialects; @@ -281,7 +283,7 @@ class ConvertTfQuantToMhloIntTest : public ::testing::Test { absl::BitGen bitgen_; }; -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantize) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantizeToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { %scale = "tf.Const"() { value = dense<0.347> : tensor } : () -> tensor @@ -306,7 +308,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { kProgram, {&arg0}, /*tf_program=*/std::nullopt, /*error_tolerance=*/0.35); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizePerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizePerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %arg0: tensor<10x10xf32>, %scale: tensor<10xf32>, %zp: tensor<10xi32> @@ -330,7 +332,7 @@ func.func @main( /*error_tolerance=*/1.0); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformDequantizePerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformDequantizePerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %arg0: tensor<10x10xi8>, %scale: tensor<10xf32>, %zp: tensor<10xi32> @@ -350,7 +352,7 @@ func.func @main( ExecuteAndCompareResultsWithTfKernel(kProgram, {&arg0, &scale, &zp}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolution) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> tensor<1x9x9x10xi32> { %input_scale = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor @@ -389,7 +391,8 @@ func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> ten ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionPerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformQuantizeConvolutionPerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>, %scale: tensor<10xf32> @@ -428,7 +431,8 @@ func.func @main( ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter, &scale}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionHybrid) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformQuantizeConvolutionHybridToValidGraph) { constexpr absl::string_view kTfProgram = R"mlir( func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> { %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () -> tensor @@ -476,7 +480,7 @@ func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) - ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDot) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10xi32> { %input_scale = "tf.Const"() { value = dense<0.588> : tensor } : () -> tensor @@ -513,7 +517,7 @@ func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10 ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybrid) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybridToValidGraph) { constexpr absl::string_view kTfProgram = R"mlir( func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> { %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () -> tensor @@ -550,7 +554,7 @@ func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x1 ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantize) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizeToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { %input_scale = "tf.Const"() { value = dense<0.2235> : tensor } : () -> tensor @@ -579,7 +583,7 @@ func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { ExecuteAndCompareResultsWithTfKernel(kProgram, {&input}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, @@ -621,7 +625,8 @@ func.func @main( /*error_tolerance=*/1.0); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerTensorToPerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformRequantizePerTensorToPerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %input: tensor<10x10xi8>, %input_scale: tensor, %input_zp: tensor, @@ -661,7 +666,8 @@ func.func @main( /*error_tolerance=*/1.0); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannelToPerTensor) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformRequantizePerChannelToPerTensorToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, @@ -701,7 +707,7 @@ func.func @main( /*error_tolerance=*/1.0); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAddToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%lhs: tensor<10x10xi32>, %rhs: tensor<10x10xi32>) -> tensor<10x10xi32> { %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor } : () -> tensor diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc index 856bbd49930341..9a5e6c53d3d1d6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc @@ -38,11 +38,12 @@ using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; using ::tensorflow::monitoring::testing::CellReader; +using ::testing::Test; static constexpr char kMetricsName[] = "/tensorflow/core/tf2xla/tf_quant_op_count"; -class LegalizeTfTypesTest : public ::testing::Test { +class LegalizeTfTypesTest : public Test { protected: void CreateModule(const char* module_string) { DialectRegistry mlir_registry; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc index 1fd1a0b6bab721..4c20b6bebdcdad 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc @@ -34,7 +34,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -class LegalizeTFQuantTest : public ::testing::Test { +using ::testing::Test; + +class LegalizeTFQuantTest : public Test { protected: void TestBridgeLowering(llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc index 281bfc996a1c77..4f1c16c9fd0347 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc @@ -34,7 +34,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -class IsOpQuantizableStableHloTest : public ::testing::Test { +using ::testing::Test; + +class IsOpQuantizableStableHloTest : public Test { protected: IsOpQuantizableStableHloTest() { ctx_.loadDialect module_op_ref = ParseModuleOpString(module_constant_add); func::FuncOp test_func = @@ -124,7 +126,7 @@ TEST_F(IsOpQuantizableStableHloTest, ConstantOp) { EXPECT_TRUE(is_constant_quantizable); } -TEST_F(IsOpQuantizableStableHloTest, TerminatorOp) { +TEST_F(IsOpQuantizableStableHloTest, TerminatorOpNotQuantizable) { OwningOpRef module_op_ref = ParseModuleOpString(module_constant_add); func::FuncOp test_func = @@ -136,7 +138,20 @@ TEST_F(IsOpQuantizableStableHloTest, TerminatorOp) { EXPECT_FALSE(is_return_quantizable); } -TEST_F(IsOpQuantizableStableHloTest, NonSameScaleStableHloOp) { +TEST_F(IsOpQuantizableStableHloTest, SameScaleOpQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* reshape_op = + FindOperationOfType(test_func); + bool is_reshape_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); + + EXPECT_TRUE(is_reshape_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, NonSameScaleOpNotQuantizable) { OwningOpRef module_op_ref = ParseModuleOpString(module_constant_add); func::FuncOp test_func = @@ -148,7 +163,7 @@ TEST_F(IsOpQuantizableStableHloTest, NonSameScaleStableHloOp) { EXPECT_FALSE(is_add_quantizable); } -TEST_F(IsOpQuantizableStableHloTest, QuantizableXlaCallModuleOp) { +TEST_F(IsOpQuantizableStableHloTest, ValidXlaCallModuleOpQuantizable) { OwningOpRef module_op_ref = ParseModuleOpString(module_composite_same_scale); func::FuncOp test_func = @@ -161,7 +176,7 @@ TEST_F(IsOpQuantizableStableHloTest, QuantizableXlaCallModuleOp) { EXPECT_TRUE(is_xla_call_module_quantizable); } -TEST_F(IsOpQuantizableStableHloTest, NonQuantizableXlaCallModuleOp) { +TEST_F(IsOpQuantizableStableHloTest, InvalidXlaCallModuleOpNotQuantizable) { OwningOpRef module_op_ref = ParseModuleOpString(module_composite_no_attr); func::FuncOp test_func = @@ -174,20 +189,7 @@ TEST_F(IsOpQuantizableStableHloTest, NonQuantizableXlaCallModuleOp) { EXPECT_FALSE(is_xla_call_module_quantizable); } -TEST_F(IsOpQuantizableStableHloTest, SameScaleStableHloOp) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* reshape_op = - FindOperationOfType(test_func); - bool is_reshape_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); - - EXPECT_TRUE(is_reshape_quantizable); -} - -TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOp) { +TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOpNotQuantizable) { OwningOpRef module_op_ref = ParseModuleOpString(module_composite_same_scale); func::FuncOp test_func = diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc index ab1ca261a4075f..f33b322cfbd9e4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -32,8 +32,9 @@ namespace { using ::testing::ElementsAreArray; using ::testing::NotNull; +using ::testing::Test; -class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { +class CreateI8F32UniformQuantizedTypeTest : public Test { protected: CreateI8F32UniformQuantizedTypeTest() : ctx_() { ctx_.loadDialect(); @@ -42,7 +43,7 @@ class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -50,7 +51,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) { EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -58,7 +59,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) { EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, SignedQuantizedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -66,7 +67,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { EXPECT_TRUE(quantized_type.isSigned()); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, StrageTypeMinMaxEqualToI8MinMax) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, StorageTypeMinMaxEqualToI8MinMax) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -84,7 +85,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { EXPECT_EQ(quantized_type.getZeroPoint(), 99); } -class CreateI32F32UniformQuantizedTypeTest : public ::testing::Test { +class CreateI32F32UniformQuantizedTypeTest : public Test { protected: CreateI32F32UniformQuantizedTypeTest() : ctx_() { ctx_.loadDialect(); @@ -93,7 +94,7 @@ class CreateI32F32UniformQuantizedTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI32F32UniformQuantizedTypeTest, HasI32StorageType) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, I32StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -101,7 +102,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasI32StorageType) { EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); } -TEST_F(CreateI32F32UniformQuantizedTypeTest, HasF32ExpressedType) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -109,7 +110,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasF32ExpressedType) { EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI32F32UniformQuantizedTypeTest, IsSigned) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, SignedQuantizedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -118,7 +119,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, IsSigned) { } TEST_F(CreateI32F32UniformQuantizedTypeTest, - SotrageTypeMinMaxEqualToI32MinMax) { + StorageTypeMinMaxEqualToI32MinMax) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -138,7 +139,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { EXPECT_EQ(quantized_type.getZeroPoint(), 1111); } -class CreateI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { +class CreateI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: CreateI8F32UniformQuantizedPerAxisTypeTest() : ctx_() { ctx_.loadDialect(); @@ -147,7 +148,7 @@ class CreateI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasI8StorageType) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, @@ -158,7 +159,7 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasI8StorageType) { EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasF32ExpressedType) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, @@ -169,7 +170,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasF32ExpressedType) { EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, IsSigned) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, + SignedQuantizedTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, @@ -218,7 +220,7 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); } -class IsI8F32UniformQuantizedTypeTest : public ::testing::Test { +class IsI8F32UniformQuantizedTypeTest : public Test { protected: IsI8F32UniformQuantizedTypeTest() { ctx_.loadDialect(); @@ -228,35 +230,35 @@ class IsI8F32UniformQuantizedTypeTest : public ::testing::Test { OpBuilder builder_{&ctx_}; }; -TEST_F(IsI8F32UniformQuantizedTypeTest, IsI8F32UniformQuantizedType) { +TEST_F(IsI8F32UniformQuantizedTypeTest, I8F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_TRUE(IsI8F32UniformQuantizedType(qi8_type)); } -TEST_F(IsI8F32UniformQuantizedTypeTest, IsQuantizedType) { +TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); } -TEST_F(IsI8F32UniformQuantizedTypeTest, IsStorageTypeI8) { +TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_TRUE(IsStorageTypeI8(qi8_type)); } -TEST_F(IsI8F32UniformQuantizedTypeTest, IsExpressedTypeF32) { +TEST_F(IsI8F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_TRUE(IsExpressedTypeF32(qi8_type)); } -class IsI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { +class IsI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: IsI8F32UniformQuantizedPerAxisTypeTest() { ctx_.loadDialect(); @@ -267,7 +269,7 @@ class IsI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { }; TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, - IsI8F32UniformQuantizedPerAxisType) { + I8F32UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), @@ -278,7 +280,7 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, EXPECT_FALSE(IsI8F32UniformQuantizedType(qi8_per_axis_type)); } -TEST_F(IsI8F32UniformQuantizedTypeTest, IsQuantizedPerAxisType) { +TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), @@ -289,7 +291,7 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, IsQuantizedPerAxisType) { NotNull()); } -TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsStorageTypeI8) { +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), @@ -299,7 +301,7 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsStorageTypeI8) { EXPECT_TRUE(IsStorageTypeI8(qi8_per_axis_type)); } -TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsExpressedTypeF32) { +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), @@ -309,7 +311,7 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsExpressedTypeF32) { EXPECT_TRUE(IsExpressedTypeF32(qi8_per_axis_type)); } -class IsI32F32UniformQuantizedTypeTest : public ::testing::Test { +class IsI32F32UniformQuantizedTypeTest : public Test { protected: IsI32F32UniformQuantizedTypeTest() { ctx_.loadDialect(); @@ -319,28 +321,28 @@ class IsI32F32UniformQuantizedTypeTest : public ::testing::Test { OpBuilder builder_{&ctx_}; }; -TEST_F(IsI32F32UniformQuantizedTypeTest, IsI32F32UniformQuantizedType) { +TEST_F(IsI32F32UniformQuantizedTypeTest, I32F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); } -TEST_F(IsI32F32UniformQuantizedTypeTest, IsQuantizedType) { +TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); } -TEST_F(IsI32F32UniformQuantizedTypeTest, IsStorageTypeI32) { +TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); EXPECT_TRUE(IsStorageTypeI32(qi32_type)); } -TEST_F(IsI32F32UniformQuantizedTypeTest, IsExpressedTypeF32) { +TEST_F(IsI32F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi32_per_axis_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), @@ -349,7 +351,7 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, IsExpressedTypeF32) { EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); } -class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public ::testing::Test { +class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public Test { protected: IsSupportedByTfliteQuantizeOrDequantizeOpsTest() { ctx_.loadDialect(); @@ -359,7 +361,7 @@ class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public ::testing::Test { OpBuilder builder_{&ctx_}; }; -TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI8) { +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/true), builder_.getF32Type(), /*scale=*/1.0, @@ -368,7 +370,7 @@ TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI8) { dyn_cast_or_null(qi8_type.getStorageType()))); } -TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI16) { +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI16Succeeds) { auto qi16_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getIntegerType(16, /*isSigned=*/true), builder_.getF32Type(), /*scale=*/1.0, @@ -377,7 +379,7 @@ TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI16) { dyn_cast_or_null(qi16_type.getStorageType()))); } -TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsUI8) { +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeUI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/false), builder_.getF32Type(), /*scale=*/1.0, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 03495d3ddae7aa..87d71438cf4e7c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -90,7 +90,7 @@ std::unique_ptr CreateContext() { return context; } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get( {2, 2}, quant::UniformQuantizedType::get( @@ -109,7 +109,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 8)); @@ -125,7 +125,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get( {2, 2}, @@ -145,7 +145,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 32)); @@ -161,7 +161,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16) { +TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16Fails) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 16)); @@ -170,7 +170,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16) { GetDenseAttrFromTensorProtoAttr(GetQint16Tensor(), result_tensor_type))); } -TEST(IsTFQintTypeTest, IsTFQintType) { +TEST(IsTFQintTypeTest, ValidTFQintTypeSucceeds) { auto context = CreateContext(); EXPECT_TRUE(IsTFQintType(TF::Qint8Type::get(context.get()))); @@ -183,7 +183,7 @@ TEST(IsTFQintTypeTest, IsTFQintType) { EXPECT_FALSE(IsTFQintType(TF::Float8E5M2RefType::get(context.get()))); } -TEST(GetIntTypeFromTFQintTest, GetIntTypeFromTFQint) { +TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) { auto context = CreateContext(); auto type = GetIntTypeFromTFQint(TF::Qint8Type::get(context.get())); From 160645ecf9c6e51e107b2a4baf6d9e2d60d43e95 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 27 Nov 2023 18:52:14 -0800 Subject: [PATCH 122/381] [xla:gpu] NFC: move thunk->CMD conversion to runtime3 #6528 PiperOrigin-RevId: 585818812 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 28 +-------- .../xla/xla/service/gpu/runtime3/BUILD | 15 +++++ .../runtime3/command_buffer_cmd_emitter.cc | 61 +++++++++++++++++++ .../gpu/runtime3/command_buffer_cmd_emitter.h | 30 +++++++++ 5 files changed, 108 insertions(+), 27 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc create mode 100644 third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1a2c311dead69c..dde94e3aee7bbb 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -332,6 +332,7 @@ cc_library( "//xla/service/gpu/kernels:custom_fusion", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/runtime3:command_buffer_cmd", + "//xla/service/gpu/runtime3:command_buffer_cmd_emitter", "//xla/service/gpu/runtime3:command_buffer_thunk", "//xla/service/gpu/runtime3:custom_call_thunk", "//xla/service/gpu/runtime3:fft_thunk", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index cd3c05a63d369d..9e25ad8bddd53e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -132,6 +132,7 @@ limitations under the License. #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/replica_id_thunk.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" #include "xla/service/gpu/runtime3/command_buffer_thunk.h" #include "xla/service/gpu/runtime3/custom_call_thunk.h" #include "xla/service/gpu/runtime3/fft_thunk.h" @@ -372,33 +373,6 @@ int DeriveNumWarpsFromTritonSoftmaxComputation( return num_warps; } -StatusOr> ConvertToCommand( - const Thunk& thunk) { - switch (thunk.kind()) { - // TODO(anlunx): Support other thunk kinds. - case Thunk::Kind::kKernel: { - auto& kernel_thunk = static_cast(thunk); - auto kernel_cmd = std::make_unique( - kernel_thunk.kernel_name(), kernel_thunk.arguments(), - kernel_thunk.launch_dimensions(), kernel_thunk.shmem_bytes()); - return kernel_cmd; - } - default: - return InternalError("Unsupported thunk kind"); - } -} - -StatusOr ConvertToCommands( - const ThunkSequence& sequence) { - CommandBufferCmdSequence cmd_sequence; - for (const std::unique_ptr& thunk : sequence) { - TF_ASSIGN_OR_RETURN(std::unique_ptr cmd, - ConvertToCommand(*thunk)); - cmd_sequence.Append(std::move(cmd)); - } - return cmd_sequence; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context) diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index 28cadbf3d4f503..45e0951d39e004 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -45,6 +45,21 @@ cc_library( ], ) +cc_library( + name = "command_buffer_cmd_emitter", + srcs = ["command_buffer_cmd_emitter.cc"], + hdrs = ["command_buffer_cmd_emitter.h"], + visibility = ["//visibility:public"], + deps = [ + ":command_buffer_cmd", + "//xla:statusor", + "//xla:util", + "//xla/service/gpu:gpu_executable", + "//xla/service/gpu:thunk", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_test( name = "command_buffer_cmd_test", srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc new file mode 100644 index 00000000000000..23e2d5407f3e74 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc @@ -0,0 +1,61 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" + +#include +#include + +#include "xla/service/gpu/kernel_thunk.h" +#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/thunk.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +namespace { + +StatusOr> ConvertToCommand( + const Thunk& thunk) { + switch (thunk.kind()) { + // TODO(anlunx): Support other thunk kinds. + case Thunk::Kind::kKernel: { + auto& kernel_thunk = static_cast(thunk); + auto kernel_cmd = std::make_unique( + kernel_thunk.kernel_name(), kernel_thunk.arguments(), + kernel_thunk.launch_dimensions(), kernel_thunk.shmem_bytes()); + return kernel_cmd; + } + default: + return InternalError("Unsupported thunk kind"); + } +} + +} // namespace + +StatusOr ConvertToCommands( + const ThunkSequence& sequence) { + CommandBufferCmdSequence cmd_sequence; + for (const std::unique_ptr& thunk : sequence) { + TF_ASSIGN_OR_RETURN(std::unique_ptr cmd, + ConvertToCommand(*thunk)); + cmd_sequence.Append(std::move(cmd)); + } + return cmd_sequence; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h new file mode 100644 index 00000000000000..f5abc2cd5d3f1a --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ +#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ + +#include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include "xla/service/gpu/thunk.h" +#include "xla/statusor.h" + +namespace xla::gpu { + +StatusOr ConvertToCommands( + const ThunkSequence& sequence); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ From 17e9908e2549fe2c024a4ecf0533f835f586a0ce Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 19:21:33 -0800 Subject: [PATCH 123/381] [stream_executor] NFC: Do not leak internal stream executor header PiperOrigin-RevId: 585823405 --- third_party/xla/xla/stream_executor/cuda/BUILD | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index c1261666286d42..2ea95309142e5e 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -79,9 +79,9 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "//xla/stream_executor", # buildcleaner: keep + "//xla/stream_executor", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/platform", "@local_tsl//tsl/platform:errors", @@ -363,6 +363,7 @@ cc_library( "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor:fft", "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_executor_header", @@ -416,7 +417,7 @@ cc_library( "@local_config_cuda//cuda:cudnn_header", "//xla/stream_executor:dnn", "//xla/stream_executor:plugin_registry", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_timer_header", From e19cad0b8aa85ffcb45a80b0efe17a22953eece0 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Mon, 27 Nov 2023 19:36:32 -0800 Subject: [PATCH 124/381] [stream_executor] Do not use KernelArgs directly in GpuExecutor::Launch PiperOrigin-RevId: 585825723 --- .../xla/xla/stream_executor/cuda/cuda_executor.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 643435a17ee2b8..500ac1789e495d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -444,10 +444,11 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, CHECK_EQ(kernel.Arity() + (packed.number_of_shared_bytes() > 0), packed.number_of_arguments()); void** params = const_cast(packed.argument_addresses().data()); - return GpuDriver::LaunchKernel( - context_, kernel.name(), cufunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), custream, params, nullptr /* = extra */); + return GpuDriver::LaunchKernel(context_, kernel.name(), cufunc, + block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), custream, + params, nullptr /* = extra */); }; // If arguments are already packed we can just launch the kernel. From 3e5d6460e63262eee6cbe59e8eea5939e78ba5a8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 19:41:31 -0800 Subject: [PATCH 125/381] [stream_executor] Add While conditional command to CommandBuffer #6973 PiperOrigin-RevId: 585826616 --- .../xla/xla/stream_executor/command_buffer.cc | 7 ++ .../xla/xla/stream_executor/command_buffer.h | 35 ++++++--- .../cuda/cuda_command_buffer_test.cc | 73 ++++++++++++++++++- .../cuda/cuda_conditional_kernels.cu.cc | 8 ++ .../cuda/cuda_test_kernels.cu.cc | 8 ++ .../stream_executor/cuda/cuda_test_kernels.h | 4 + .../stream_executor/gpu/gpu_command_buffer.cc | 38 ++++++++++ .../stream_executor/gpu/gpu_command_buffer.h | 8 ++ .../rocm/hip_conditional_kernels.cu.cc | 3 + .../stream_executor_internal.h | 16 ++++ 10 files changed, 189 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index b27ac616cda856..28fc7533fe2a49 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -141,6 +141,13 @@ tsl::Status CommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, std::move(body_builder)); } +tsl::Status CommandBuffer::While(StreamExecutor* executor, + DeviceMemory pred, Builder cond_builder, + Builder body_builder) { + return implementation_->While(executor, pred, std::move(cond_builder), + std::move(body_builder)); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 1c064460b369e8..0a69a463167e30 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -127,30 +127,45 @@ class CommandBuffer { // Command buffer condtitional commands API //--------------------------------------------------------------------------// - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`. + // Adds a conditional operation that will execute a command buffer constructed + // by `then_builder` if `pred` value is `true`. tsl::Status If(StreamExecutor* executor, DeviceMemory pred, Builder then_builder); - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`, or a command buffer - // constructed by `else_builder` if `predicate` is `false`. + // Adds a conditional operation that will execute a command buffer constructed + // by `then_builder` if `pred` value is `true`, or a command buffer + // constructed by `else_builder` if `pred` is `false`. tsl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, Builder then_builder, Builder else_builder); - // Adds a conditional operation that will run a command buffer constructed by - // the `branches` builder at `index`. If `index` is out of range, then it will - // run a conditional command buffer constructed by the last builder. + // Adds a conditional operation that will execute a command buffer constructed + // by the `branches` builder at `index`. If `index` is out of range, then it + // will run a conditional command buffer constructed by the last builder. // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case tsl::Status Case(StreamExecutor* executor, DeviceMemory index, std::vector branches); - // Adds a conditional operation that will run a command buffer constructed by - // the `body_builder` exactly `num_iteration` times. + // Adds a conditional operation that will execute a command buffer constructed + // by the `body_builder` exactly `num_iteration` times. tsl::Status For(StreamExecutor* executor, int32_t num_iteration, DeviceMemory loop_index, Builder body_builder); + // Adds a conditional operation that will execute a command buffer constructed + // by the `cond_builder` that must update `pred` value, and then depending on + // the value might execute command buffer constructed by `body_builder` and + // `cond_builder`. Will continue while `pred` value is `true`. + // + // In pseudocode: + // + // cond_builder() + // while(pred): + // body_builder() + // cond_builder() + // + tsl::Status While(StreamExecutor* executor, DeviceMemory pred, + Builder cond_builder, Builder body_builder); + //--------------------------------------------------------------------------// // Finalizes command buffer and makes it executable. Once command buffer is diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 8322fa0a672263..3f9b94a0e39ae2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/log/check.h" -#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_test_kernels.h" #include "xla/stream_executor/kernel.h" @@ -38,6 +37,8 @@ using AddI32Kernel = TypedKernel, DeviceMemory, DeviceMemory>; using MulI32Kernel = TypedKernel, DeviceMemory, DeviceMemory>; +using IncAndCmpKernel = + TypedKernel, DeviceMemory, int32_t>; using AddI32Ptrs3 = TypedKernel>; @@ -564,6 +565,76 @@ TEST(CudaCommandBufferTest, ConditionalFor) { ASSERT_EQ(dst, expected); } +TEST(CudaCommandBufferTest, ConditionalWhile) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + if (!CommandBuffer::SupportsConditionalCommands(platform)) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + AddI32Kernel add(executor); + IncAndCmpKernel inc_and_cmp(executor); + + { // Load addition kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); + TF_ASSERT_OK(executor->GetKernel(spec, &add)); + } + + { // Load inc_and_cmp kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetIncAndCmpCudaKernel(), "inc_and_cmp"); + TF_ASSERT_OK(executor->GetKernel(spec, &inc_and_cmp)); + } + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=0, loop_index=1, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory loop_index = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + static constexpr bool kTrue = true; + stream.ThenMemcpy(&pred, &kTrue, 1); + stream.ThenMemset32(&loop_index, 1, sizeof(int32_t)); + stream.ThenMemset32(&a, 1, byte_length); + stream.ThenMemZero(&b, byte_length); + + int32_t num_iters = 10; + + // Loop cond: loop_index++ < num_iters; + CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { + return cond_cmd->Launch(inc_and_cmp, ThreadDim(), BlockDim(), loop_index, + pred, num_iters); + }; + + // Loop body: b = a + b + CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(length), a, b, b); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.While(executor, pred, cond_builder, body_builder)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 42); + stream.ThenMemcpy(dst.data(), b, byte_length); + + std::vector expected = {10, 10, 10, 10}; + ASSERT_EQ(dst, expected); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index 8796c17461febb..09475192d460bf 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -22,6 +22,9 @@ namespace stream_executor { namespace cuda { namespace { +// In all kernels defined above we set conditional handle value to `1` when we +// want to execute a CUDA graph tied to it, and to `0` otherwise. + #if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) && \ CUDA_VERSION >= 12030 @@ -115,6 +118,11 @@ void* GetSetForConditionKernel() { return reinterpret_cast(&cuda::SetForCondition); } +void* GetSetWhileConditionKernel() { + // While condition kernel is the same as an `If` with a single branch. + return reinterpret_cast(&cuda::SetIfCondition); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc index 171deb030e0b83..84b4e0a8d4d5c0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc @@ -29,6 +29,12 @@ __global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { c[index] = a[index] * b[index]; } +__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t value) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + pred[index] = counter[index] < value; + counter[index] += 1; +} + __global__ void AddI32Ptrs3(Ptrs3 ptrs) { int index = threadIdx.x + blockIdx.x * blockDim.x; ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; @@ -38,6 +44,8 @@ void* GetAddI32CudaKernel() { return reinterpret_cast(&AddI32); } void* GetMulI32CudaKernel() { return reinterpret_cast(&MulI32); } +void* GetIncAndCmpCudaKernel() { return reinterpret_cast(&IncAndCmp); } + void* GetAddI32Ptrs3CudaKernel() { return reinterpret_cast(&AddI32Ptrs3); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h index c02682b22f9d13..94014f5d76092f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h @@ -88,6 +88,10 @@ void* GetAddI32CudaKernel(); // Returns a pointer to device kernel doing multiplication instead of addition. void* GetMulI32CudaKernel(); +// Returns a pointer to device kernel doing increment and compare, intended for +// testing on-device while loops. +void* GetIncAndCmpCudaKernel(); + // Returns a pointer to device kernel compiled from the CUDA C++ but with all // three pointers passed to argument as an instance of `Ptr3` template to test // StreamExecutor arguments packing for custom C++ types. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 7afd7c11c1cd42..7737c73aa29058 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -550,6 +550,44 @@ tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, return CreateConditionalCommand(ConditionType::kWhile, set_cond_fn, builders); } +tsl::Status GpuCommandBuffer::While(StreamExecutor* executor, + DeviceMemory pred, + CommandBuffer::Builder cond_builder, + CommandBuffer::Builder body_builder) { + DCHECK(executor->implementation() == parent_); + + // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on + // every call to `While`. + SetWhileConditionKernel set_while_condition(executor); + + { // Load kernels that updates condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddInProcessSymbol(gpu::GetSetWhileConditionKernel(), + "set_while_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_while_condition)); + } + + // TODO(ezhulenev): We assume that `pred` already has a value that decides if + // we should go into the first loop iteration. Instead we should run + // `cond_builder` to update primary command buffer. + + auto set_cond_fn = [&](absl::Span handles) { + return Launch(set_while_condition, ThreadDim(), BlockDim(), handles[0], + pred); + }; + + auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { + TF_RETURN_IF_ERROR(body_builder(body)); + TF_RETURN_IF_ERROR(cond_builder(body)); + return body->Launch(set_while_condition, ThreadDim(), BlockDim(), handle, + pred); + }; + + std::array builders = {std::move(body)}; + + return CreateConditionalCommand(ConditionType::kWhile, set_cond_fn, builders); +} + tsl::Status GpuCommandBuffer::Finalize() { TF_RETURN_IF_ERROR(CheckNotFinalized()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index d00bb2ab69408d..4ac4329e53df02 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -74,6 +74,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { DeviceMemory loop_index, CommandBuffer::Builder body_builder) override; + tsl::Status While(StreamExecutor* executor, DeviceMemory pred, + CommandBuffer::Builder cond_builder, + CommandBuffer::Builder body_builder) override; + tsl::Status Finalize() override; tsl::Status Update() override; @@ -132,6 +136,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { using SetForConditionKernel = TypedKernel, int32_t>; + using SetWhileConditionKernel = + TypedKernel>; + // A callback to launch a kernel that updates conditional handles state. using SetConditionFn = std::function)>; @@ -274,6 +281,7 @@ void* GetSetIfConditionKernel(); void* GetSetIfElseConditionKernel(); void* GetSetCaseConditionKernel(); void* GetSetForConditionKernel(); +void* GetSetWhileConditionKernel(); } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index a8cbe5376ae0b9..654b9f02c4879e 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -37,6 +37,9 @@ void* GetSetCaseConditionKernel() { void* GetSetForConditionKernel() { return reinterpret_cast(&rocm::SetCondition); } +void* GetSetWhileConditionKernel() { + return reinterpret_cast(&rocm::SetCondition); +} } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index af5f14427f2b4f..e21adebc37665e 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -173,6 +173,22 @@ class CommandBufferInterface { DeviceMemory loop_index, CommandBuffer::Builder body_builder) = 0; + // Adds a conditional operation that will execute a command buffer constructed + // by the `cond_builder` that must update `pred` value, and then depending on + // the value might execute command buffer constructed by `body_builder` and + // `cond_builder`. Will continue while `pred` value is `true`. + // + // In pseudocode: + // + // cond_builder() + // while(pred): + // body_builder() + // cond_builder() + // + virtual tsl::Status While(StreamExecutor* executor, DeviceMemory pred, + CommandBuffer::Builder cond_builder, + CommandBuffer::Builder body_builder) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From 1f5111710d31cec6ab819eb8ed0f0332a8ac0835 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 20:09:05 -0800 Subject: [PATCH 126/381] [stream_executor] NFC: Do not leak internal stream executor header PiperOrigin-RevId: 585832869 --- third_party/xla/xla/stream_executor/BUILD | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 638fbaca0278ae..66c87288ee136d 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -75,7 +75,6 @@ filegroup( "stream.h", "stream_executor.h", "stream_executor_internal.h", # TODO(ezhulenev): Remove private header - "stream_executor_pimpl.h", # TODO(ezhulenev): Remove private header "temporary_device_memory.h", "temporary_memory_manager.h", "trace_listener.h", @@ -133,7 +132,9 @@ cc_library( ":stream_executor_plugin_headers", ], visibility = ["//visibility:public"], - deps = STREAM_EXECUTOR_DEPENDENCIES + if_static([ + deps = STREAM_EXECUTOR_DEPENDENCIES + [ + ":stream_executor_pimpl", + ] + if_static([ ":stream_executor_impl", "@com_google_protobuf//:protobuf", # indirectly-used by dnn.h ]), @@ -387,6 +388,7 @@ cc_library( cc_library( name = "stream_executor_headers", hdrs = [ + "stream_executor_pimpl.h", # TODO(ezhulenev): Remove internal header ":stream_executor_api_headers", ":stream_executor_plugin_headers", ], From 3e93ecc2ab5555bb8e71917982cf579ef77346bb Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Mon, 27 Nov 2023 20:41:14 -0800 Subject: [PATCH 127/381] Adds tests for AbsOp, DivOp, ExpOp, MaxOp, MulOp, PowOp, SubOp in `xla_builder_test` and `shape_inference_test`. The tests test shape inference for unbounded dynamism. cl/580710576 adds the necessary changes in XlaBuilder API, verifier, and shape inference following StableHLO rules for unbounded dynamism. Implicit broadcasting support in XlaBuilder API will be addressed in a follow up CL. PiperOrigin-RevId: 585838056 --- .../xla/xla/client/xla_builder_test.cc | 190 ++++++++++++++++++ .../xla/xla/service/shape_inference_test.cc | 145 +++++++++++++ 2 files changed, 335 insertions(+) diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 2fc57142837115..8979a6f2b8e461 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1558,6 +1558,21 @@ TEST_F(XlaBuilderTest, TopKDimensions) { EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(1), k); } +TEST_F(XlaBuilderTest, UnboundedAbs) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Abs(Parameter(&b, 0, operand.value(), "operand")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + TEST_F(XlaBuilderTest, UnboundedAdd) { XlaBuilder b(TestName()); StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); @@ -1590,5 +1605,180 @@ TEST_F(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { HasSubstr("Unbounded dynamic shapes not supported")); } +TEST_F(XlaBuilderTest, UnboundedDiv) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Div(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedDivUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Div(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedExp) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Exp(Parameter(&b, 0, operand.value(), "operand")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMax) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Max(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMaxUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Max(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedMul) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Mul(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMulUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Mul(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedPow) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Pow(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedPowUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Pow(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedSub) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Sub(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedSubUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Sub(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 229a1a02736747..2068bb0ec3661b 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -114,6 +114,10 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { class UnboundedBinaryOpShapeInferenceTest : public ::testing::TestWithParam> {}; +// Subclass for testing unbounded dynamic unary ops +class UnboundedUnaryOpShapeInferenceTest + : public ::testing::TestWithParam> {}; + TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = @@ -3742,6 +3746,19 @@ INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest, BF16}), ScatterTestName()); +TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedAbs) { + StatusOr operand = ParseShape(GetParam()[0]); + StatusOr expected = ParseShape(GetParam()[1]); + ASSERT_IS_OK(operand.status()); + StatusOr inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kExp, operand.value()); + ASSERT_IS_OK(expected.status()); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { StatusOr lhs = ParseShape(GetParam()[0]); StatusOr rhs = ParseShape(GetParam()[1]); @@ -3762,6 +3779,119 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { } } +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { + auto lhs = ParseShape(GetParam()[0]); + auto rhs = ParseShape(GetParam()[1]); + auto expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kDivide, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op divide with incompatible shapes")); + } +} + +TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedExp) { + auto operand = ParseShape(GetParam()[0]); + auto expected = ParseShape(GetParam()[1]); + ASSERT_IS_OK(operand.status()); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kExp, operand.value()); + ASSERT_IS_OK(expected.status()); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { + auto lhs = ParseShape(GetParam()[0]); + auto rhs = ParseShape(GetParam()[1]); + auto expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kMaximum, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op maximum with incompatible shapes")); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { + auto lhs = ParseShape(GetParam()[0]); + auto rhs = ParseShape(GetParam()[1]); + auto expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kMultiply, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op multiply with incompatible shapes")); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { + auto lhs = ParseShape(GetParam()[0]); + auto rhs = ParseShape(GetParam()[1]); + auto expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op power with incompatible shapes")); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { + auto lhs = ParseShape(GetParam()[0]); + auto rhs = ParseShape(GetParam()[1]); + auto expected = ParseShape(GetParam()[2]); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kSubtract, lhs.value(), rhs.value(), + /*broadcast_dimensions=*/{}); + if (inferred_status.ok()) { + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); + } else { + EXPECT_THAT(inferred_status.status().message(), + HasSubstr("Binary op subtract with incompatible shapes")); + } +} + INSTANTIATE_TEST_SUITE_P( UnboundedDynamism, UnboundedBinaryOpShapeInferenceTest, ::testing::Values( @@ -3783,5 +3913,20 @@ INSTANTIATE_TEST_SUITE_P( // ?,2 | ?,3 | error std::vector({"f32[?,2]", "f32[?,3]", ""}))); +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedUnaryOpShapeInferenceTest, + ::testing::Values( + // OPERAND | Result + // 1 | 1 + std::vector({"f32[1]", "f32[1]"}), + // 2 | 2 + std::vector({"f32[2]", "f32[2]"}), + // <=2 | <=2 + std::vector({"f32[<=2]", "f32[<=2]"}), + // ? | ? + std::vector({"f32[?]", "f32[?]"}), + // ?,3 | ?,3 + std::vector({"f32[?,3]", + "f32[?,3]"}))); + } // namespace } // namespace xla From 5ad0df07de341cd8160e04549fb008048097178b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 27 Nov 2023 22:45:01 -0800 Subject: [PATCH 128/381] [stream_executor] NFC: Restore private visibility of stream executor private targets PiperOrigin-RevId: 585856635 --- third_party/xla/xla/stream_executor/cuda/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 2ea95309142e5e..dc91bfeb6eef1a 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -583,10 +583,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_absl//absl/strings:str_format", - "//xla/stream_executor:command_buffer", - "//xla/stream_executor:kernel", + "//xla/stream_executor", "//xla/stream_executor:plugin_registry", - "//xla/stream_executor:stream_executor", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:asm_compiler", "//xla/stream_executor/gpu:gpu_command_buffer", From ebcfc558856ff3253d90f5724fc5f352f7cfc5a9 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Mon, 27 Nov 2023 22:53:04 -0800 Subject: [PATCH 129/381] Add integration tests for gather and slice PiperOrigin-RevId: 585857981 --- .../python/integration_test/quantize_model_test.py | 2 ++ .../integration_test/quantize_model_test_base.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 9a796dedfe8e54..e92b996a1c5110 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -135,9 +135,11 @@ def data_gen() -> repr_dataset.RepresentativeDataset: parameter_combinations([{ 'same_scale_op': [ 'concatenate', + 'gather', 'pad', 'reshape', 'select', + 'slice', 'transpose', ], }]) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index cd96da54209b5b..9600754c67e7d1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -181,6 +181,8 @@ def matmul_and_same_scale( if self.same_scale_op == 'concatenate': ones = array_ops.ones_like(out) out = array_ops.concat([out, ones], 0) + elif self.same_scale_op == 'gather': + out = array_ops.gather(out, indices=[0], axis=0) elif self.same_scale_op == 'pad': paddings = array_ops.ones( (array_ops.rank(out), 2), dtype=dtypes.int32 @@ -195,6 +197,14 @@ def matmul_and_same_scale( ) ones = array_ops.ones_like(out) out = math_ops.select(condition, out, ones) + elif self.same_scale_op == 'slice': + begin = array_ops.zeros( + (array_ops.rank(out)), dtype=dtypes.int32 + ) + size = array_ops.ones( + (array_ops.rank(out)), dtype=dtypes.int32 + ) + out = array_ops.slice(out, begin, size) elif self.same_scale_op == 'transpose': out = array_ops.transpose(out) else: From da795cd6073f56fcf6aa8b011e033f66a2ed37fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 00:21:33 -0800 Subject: [PATCH 130/381] [XLA/GPU]Add methods to construct GpuCompilationEnvironment. -Create GpuCompEnv from flags vector, from env var, and with default values. -Extend command_line_flags to parse string flags from a vector. -Implement ProcessNewEnv for GpuCompEnv which merge added GpuCompEnv proto and EnvVar flags, and initializes missing flags with default values. -Added a dummy flag to implement tests. PiperOrigin-RevId: 585874452 --- .../tsl/tsl/util/command_line_flags.cc | 24 +++ .../tsl/tsl/util/command_line_flags.h | 5 + third_party/xla/xla/service/BUILD | 25 +++ .../service/gpu_compilation_environment.cc | 93 ++++++++++- .../xla/service/gpu_compilation_environment.h | 11 +- .../gpu_compilation_environment_test.cc | 151 ++++++++++++++++++ third_party/xla/xla/xla.proto | 6 +- 7 files changed, 305 insertions(+), 10 deletions(-) create mode 100644 third_party/xla/xla/service/gpu_compilation_environment_test.cc diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc index 4e69ea5638b8be..5e316e9ae9fc6a 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc +++ b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tsl/util/command_line_flags.h" +#include #include #include #include @@ -291,6 +292,29 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); } +/*static*/ bool Flags::Parse(std::vector& flags, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (auto& flag : flags) { + for (const Flag& flag_object : flag_list) { + bool value_parsing_ok; + bool was_found = flag_object.Parse(flag, &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + // Clear parsed flags, these empty entries are removed later. + if (was_found) { + flag.clear(); + break; + } + } + } + auto IsEmpty = [](const std::string& flag) { return flag.empty(); }; + flags.erase(std::remove_if(flags.begin(), flags.end(), IsEmpty), flags.end()); + return result; +} + /*static*/ string Flags::Usage(const string& cmdline, const std::vector& flag_list) { string usage_text; diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h index 6553bc887c853e..2710de5753cd01 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h +++ b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h @@ -132,6 +132,11 @@ class Flags { // first remaining argument is not "--help". static bool Parse(int* argc, char** argv, const std::vector& flag_list); + // Similar as above, but accepts a mutable vector of strings in place of + // argc and argv. Doesn't ignore the first flag, and return the unknown flags + // back in flags vector. + static bool Parse(std::vector& flags, + const std::vector& flag_list); // Return a usage message with command line cmdline, and the // usage_text strings in flag_list[]. static string Usage(const string& cmdline, diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 0f615afd3fad88..18750c8c4f9a89 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7444,8 +7444,33 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":compilation_environments", + "//xla:parse_flags_from_env", + "//xla:statusor", + "//xla:util", "//xla:xla_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/util:command_line_flags", + ], +) + +xla_cc_test( + name = "gpu_compilation_environment_test", + size = "small", + srcs = ["gpu_compilation_environment_test.cc"], + deps = [ + ":compilation_environments", + ":gpu_compilation_environment", + "//xla:parse_flags_from_env", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu_compilation_environment.cc b/third_party/xla/xla/service/gpu_compilation_environment.cc index 0b0fbfb6bb9986..8f5b0b3c31d128 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.cc +++ b/third_party/xla/xla/service/gpu_compilation_environment.cc @@ -15,17 +15,67 @@ limitations under the License. #include "xla/service/gpu_compilation_environment.h" +#include #include +#include +#include +#include "absl/strings/str_join.h" +#include "xla/parse_flags_from_env.h" #include "xla/service/compilation_environments.h" -#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" +#include "tsl/util/command_line_flags.h" namespace xla { -// TODO(b/284274097): Create flags with default values when flags -// are moved from DebugOptions to GpuCompilationEnvironment. -std::unique_ptr CreateDefaultGpuCompEnv() { - return std::make_unique(); +void InitializeFlagsForGpuCompEnv(std::vector* flag_list, + GpuCompilationEnvironment* gpu_comp_env) { + auto int64_setter_for = + [gpu_comp_env]( + void (GpuCompilationEnvironment::*member_setter)(int64_t)) { + return [gpu_comp_env, member_setter](int64_t value) { + (gpu_comp_env->*member_setter)(value); + return true; + }; + }; + flag_list->push_back(tsl::Flag( + "dummy_flag", + int64_setter_for(&GpuCompilationEnvironment::set_dummy_flag), + gpu_comp_env->dummy_flag(), "Dummy flag to demonstrate the flow")); +} + +StatusOr CreateGpuCompEnvFromFlagStrings( + std::vector& flags, bool strict) { + GpuCompilationEnvironment gpu_comp_env; + std::vector flag_objects; + InitializeFlagsForGpuCompEnv(&flag_objects, &gpu_comp_env); + bool result = tsl::Flags::Parse(flags, flag_objects); + if (!result || (strict && !flags.empty())) { + return InvalidArgument("Could not parse flags: %s", + absl::StrJoin(flags, ", ")); + } + return gpu_comp_env; +} + +StatusOr CreateGpuCompEnvFromEnvVar() { + GpuCompilationEnvironment env; + std::vector flag_objects; + InitializeFlagsForGpuCompEnv(&flag_objects, &env); + bool result = ParseFlagsFromEnvAndIgnoreUnknown("XLA_FLAGS", flag_objects); + if (!result) { + return InvalidArgument("Could not parse XLA_FLAGS."); + } + return env; +} + +GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues() { + GpuCompilationEnvironment env; + env.set_dummy_flag(1); + return env; } namespace { @@ -36,10 +86,39 @@ namespace { // // The implementation returns Default env if one doesn't exist already. // NOLINTNEXTLINE -std::unique_ptr ProcessNewGpuCompilationEnvironment( +StatusOr> +ProcessNewGpuCompilationEnvironment( std::unique_ptr env) { // NOLINT if (!env) { - return xla::CreateDefaultGpuCompEnv(); + env = std::make_unique(); + } + TF_ASSIGN_OR_RETURN(GpuCompilationEnvironment from_env, + CreateGpuCompEnvFromEnvVar()); + + auto default_env = CreateGpuCompEnvWithDefaultValues(); + + auto reflection = env->GetReflection(); + auto reflection_from_env = from_env.GetReflection(); + auto descriptor = GpuCompilationEnvironment::descriptor(); + std::vector missing_fields; + + for (int j = 0; j < descriptor->field_count(); ++j) { + const tsl::protobuf::FieldDescriptor* field = descriptor->field(j); + if (reflection->HasField(*env, field) && + reflection_from_env->HasField(from_env, field)) { + return InvalidArgument( + "Flag %s is set in both XLA_FLAGS env var and " + "GpuCompilationEnvironment.", + field->name()); + } else if (!reflection->HasField(*env, field) && + !reflection_from_env->HasField(from_env, field)) { + missing_fields.push_back(field); + } + } + env->MergeFrom(from_env); + + if (!missing_fields.empty()) { + reflection->SwapFields(env.get(), &default_env, missing_fields); } return env; } diff --git a/third_party/xla/xla/service/gpu_compilation_environment.h b/third_party/xla/xla/service/gpu_compilation_environment.h index ade7536a78261a..99d93c185042ce 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.h +++ b/third_party/xla/xla/service/gpu_compilation_environment.h @@ -15,13 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_COMPILATION_ENVIRONMENT_H_ #define XLA_SERVICE_GPU_COMPILATION_ENVIRONMENT_H_ -#include +#include +#include +#include "xla/statusor.h" #include "xla/xla.pb.h" namespace xla { -std::unique_ptr CreateDefaultGpuCompEnv(); +StatusOr CreateGpuCompEnvFromFlagStrings( + std::vector& flags, bool strict); + +StatusOr CreateGpuCompEnvFromEnvVar(); + +GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues(); } // namespace xla #endif // XLA_SERVICE_GPU_COMPILATION_ENVIRONMENT_H_ diff --git a/third_party/xla/xla/service/gpu_compilation_environment_test.cc b/third_party/xla/xla/service/gpu_compilation_environment_test.cc new file mode 100644 index 00000000000000..22efaa4a317d66 --- /dev/null +++ b/third_party/xla/xla/service/gpu_compilation_environment_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu_compilation_environment.h" + +#include +#include +#include +#include + +#include +#include +#include "xla/parse_flags_from_env.h" +#include "xla/service/compilation_environments.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using ::tsl::testing::StatusIs; + +void set_xla_flags_env_var(const std::string& xla_flags) { + int* pargc; + std::vector* pargv; + ResetFlagsFromEnvForTesting("XLA_FLAGS", &pargc, &pargv); + tsl::setenv("XLA_FLAGS", xla_flags.c_str(), true /*overwrite*/); +} + +TEST(CreateGpuCompEnvFromFlagStringsTest, ValidFlags) { + std::vector flags = {"--dummy_flag=2"}; + + TF_ASSERT_OK_AND_ASSIGN( + GpuCompilationEnvironment gpu_comp_env, + CreateGpuCompEnvFromFlagStrings(flags, /*strict=*/true)); + + ASSERT_EQ(gpu_comp_env.dummy_flag(), 2); + ASSERT_TRUE(flags.empty()); +} + +TEST(CreateGpuCompEnvFromFlagStringsTest, EmptyFlags) { + std::vector flags; + + TF_ASSERT_OK_AND_ASSIGN( + GpuCompilationEnvironment gpu_comp_env, + CreateGpuCompEnvFromFlagStrings(flags, /*strict=*/true)); +} + +TEST(CreateGpuCompEnvFromFlagStringsTest, InvalidFlagName) { + std::vector flags = {"--xla_gpu_invalid_flag=2"}; + + EXPECT_THAT(CreateGpuCompEnvFromFlagStrings(flags, /*strict=*/true), + StatusIs(tsl::error::INVALID_ARGUMENT)); + + TF_ASSERT_OK_AND_ASSIGN( + GpuCompilationEnvironment gpu_comp_env, + CreateGpuCompEnvFromFlagStrings(flags, /*strict=*/false)); + ASSERT_EQ(flags.size(), 1); +} + +TEST(CreateGpuCompEnvFromFlagStringsTest, InvalidFlagValue) { + std::vector flags = {"--dummy_flag=foo"}; + + EXPECT_THAT(CreateGpuCompEnvFromFlagStrings(flags, /*strict=*/false), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + +TEST(CreateGpuCompEnvFromEnvVarTest, ValidFlags) { + set_xla_flags_env_var("--dummy_flag=4"); + + TF_ASSERT_OK_AND_ASSIGN(GpuCompilationEnvironment gpu_comp_env, + CreateGpuCompEnvFromEnvVar()); + + ASSERT_EQ(gpu_comp_env.dummy_flag(), 4); +} + +TEST(CreateGpuCompEnvFromEnvVarTest, InvalidFlagValue) { + set_xla_flags_env_var("--dummy_flag=foo"); + + EXPECT_THAT(CreateGpuCompEnvFromEnvVar(), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + +TEST(ProcessNewEnvTest, BothProtoAndEnvVarUnset) { + set_xla_flags_env_var(""); + CompilationEnvironments envs; + + const auto& env = envs.GetEnv(); + + EXPECT_EQ(env.dummy_flag(), 1); +} + +TEST(ProcessNewEnvTest, ProtoSetButEnvVarUnset) { + set_xla_flags_env_var(""); + CompilationEnvironments envs; + { + auto env = std::make_unique(); + env->set_dummy_flag(2); + TF_ASSERT_OK(envs.AddEnv(std::move(env))); + } + const auto& env = envs.GetEnv(); + + EXPECT_EQ(env.dummy_flag(), 2); +} + +TEST(ProcessNewEnvTest, ProtoUnsetButEnvVarSet) { + set_xla_flags_env_var("--dummy_flag=4"); + CompilationEnvironments envs; + const auto& env = envs.GetEnv(); + + EXPECT_EQ(env.dummy_flag(), 4); +} + +TEST(ProcessNewEnvTest, BothProtoAndEnvVarSetButNoConflict) { + set_xla_flags_env_var("--dummy_flag=4"); + CompilationEnvironments envs; + { + auto env = std::make_unique(); + TF_ASSERT_OK(envs.AddEnv(std::move(env))); + } + const auto& env = envs.GetEnv(); + EXPECT_EQ(env.dummy_flag(), 4); +} + +TEST(ProcessNewEnvTest, BothProtoAndEnvVarSetWithConflict) { + set_xla_flags_env_var("--dummy_flag=4"); + + CompilationEnvironments envs; + auto env = std::make_unique(); + env->set_dummy_flag(2); + EXPECT_THAT(envs.AddEnv(std::move(env)), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 4a9b819ef1f90a..71f56835c3b112 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -695,7 +695,11 @@ message DebugOptions { // Contains flags which affects the GPU compilation result. // These flags are part of Debug Options as of now, and will be migrated to // this proto. -message GpuCompilationEnvironment {} +message GpuCompilationEnvironment { + // Temporary dummy flag is added to test the flow. + // To be removed when we add flags here. + int64 dummy_flag = 1; +} message ShardableValueUpdatePairProto { int64 input_parameter_number = 1; From bb8148cdfdc1254c29a0e3f31992f15588776bff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 01:02:05 -0800 Subject: [PATCH 131/381] Update GraphDef version to 1694. PiperOrigin-RevId: 585882366 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8a73ca7c587313..d766021c9e5a68 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1693 // Updated: 2023/11/27 +#define TF_GRAPH_DEF_VERSION 1694 // Updated: 2023/11/28 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 99926785f7c9eaf53d94343916f14f300965ae72 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 01:02:07 -0800 Subject: [PATCH 132/381] compat: Update forward compatibility horizon to 2023-11-28 PiperOrigin-RevId: 585882378 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 2941ade58de207..221e88fa1b8b0a 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 27) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 28) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From b37d189391243ecebb9f9b930862e5239f4f6296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 28 Nov 2023 01:32:02 -0800 Subject: [PATCH 133/381] [XLA:GPU] Remove xla_gpu_single_wave_autotuning flag That is the default behavior now. PiperOrigin-RevId: 585888948 --- third_party/xla/xla/debug_options_flags.cc | 8 ----- .../xla/xla/service/gpu/triton_autotuner.cc | 36 ++++++------------- third_party/xla/xla/xla.proto | 4 +-- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index b8673b2a43dd67..4a1d998e5d76a4 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -201,7 +201,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_collect_cost_model_stats(false); opts.set_xla_gpu_enable_split_k_autotuning(true); - opts.set_xla_gpu_single_wave_autotuning(true); opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); @@ -1354,13 +1353,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_split_k_autotuning), debug_options->xla_gpu_enable_split_k_autotuning(), "Enable split_k autotuning for triton gemms.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_single_wave_autotuning", - bool_setter_for(&DebugOptions::set_xla_gpu_single_wave_autotuning), - debug_options->xla_gpu_single_wave_autotuning(), - "Enable single \"wave\" autotuning. This uses more memory for " - "compilation, but utilizes CPU cores better, so compilation can be " - "faster.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_reduction_epilogue_fusion", diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.cc b/third_party/xla/xla/service/gpu/triton_autotuner.cc index 5909fe611e419b..8c18f253933c31 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner.cc @@ -801,10 +801,10 @@ StatusOr Execute(const AutotuneConfig& config, return best_triton; } -Status DumpAutotunedFusions(const AutotuneConfig& config, - AutotunerCompileUtil& util, - const AutotuneResult result, - const HloFusionInstruction* fusion, int fusion_id) { +Status DumpAutotunedFusion(const AutotuneConfig& config, + AutotunerCompileUtil& util, + const AutotuneResult result, + const HloFusionInstruction* fusion, int fusion_id) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, util.ExtractModule([&](const DebugOptions& debug_opts) { @@ -831,8 +831,7 @@ Status Autotune(const AutotuneConfig& config, AutotunerCompileUtil& util, tsl::thread::ThreadPool* thread_pool, const DebugOptions& debug_opts, const absl::flat_hash_map& gemm_config_sets, - int& fusion_id_for_dump) { + GemmConfigSet>& gemm_config_sets) { absl::flat_hash_map executable_sets; TF_ASSIGN_OR_RETURN( @@ -850,6 +849,7 @@ Status Autotune(const AutotuneConfig& config, AutotunerCompileUtil& util, }); } + int fusion_id = 0; for (const auto& key_value : executable_sets) { const HloFusionInstruction* fusion = key_value.first; const ExecutableSet& executable_set = key_value.second; @@ -858,8 +858,8 @@ Status Autotune(const AutotuneConfig& config, AutotunerCompileUtil& util, fusion, executable_set)); if (debug_opts.xla_gpu_dump_autotuned_triton_fusions()) { - TF_RETURN_IF_ERROR(DumpAutotunedFusions(config, util, result, fusion, - fusion_id_for_dump)); + TF_RETURN_IF_ERROR( + DumpAutotunedFusion(config, util, result, fusion, fusion_id++)); } const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config); @@ -869,8 +869,6 @@ Status Autotune(const AutotuneConfig& config, AutotunerCompileUtil& util, LOG(WARNING) << "AutotunerUtil::AddResult already existed: " << key.ToString(); } - - fusion_id_for_dump += 1; } return OkStatus(); @@ -946,22 +944,8 @@ StatusOr TritonAutotuner::Run( VLOG(1) << "Autotuning " << gemm_config_sets.size() << " fusions " << correctness_check_str << "."; - int fusion_id_for_dump = 0; - if (debug_options.xla_gpu_single_wave_autotuning()) { - // Tune all fusions at once to save time. - TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, - debug_options, gemm_config_sets, - fusion_id_for_dump)); - } else { - // Tune each fusion separately to avoid running out of memory. - for (const auto& key_value : gemm_config_sets) { - absl::flat_hash_map - single_element_map({key_value}); - TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, - debug_options, single_element_map, - fusion_id_for_dump)); - } - } + TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, + debug_options, gemm_config_sets)); VLOG(1) << "Done autotuning."; } } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 71f56835c3b112..39cd91b4940b0b 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -641,7 +641,6 @@ message DebugOptions { bool xla_gpu_enable_split_k_autotuning = 241; - bool xla_gpu_single_wave_autotuning = 242; // Whether reduction epilogue fusion is enabled in fusion passes. bool xla_gpu_enable_reduction_epilogue_fusion = 243; // Allow early return when acquiring NCCL cliques. @@ -689,7 +688,8 @@ message DebugOptions { // xla_gpu_allow_all_reduce_kernel // xla_gpu_enable_experimental_block_size // xla_gpu_graph_level - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194; + // xla_gpu_single_wave_autotuning + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242; } // Contains flags which affects the GPU compilation result. From d4fc2ec8b46c10e4ee9ff277558e95d236b850fa Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 28 Nov 2023 01:52:04 -0800 Subject: [PATCH 134/381] Add a vector->SCF pass to hlo_xla_runtime_pipeline. Without this pass some of the vector.transfers are not unrolled/converted and the pipeline can fail. PiperOrigin-RevId: 585893540 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 336069e8c9baf5..8d05c8ce160edb 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -500,6 +500,7 @@ cc_library( "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", + "@llvm-project//mlir:VectorToSCF", "@llvm-project//mlir:VectorTransforms", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index 5d96212cfba1d2..c6e7792e1cfe0a 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project @@ -292,6 +293,7 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createConvertVectorToSCFPass()); pm.addNestedPass(xla::cpu::createLegalizeI1VectorTransferOpsPass()); pm.addNestedPass( xla::cpu::createConvertXlaCpuMemRefElementCastToLLVMPass()); From e0af6df361f29dc1b8c2d99ee9284b9aa59d74d7 Mon Sep 17 00:00:00 2001 From: sushreebarsa <84765720+sushreebarsa@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:45:50 +0530 Subject: [PATCH 135/381] Update image_ops_impl.py --- tensorflow/python/ops/image_ops_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 265293d1536775..6e19d47460ce0c 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2030,7 +2030,7 @@ def random_brightness(image, max_delta, seed=None): Args: image: An image or images to adjust. - max_delta: float, must be non-negative. The max_delta parameter controls the maximum relative change in brightness. This means that the actual change in brightness will depend on the range of values in the input image. + max_delta: float, must be non-negative. This parameter controls the maximum relative change in brightness. seed: A Python integer. Used to create a random seed. See `tf.compat.v1.set_random_seed` for behavior. From 891be6245c80db066696834bed7ecac7118ac9a8 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 28 Nov 2023 03:36:52 -0800 Subject: [PATCH 136/381] [XLA:GPU] Add cache of runtime data for unfused instruction. PiperOrigin-RevId: 585916356 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../gpu/model/gpu_performance_model.cc | 50 +++++++++++++++++-- .../service/gpu/model/gpu_performance_model.h | 35 ++++++++++++- .../gpu/model/gpu_performance_model_test.cc | 10 ++-- .../xla/xla/service/gpu/priority_fusion.cc | 7 ++- 5 files changed, 92 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 7ef915f9fcaf6f..d8a78bfc3515d6 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -187,6 +187,7 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/time", ] + if_cuda_is_configured(xla_nvml_deps()), diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 1abceb893dcc56..0f759abd33e60e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -292,6 +292,30 @@ bool IsReadCoalesced(const std::optional& fusion_analysis, } // namespace +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + + auto it = instruction_runtime_data_.find(HloInstructionAdaptor(instruction)); + if (it != instruction_runtime_data_.end()) { + return it->second; + } + return std::nullopt; +} + +void GpuPerformanceModelCache::Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data) { + absl::MutexLock lock(&mutex_); + + instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; +} + +void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + + instruction_runtime_data_.erase(HloInstructionAdaptor(instruction)); +} + /*static*/ EstimateRunTimeData GpuPerformanceModel::EstimateRunTimeForInstruction( const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, @@ -337,6 +361,26 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( return {flops, bytes_written, num_threads, write_time, exec_time}; } +/*static*/ EstimateRunTimeData +GpuPerformanceModel::EstimateRunTimeForInstructionCached( + const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config) { + if (config.gpu_performance_model_cache) { + if (auto cached_result = config.gpu_performance_model_cache->Get(*instr)) { + return *cached_result; + } + } + + auto runtime_data = + EstimateRunTimeForInstruction(instr, cost_analysis, config); + + if (config.gpu_performance_model_cache) { + config.gpu_performance_model_cache->Set(*instr, runtime_data); + } + + return runtime_data; +} + // Returns utilization of operand by instruction. Returns 0, if the operand is // not used by the instruction. float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, @@ -669,14 +713,14 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( } EstimateRunTimeData producer_runtime = - EstimateRunTimeForInstruction(producer, cost_analysis, config); + EstimateRunTimeForInstructionCached(producer, cost_analysis, config); std::vector consumer_runtimes; if (config.calculate_full_priority) { consumer_runtimes.reserve(fused_consumers.size()); for (auto* consumer : fused_consumers) { consumer_runtimes.push_back( - EstimateRunTimeForInstruction(consumer, cost_analysis, config)); + EstimateRunTimeForInstructionCached(consumer, cost_analysis, config)); } } @@ -715,7 +759,7 @@ void GpuPerformanceModel::RecordEstimatedRunTime( DCHECK(cost_analysis != nullptr) << "expected cost analysis"; EstimateRunTimeData data = - EstimateRunTimeForInstruction(instruction, cost_analysis, config); + EstimateRunTimeForInstructionCached(instruction, cost_analysis, config); double cycles = absl::ToDoubleNanoseconds(data.exec_time) * cost_analysis->device_info_->clock_rate_ghz(); diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index 5d62ab4351f2fe..8962a5d6b6d176 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -19,9 +19,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" @@ -55,6 +57,26 @@ struct EstimateRunTimeData { absl::Duration exec_time; }; +class GpuPerformanceModelCache { + public: + // Returns cached runtime data for the instruction. Returns nullopt if there + // is no data in cache. + std::optional Get(const HloInstruction& instruction); + + // Sets cache value for the instruction. + void Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data); + + // Removes all cache entries for this instruction. + void Invalidate(const HloInstruction& instruction); + + private: + absl::Mutex mutex_; + + absl::flat_hash_map + instruction_runtime_data_; +}; + struct GpuPerformanceModelOptions { // Whether to attempt to model the effect of uncoalesced reads. bool consider_coalescing = false; @@ -70,23 +92,28 @@ struct GpuPerformanceModelOptions { // If present, use this to retrieve fusion analyses. HloFusionAnalysisCache* fusion_analysis_cache = nullptr; + GpuPerformanceModelCache* gpu_performance_model_cache = nullptr; + static GpuPerformanceModelOptions Default() { return GpuPerformanceModelOptions(); } static GpuPerformanceModelOptions PriorityFusion( - HloFusionAnalysisCache* fusion_analysis_cache) { + HloFusionAnalysisCache* fusion_analysis_cache, + GpuPerformanceModelCache* gpu_performance_model_cache) { GpuPerformanceModelOptions config; config.consider_coalescing = true; config.first_read_from_dram = true; config.calculate_full_priority = true; config.fusion_analysis_cache = fusion_analysis_cache; + config.gpu_performance_model_cache = gpu_performance_model_cache; return config; } static GpuPerformanceModelOptions ForModule(const HloModule* module) { return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion(nullptr) // Only cache within priority fusion. + ? PriorityFusion(nullptr, + nullptr) // Only cache within priority fusion. : Default(); } }; @@ -102,6 +129,10 @@ class GpuPerformanceModel { const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); + static EstimateRunTimeData EstimateRunTimeForInstructionCached( + const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config); + // TODO(shyshkov): Unify interface with EstimateRunTimeForInstruction. static absl::Duration EstimateRunTimeForFusion( const HloInstruction* producer, const HloInstruction* consumer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 0fed03cbba652e..0ed4f664ef1f64 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -360,8 +360,8 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), - consumers); + producer, &analysis_, + GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_fused), 514, 10); @@ -499,8 +499,8 @@ ENTRY e2 { const HloInstruction* producer = consumer->operand(0); GpuPerformanceModel::RunTimes t1 = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), - {consumer}); + producer, &analysis_, + GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), {consumer}); HloComputation* computation_with_fusion = module->GetComputationWithName("e2"); @@ -510,7 +510,7 @@ ENTRY e2 { GpuPerformanceModel::RunTimes t2 = GpuPerformanceModel::EstimateRunTimes( root_with_fusion, &analysis_, - GpuPerformanceModelOptions::PriorityFusion(nullptr), {}); + GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), {}); EXPECT_EQ(t1.time_fused, t2.time_unfused); } diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 79e6e20235c3f5..8b07d6c329df82 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -187,6 +187,8 @@ class GpuPriorityFusionQueue : public FusionQueue { HloInstructionAdaptor fusion_adaptor(*fusion); can_fuse_cache_.erase(fusion_adaptor); + gpu_performance_model_cache_.Invalidate(*fusion); + fusion_analysis_cache_.Invalidate(*fusion); fusion_analysis_cache_.Invalidate(*original_producer); @@ -300,7 +302,8 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, &cost_analysis_, - GpuPerformanceModelOptions::PriorityFusion(&fusion_analysis_cache_), + GpuPerformanceModelOptions::PriorityFusion( + &fusion_analysis_cache_, &gpu_performance_model_cache_), producer->users()); if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); @@ -413,6 +416,8 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::flat_hash_map> can_fuse_cache_; absl::Mutex can_fuse_cache_mutex_; + + GpuPerformanceModelCache gpu_performance_model_cache_; }; } // namespace From 7a06fc70c4d581ad50a9853b3476b54e68185b4e Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 28 Nov 2023 03:51:52 -0800 Subject: [PATCH 137/381] Retire deallocation dialect. Most of this is currently unused and it was upstreamed to MLIR. The buffer reuse pass appears to still be used. PiperOrigin-RevId: 585918957 --- third_party/xla/xla/mlir_hlo/BUILD | 76 -- .../xla/mlir_hlo/deallocation/CMakeLists.txt | 1 - .../mlir_hlo/deallocation/IR/CMakeLists.txt | 43 - .../deallocation/IR/deallocation_ops.cc | 62 -- .../deallocation/IR/deallocation_ops.h | 37 - .../deallocation/IR/deallocation_ops.td | 110 --- .../xla/xla/mlir_hlo/deallocation/README.md | 446 --------- .../deallocation/transforms/CMakeLists.txt | 8 - .../deallocation/transforms/analysis.cc | 115 --- .../deallocation/transforms/analysis.h | 58 -- .../deallocation/transforms/buffer_reuse.cc | 7 - .../convert_deallocation_ops_to_llvm.cc | 202 ---- .../deallocation/transforms/deallocate.cc | 585 ----------- .../transforms/deallocation_simplification.cc | 237 ----- .../transforms/deallocation_to_scf.cc | 143 --- .../deallocation/transforms/debug_passes.cc | 88 -- .../mlir_hlo/deallocation/transforms/passes.h | 31 - .../deallocation/transforms/passes.td | 77 -- .../transforms/split_alloc_tensors.cc | 70 -- .../transforms/xla_buffer_arg_rewrite.cc | 91 -- .../tests/Dialect/deallocation/analysis.mlir | 65 -- .../Dialect/deallocation/buffer_reuse.mlir | 93 -- .../convert_deallocation_ops_to_llvm.mlir | 77 -- .../Dialect/deallocation/deallocate.mlir | 928 ------------------ .../deallocation/deallocate_invalid.mlir | 76 -- .../deallocation/deallocation_ops.mlir | 39 - .../deallocation_simplification.mlir | 181 ---- .../deallocation/deallocation_to_scf.mlir | 42 - .../deallocation/split_alloc_tensors.mlir | 33 - .../tools/mlir-hlo-opt/CMakeLists.txt | 1 - .../tools/mlir-hlo-opt/mlir-hlo-opt.cc | 4 +- .../transforms/generic_host_to_llvm.cc | 4 - .../xla/xla/mlir_hlo/transforms/passes.td | 1 - .../xla/xla/service/cpu/cpu_compiler.cc | 9 +- .../service/cpu/hlo_xla_runtime_pipeline.cc | 40 +- .../service/cpu/hlo_xla_runtime_pipeline.h | 1 - 36 files changed, 16 insertions(+), 4065 deletions(-) delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/IR/CMakeLists.txt delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.h delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.td delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/README.md delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.h delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/convert_deallocation_ops_to_llvm.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocate.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_simplification.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_to_scf.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/debug_passes.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/split_alloc_tensors.cc delete mode 100644 third_party/xla/xla/mlir_hlo/deallocation/transforms/xla_buffer_arg_rewrite.cc delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/analysis.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/convert_deallocation_ops_to_llvm.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate_invalid.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_ops.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_simplification.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_to_scf.mlir delete mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/split_alloc_tensors.mlir diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index f9ca7a0a2b0c69..76c566c2bc78b6 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -375,73 +375,17 @@ gentbl_cc_library( deps = [":hlo_ops_td_files"], ) -td_library( - name = "deallocation_ops_td_files", - srcs = glob(["deallocation/IR/*.td"]), - compatible_with = get_compatible_with_portable(), - includes = ["."], - deps = [ - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "deallocation_ops_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "deallocation/IR/deallocation_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "deallocation/IR/deallocation_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "deallocation/IR/deallocation_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "deallocation/IR/deallocation_dialect.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "deallocation/IR/deallocation_typedefs.h.inc", - ), - ( - ["-gen-typedef-defs"], - "deallocation/IR/deallocation_typedefs.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "deallocation/IR/deallocation_ops.td", - deps = [":deallocation_ops_td_files"], -) - cc_library( name = "deallocation_passes", srcs = [ - "deallocation/transforms/analysis.cc", "deallocation/transforms/buffer_reuse.cc", - "deallocation/transforms/convert_deallocation_ops_to_llvm.cc", - "deallocation/transforms/deallocate.cc", - "deallocation/transforms/deallocation_simplification.cc", - "deallocation/transforms/deallocation_to_scf.cc", - "deallocation/transforms/debug_passes.cc", - "deallocation/transforms/split_alloc_tensors.cc", - "deallocation/transforms/xla_buffer_arg_rewrite.cc", ], hdrs = [ - "deallocation/transforms/analysis.h", "deallocation/transforms/passes.h", ], strip_include_prefix = ".", visibility = ["//visibility:public"], deps = [ - ":deallocation", ":deallocation_passes_inc_gen", ":deallocation_utils", "@llvm-project//llvm:Support", @@ -479,23 +423,6 @@ gentbl_cc_library( deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) -cc_library( - name = "deallocation", - srcs = ["deallocation/IR/deallocation_ops.cc"], - hdrs = ["deallocation/IR/deallocation_ops.h"], - strip_include_prefix = ".", - visibility = ["//visibility:public"], - deps = [ - ":deallocation_ops_inc_gen", - ":deallocation_utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:SCFDialect", - ], -) - cc_library( name = "deallocation_utils", srcs = ["deallocation/utils/util.cc"], @@ -503,7 +430,6 @@ cc_library( strip_include_prefix = ".", visibility = ["//visibility:public"], deps = [ - ":deallocation_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", @@ -1294,7 +1220,6 @@ cc_library( strip_include_prefix = ".", visibility = ["//visibility:public"], deps = [ - ":deallocation", ":deallocation_passes", ":lhlo", ":mhlo_passes", @@ -1553,7 +1478,6 @@ cc_binary( srcs = ["tools/mlir-hlo-opt/mlir-hlo-opt.cc"], deps = [ ":all_passes", - ":deallocation", ":hlo_dialect_registration", ":lhlo", ":lhlo_gpu", diff --git a/third_party/xla/xla/mlir_hlo/deallocation/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/deallocation/CMakeLists.txt index 0868972988fc99..d758e74bb38e82 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/deallocation/CMakeLists.txt @@ -12,6 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_subdirectory(IR) add_subdirectory(transforms) add_subdirectory(utils) \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/deallocation/IR/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/deallocation/IR/CMakeLists.txt deleted file mode 100644 index 89ace26593b6d0..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/IR/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(LLVM_TARGET_DEFINITIONS deallocation_ops.td) -mlir_tablegen(deallocation_ops.h.inc -gen-op-decls) -mlir_tablegen(deallocation_ops.cc.inc -gen-op-defs) -mlir_tablegen(deallocation_dialect.h.inc -gen-dialect-decls) -mlir_tablegen(deallocation_dialect.cc.inc -gen-dialect-defs) -mlir_tablegen(deallocation_typedefs.h.inc -gen-typedef-decls) -mlir_tablegen(deallocation_typedefs.cc.inc -gen-typedef-defs) - -add_public_tablegen_target(MLIRdeallocation_opsIncGen) -add_dependencies(mlir-headers MLIRdeallocation_opsIncGen) - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_dialect_library(DeallocationDialect - deallocation_ops.cc - - DEPENDS - MLIRdeallocation_opsIncGen - - LINK_LIBS PUBLIC - MLIRDeallocationUtils - MLIRControlFlowInterfaces - MLIRIR - MLIRMemRefDialect - MLIRSCFDialect - MLIRSupport -) diff --git a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.cc b/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.cc deleted file mode 100644 index c899022593b755..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "deallocation/IR/deallocation_ops.h" - -#include "deallocation/IR/deallocation_dialect.cc.inc" -#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" - -#define GET_TYPEDEF_CLASSES -#include "deallocation/IR/deallocation_typedefs.cc.inc" -#undef GET_TYPEDEF_CLASSES - -namespace mlir { -namespace deallocation { - -void DeallocationDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "deallocation/IR/deallocation_ops.cc.inc" -#undef GET_OP_LIST - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "deallocation/IR/deallocation_typedefs.cc.inc" -#undef GET_TYPEDEF_LIST - >(); -} - -void OwnOp::build(OpBuilder& odsBuilder, OperationState& odsState, - Value memref) { - return build(odsBuilder, odsState, - OwnershipIndicatorType::get(odsBuilder.getContext()), memref); -} - -void NullOp::build(OpBuilder& odsBuilder, OperationState& odsState) { - return build(odsBuilder, odsState, - OwnershipIndicatorType::get(odsBuilder.getContext())); -} - -} // namespace deallocation -} // namespace mlir - -#define GET_OP_CLASSES -#include "deallocation/IR/deallocation_ops.cc.inc" diff --git a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.h b/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.h deleted file mode 100644 index 74dec9c2e8ee43..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DEALLOACTION_DEALLOCATION_OPS_H -#define MLIR_HLO_DEALLOACTION_DEALLOCATION_OPS_H - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#define GET_TYPEDEF_CLASSES -#include "deallocation/IR/deallocation_typedefs.h.inc" -#undef GET_TYPEDEF_CLASSES - -#define GET_OP_CLASSES -#include "deallocation/IR/deallocation_dialect.h.inc" -#include "deallocation/IR/deallocation_ops.h.inc" -#undef GET_OP_CLASSES - -#endif // MLIR_HLO_DEALLOACTION_DEALLOCATION_OPS_H diff --git a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.td b/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.td deleted file mode 100644 index 97c001dfdc22ce..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/IR/deallocation_ops.td +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef DEALLOCATION_OPS_TD_ -#define DEALLOCATION_OPS_TD_ - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/AttrTypeBase.td" - -def DeallocationDialect : Dialect { - let name = "deallocation"; - - let summary = "Operations for the deallocation dialect"; - let description = [{ - Ops for modelling owned/unowned memrefs using null pointers. - }]; - let cppNamespace = "::mlir::deallocation"; - - let useDefaultTypePrinterParser = 1; - let usePropertiesForAttributes = 0; -} - -def OwnershipIndicatorType : TypeDef { - let mnemonic = "ownership"; - let summary = "an ownership indicator"; -} - -class DeallocationOp traits = []> - : Op; - -def GetBufferOp : DeallocationOp<"get_buffer", [Pure]> { - let summary = "extracts the base pointer as an index"; - - let arguments = (ins AnyTypeOf<[AnyMemRef, OwnershipIndicatorType]>:$alloc); - let results = (outs Index:$result); - - let assemblyFormat = "attr-dict $alloc `:` type($alloc)"; -} - -def OwnOp : DeallocationOp<"own", [Pure]> { - let summary = "declare ownership"; - - let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); - let results = (outs OwnershipIndicatorType:$result); - - let builders = [ - OpBuilder<(ins "Value":$memref)> - ]; - - let assemblyFormat = "attr-dict $memref `:` type($memref)"; -} - -def NullOp : DeallocationOp<"null", [Pure]> { - let summary = "null pointer"; - - let results = (outs OwnershipIndicatorType:$result); - - let builders = [ - OpBuilder<(ins)> - ]; - - let assemblyFormat = "attr-dict"; -} - -def FreeOp : DeallocationOp<"free"> { - let summary = "free"; - - let arguments = (ins OwnershipIndicatorType:$alloc); - - let assemblyFormat = "attr-dict $alloc"; -} - -// TODO(jreiffers): Implement InferTypeOpInterface. -def RetainOp : DeallocationOp<"retain", [AttrSizedOperandSegments]> { - let summary = "null-safe dealloc"; - - let description = [{ - For each memref in `retained`, finds the alloc in `allocs` that it is - derived from and returns it. If not found, returns `null`. - - Any allocs that are not in the result are deallocated. - - `allocs` may contain `null`s. Otherwise, all allocs must be distinct. - `retained` values may alias. - }]; - - let arguments = (ins Variadic:$retained, - Variadic:$allocs); - let results = (outs Variadic:$result_allocs); - - let assemblyFormat = [{ - `(` $retained `)` `of` `(` $allocs `)` attr-dict `:` - functional-type(operands, results) - }]; -} - -#endif // DEALLOCATION_TD_ diff --git a/third_party/xla/xla/mlir_hlo/deallocation/README.md b/third_party/xla/xla/mlir_hlo/deallocation/README.md deleted file mode 100644 index 9cbbed9769ee19..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/README.md +++ /dev/null @@ -1,446 +0,0 @@ -# MLIR-HLO deallocation and buffer reuse passes - -MLIR-HLO deallocation is an alternative to the upstream buffer-deallocation and -buffer-hoisting passes. - -The core concept is that of *ownership*, i.e. for each allocation, we track an -*ownership indicator* that can be moved around. These indicators can be -understood as a `std::unique_ptr` or alternatively a ref-counted pointer with a -maximum count of 1. At the end of a block, an ownership indicator must either -be yielded or the underlying alloc must be freed. In practice, it is not always -known whether a particular alloc is owned by the current block. Therefore, we -must also be able to represent empty ownership indicators (i.e., null pointers). - -## Usage - -This is the recommended and supported pass pipeline to use these passes: - -1. `hlo-split-alloc-tensors` -1. `one-shot-bufferize` with `create-deallocs=0` -1. `hlo-deallocate` -1. `hlo-deallocation-simplification` -1. `hlo-buffer-reuse` -1. `hlo-deallocation-simplification` -1. `hlo-deallocation-to-scf` -1. (...) -1. `convert-deallocation-ops-to-llvm` - -It is possible to use just the deallocation pass or just buffer-reuse, but the -former isn't recommended because the output will be inefficient. The latter will -work as long as the invariants assumed by this code are maintained (in -particular, there should be no unranked memrefs in the input IR, since as -described above, the code here assigns special meaning to those). - -## "ABI" - -As long as the IR contains only a single function, there shouldn't be any sharp -edges here. If there are multiple functions, it is important to pay attention to -the ABI assumed here: - -1. Function arguments are always owned by the caller. -1. Function results are always owned by the caller **and do not alias with any - function arguments**. In other words, function results are always freshly - allocated buffers. Function arguments may alias each other. - -Warning: The second condition here is particularly important - if a function -returns one of its arguments, the deallocation pass will silently introduce a -double free. - -This restriction could be lifted by introducing ownership indicators for -function arguments, but as of March 2023, this is not done. - -## The deallocation pass - -The deallocation pass assumes that: - -1. The input IR was fully bufferized (i.e., no tensors are left in the - program). -1. No `dealloc`s, `alloca`s or `realloc`s exist yet. -1. No `memrefs` with distinct element types alias (strict aliasing; in - particular, no `xla_cpu.memref_element_cast` ops should exist at this point) - -The basic deallocation algorithm works mostly locally within blocks. It -transforms the input IR op by op, keeping track of memref alias information as -it goes. For each op, it produces the following information: 1) which allocs -were released by the parent block (i.e., are no longer owned by it; more on that -in the section on transferring ownership), 2) which new allocs are now owned by -the parent block. For example, when processing an `alloc` op, nothing is -released, and the result of the op is now owned by the block. It also keeps -track of aliasing information. Conservatively, it is assumed that all inputs -alias all compatible outputs. - -When transforming a block, it is not possible to know in general whether -`memref` arguments are owned by it or by some ancestor. Therefore, we introduce -ownership indicator arguments (`!deallocation.ownership`) for each `memref` -argument. Inside the block, `allocs` and alias sets are tracked as described -above. At the end of the block, we must reconcile these memrefs and potentially -owned allocs. We can do this separately for those that are yielded from the -block and those that aren't. - -For `memrefs` (or rather sets of `memrefs` that potentially alias) that aren't -yielded, we must free the corresponding `alloc` if we own it. In general, we -can't know statically whether that's the case, so we use the `retain` op, which -frees non-null allocs [^1] that are no longer needed. To find the place to -insert the op, we simply traverse the block backwards, starting from the -terminator, and look for the last op that contains any reference to a memref -from the alias set. - -``` - // Free %alloc_0 and %alloc_1 iff they are non-null. - deallocation.retain() of(%alloc_0, %alloc_1) - : (!deallocation.ownership, !deallocation.ownership) -> () -``` - -For `memrefs` that are yielded, we also insert retain ops, but this time, we -must retain allocs if we own them. The `retain` ops look like this: - -``` - // Check if %yielded_memref aliases with any of %a, %b or %c. If it does, - // return the corresponding memref. Free the others if they are non-null. - %maybe_owned = deallocation.retain(%yielded_memref) of(%a, %b, %c) - : (!deallocation.ownership, !deallocation.ownership, !deallocation.ownership) - -> (!deallocation.ownership) -``` - -To understand where such ops come from, consider the following code: - -``` - %result = scf.if %cond -> memref<2xi32> { - scf.yield %some_alloc : memref<2xi32> - } else { - %new_alloc = memref.alloc() : memref<2xi32> - scf.yield %new_alloc : memref<2xi32> - } -``` - -Whether the parent block owns the alloc that backs `%result` depends on which -branch was taken. Therefore, after transforming the block, the `if` will look -like this: - -``` - %result, %result_ownership = scf.if %cond -> memref<2xi32> { - %null = deallocation.null - scf.yield %some_alloc, %null : memref<2xi32>, !deallocation.ownership - } else { - %new_alloc = memref.alloc() : memref<2xi32> - %new_alloc_owned = deallocation.own %new_alloc : memref<2x32> - scf.yield %new_alloc, %new_alloc_owned : memref<2xi32>, !deallocation.ownership - } -``` - -`%result_ownership` is nonnull iff `%result` is owned by the parent block. If -`%result` is yielded, the corresponding retain op would be: - -``` - %yielded_result_ownership = deallocation.retain(%result) of(%result_ownership) -``` - -However, here we can statically determine that this always results in -`%result_ownership`, so the `retain` op will not be emitted. - -### Loops and if: `RegionBranchOpInterface` - -RegionBranchOpInterface ops mostly follow what was described above for blocks, -but there are two interesting things about them: - -1. Regions with multiple predecessors -1. Transferring ownership to the op - -*Multiple predecessors*. In `scf.while`, and `scf.if`, some regions have -multiple predecessors (in the case of `while`, the `before` region, in the case -of `if`, the parent region). As it turns out, no special logic is required to -handle this - the regions will always yield the same types of memrefs, and -therefore the added ownership indicators will also have the same types. - -*Transfer of ownership*. If a `memref` operand of a loop has no further uses -after the loop, we can transfer the ownership indicator for the operand to the -loop. Note that this does not necessarily mean ownership is actually -transferred - the ownership indicator may be null. - -#### Implicit capture / implicit transfer of ownership - -Consider the following program, which conditionally reallocates a memref: - -``` -%alloc = memref.alloc(%size) : memref -scf.for %i = %lb to %ub step %step iter_args(%arg0 = %alloc) { - %should_grow, %new_size = "dummy.check_capacity"(%arg0) - : (memref) -> (i1, index) - %mem = scf.if %should_grow { - %0 = memref.realloc %arg0(%new_size) : memref -> memref - scf.yield %0 : memref - } else { - scf.yield %arg0 : memref - } - "dummy.use"(%mem) : (memref) -> () - scf.yield %mem : memref -} -``` - -`%arg0` is owned by the loop, but it must not be deallocated at the end of the -loop body - otherwise, we'd run into a double free when it is reallocated. - -We solve this by defining implicit captures, or implicit transfer of ownership. -`memref.realloc` ops are considered to implicitly capture and release their -operand. There are a couple of restrictions to this: - -1. Only ops owned by the parent block can be implicitly captured. -1. Implicit capture is only allowed in `scf.if` ops. This rule may be applied - recursively. -1. The implicit capture must be the last use of the captured value across all - execution paths. -1. Implied by the previous rule: Implicit capture is not allowed in `scf.if` - ops that do not have an else branch. - -To illustrate these restrictions, we can look at some IR that violates them: - -``` -%alloc = memref.alloc() -scf.if %cond { - %0 = memref.realloc %alloc // invalid -} -``` - -This IR contains an implicit capture inside an `scf.if` without an `else` -branch. Since `%alloc` is only freed if `%cond` is true, there must be some -further use of `%alloc`, which is invalid. To make this valid, the following IR -should be emitted instead: - -``` -%alloc = memref.alloc() -%0 = scf.if %cond { - %1 = memref.realloc %alloc - scf.yield %1 -} else { - scf.yield %alloc -} -``` - -Note that `scf.yield %alloc` is executed no execution path that also executes -the `realloc`, so condition 3 is not violated. - -An example that violates condition 1: - -``` -%alloc = memref.alloc() -scf.for %i = %lb to %ub step %step { - scf.if ... { - %0 = memref.realloc %alloc // invalid - } else { - ... - } -} -``` - -`%alloc` cannot be implicitly captured here, since there is no chain of ancestor -`scf.if` ops to its definition. To make this valid, turn `%alloc` into an -`iter_arg`: - -``` -%alloc = memref.alloc() -%0 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %alloc) { - %1 = scf.if ... { - %2 = memref.realloc %alloc - } else { - ... - } - scf.yield %1 -} -``` - -## Ops in the deallocation dialect - -### The `null` op - -Creates a null pointer. - -### The `own` op - -Declares ownership of an alloc and returns an ownership indicator. This is -lowered to an extraction of the alloc's base pointer. - -### The `retain` op - -Takes a list of memrefs and a list of ownership indicator. For each memref, -returns the ownership (alloc) that it was derived from (if present). Each alloc -is returned at most once. Alloc that are not returned are freed. - -Some retain ops can be simplified to a no op (e.g. if there's only one alloc -and one memref, and they're the same). Others can be rewritten to memref.dealloc -(if we know that the alloc is non-null and there is no memref). This is done by -the `deallocation-simplification` pass. - -There are two lowerings of `retain`: retains with a single memref or a single -ownership indicator are lowered to a sequence of `scf.if` ops. Lowerings with -more than one of either are instead lowered to a library call. For details, see -the section on the deallocation-to-scf pass. - -### The `get_buffer` op - -Returns the memref's base pointer as an index. - -## The buffer reuse pass - -The buffer reuse pass is intended to be run after the deallocation pass and -assumes that the code has the structure that the pass guarantees (in particular, -unranked memref == ownership indicator). For best results, the IR should be -canonicalized first. - -### Loop simplification - -As a preprocessing step, this pass transforms `retain` ops that operate on the -result of loops. Consider the following IR: - -``` -%alloc1 = memref.alloc() : memref<4xi32> -%alloc2 = memref.alloc() : memref<4xi32> -%0:4 = scf.while(%arg0 = %alloc1, $arg1 = %alloc2) { - scf.condition(%cond) %arg1, %arg0 -do { - (...) - scf.yield %arg0, %arg1 -} -memref.dealloc %0#0 : memref<4xi32> -memref.dealloc %0#1 : memref<4xi32> -``` - -`%0#0` and `%0#1` are `%alloc1` and `%alloc2`, in some order. Since there is no -further use of these allocs and they are all deallocated, we can rewrite the -operands to `%alloc1` and `%alloc2`, even though we don't know which one is -which. - -The purpose of this preprocessing step is to allow more buffer reuse, which -requires `dealloc`/`alloc` pairs to work. - -### Buffer reuse - -Buffer reuse coalesces `dealloc`/`alloc` pairs: - -``` -memref.dealloc %alloc : memref<100xi32> -(...) -%alloc_1 = memref.alloc() : memref<100xi32> -``` - -Instead of deallocating and allocating, we replace all uses of `%alloc_1` with -`%alloc`. Currently, we only do this for immediate `dealloc`/`alloc` pairs with -no other `alloc`/`dealloc` ops in between. So in the example above, if `(...)` -included any other allocation or deallocation, no reuse would occur. - -### Copy elision - -Another simple transformation eliminates `alloc`/`copy`/`dealloc` patterns: - -``` -%a = memref.alloc() : memref<100xi32> -(... 1) // no uses of %a -memref.copy %b, %a : memref<100xi32> to memref<100xi32> -memref.dealloc %b : memref<100xi32> -(... 2) // uses of %a -``` - -Since `%a` is completely overwritten with `%b`, which is deallocated immediately -afterwards, we can remove the allocation of `%a` and replace its uses with `%b`. - -``` -(... 1) // potential uses of %b -(... 2) // all uses of %a replaced with %b -``` - -Note: This pattern could be generalized to only look at copy ops and the uses of -its operand, leaving the elimination of the allocation and deallocation to other -patterns. As of March 2023, this is not done. - -### Hoisting - -The second transformation implemented in this pass is buffer hoisting. This -simply looks for allocs that happen in each iteration of a loop and moves them -out of the loop: - -``` -scf.for %i = %c0 to %c1000 step %c1 { - %foo = memref.alloc() : memref<100xi32> - (...) - memref.dealloc %foo : memref<100xi32> -} -``` - -Since the contents of a freshly allocated memref are undefined, this can be -transformed as follows: - -``` -%foo = memref.alloc() : memref<100xi32> -scf.for %i = %c0 to %c1000 step %c1 { - (...) -} -memref.dealloc %foo : memref<100xi32> -``` - -The same transformation applies for while loops, with the caveat that it may -increase peak heap usage in that case. - -### Double buffering - -Double buffering can be considered a variant of hoisting. It is useful in cases -where use ranges of buffers overlap, preventing simple hoisting. Consider the -following IR (ownership indicator omitted for clarity): - -``` -%0 = scf.for %i = %c0 to %c1000 step %c1 iter_args(%arg = %alloc) - -> memref<100xi32> { - %tmp = memref.alloc() : memref<100xi32> - "some.op"(%tmp, %arg) : (memref<100xi32>, memref<100xi32>) -> () - memref.dealloc %arg : memref<100xi32> - scf.yield %tmp : memref<100xi32> -} -memref.dealloc %0 : memref<100xi32> -``` - -The live ranges of `%alloc` and `%tmp` overlap, so we can't do straightforward -hoisting here. However, we only need two distinct buffers at any given time, so -instead, we introduce an additional iter arg for the temporary buffer, hoist and -swap in each iteration: - -``` -%tmp = memref.alloc() : memref<100xi32> -%0, %1 = scf.for %i = %c0 to %c1000 step %c1 - iter_args(%arg = %alloc, %tmp_ = %tmp) -> memref<100xi32> { - "some.op"(%tmp_, %arg) : (memref<100xi32>, memref<100xi32>) -> () - scf.yield %tmp_, %arg : memref<100xi32>, memref<100xi32> -} -memref.dealloc %1 : memref<100xi32> -memref.dealloc %0 : memref<100xi32> -``` - -Note that the presence of a deallocation of `%arg` inside the loop implies no -further uses of `%alloc` after the loop. So, similarly to the case described in -the section on loop simplification, it doesn't matter which alloc is in `%0` and -which one is in `%1`. - -Double buffering works analogously for `while` loops, with the exception that -buffers have to be plumbed through the before region. - -Note: as of March 2023, double buffering allocations in `while` loops is only -implemented for the `after` region. - -## The split-alloc-tensors pass - -This pass is a helper pass to improve the behavior of the other passes when used -together with `one-shot-bufferize`. The purpose of this pass is to prevent -accidental buffer reuse by `one-shot-bufferize` by ensuring each `alloc_tensor` -is used only once, thereby minimizing the sizes of live ranges and enabling the -buffer reuse pass to work optimally. - -## The deallocation-to-scf pass - -As described previously, most `deallocation.retain` ops are eliminated either by -canonicalization or by `buffer-reuse`. `deallocation-to-scf` lowers the ones -that remain to sequences of `scf.if` ops. - -Because the size of the emitted code is in `O(|allocs| * |memrefs|)`, we only -use this lowering when at least one of `|allocs|` or `|memrefs|` is 1. - -[^1]: `memref.dealloc` happens to tolerate null inputs as well, but at this - point of the pipeline, we assume that the argument is always non-null, - because 1) this behavior isn't documented 2) it simplifies analysis in - subsequent passes. diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt index a61329f0677657..efe5e92b44ac89 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt @@ -21,15 +21,7 @@ include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}) add_mlir_library(DeallocationPasses - analysis.cc buffer_reuse.cc - convert_deallocation_ops_to_llvm.cc - deallocate.cc - deallocation_simplification.cc - deallocation_to_scf.cc - debug_passes.cc - split_alloc_tensors.cc - xla_buffer_arg_rewrite.cc DEPENDS MLIRDeallocationPassesIncGen diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.cc deleted file mode 100644 index 85019895762d9b..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "deallocation/transforms/analysis.h" - -#include - -#include "deallocation/utils/util.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -namespace mlir { -namespace deallocation { - -namespace { - -bool isRestrictBbArg(Value value) { - auto bbarg = llvm::dyn_cast(value); - auto func = - llvm::dyn_cast(value.getParentBlock()->getParentOp()); - if (!bbarg || !func) return false; - auto isRestrict = func.getArgAttrOfType(bbarg.getArgNumber(), - "deallocation.restrict"); - return isRestrict && isRestrict.getValue(); -} - -bool isMemref(Value v) { return llvm::isa(v.getType()); } - -} // namespace - -void DeallocationAnalysis::collectBackingMemory( - Value source, DenseSet& visited, - breaks_if_you_move_ops::ValueSet& results) { - if (!isMemref(source)) return; - if (!visited.insert(source).second) return; - - auto type = getElementTypeOrSelf(source); - if (auto bbarg = llvm::dyn_cast(source)) { - results.insert(source); - if (llvm::isa(bbarg.getParentBlock()->getParentOp())) { - if (!isRestrictBbArg(source)) { - // Restrict bbargs can't alias anything else. - for (auto arg : bbarg.getParentBlock()->getArguments()) { - if (isMemref(arg) && getElementTypeOrSelf(arg.getType()) == type) { - results.insert(arg); - } - } - } - } else if (auto rbi = llvm::dyn_cast( - bbarg.getParentRegion()->getParentOp())) { - for (const auto& edge : - getPredecessorRegions(rbi, bbarg.getParentRegion())) { - if (bbarg.getArgNumber() >= edge.successorValueIndex && - static_cast(bbarg.getArgNumber() - - edge.successorValueIndex) <= - edge.getPredecessorOperands().size()) { - Value dep = edge.getPredecessorOperand(bbarg.getArgNumber()); - collectBackingMemory(dep, visited, results); - } - } - } - return; - } - - auto result = llvm::cast(source); - if (auto rbi = llvm::dyn_cast(result.getOwner())) { - for (const auto& edge : - getPredecessorRegions(rbi, RegionBranchPoint::parent())) { - collectBackingMemory(edge.getPredecessorOperand(result.getResultNumber()), - visited, results); - } - } - - if (auto mem = llvm::dyn_cast(result.getOwner())) { - if (mem.getEffectOnValue(result).has_value()) { - results.insert(result); - } - } - - for (auto operand : result.getOwner()->getOperands()) { - if (isMemref(operand) && getElementTypeOrSelf(operand) == type) { - collectBackingMemory(operand, visited, results); - } - } -} - -const breaks_if_you_move_ops::ValueSet& DeallocationAnalysis::getBackingMemory( - Value source) { - auto it = backingMemory.find(source); - if (it != backingMemory.end()) return it->second; - - auto& results = backingMemory[source]; - DenseSet visited; - collectBackingMemory(source, visited, results); - return results; -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.h b/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.h deleted file mode 100644 index aefb8ff03ce03d..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/analysis.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef MLIR_HLO_DEALLOCATION_TRANSFORMS_ANALYSIS_H -#define MLIR_HLO_DEALLOCATION_TRANSFORMS_ANALYSIS_H - -#include "deallocation/utils/util.h" -#include "mlir/IR/Value.h" - -namespace mlir { -namespace deallocation { - -class DeallocationAnalysis { - public: - // Returns the set of all possible values that may back the given value. A - // value `A` is considered to back another value `B` if - // a) `A` is an alloc or a bbarg - // b) `B` depends on `A` (possibly indirectly) - // - // For example, in this IR: - // - // func.func @foo(%arg0: memref) -> memref { - // %c0 = arith.constant 0 : index - // %c4 = arith.constant 4 : index - // %c1 = arith.constant 1 : index - // %ret = scf.for %i = %c0 to %c4 step %c1 iter_args(%x = %arg0) - // -> memref { - // %y = some.op(%x) : memref -> memref - // scf.yield %y : memref - // } - // func.return %ret : memref - // } - // - // `getBackingMemory(%ret)` is {`%arg0`, `%x`, `%y`}. - const breaks_if_you_move_ops::ValueSet& getBackingMemory(Value source); - - private: - void collectBackingMemory(Value source, DenseSet& visited, - breaks_if_you_move_ops::ValueSet& results); - - DenseMap backingMemory; -}; - -} // namespace deallocation -} // namespace mlir - -#endif // MLIR_HLO_DEALLOCATION_TRANSFORMS_ANALYSIS_H diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 3c9d735e6bd5e8..7f95cdbbabc514 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "deallocation/IR/deallocation_ops.h" #include "deallocation/transforms/passes.h" #include "deallocation/utils/util.h" #include "llvm/ADT/STLExtras.h" @@ -375,12 +374,6 @@ void promoteToStack(memref::DeallocOp dealloc) { auto alloca = b.create( alloc->getLoc(), alloc->getResultTypes()[0].cast(), alloc.getAlignmentAttr()); - for (auto* user : alloc->getUsers()) { - if (auto ownership = llvm::dyn_cast(user)) { - b.setInsertionPoint(ownership); - ownership->replaceAllUsesWith(b.create(ownership.getLoc())); - } - } alloc->replaceAllUsesWith(ValueRange{alloca.getResult()}); alloc->erase(); dealloc->erase(); diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/convert_deallocation_ops_to_llvm.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/convert_deallocation_ops_to_llvm.cc deleted file mode 100644 index 5cdd1e9987f563..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/convert_deallocation_ops_to_llvm.cc +++ /dev/null @@ -1,202 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "deallocation/IR/deallocation_ops.h" -#include "deallocation/transforms/passes.h" -#include "mlir/Analysis/DataLayoutAnalysis.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace deallocation { -namespace { - -struct NullOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - NullOp nullOp, OpAdaptor, - ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp( - nullOp, LLVM::LLVMPointerType::get(rewriter.getContext(), 0)); - return success(); - } -}; - -struct OwnOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - OwnOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOp(op, MemRefDescriptor(adaptor.getMemref()) - .allocatedPtr(rewriter, op->getLoc())); - return success(); - } -}; - -struct GetBufferOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - GetBufferOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - if (op.getAlloc().getType().isa()) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->getIndexType(), adaptor.getAlloc()); - } else { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->getIndexType(), - MemRefDescriptor(adaptor.getAlloc()) - .allocatedPtr(rewriter, op->getLoc())); - } - return success(); - } -}; - -struct FreeOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - FreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - auto freeFn = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); - - rewriter.replaceOpWithNewOp(op, freeFn, adaptor.getAlloc()); - return success(); - } -}; - -struct RetainOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult matchAndRewrite( - RetainOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - auto loc = op.getLoc(); - auto ptrTy = LLVM::LLVMPointerType::get(op.getContext()); - rewriter.setInsertionPoint(op); - auto alloca = rewriter.create( - loc, SmallVector(op->getNumResults(), ptrTy)); - auto& body = alloca.getBodyRegion().emplaceBlock(); - rewriter.setInsertionPoint(&body, body.begin()); - - auto i64Ty = rewriter.getI64Type(); - auto ptrPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); - Type indexType = ConvertOpToLLVMPattern::getIndexType(); - auto getBuffers = [&](ValueRange values) { - auto ret = rewriter.create( - loc, ptrPtrTy, ptrTy, - createIndexAttrConstant(rewriter, loc, indexType, - values.size() * - getTypeConverter()->getPointerBitwidth() / - CHAR_BIT)); - for (auto [index, value] : llvm::enumerate(values)) { - auto ptr = rewriter.create( - loc, ptrPtrTy, ptrTy, ret, - createIndexAttrConstant(rewriter, loc, indexType, index)); - rewriter.create(loc, value, ptr); - } - return ret; - }; - - Value numAllocs = createIndexAttrConstant(rewriter, loc, indexType, - op.getAllocs().size()); - Value allocBuffers = getBuffers(adaptor.getAllocs()); - Value numRetained = createIndexAttrConstant(rewriter, loc, indexType, - op.getRetained().size()); - Value retainedBuffers = getBuffers(adaptor.getRetained()); - - auto retainFn = - LLVM::lookupOrCreateFn(op->getParentOfType(), "retainBuffers", - {i64Ty, ptrPtrTy, i64Ty, ptrPtrTy}, - LLVM::LLVMVoidType::get(op->getContext())); - rewriter.create( - loc, retainFn, - ValueRange{numAllocs, allocBuffers, numRetained, retainedBuffers}); - - SmallVector results; - for (auto index : llvm::seq(0, op.getRetained().size())) { - auto ptr = rewriter.create( - loc, ptrPtrTy, ptrTy, retainedBuffers, - createIndexAttrConstant(rewriter, loc, indexType, index)); - results.push_back(rewriter.create(loc, ptrTy, ptr)); - } - rewriter.create(loc, results); - - rewriter.replaceOp(op, alloca->getResults()); - return success(); - } -}; - -#define GEN_PASS_DEF_CONVERTDEALLOCATIONOPSTOLLVMPASS -#include "deallocation/transforms/passes.h.inc" - -struct ConvertDeallocationOpsToLLVMPass - : public impl::ConvertDeallocationOpsToLLVMPassBase< - ConvertDeallocationOpsToLLVMPass> { - ConvertDeallocationOpsToLLVMPass() = default; - - void runOnOperation() override { - Operation* func = getOperation(); - const auto& dataLayoutAnalysis = getAnalysis(); - LowerToLLVMOptions options(&getContext(), - dataLayoutAnalysis.getAtOrAbove(func)); - - LLVMTypeConverter typeConverter(&getContext(), options, - &dataLayoutAnalysis); - RewritePatternSet patterns(&getContext()); - populateDeallocationToLLVMConversionPatterns(typeConverter, patterns); - - LLVMConversionTarget target(getContext()); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - if (failed(applyPartialConversion(func, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -void populateDeallocationToLLVMConversionPatterns(LLVMTypeConverter& converter, - RewritePatternSet& patterns) { - converter.addConversion([&](OwnershipIndicatorType) { - return LLVM::LLVMPointerType::get(&converter.getContext()); - }); - patterns.add(converter); -} - -std::unique_ptr> -createConvertDeallocationOpsToLLVM() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocate.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocate.cc deleted file mode 100644 index 90a1ed051ffff1..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocate.cc +++ /dev/null @@ -1,585 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include - -#include "deallocation/IR/deallocation_ops.h" -#include "deallocation/transforms/passes.h" -#include "deallocation/utils/util.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace deallocation { -namespace { - -bool isMemref(Value v) { return v.getType().isa(); } - -struct TransformResult { - // Allocs that are no longer owned by the current block. Note that it is valid - // for an alloc to be both in `acquired` and `released`, if it was temporarily - // released and then reacquired. It is valid to release an alloc that's not - // owned by the current block, if some ancestor that is reachable without - // crossing a loop boundary owns it. - // Collects values that are the actual memrefs. - breaks_if_you_move_ops::ValueSet released; - - // Allocs that are now owned by the current block. Order matters here - it's - // the same order as in the terminator/result list. - // Collects values that are the ownership indicators. - SmallVector acquired; -}; - -bool doesAlias(Operation* op, Value v, - breaks_if_you_move_ops::ValueEquivalenceClasses& aliases, - bool considerOperands = true) { - auto eq = [&](Value other) { return aliases.isEquivalent(v, other); }; - return op && ((considerOperands && llvm::any_of(op->getOperands(), eq)) || - llvm::any_of(op->getResults(), eq) || - llvm::any_of(op->getRegions(), [&](Region& region) { - return llvm::any_of(region.getOps(), [&](Operation& subOp) { - return doesAlias(&subOp, v, aliases); - }); - })); -} - -struct Deallocator { - void setOwnershipIndicator(Value owned, Value indicator); - Value findOwnershipIndicator(Value v); - - // Transform ops, introducing deallocs. - LogicalResult transformModuleOp(ModuleOp op); - LogicalResult transformFuncOp(func::FuncOp op); - FailureOr transformBlock(Block& block, - bool ownsInputs = true); - FailureOr transformIfImplicitCapture( - scf::IfOp op, TransformResult& ifResult, TransformResult& elseResult); - FailureOr transformOp( - RegionBranchOpInterface op, - const breaks_if_you_move_ops::ValueSet& ownedMemrefs); - FailureOr transformOp(func::CallOp op); - FailureOr transformOp( - Operation* op, const breaks_if_you_move_ops::ValueSet& ownedMemrefs); - - // Internal state keeping track of - // - inter-function aliasing, - // - intra-function aliasing, and - // - ownership indicators per memref. - std::map>> - functionAliasOverapprox; - breaks_if_you_move_ops::ValueEquivalenceClasses aliasOverapprox; - breaks_if_you_move_ops::ValueMap ownershipIndicator; -}; - -void Deallocator::setOwnershipIndicator(Value owned, Value indicator) { - ownershipIndicator[owned] = indicator; - aliasOverapprox.unionSets(owned, indicator); -} - -Value Deallocator::findOwnershipIndicator(Value v) { - if (llvm::isa_and_nonnull( - v.getDefiningOp())) { - return findOwnershipIndicator(v.getDefiningOp()->getOperand(0)); - } - auto it = ownershipIndicator.find(v); - if (it != ownershipIndicator.end()) return it->second; - return {}; -} - -LogicalResult Deallocator::transformModuleOp(ModuleOp op) { - LogicalResult result = success(); - op.walk([&](func::FuncOp funcOp) { - if (failed(transformFuncOp(funcOp))) { - result = failure(); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - - return result; -} - -// TODO(frgossen): Also allow passing ownership to functions. -LogicalResult Deallocator::transformFuncOp(func::FuncOp op) { - // If we find an aliasing record for this function, it is already being - // transformed. We might be hitting a cycle in the call graph here, in which - // case this is a temporary aliasing overapproximation and may be refined - // later. - if (functionAliasOverapprox.find(op) != functionAliasOverapprox.end()) - return success(); - - // Mark function as being processed and provide a valid overapproximation for - // aliasing: every result may alias every argument. - SmallVector> trivialOverapproximation; - int numOwnershipResults = 0; - auto allArgs = llvm::to_vector(llvm::seq(0, op.getNumArguments())); - for (Type resultTy : op.getFunctionType().getResults()) { - auto& resultAliasing = trivialOverapproximation.emplace_back(); - if (!llvm::isa(resultTy)) continue; - resultAliasing = allArgs; - numOwnershipResults++; - } - trivialOverapproximation.append(numOwnershipResults, allArgs); - functionAliasOverapprox[op] = trivialOverapproximation; - - if (op->getNumRegions() == 0) return success(); - - // Transform function body. - assert(op.getBody().getBlocks().size() == 1 && - "expect single block functions"); - Block& block = op.getBody().front(); - auto transformedBlock = transformBlock(block, /*ownsInputs=*/false); - if (failed(transformedBlock)) return failure(); - if (!transformedBlock->released.empty()) { - op->emitOpError("invalid realloc of memref"); - return failure(); - } - - // Update terminator and pass on the ownership indicator per escaping memref. - auto returnOp = llvm::dyn_cast(block.getTerminator()); - returnOp->setOperands(returnOp.getNumOperands(), 0, - transformedBlock->acquired); - op.setFunctionType(mlir::FunctionType::get( - op.getContext(), block.getArgumentTypes(), returnOp.getOperandTypes())); - - // Refine function aliasing based on return values. - SmallVector> refinedOverapproximation; - for (Value result : returnOp.getOperands()) { - auto& resultAliasing = refinedOverapproximation.emplace_back(); - for (auto [j, arg] : llvm::enumerate(op.getArguments())) { - if (aliasOverapprox.isEquivalent(result, arg)) - resultAliasing.push_back(j); - } - } - functionAliasOverapprox[op] = refinedOverapproximation; - - return success(); -} - -FailureOr Deallocator::transformBlock(Block& block, - bool ownsInputs) { - auto loc = block.getParent()->getLoc(); - auto ownershipTy = OwnershipIndicatorType::get(loc.getContext()); - // Introduce block arguments for the owned inputs. - breaks_if_you_move_ops::ValueSet ownedMemrefs; - if (ownsInputs) { - for (auto arg : llvm::to_vector( - llvm::make_filter_range(block.getArguments(), isMemref))) { - // Add an argument for a potentially owned memref. - auto newArg = block.addArgument(ownershipTy, loc); - ownedMemrefs.insert(newArg); - setOwnershipIndicator(arg, newArg); - } - } - - TransformResult blockResult; - for (auto& op : llvm::make_early_inc_range(block.without_terminator())) { - auto opResult = transformOp(&op, ownedMemrefs); - if (failed(opResult)) return failure(); - // Remove released memrefs. - for (auto v : opResult->released) { - auto owned = llvm::find(ownedMemrefs, v); - // If we don't own the released value, pass the release on to the parent. - if (owned == ownedMemrefs.end()) { - if (!blockResult.released.insert(v).second) { - block.getParentOp()->emitOpError("same value released twice"); - return failure(); - } - } else { - ownedMemrefs.erase(owned); - } - } - ownedMemrefs.insert(opResult->acquired.begin(), opResult->acquired.end()); - } - auto yieldedMemrefs = llvm::to_vector( - llvm::make_filter_range(block.getTerminator()->getOperands(), isMemref)); - - // Group yielded memrefs and owned memrefs by equivalence class leader. - auto groupByLeader = [&](auto& values) { - breaks_if_you_move_ops::ValueMap> result; - for (auto v : values) { - aliasOverapprox.insert(v); - result[aliasOverapprox.getLeaderValue(v)].push_back(v); - } - return result; - }; - auto yieldedByLeader = groupByLeader(yieldedMemrefs); - auto ownedByLeader = groupByLeader(ownedMemrefs); - - // Create one retain per equivalence class. - DenseSet alreadyRetained; - ImplicitLocOpBuilder b(loc, block.getTerminator()); - auto null = b.create(); - blockResult.acquired = - SmallVector(yieldedMemrefs.size(), null.getResult()); - for (auto [leader, yielded] : yieldedByLeader) { - auto& ownedGroup = ownedByLeader[leader]; - alreadyRetained.insert(ownedGroup.begin(), ownedGroup.end()); - if (yielded.size() == 1 && ownedGroup.size() == 1) { - auto oi = ownershipIndicator.find(yielded[0]); - if (oi != ownershipIndicator.end() && oi->second == ownedGroup.front()) { - blockResult.acquired[llvm::find(yieldedMemrefs, yielded.front()) - - yieldedMemrefs.begin()] = ownedGroup.front(); - continue; - } - } - SmallVector types(yielded.size(), ownershipTy); - auto retain = b.create(types, yielded, ownedGroup); - for (auto [retained, result] : llvm::zip(retain.getResults(), yielded)) { - aliasOverapprox.unionSets(retained, result); - blockResult.acquired[llvm::find(yieldedMemrefs, result) - - yieldedMemrefs.begin()] = retained; - } - } - if (!llvm::is_contained(blockResult.acquired, null.getResult())) null.erase(); - - // Handle owned memrefs that don't alias any yielded memref. - for (auto v : ownedMemrefs) { - if (!alreadyRetained.contains(v)) { - b.create(TypeRange{}, ValueRange{}, ValueRange{v}); - } - } - - return blockResult; -} - -FailureOr -Deallocator::transformIfImplicitCapture(scf::IfOp op, TransformResult& ifResult, - TransformResult& elseResult) { - if (ifResult.released == elseResult.released) { - return ifResult.released; - } - - auto fixAcquiredAlloc = [&](Value v, Region& region, - TransformResult& result) -> LogicalResult { - if (region.empty()) { - op.emitOpError("cannot implicitly capture from an if without else"); - return failure(); - } - auto* terminator = region.front().getTerminator(); - auto operands = terminator->getOperands(); - auto it = llvm::find_if(operands, [&](Value operand) { - return findOwnershipIndicator(operand) == v; - }); - if (it == operands.end()) { - op.emitOpError("released value not yielded on other branch"); - return failure(); - } - ownershipIndicator.erase(v); - - auto index = std::count_if(operands.begin(), it, isMemref); - result.acquired[index] = v; - return success(); - }; - - for (auto v : ifResult.released) { - if (!llvm::is_contained(elseResult.released, v)) { - if (failed(fixAcquiredAlloc(v, op.getElseRegion(), elseResult))) - return failure(); - } - } - for (auto v : elseResult.released) { - if (!llvm::is_contained(ifResult.released, v)) { - if (failed(fixAcquiredAlloc(v, op.getThenRegion(), ifResult))) - return failure(); - } - } - - breaks_if_you_move_ops::ValueSet released = ifResult.released; - released.insert(elseResult.released.begin(), elseResult.released.end()); - return released; -} - -FailureOr Deallocator::transformOp( - RegionBranchOpInterface op, - const breaks_if_you_move_ops::ValueSet& ownedMemrefs) { - SmallVector originalNumArgsByRegion; - SmallVector transformResultsByRegion; - transformResultsByRegion.reserve(op->getNumRegions()); - - bool mayImplicitlyCapture = llvm::isa(op); - for (auto [index, region] : llvm::enumerate(op->getRegions())) { - assert(region.getBlocks().size() <= 1 && - "expected regions to have at most one block"); - auto edges = getSuccessorRegions(op, region); - originalNumArgsByRegion.push_back(region.getNumArguments()); - - auto& result = transformResultsByRegion.emplace_back(); - if (region.empty()) continue; - - // Transform the block and collect acquired/released memrefs. - auto transformResultOrError = transformBlock(region.front()); - if (failed(transformResultOrError)) return failure(); - - result = *std::move(transformResultOrError); // NOLINT - if (!result.released.empty() && !mayImplicitlyCapture) { - // This error means that there's a realloc or free in a loop, and the op - // defining the value is outside the loop. This is not valid. To fix - // this, turn the argument of realloc/free into an iter arg. - op.emitOpError( - "can't implicitly capture across loop boundaries; use an " - "explicit iter arg instead"); - return failure(); - } - } - - breaks_if_you_move_ops::ValueSet released; - if (llvm::any_of(transformResultsByRegion, [](const TransformResult& result) { - return !result.released.empty(); - })) { - auto releasedByIf = transformIfImplicitCapture( - llvm::cast(op.getOperation()), transformResultsByRegion[0], - transformResultsByRegion[1]); - if (failed(releasedByIf)) return failure(); - released = *std::move(releasedByIf); // NOLINT - } - - // Adjust terminator operands. - for (auto [region, transformResult] : - llvm::zip(op->getRegions(), transformResultsByRegion)) { - if (region.empty()) continue; - auto* terminator = region.front().getTerminator(); - terminator->setOperands(terminator->getNumOperands(), 0, - transformResult.acquired); - } - - ImplicitLocOpBuilder b(op.getLoc(), op); - SmallVector operands = op->getOperands(); - Value null = nullptr; - // If we pass an owned memref to the loop and don't reuse it afterwards, we - // can transfer ownership. - for (auto operand : llvm::make_filter_range(operands, isMemref)) { - auto isLastUse = [&]() { - for (auto* candidate = op.getOperation(); candidate != nullptr; - candidate = candidate->getNextNode()) { - if (doesAlias(candidate, operand, aliasOverapprox, - /*considerOperands=*/candidate != op.getOperation())) { - return false; - } - } - return true; - }; - - Value ownershipIndicator = findOwnershipIndicator(operand); - if (ownershipIndicator && - !llvm::is_contained(released, ownershipIndicator) && - llvm::is_contained(ownedMemrefs, ownershipIndicator) && isLastUse()) { - // This is an alloc that is not used again, so we can pass ownership - // to the loop. - op->insertOperands(op->getNumOperands(), ownershipIndicator); - released.insert(ownershipIndicator); - } else { - // Either the operand is not an alloc or it's reused. - if (!null) null = b.create().getResult(); - op->insertOperands(op->getNumOperands(), null); - } - } - - RegionBranchOpInterface newOp = moveRegionsToNewOpButKeepOldOp(op); - auto numOriginalResults = op->getNumResults(); - auto newResults = newOp->getResults().take_front(numOriginalResults); - auto retained = newOp->getResults().drop_front(numOriginalResults); - op->replaceAllUsesWith(newResults); - op->erase(); - - for (auto [result, indicator] : - llvm::zip(llvm::make_filter_range(newOp->getResults(), isMemref), - newOp->getResults().drop_front(numOriginalResults))) { - setOwnershipIndicator(result, indicator); - } - - auto setupAliases = [&](RegionBranchPoint point) { - for (auto& region : getSuccessorRegions(newOp, point)) { - for (auto [pred, succ] : llvm::zip(region.getPredecessorOperands(), - region.getSuccessorValues())) { - aliasOverapprox.unionSets(pred, succ); - } - } - }; - auto setMemrefAliases = [this](ValueRange a, ValueRange b) { - for (auto [aa, bb] : llvm::zip(llvm::make_filter_range(a, isMemref), b)) { - aliasOverapprox.unionSets(aa, bb); - } - }; - setupAliases(RegionBranchPoint::parent()); - for (uint32_t i = 0; i < newOp->getNumRegions(); ++i) { - setupAliases(newOp->getRegion(i)); - auto args = newOp->getRegion(i).getArguments(); - auto n = originalNumArgsByRegion[i]; - setMemrefAliases(args.take_front(n), args.drop_front(n)); - } - setMemrefAliases(newResults, retained); - return TransformResult{released, retained}; -} - -// TODO(frgossen): Also allow passing ownership to functions. -FailureOr Deallocator::transformOp(func::CallOp op) { - ImplicitLocOpBuilder b(op.getLoc(), op); - - // Extend result types with ownership indicators. - SmallVector newResultTys(op.getResultTypes()); - int64_t numMemrefResults = llvm::count_if(op.getResults(), isMemref); - newResultTys.append( - SmallVector(numMemrefResults, b.getType())); - auto newOp = b.create(op.getCalleeAttr(), newResultTys, - op.getOperands()); - - // Follow the call graph and process the callee first to get accurate aliasing - // information. - auto callee = llvm::cast( - op->getParentOfType().lookupSymbol(op.getCallee())); - if (failed(transformFuncOp(callee))) return failure(); - - // Update ownership indicators and aliasing. - int64_t numResults = op.getNumResults(); - int64_t ownershipIndicatorIdx = numResults; - for (auto [result, resultAliasing] : - llvm::zip(newOp.getResults().take_front(numResults), - functionAliasOverapprox[callee])) { - if (!isMemref(result)) continue; - setOwnershipIndicator(result, newOp.getResult(ownershipIndicatorIdx++)); - for (int64_t i : resultAliasing) { - aliasOverapprox.unionSets(result, op.getOperand(i)); - } - } - - // Replace old op. - op.replaceAllUsesWith(newOp.getResults().take_front(numResults)); - op.erase(); - - // Collect ownership indicators. - auto retained = newOp->getResults().drop_front(numResults); - return TransformResult{{}, retained}; -} - -// Returns the set of values that are potentially owned by the op. -FailureOr Deallocator::transformOp( - Operation* op, const breaks_if_you_move_ops::ValueSet& ownedMemrefs) { - if (auto rbi = llvm::dyn_cast(op)) { - return transformOp(rbi, ownedMemrefs); - } - if (auto callOp = llvm::dyn_cast(op)) { - return transformOp(callOp); - } - - if (auto me = llvm::dyn_cast(op)) { - if (llvm::isa(op)) { - // Don't attempt to memory manage memref.alloca. - return TransformResult{}; - } - TransformResult result; - OpBuilder b(op->getContext()); - b.setInsertionPointAfter(op); - - SmallVector> allocs, - frees; - me.getEffects(allocs); - me.getEffects(frees); - if (!allocs.empty() || !frees.empty()) { - for (const auto& alloc : allocs) { - auto owned = b.create(op->getLoc(), alloc.getValue()); - setOwnershipIndicator(alloc.getValue(), owned); - result.acquired.push_back(owned); - } - for (const auto& free : frees) { - auto ownershipIndicator = findOwnershipIndicator(free.getValue()); - if (!ownershipIndicator) { - op->emitOpError("unable to find ownership indicator for operand"); - return failure(); - } - result.released.insert(ownershipIndicator); - } - return result; - } - } - - // Deallocate ops inside unknown op regions. - // Also assert that unknown ops with regions return no memrefs. There is no - // way to generically transform such ops, if they exist. Eventually we'll need - // an interface for this. - if (op->getNumRegions() > 0) { - assert(llvm::none_of(op->getResults(), isMemref)); - for (auto& region : op->getRegions()) { - for (auto& block : region.getBlocks()) { - auto transformedBlock = transformBlock(block, /*ownsInputs=*/false); - if (failed(transformedBlock)) return failure(); - if (!transformedBlock->acquired.empty() || - !transformedBlock->released.empty()) { - op->emitOpError("block unexpectededly released or returned an alloc"); - return failure(); - } - } - } - } - - // Assume any memref operand may alias any memref result. - for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) { - for (auto arg : llvm::make_filter_range(op->getOperands(), isMemref)) { - if (getElementTypeOrSelf(result.getType()) == - getElementTypeOrSelf(arg.getType())) { - aliasOverapprox.unionSets(result, arg); - } - } - } - // No new allocations or releases. - return TransformResult{}; -} - -#define GEN_PASS_DEF_DEALLOCATEPASS -#include "deallocation/transforms/passes.h.inc" - -struct DeallocatePass : public impl::DeallocatePassBase { - void runOnOperation() override { - ModuleOp moduleOp = getOperation(); - if (failed(Deallocator().transformModuleOp(moduleOp))) { - signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> createDeallocatePass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_simplification.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_simplification.cc deleted file mode 100644 index b140e8ff742b70..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_simplification.cc +++ /dev/null @@ -1,237 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "deallocation/IR/deallocation_ops.h" -#include "deallocation/transforms/passes.h" -#include "deallocation/utils/util.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace deallocation { -namespace { - -// Returns the value owned by the given ownership indicator. Returns null if it -// could not be determined. -Value getOwnedValue(Value v) { - ValueRange vals; - unsigned valueNum; - if (auto bbarg = v.dyn_cast()) { - vals = v.getParentBlock()->getArguments(); - valueNum = bbarg.getArgNumber(); - } else { - vals = v.getDefiningOp()->getResults(); - valueNum = v.cast().getResultNumber(); - } - - int64_t num = llvm::count_if(vals.take_front(valueNum), [](Value it) { - return it.getType().isa(); - }); - - auto memrefs = llvm::make_filter_range( - vals, [](Value it) { return it.getType().isa(); }); - - auto it = memrefs.begin(); - for (auto end = memrefs.end(); it != end && num > 0; ++it) { - --num; - } - if (it == memrefs.end()) return {}; - return *it; -} - -enum AllocNullability : uint32_t { - UNDEFINED = 0, - ALWAYS_NULL = 1, - NEVER_NULL = 2, - SOMETIMES_NULL = 3 -}; - -AllocNullability operator|=(AllocNullability& lhs, AllocNullability rhs) { - return lhs = static_cast(static_cast(lhs) | rhs); -} - -struct AllocInfo { - AllocNullability nullability; - // Set only if nullability is NEVER_NULL. - Value nonNullValue; -}; - -// Returns the nullability of `v`. `pending` contains a set of `Values` we're -// already considering in the computation of some value's nullability. It is -// assumed that we will eventually take the maximum (logical or) of all -// nullability values in this set. -AllocInfo getAllocNullabilityImpl(Value v, llvm::DenseSet& pending) { - if (llvm::isa_and_present(v.getDefiningOp())) { - return {NEVER_NULL, v.getDefiningOp()->getOperand(0)}; - } - - if (llvm::isa_and_present(v.getDefiningOp())) { - return {ALWAYS_NULL, {}}; - } - - if (auto retain = llvm::dyn_cast_or_null(v.getDefiningOp())) { - // We start with ALWAYS_NULL because a retain without any allocs is null. - // Also, because a retain with a non-null alloc can be null (otherwise, this - // would have been cleaned up by `retainNoOp`). - AllocNullability nullability = ALWAYS_NULL; - for (auto alloc : retain.getAllocs()) { - if (pending.insert(alloc).second) { - // We can ignore the non-null value here, since the final outcome won't - // be NEVER_NULL. - nullability |= getAllocNullabilityImpl(alloc, pending).nullability; - } - if (nullability == SOMETIMES_NULL) break; - } - return {nullability, {}}; - } - - // Returns the nullability of an operand in each of the region's predecessors. - auto getPredecessorNullability = - [&](RegionBranchOpInterface rbi, - RegionBranchPoint successorRegionPoint, - int64_t successorArgIndex) -> AllocInfo { - AllocNullability nullability = UNDEFINED; - for (const auto& pred : getPredecessorRegions(rbi, successorRegionPoint)) { - Value operand = pred.getPredecessorOperand(successorArgIndex); - // It is safe to skip values that are already being considered higher - // up in the call stack, because we end up taking the maximum of all - // nullability values. - if (pending.insert(operand).second) { - nullability |= getAllocNullabilityImpl(operand, pending).nullability; - } - if (nullability == SOMETIMES_NULL) break; - } - if (nullability == NEVER_NULL) { - return {NEVER_NULL, getOwnedValue(v)}; - } - return {nullability, {}}; - }; - - // If `v` is a block argument, check all incoming edges. - if (auto bbarg = v.dyn_cast()) { - if (auto rbi = llvm::dyn_cast( - bbarg.getParentRegion()->getParentOp())) { - return getPredecessorNullability( - rbi, bbarg.getParentRegion(), - bbarg.getArgNumber()); - } - } - - if (auto rbi = - llvm::dyn_cast_or_null(v.getDefiningOp())) { - return getPredecessorNullability(rbi, mlir::RegionBranchPoint::parent(), - llvm::cast(v).getResultNumber()); - } - - // Something we don't understand. - return {AllocNullability::SOMETIMES_NULL, {}}; -} - -bool allocIsNull(Value v) { - llvm::DenseSet pendingChecks; - return getAllocNullabilityImpl(v, pendingChecks).nullability == ALWAYS_NULL; -} - -// Returns true if the value is just passed around, but never really used. -bool valueIsUnused(Value value) { - llvm::DenseSet pendingChecks; - std::function checkValue; - std::function checkUser; - - checkUser = [&](OpOperand& user) -> bool { - RegionBranchPoint regionPoint = mlir::RegionBranchPoint::parent(); - auto rbi = llvm::dyn_cast(user.getOwner()); - if (user.getOwner()->mightHaveTrait()) { - rbi = llvm::dyn_cast( - user.getOwner()->getParentOp()); - regionPoint = user.getOwner()->getParentRegion(); - } - return rbi && llvm::all_of(getSuccessorRegions(rbi, regionPoint), - [&](const RegionEdge& edge) { - return checkValue(edge.getSuccessorValue( - user.getOperandNumber())); - }); - }; - checkValue = [&](Value value) { - if (!pendingChecks.insert(value).second) return true; - return llvm::all_of(value.getUses(), checkUser); - }; - - return checkValue(value); -} - -#define GEN_PASS_DEF_DEALLOCATIONSIMPLIFICATIONPASS -#include "deallocation/transforms/passes.h.inc" - -struct DeallocationSimplificationPass - : public impl::DeallocationSimplificationPassBase< - DeallocationSimplificationPass> { - void runOnOperation() override { - getOperation()->walk([](RetainOp op) { - OpBuilder b(op); - // If all allocs are null, the result is null and there is nothing to - // deallocate. - if (llvm::all_of(op.getAllocs(), allocIsNull)) { - auto null = b.create(op.getLoc()); - auto nulls = llvm::SmallVector(op.getNumResults(), null); - op.replaceAllUsesWith(nulls); - op.erase(); - return; - } - - if (op.getRetained().empty() && op.getAllocs().size() == 1) { - llvm::DenseSet pendingChecks; - auto nullability = - getAllocNullabilityImpl(op.getAllocs()[0], pendingChecks); - if (nullability.nullability != NEVER_NULL || - !nullability.nonNullValue) { - return; - } - - b.setInsertionPoint(op); - b.create(op.getLoc(), nullability.nonNullValue); - op.erase(); - } - }); - getOperation()->walk([](OwnOp op) { - if (op.use_empty()) { - op.erase(); - } else if (valueIsUnused(op.getResult())) { - OpBuilder b(op); - op.replaceAllUsesWith(b.create(op.getLoc()).getResult()); - op.erase(); - } - }); - } -}; - -} // namespace - -std::unique_ptr> -createDeallocationSimplificationPass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_to_scf.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_to_scf.cc deleted file mode 100644 index b6e82d4d4c3c99..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/deallocation_to_scf.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "deallocation/IR/deallocation_ops.h" -#include "deallocation/transforms/passes.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace deallocation { -namespace { - -#define GEN_PASS_DEF_DEALLOCATIONTOSCFPASS -#include "deallocation/transforms/passes.h.inc" - -LogicalResult rewriteRetain(RetainOp op, PatternRewriter& rewriter) { - assert(!op.getAllocs().empty() && "run canonicalization first"); - - if (op.getRetained().size() != 1 && op.getAllocs().size() != 1) { - return rewriter.notifyMatchFailure( - op, "this retain needs to be lowered to a library call"); - } - // Note: The generated code has size O(|`allocs`| * |`retains`|). If there are - // cases where this gets too big, we should lower it to a library call - // instead. - - auto loc = op.getLoc(); - - // Get the buffers of all `alloc` values. - SmallVector remainingBuffersAndResult; - for (Value alloc : op.getAllocs()) { - remainingBuffersAndResult.push_back(alloc); - } - llvm::copy(llvm::map_range(op.getAllocs(), - [&](Value alloc) -> Value { - return rewriter.create( - loc, rewriter.getIndexType(), alloc); - }), - std::back_inserter(remainingBuffersAndResult)); - remainingBuffersAndResult.push_back({}); - - Value null = rewriter.create(loc); - auto zero = rewriter.create(loc, 0); - SmallVector results; - - size_t nAllocs = op.getAllocs().size(); - for (auto [retainedIndex, retained] : llvm::enumerate(op.getRetained())) { - auto retainedBuffer = - rewriter.create(loc, rewriter.getIndexType(), retained); - - remainingBuffersAndResult.back() = null; - for (auto allocIndex : llvm::seq(0, nAllocs)) { - auto isSame = rewriter.create( - loc, arith::CmpIPredicate::eq, retainedBuffer, - remainingBuffersAndResult[nAllocs + allocIndex]); - - // If the buffers are the same, remove the alloc from consideration for - // future `retained` values. - SmallVector yieldedIfSame{null, zero, - remainingBuffersAndResult[allocIndex]}; - SmallVector yieldedIfDifferent{ - remainingBuffersAndResult[allocIndex], - remainingBuffersAndResult[allocIndex + nAllocs], - remainingBuffersAndResult.back()}; - - auto ifOp = - rewriter.create(loc, TypeRange{ValueRange{yieldedIfSame}}, - isSame, /*withElseRegion=*/true); - ifOp.getThenBodyBuilder().create(loc, yieldedIfSame); - - // Otherwise, keep the current results. - ifOp.getElseBodyBuilder().create(loc, yieldedIfDifferent); - - remainingBuffersAndResult[allocIndex] = ifOp.getResult(0); - remainingBuffersAndResult[allocIndex + nAllocs] = ifOp.getResult(1); - remainingBuffersAndResult.back() = ifOp.getResult(2); - } - - results.push_back(remainingBuffersAndResult.back()); - } - - // Deallocate any remaining buffers. - for (auto index : llvm::seq(0, nAllocs)) { - auto nonZero = rewriter.create( - loc, arith::CmpIPredicate::ne, - remainingBuffersAndResult[index + nAllocs], zero); - rewriter.create( - loc, nonZero, [&](OpBuilder& thenBuilder, Location loc) { - thenBuilder.create(loc, remainingBuffersAndResult[index]); - thenBuilder.create(loc); - }); - } - - rewriter.replaceOp(op, results); - - return success(); -} - -struct DeallocationToScfPass - : public impl::DeallocationToScfPassBase { - void runOnOperation() override { - MLIRContext* ctx = &getContext(); - RewritePatternSet patterns(ctx); - patterns.add(rewriteRetain); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createDeallocationToScfPass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/debug_passes.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/debug_passes.cc deleted file mode 100644 index 9900a74b9e22b9..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/debug_passes.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include - -#include "deallocation/transforms/analysis.h" -#include "deallocation/transforms/passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Operation.h" - -namespace mlir { -namespace deallocation { -namespace { - -#define GEN_PASS_DEF_ANNOTATEDEALLOCATIONPASS -#include "deallocation/transforms/passes.h.inc" - -std::string getDebugString(AsmState& state, DeallocationAnalysis& analysis, - Value value) { - std::string out; - llvm::raw_string_ostream os(out); - llvm::interleaveComma(analysis.getBackingMemory(value), os, - [&](Value v) { v.printAsOperand(os, state); }); - return out; -} - -Attribute getDebugAttribute(AsmState& state, DeallocationAnalysis& analysis, - Region& region) { - mlir::OpBuilder b(region.getContext()); - return b.getArrayAttr(llvm::to_vector( - llvm::map_range(region.getArguments(), [&](Value arg) -> Attribute { - return b.getStringAttr(getDebugString(state, analysis, arg)); - }))); -} - -struct AnnotatePass : public impl::AnnotateDeallocationPassBase { - void runOnOperation() override { - DeallocationAnalysis analysis; - AsmState state(getOperation()); - mlir::OpBuilder b(getOperation()); - getOperation().walk([&](Operation* op) { - std::string out; - llvm::raw_string_ostream os(out); - if (op->getNumRegions() > 0) { - op->setAttr("deallocation.region_args_backing_memory", - b.getArrayAttr(llvm::to_vector( - llvm::map_range(op->getRegions(), [&](Region& region) { - return getDebugAttribute(state, analysis, region); - })))); - } - - if (op->getNumResults() > 0) { - op->setAttr("deallocation.result_backing_memory", - b.getArrayAttr(llvm::to_vector(llvm::map_range( - op->getResults(), [&](Value result) -> Attribute { - return b.getStringAttr( - getDebugString(state, analysis, result)); - })))); - } - }); - } -}; - -} // namespace - -// Pass to annotate ops with debug information. -std::unique_ptr> -createDeallocationAnnotationPass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.h b/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.h index 183082fc54e04d..4ba90fed2eae98 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.h @@ -26,42 +26,11 @@ limitations under the License. namespace mlir { namespace deallocation { -// Pass to split bufferization.alloc_tensor ops to optimize buffer reuse. -std::unique_ptr> -createSplitAllocTensorsPass(); - -// Pass to insert deallocations (in the form of `deallocation.retain`) ops. Most -// deallocations are typically converted to `memref.dealloc` by -// canonicalization. -std::unique_ptr> createDeallocatePass(); - -// Pass to annotate ops with debug information. -std::unique_ptr> -createDeallocationAnnotationPass(); - -// Pass to annotate buffer arguments with aliasing information. -std::unique_ptr> -createXlaBufferArgRewritePass(); - // Pass to reuse buffers (hoisting, double buffering, dealloc/alloc // coalescing). std::unique_ptr> createBufferReusePass(); -// Lowers retain to SCF. -std::unique_ptr> -createDeallocationToScfPass(); - -// Convert `deallocation` ops to LLVM. -std::unique_ptr> -createConvertDeallocationOpsToLLVM(); - -std::unique_ptr> -createDeallocationSimplificationPass(); - -void populateDeallocationToLLVMConversionPatterns(LLVMTypeConverter& converter, - RewritePatternSet& patterns); - #define GEN_PASS_REGISTRATION #include "deallocation/transforms/passes.h.inc" diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.td b/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.td index cc9ce1e5e44974..34dc1f56be6c49 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/passes.td @@ -15,63 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def SplitAllocTensorsPass : - Pass<"hlo-split-alloc-tensors", "mlir::func::FuncOp"> { - let summary = "Split bufferization.alloc_tensor ops."; - let description = [{ - `bufferization.alloc_tensor` ops that are reused for non-conflicting ops - prevent buffer reuse. This pass replaces each use of an alloc tensor with a - fresh alloc. - }]; - let constructor = "::mlir::deallocation::createSplitAllocTensorsPass()"; - let dependentDialects = ["::mlir::bufferization::BufferizationDialect"]; -} - -def AnnotateDeallocationPass : - Pass<"hlo-deallocation-annotation", "mlir::ModuleOp"> { - let summary = "Annotate ops with deallocation debug information."; - let description = [{ - Adds attributes to annotate ops with debug information about the - deallocation analysis. - }]; - let constructor = "::mlir::deallocation::createDeallocationAnnotationPass()"; -} - -def DeallocatePass : Pass<"hlo-deallocate", "mlir::ModuleOp"> { - let summary = "Deallocate buffers by inserting `deallocation.retain` ops."; - let description = [{ - Inserts deallocations (in the form of `deallocation.retain`) ops. Most - deallocations are typically converted to `memref.dealloc` by - canonicalization. - }]; - let constructor = "::mlir::deallocation::createDeallocatePass()"; - let dependentDialects = ["::mlir::deallocation::DeallocationDialect"]; -} - -def DeallocationSimplificationPass : Pass<"hlo-deallocation-simplification", - "mlir::func::FuncOp"> { - let summary = "Simplifies deallocation.retain ops."; - let constructor = "::mlir::deallocation::createDeallocationSimplificationPass()"; - let dependentDialects = ["::mlir::deallocation::DeallocationDialect"]; -} - -def XlaBufferArgRewritePass : - Pass<"hlo-xla-buffer-arg-rewrite", "mlir::func::FuncOp"> { - let summary = "Rewrites XLA framework buffer arguments with alias information"; - let description = [{ - In the presence of variables, some results of the main function will alias - other parameters. This pass rewrites the main function to annotate results - for which this isn't the case with the `deallocation.restrict` attribute, - indicating that they do not alias with any other buffer and allowing the - buffer-reuse pass to optimize them. - - The pass uses attributes present in XLA programs - (`xla_framework.input_mapping`, `xla_framework.result_mapping` and - `xla_framework.result_inner_mapping`, specifically). - }]; - let constructor = "::mlir::deallocation::createXlaBufferArgRewritePass()"; -} - def BufferReusePass : Pass<"hlo-buffer-reuse", "mlir::func::FuncOp"> { let summary = "Reuse buffers."; let description = [{ @@ -93,23 +36,3 @@ def BufferReusePass : Pass<"hlo-buffer-reuse", "mlir::func::FuncOp"> { let dependentDialects = ["::mlir::memref::MemRefDialect"]; } -def ConvertDeallocationOpsToLLVMPass - : Pass<"hlo-convert-deallocation-ops-to-llvm", "mlir::func::FuncOp"> { - let summary = "Convert `deallocation` ops to LLVM"; - let constructor = "::mlir::deallocation::createConvertDeallocationOpsToLLVM()"; - let dependentDialects = [ - "::mlir::LLVM::LLVMDialect", - "::mlir::memref::MemRefDialect", - ]; -} - -def DeallocationToScfPass : Pass<"hlo-deallocation-to-scf", - "mlir::func::FuncOp"> { - let summary = "Lowers retain to scf."; - let constructor = "::mlir::deallocation::createDeallocationToScfPass()"; - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::scf::SCFDialect", - "::mlir::memref::MemRefDialect", - ]; -} diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/split_alloc_tensors.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/split_alloc_tensors.cc deleted file mode 100644 index 0d0bf7a6a790fd..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/split_alloc_tensors.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "deallocation/transforms/passes.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace deallocation { -namespace { - -void splitAllocTensors(Block& block) { - for (auto& op : block) { - for (auto [index, operand] : llvm::enumerate(op.getOperands())) { - auto* definingOp = operand.getDefiningOp(); - if (llvm::isa_and_nonnull(definingOp)) { - op.setOperand(index, OpBuilder(&op).clone(*definingOp)->getResult(0)); - } - } - - for (auto& region : op.getRegions()) { - for (auto& block : region.getBlocks()) { - splitAllocTensors(block); - } - } - } - - for (auto& op : llvm::make_early_inc_range(block)) { - if (llvm::isa(op) && op.use_empty()) { - op.erase(); - } - } -} - -#define GEN_PASS_DEF_SPLITALLOCTENSORSPASS -#include "deallocation/transforms/passes.h.inc" - -struct SplitAllocTensorsPass - : public impl::SplitAllocTensorsPassBase { - void runOnOperation() override { - splitAllocTensors(getOperation().getBody().front()); - } -}; - -} // namespace - -std::unique_ptr> -createSplitAllocTensorsPass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/xla_buffer_arg_rewrite.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/xla_buffer_arg_rewrite.cc deleted file mode 100644 index 919308910c27d8..00000000000000 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/xla_buffer_arg_rewrite.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "deallocation/transforms/passes.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/TypeUtilities.h" - -namespace mlir { -namespace deallocation { -namespace { - -#define GEN_PASS_DEF_XLABUFFERARGREWRITEPASS -#include "deallocation/transforms/passes.h.inc" - -constexpr char kInputMapping[] = "xla_framework.input_mapping"; -constexpr char kResultMapping[] = "xla_framework.result_mapping"; -constexpr char kResultInnerMapping[] = "xla_framework.result_inner_mapping"; - -struct XlaBufferArgRewritePass - : public impl::XlaBufferArgRewritePassBase { - void runOnOperation() override { - func::FuncOp op = getOperation(); - if (!op->hasAttr(kResultMapping)) return; - - // Collect result arguments and input arguments. - auto results = llvm::to_vector( - llvm::make_filter_range(op.getArguments(), [&](auto arg) { - return op.getArgAttr(arg.getArgNumber(), kInputMapping) == nullptr; - })); - auto args = - llvm::to_vector(llvm::map_range(op.getArguments(), [&](auto arg) { - auto buffer = op.getArgAttrOfType(arg.getArgNumber(), - kInputMapping); - return buffer ? buffer.getInt() : -1; - })); - - SmallVector resultMapping; - if (auto innerMapping = op->getAttrOfType(kResultInnerMapping)) { - resultMapping = llvm::to_vector(llvm::map_range( - innerMapping.getAsValueRange(), - [](const APInt& value) { return value.getSExtValue(); })); - } else if (auto mapping = op->getAttrOfType(kResultMapping)) { - resultMapping = {mapping.getInt()}; - } - - if (resultMapping.size() != results.size()) { - op.emitOpError( - "number of result arguments does not match size of mapping"); - signalPassFailure(); - return; - } - - for (auto [bufferIndex, result] : llvm::zip(resultMapping, results)) { - // If the result doesn't alias any argument, add the - // `deallocation.restrict` attribute to signal to the buffer reuse pass - // that this buffer is guaranteed not to alias any other argument. - if (!llvm::is_contained(args, bufferIndex)) { - op.setArgAttr(result.getArgNumber(), "deallocation.restrict", - OpBuilder(op).getBoolAttr(true)); - } - } - } -}; - -} // namespace - -std::unique_ptr> -createXlaBufferArgRewritePass() { - return std::make_unique(); -} - -} // namespace deallocation -} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/analysis.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/analysis.mlir deleted file mode 100644 index fb4d520a8d6888..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/analysis.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --allow-unregistered-dialect \ -// RUN: --hlo-deallocation-annotation | \ -// RUN: FileCheck %s - -func.func @loop_nested_alloc( - %lb: index, %ub: index, %step: index, - %buf: memref<2xf32>, %res: memref<2xf32>) { - // CHECK-LABEL: func.func @loop_nested_alloc - // CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, - // CHECK-SAME: %[[BUF:.*]]: memref<2xf32>, %[[RES:.*]]: memref<2xf32>) - // CHECK-SAME: attributes {deallocation.region_args_backing_memory = {{\[\[}} - // CHECK-SAME: "", "", "", "%[[BUF]], %[[RES]]", "%[[BUF]], %[[RES]]"]]} { - %0 = memref.alloc() : memref<2xf32> - // CHECK: %[[ALLOC1:.*]] = memref.alloc() - // CHECK-SAME: {deallocation.result_backing_memory = ["%[[ALLOC1]]"]} : memref<2xf32> - %1 = scf.for %i = %lb to %ub step %step - iter_args(%iterBuf = %buf) -> memref<2xf32> { - // CHECK: %[[FOR1:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] - // CHECK-SAME: iter_args(%[[ITER_BUF:.*]] = %[[BUF]]) -> (memref<2xf32>) - %2 = scf.for %j = %lb to %ub step %step - iter_args(%iterBuf2 = %iterBuf) -> memref<2xf32> { - // CHECK: %[[FOR2:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] - // CHECK-SAME: iter_args(%[[ITER_BUF2:.*]] = %[[ITER_BUF]]) -> (memref<2xf32>) - %3 = memref.alloc() : memref<2xf32> - // CHECK: %[[ALLOC2:.*]] = memref.alloc() - %4 = arith.cmpi eq, %i, %ub : index - // CHECK: arith.cmpi - %5 = scf.if %4 -> (memref<2xf32>) { - // CHECK: %[[IF:.*]] = scf.if - %6 = memref.alloc() : memref<2xf32> - // CHECK: %[[ALLOC3:.*]] = memref.alloc() - scf.yield %6 : memref<2xf32> - // CHECK: scf.yield %[[ALLOC3]] - } else { - scf.yield %iterBuf2 : memref<2xf32> - // CHECK: scf.yield %[[ITER_BUF2]] - } - scf.yield %5 : memref<2xf32> - // CHECK: scf.yield %[[IF]] - } - scf.yield %2 : memref<2xf32> - // CHECK: scf.yield %[[FOR2]] - } - // CHECK: } {deallocation.region_args_backing_memory = {{\[\[}}"", "%[[ITER_BUF]], %[[ITER_BUF2]], %[[BUF]], %[[RES]], %[[ALLOC3]]"]], - // CHECK-SAME: deallocation.result_backing_memory = ["%[[ITER_BUF]], %[[ITER_BUF2]], %[[BUF]], %[[RES]], %[[ALLOC3]]"]} - memref.copy %1, %res : memref<2xf32> to memref<2xf32> - return -} - -// ----- - -func.func @arith_select() -> (memref, memref) { - %cond = "test.make_condition"() : () -> (i1) - %a = memref.alloc() : memref - %b = memref.alloc() : memref - %c = arith.select %cond, %a, %b : memref - return %a, %c : memref, memref -} - -// CHECK-LABEL: @arith_select -// CHECK: %[[COND:.*]] = "test.make_condition" -// CHECK: %[[A:.*]] = memref.alloc -// CHECK: %[[B:.*]] = memref.alloc -// CHECK: %[[C:.*]] = arith.select %[[COND]], %[[A]], %[[B]] -// CHECK-SAME: result_backing_memory = ["%[[A]], %[[B]]"] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/buffer_reuse.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/buffer_reuse.mlir index b7eb2bfb1145da..19c0bd67274c8c 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/buffer_reuse.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/buffer_reuse.mlir @@ -206,99 +206,6 @@ func.func @double_buffer_while_both(%lb: index, %ub: index, %step: index) { // ----- -func.func @simplify_loop_dealloc() { - %a = memref.alloc() : memref - %a_owned = deallocation.own %a : memref - %b = memref.alloc() : memref - %b_owned = deallocation.own %b : memref - %c = memref.alloc() : memref - %c_owned = deallocation.own %c : memref - %w:6 = scf.while (%arg0 = %a, %arg1 = %b, %arg2 = %c, %arg3 = %a_owned, %arg4 = %b_owned, %arg5 = %c_owned) - : (memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership) -> - (memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership) { - %cond = "test.make_condition"() : () -> i1 - scf.condition(%cond) %arg2, %arg1, %arg0, %arg5, %arg4, %arg3 - : memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership - } do { - ^bb0(%arg0: memref, %arg1: memref, %arg2: memref, - %arg3: !deallocation.ownership, %arg4: !deallocation.ownership, %arg5: !deallocation.ownership): - scf.yield %arg1, %arg0, %arg2, %arg4, %arg3, %arg5 - : memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership - } - memref.dealloc %w#0 : memref - memref.dealloc %w#1 : memref - memref.dealloc %w#2 : memref - return -} - -// CHECK-LABEL: @simplify_loop_dealloc -// CHECK: memref.alloca -// CHECK: memref.alloca -// CHECK: memref.alloca -// CHECK-NOT: memref.alloc -// CHECK-NOT: memref.dealloc - -// ----- - -func.func @hoist_always_reallocated() { - %a = memref.alloc() : memref - %b = deallocation.own %a : memref - %w:3 = scf.while(%arg0 = %a, %arg1 = %b) - : (memref, !deallocation.ownership) - -> (i32, memref, !deallocation.ownership) { - %cond = "test.make_condition"() : () -> i1 - %v = "test.dummy"() : () -> i32 - memref.dealloc %arg0 : memref - %0 = memref.alloc() : memref - %1 = deallocation.own %0 : memref - scf.condition (%cond) %v, %0, %1 : i32, memref, !deallocation.ownership - } do { - ^bb0(%_: i32, %arg0: memref, %arg1 : !deallocation.ownership): - memref.dealloc %arg0 : memref - %0 = memref.alloc() : memref - %1 = deallocation.own %0 : memref - scf.yield %0, %1 : memref, !deallocation.ownership - } - memref.dealloc %w#1 : memref - return -} - -// CHECK-LABEL: @hoist_always_reallocated -// CHECK-NEXT: memref.alloca -// CHECK-NEXT: deallocation.null -// CHECK-NEXT: scf.while -// CHECK-NOT: memref.alloc - -// ----- - -func.func @hoist_passthrough() { - %a = memref.alloc() : memref - %b = deallocation.own %a : memref - %w:3 = scf.while(%arg0 = %a, %arg1 = %b) - : (memref, !deallocation.ownership) - -> (i32, memref, !deallocation.ownership) { - %cond = "test.make_condition"() : () -> i1 - %v = "test.dummy"() : () -> i32 - memref.dealloc %arg0 : memref - %0 = memref.alloc() : memref - %1 = deallocation.own %0 : memref - scf.condition (%cond) %v, %0, %1 : i32, memref, !deallocation.ownership - } do { - ^bb0(%_: i32, %arg0: memref, %arg1: !deallocation.ownership): - scf.yield %arg0, %arg1 : memref, !deallocation.ownership - } - memref.dealloc %w#1 : memref - return -} - -// CHECK-LABEL: @hoist_passthrough -// CHECK-NEXT: memref.alloca -// CHECK-NEXT: deallocation.null -// CHECK-NEXT: scf.while -// CHECK-NOT: memref.alloc - -// ----- - func.func @allocs_in_different_scopes_with_no_overlap() { %alloc0 = memref.alloc() : memref<4xi32> "test.use"(%alloc0) : (memref<4xi32>) -> () diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/convert_deallocation_ops_to_llvm.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/convert_deallocation_ops_to_llvm.mlir deleted file mode 100644 index 121b9f346cd67f..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/convert_deallocation_ops_to_llvm.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: mlir-hlo-opt -hlo-convert-deallocation-ops-to-llvm %s \ -// RUN: -split-input-file | FileCheck %s - -// CHECK-LABEL: func.func @null() -func.func @null() -> !deallocation.ownership { - %null = deallocation.null - func.return %null : !deallocation.ownership -} -// CHECK: %[[NULL:.*]] = llvm.mlir.zero : !llvm.ptr -// CHECK: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[NULL]] -// CHECK: return %[[RET]] - -// ----- - -// CHECK-LABEL: func.func @memref_get_buffer -func.func @memref_get_buffer(%arg0: memref<2x?xf32>) -> index { - %ret = deallocation.get_buffer %arg0 : memref<2x?xf32> - return %ret : index -} - -// CHECK-NEXT: builtin.unrealized_conversion_cast -// CHECK-NEXT: llvm.extractvalue -// CHECK-NEXT: llvm.ptrtoint - -// ----- - -// CHECK-LABEL: func.func @ownership_get_buffer -func.func @ownership_get_buffer(%arg0: !deallocation.ownership) -> index { - %ret = deallocation.get_buffer %arg0 : !deallocation.ownership - return %ret : index -} - -// CHECK-NEXT: builtin.unrealized_conversion_cast -// CHECK-NEXT: llvm.ptrtoint - -// ----- - -// CHECK-LABEL: func.func @own( -func.func @own(%arg0: memref<2x?xf32>) -> !deallocation.ownership { - %ret = deallocation.own %arg0 : memref<2x?xf32> - return %ret : !deallocation.ownership -} - -// CHECK-NEXT: builtin.unrealized_conversion_cast -// CHECK-NEXT: llvm.extractvalue -// CHECK-NEXT: builtin.unrealized_conversion_cast - -// ----- - -func.func @freeAlloc(%arg0: !deallocation.ownership) { - deallocation.free %arg0 - return -} - -// CHECK: @freeAlloc -// CHECK-NEXT: builtin.unrealized_conversion_cast -// CHECK-NEXT: llvm.call @free - -// ----- - -func.func @retain_multiple(%arg0: memref, %arg1: memref, - %arg2: !deallocation.ownership, %arg3: !deallocation.ownership) - -> (!deallocation.ownership, !deallocation.ownership) { - %ret:2 = deallocation.retain(%arg0, %arg1) of (%arg2, %arg3) - : (memref, memref, !deallocation.ownership, !deallocation.ownership) - -> (!deallocation.ownership, !deallocation.ownership) - return %ret#0, %ret#1 : !deallocation.ownership, !deallocation.ownership -} - -// CHECK-LABEL: @retain_multiple -// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref -// CHECK-SAME: %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: -// CHECK: memref.alloca_scope -// CHECK: llvm.alloca -// CHECK: llvm.alloca -// CHECK: call @retainBuffers -// CHECK: memref.alloca_scope.return diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate.mlir deleted file mode 100644 index 83a03c9853d062..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate.mlir +++ /dev/null @@ -1,928 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --allow-unregistered-dialect \ -// RUN: --hlo-deallocate | \ -// RUN: FileCheck %s - -// RUN: mlir-hlo-opt %s --split-input-file --allow-unregistered-dialect \ -// RUN: --hlo-deallocate --hlo-deallocation-simplification | \ -// RUN: FileCheck %s --check-prefix=CHECK-SIMPLE - -func.func @loop_nested_alloc( - %lb: index, %ub: index, %step: index, - %buf: memref<2xf32>, %res: memref<2xf32>) { - %0 = memref.alloc() : memref<2xf32> - %1 = scf.for %i = %lb to %ub step %step - iter_args(%iterBuf = %buf) -> memref<2xf32> { - %2 = scf.for %i2 = %lb to %ub step %step - iter_args(%iterBuf2 = %iterBuf) -> memref<2xf32> { - %3 = memref.alloc() : memref<2xf32> - %4 = arith.cmpi eq, %i, %ub : index - %5 = scf.if %4 -> (memref<2xf32>) { - %6 = memref.alloc() : memref<2xf32> - scf.yield %6 : memref<2xf32> - } else { - scf.yield %iterBuf2 : memref<2xf32> - } - scf.yield %5 : memref<2xf32> - } - scf.yield %2 : memref<2xf32> - } - memref.copy %1, %res : memref<2xf32> to memref<2xf32> - return -} - -// CHECK-LABEL: func @loop_nested_alloc -// CHECK-SAME: %[[ARG3:[a-z0-9]*]]: memref<2xf32>, %[[OUT:.*]]: memref<2xf32>) -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2xf32> -// CHECK: %[[ALLOC_OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK: %[[ARG3_UNOWNED:.*]] = deallocation.null -// CHECK: %[[FOR1:.*]]:2 = scf.for {{.*}}iter_args(%[[A:.*]] = %[[ARG3]], %[[A_OWNERSHIP:.*]] = %[[ARG3_UNOWNED]]) -// CHECK: %[[FOR2:.*]]:2 = scf.for {{.*}} iter_args(%[[B:.*]] = %[[A]], %[[B_OWNERSHIP:.*]] = %[[A_OWNERSHIP]]) -// CHECK: %[[ALLOC2:.*]] = memref.alloc() : memref<2xf32> -// CHECK: %[[ALLOC2_OWNED:.*]] = deallocation.own %[[ALLOC2]] -// CHECK: %[[IF:.*]]:2 = scf.if -// CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<2xf32> -// CHECK: %[[ALLOC3_OWNED:.*]] = deallocation.own %[[ALLOC3]] -// CHECK: scf.yield %[[ALLOC3]], %[[ALLOC3_OWNED]] -// CHECK: } else { -// CHECK: %[[NULL:.*]] = deallocation.retain(%[[B]]) of() -// CHECK: scf.yield %[[B]], %[[NULL]] -// CHECK: } -// CHECK: %[[RETAINED_IF:.*]] = deallocation.retain(%[[IF]]#0) of(%[[B_OWNERSHIP]], %[[IF]]#1) -// CHECK: deallocation.retain() of(%[[ALLOC2_OWNED]]) -// CHECK: scf.yield %[[IF]]#0, %[[RETAINED_IF]] -// CHECK: } -// CHECK: scf.yield %[[FOR2]]#0, %[[FOR2]]#1 -// CHECK: } -// CHECK: memref.copy %[[FOR1]]#0, %[[OUT]] -// CHECK: deallocation.retain() of(%[[ALLOC_OWNED]]) -// CHECK: deallocation.retain() of(%[[FOR1]]#1) -// CHECK: return - -// ----- - -func.func @nested_if() -> (memref<2xf32>, memref<2xf32>) { - %alloc_0 = memref.alloc() : memref<2xf32> - %alloc_1 = memref.alloc() : memref<2xf32> - %a = "test.condition"() : () -> i1 - %0 = scf.if %a -> (memref<2xf32>) { - %2 = memref.alloc() : memref<2xf32> - scf.yield %2 : memref<2xf32> - } else { - %b = "test.condition"() : () -> i1 - %3 = scf.if %b -> (memref<2xf32>) { - scf.yield %alloc_0 : memref<2xf32> - } else { - scf.yield %alloc_1 : memref<2xf32> - } - scf.yield %3 : memref<2xf32> - } - return %alloc_0, %0 : memref<2xf32>, memref<2xf32> -} - -// CHECK-LABEL: func @nested_if -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK: %[[ALLOC0_OWNED:.*]] = deallocation.own %[[ALLOC0]] -// CHECK: %[[ALLOC1:.*]] = memref.alloc() -// CHECK: %[[ALLOC1_OWNED:.*]] = deallocation.own %[[ALLOC1]] -// CHECK: %[[IF1:.*]]:2 = scf.if -// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: %[[ALLOC2_OWNED:.*]] = deallocation.own %[[ALLOC2]] -// CHECK-NEXT: scf.yield %[[ALLOC2]], %[[ALLOC2_OWNED]] -// CHECK-NEXT: } else { -// CHECK: %[[IF2:.*]]:2 = scf.if -// CHECK-NEXT: %[[NULL:.*]] = deallocation.retain(%[[ALLOC0]]) of() -// CHECK-NEXT: scf.yield %[[ALLOC0]], %[[NULL]] -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[NULL:.*]] = deallocation.retain(%[[ALLOC1]]) of() -// CHECK-NEXT: scf.yield %[[ALLOC1]], %[[NULL]] -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[IF2]]#0, %[[IF2]]#1 -// CHECK-NEXT: } -// CHECK-NEXT: %[[RETAINED:.*]]:2 = deallocation.retain(%[[ALLOC0]], %[[IF1]]#0) of(%[[ALLOC0_OWNED]], %[[ALLOC1_OWNED]], %[[IF1]]#1) -// CHECK-NEXT: return %[[ALLOC0]], %[[IF1]]#0, %[[RETAINED]]#0, %[[RETAINED]]#1 : memref<2xf32>, memref<2xf32>, !deallocation.ownership, !deallocation.ownership - -// ----- - -func.func @while(%arg0: index) -> (memref, memref, memref) { - %a = memref.alloc(%arg0) : memref - %w:3 = scf.while (%arg1 = %a, %arg2 = %a, %arg3 = %a) : (memref, memref, memref) - -> (memref, memref, memref) { - %0 = "test.make_condition"() : () -> i1 - scf.condition(%0) %arg1, %arg2, %arg3 : memref, memref, memref - } do { - ^bb0(%arg1: memref, %arg2: memref, %arg3: memref): - %b = memref.alloc(%arg0) : memref - %q = memref.alloc(%arg0) : memref - scf.yield %q, %b, %arg2: memref, memref, memref - } - return %w#0, %w#1, %w#2 : memref, memref, memref -} - -// CHECK-LABEL: func @while( -// CHECK-SAME: %[[ARG0:.*]]: -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%arg0) : memref -// CHECK-NEXT: %[[ALLOC_OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NEXT: %[[NULL:.*]] = deallocation.null -// CHECK-NEXT: %[[WHILE:.*]]:6 = scf.while (%[[A:[a-z0-9]*]] = %[[ALLOC]], %[[B:[a-z0-9]*]] = %[[ALLOC]], %[[C:[a-z0-9]*]] = %[[ALLOC]], -// CHECK-SAME: %[[A_OWNERSHIP:.*]] = %[[ALLOC_OWNED]], %[[B_OWNERSHIP:.*]] = %[[NULL]], %[[C_OWNERSHIP:.*]] = %[[NULL]]) -// CHECK: scf.condition{{.*}} %[[A]], %[[B]], %[[C]], %[[A_OWNERSHIP]], %[[B_OWNERSHIP]], %[[C_OWNERSHIP]] -// CHECK: } do { -// CHECK: %[[ALLOC1:.*]] = memref.alloc(%[[ARG0]]) -// CHECK: %[[ALLOC1_OWNED:.*]] = deallocation.own %[[ALLOC1]] -// CHECK: %[[ALLOC2:.*]] = memref.alloc(%[[ARG0]]) -// CHECK: %[[ALLOC2_OWNED:.*]] = deallocation.own %[[ALLOC2]] -// CHECK: deallocation.retain() of(%[[A_OWNERSHIP]]) -// CHECK: deallocation.retain() of(%[[C_OWNERSHIP]]) -// CHECK: scf.yield %[[ALLOC2]], %[[ALLOC1]], %[[B]], %[[ALLOC2_OWNED]], %[[ALLOC1_OWNED]], %[[B_OWNERSHIP]] -// CHECK: } -// CHECK: %[[RESULTS_RETAINED:.*]] = deallocation.retain(%[[WHILE]]#0, %[[WHILE]]#1, %[[WHILE]]#2) -// CHECK-SAME: of(%[[WHILE]]#3, %[[WHILE]]#4, %[[WHILE]]#5) -// CHECK: return %[[WHILE]]#0, %[[WHILE]]#1, %[[WHILE]]#2 - -// ----- - -func.func @if_without_else() { - %cond = "test.make_condition"() : () -> i1 - scf.if %cond { - %x = memref.alloc() : memref<2xf32> - "test.use"(%x) : (memref<2xf32>) -> () - scf.yield - } - return -} - -// CHECK-LABEL: @if_without_else -// CHECK: scf.if -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[ALLOC_OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NEXT: test.use -// CHECK-NEXT: deallocation.retain() of(%[[ALLOC_OWNED]]) - -// CHECK-SIMPLE-LABEL: @if_without_else -// CHECK-SIMPLE: scf.if -// CHECK-SIMPLE-NEXT: memref.alloc -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: memref.dealloc - -// ----- - -func.func @yield_same_alloc_twice() { - %alloc = memref.alloc() : memref - scf.while (%a = %alloc, %b = %alloc) : (memref, memref) -> () { - %cond = "test.make_condition"() : () -> i1 - scf.condition(%cond) - } do { - ^bb0(): - scf.yield %alloc, %alloc : memref, memref - } - return -} - -// CHECK-LABEL: @yield_same_alloc_twice -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[ALLOC_OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NEXT: %[[NULL:.*]] = deallocation.null -// CHECK: scf.while -// CHECK-SAME: %[[ALLOC]] -// CHECK-SAME: %[[ALLOC]] -// CHECK-SAME: %[[NULL]] -// CHECK-SAME: %[[NULL]] -// CHECK: do -// CHECK-NEXT: %[[NULL:.*]] = deallocation.null -// CHECK-NEXT: %[[RETAIN:.*]]:2 = deallocation.retain(%[[ALLOC]], %[[ALLOC]]) of() -// CHECK-NEXT: scf.yield %[[ALLOC]], %[[ALLOC]], %[[RETAIN]]#1, %[[NULL]] - -// ----- - -func.func @yield_derived(%lb: index, %ub: index, %step: index) { - %0 = memref.alloc() : memref<2xi32> - %1 = scf.for %i2 = %lb to %ub step %step iter_args(%arg0 = %0) -> memref<2xi32> { - %2 = memref.alloc() : memref<2xi32> - %3 = "test.someop"(%2) : (memref<2xi32>) -> memref<1xi32> - %4 = "test.someop"(%3) : (memref<1xi32>) -> memref<2xi32> - scf.yield %4 : memref<2xi32> - } - "test.use"(%1) : (memref<2xi32>) -> () - return -} - -// CHECK-LABEL: @yield_derived -// CHECK-NEXT: memref.alloc -// CHECK-NEXT: deallocation.own -// CHECK-NEXT: scf.for -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[ALLOC_OWNED:.*]] = deallocation.own -// CHECK-NEXT: "test.someop" -// CHECK-NEXT: %[[RESULT:.*]] = "test.someop" -// CHECK-NEXT: %[[RETAINED:.*]] = deallocation.retain -// CHECK-NEXT: deallocation.retain() of -// CHECK-NEXT: scf.yield %[[RESULT]], %[[RETAINED]] -// CHECK-NEXT: } -// CHECK-NEXT: test.use -// CHECK-NEXT: retain - -// ----- - -func.func @unknown_op() { - %c0 = arith.constant 0 : index - %c512 = arith.constant 512 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c512, %c512) step (%c1, %c8) { - %alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> - "test.use"(%alloc_14) : (memref<512x512xf32>) -> () - scf.yield - } - return -} - -// TODO(jreiffers): Remove the `own` op in simplification. -// CHECK-SIMPLE-LABEL: @unknown_op -// CHECK-SIMPLE: scf.parallel -// CHECK-SIMPLE-NEXT: memref.alloc() -// CHECK-SIMPLE: test.use -// CHECK-SIMPLE-NEXT: memref.dealloc - -// ----- - -func.func @unconditional_realloc(%init: index, %new: index) { - %alloc = memref.alloc(%init) : memref - "test.use"(%alloc) : (memref) -> () - %realloc = memref.realloc %alloc(%new) : memref to memref - "test.use"(%realloc) : (memref) -> () - return -} - -// CHECK-LABEL: @unconditional_realloc -// CHECK-NEXT: memref.alloc -// CHECK-NEXT: deallocation.own -// CHECK-NEXT: test.use -// CHECK-NEXT: %[[REALLOC:.*]] = memref.realloc -// CHECK-NEXT: %[[OWNED:.*]] = deallocation.own %[[REALLOC]] -// CHECK-NEXT: test.use -// CHECK-NEXT: deallocation.retain() of(%[[OWNED]]) -// CHECK-NEXT: return - -// CHECK-SIMPLE-LABEL: @unconditional_realloc -// CHECK-SIMPLE-NEXT: memref.alloc -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: %[[REALLOC:.*]] = memref.realloc -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: memref.dealloc %[[REALLOC]] - -// ----- - -func.func @realloc_in_if(%init: index) { - %alloc = memref.alloc(%init) : memref - %cond = "test.make_condition"() : () -> (i1) - %new_alloc = scf.if %cond -> memref { - %new_size = "test.make_index"() : () -> (index) - %ret = memref.realloc %alloc(%new_size) : memref to memref - scf.yield %ret : memref - } else { - scf.yield %alloc: memref - } - "test.use"(%new_alloc) : (memref) -> () - return -} - -// CHECK-LABEL: @realloc_in_if -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NEXT: test.make_condition -// CHECK-NEXT: %[[NEW_ALLOC:.*]]:2 = scf.if -// CHECK-NEXT: test.make_index -// CHECK-NEXT: %[[REALLOC:.*]] = memref.realloc %[[ALLOC]] -// CHECK-NEXT: %[[REALLOC_OWNED:.*]] = deallocation.own %[[REALLOC]] -// CHECK-NEXT: scf.yield %[[REALLOC]], %[[REALLOC_OWNED]] -// CHECK-NEXT: } else { -// CHECK-NEXT: deallocation.retain(%[[ALLOC]]) of() -// CHECK-NEXT: scf.yield %[[ALLOC]], %[[OWNED]] -// CHECK-NEXT: } -// CHECK-NEXT: "test.use"(%[[NEW_ALLOC]]#0) -// CHECK-NEXT: deallocation.retain() of(%[[NEW_ALLOC]]#1) -// CHECK-NEXT: return - -// ----- - -func.func @realloc_in_if_strange_but_ok(%size: index, %cond: i1) { - %alloc = memref.alloc(%size) : memref - scf.if %cond -> memref { - %realloc = memref.realloc %alloc(%size) : memref to memref - %new = memref.alloc(%size) : memref - scf.yield %new : memref - } else { - "test.dummy"() : () -> () - scf.yield %alloc : memref - } - return -} - -// CHECK-LABEL: @realloc_in_if_strange_but_ok -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[OWNED:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NOT: deallocation.retain() of(%[[OWNED]]) - -// ----- - -func.func @realloc_in_loop(%size: index, %lb: index, %ub: index, %step: index) { - %alloc = memref.alloc(%size) : memref - scf.for %i = %lb to %ub step %step iter_args(%arg0 = %alloc) -> memref { - %cond = "test.make_condition"() : () -> i1 - %new = scf.if %cond -> memref { - %realloc = memref.realloc %arg0(%size) : memref to memref - scf.yield %realloc : memref - } else { - scf.yield %arg0 : memref - } - scf.yield %new : memref - } - return -} - -// CHECK-LABEL: @realloc_in_loop -// CHECK-NEXT: memref.alloc -// CHECK-NEXT: %[[OWNED:.*]] = deallocation.own -// CHECK-NEXT: %[[FOR:.*]]:2 = scf.for -// CHECK: %[[IF:.*]]:2 = scf.if -// CHECK: scf.yield %[[IF]]#0, %[[IF]]#1 -// CHECK-NEXT: } -// CHECK-NEXT: deallocation.retain() of(%[[FOR]]#1) -// CHECK-NEXT: return - -// ----- - -func.func @alloca() { - %alloca = memref.alloca() : memref<2xf32> - %passthrough = "test.use"(%alloca) : (memref<2xf32>) -> (memref<2xf32>) - "test.use"(%passthrough) : (memref<2xf32>) -> () - return -} - -// CHECK-LABEL: @alloca() -// CHECK-NEXT: memref.alloca -// CHECK-NEXT: test.use -// CHECK-NEXT: test.use -// CHECK-NEXT: return - -// ----- - -func.func @dealloc() { - %alloc = memref.alloc() : memref - "test.use"(%alloc) : (memref) -> () - memref.dealloc %alloc: memref - return -} - -// CHECK-LABEL: @dealloc -// CHECK-SIMPLE-LABEL: @dealloc -// CHECK-SIMPLE-NEXT: memref.alloc -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: memref.dealloc -// CHECK-SIMPLE-NEXT: return - -// ----- - -func.func @dealloc_in_loop(%lb: index, %ub: index, %step: index) { - scf.for %i = %lb to %ub step %step { - %alloc = memref.alloc() : memref - "test.use"(%alloc) : (memref) -> () - memref.dealloc %alloc: memref - } - return -} - -// CHECK-LABEL: @dealloc_in_loop -// CHECK-SIMPLE-LABEL: @dealloc_in_loop -// CHECK-SIMPLE-NEXT: scf.for -// CHECK-SIMPLE-NEXT: memref.alloc -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: memref.dealloc -// CHECK-SIMPLE-NEXT: } -// CHECK-SIMPLE-NEXT: return - -// ----- - -func.func @dealloc_around_loop(%lb: index, %ub: index, %step: index) { - %alloc = memref.alloc() : memref - scf.for %i = %lb to %ub step %step { - "test.use"(%alloc) : (memref) -> () - } - memref.dealloc %alloc: memref - return -} - -// CHECK-LABEL: @dealloc_around_loop -// CHECK-SIMPLE-LABEL: @dealloc_around_loop -// CHECK-SIMPLE-NEXT: memref.alloc -// CHECK-SIMPLE-NEXT: scf.for -// CHECK-SIMPLE-NEXT: test.use -// CHECK-SIMPLE-NEXT: } -// CHECK-SIMPLE-NEXT: memref.dealloc -// CHECK-SIMPLE-NEXT: return - -// ----- - -func.func @memory_effect_no_free_or_alloc() { - %alloc = memref.alloc() : memref - %expand_shape = memref.expand_shape %alloc [] : memref into memref<1x1xi32> - "test.use"(%expand_shape) : (memref<1x1xi32>) -> () - return -} - -// CHECK-LABEL: @memory_effect_no_free_or_alloc -// CHECK-NEXT: memref.alloc -// CHECK-NEXT: deallocation.own -// CHECK-NEXT: memref.expand_shape -// CHECK-NEXT: test.use -// CHECK-NEXT: deallocation.retain - -// ----- - -func.func @id(%arg0: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - return %arg0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = call @id(%arg0) : (memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @id(%[[ARG0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG0]]) of() -// CHECK: return %[[ARG0]], %[[RETAIN]] - -// CHECK: @user(%[[ARG0_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @id(%[[ARG0_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @id_select(%arg0: i1, %arg1: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = arith.select %arg0, %arg1, %arg1 : memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: i1, %arg1: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = call @id_select(%arg0, %arg1) : (i1, memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @id_select(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>) -// CHECK: %[[SELECT:.*]] = arith.select %[[ARG0]], %[[ARG1]], %[[ARG1]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[SELECT]]) of() -// CHECK: return %[[SELECT]], %[[RETAIN]] - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @id_select(%[[ARG0_0]], %[[ARG1_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @ite(%arg0: i1, %arg1: memref<1x2x3xf32>, %arg2: memref<1x2x3xf32>) - -> memref<1x2x3xf32> { - %0 = scf.if %arg0 -> (memref<1x2x3xf32>) { - scf.yield %arg1 : memref<1x2x3xf32> - } else { - scf.yield %arg2 : memref<1x2x3xf32> - } - return %0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: i1, %arg1: memref<1x2x3xf32>, %arg2: memref<1x2x3xf32>) - -> memref<1x2x3xf32> { - %0 = call @ite(%arg0, %arg1, %arg2) - : (i1, memref<1x2x3xf32>, memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @ite(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>, %[[ARG2:.*]]: memref<1x2x3xf32>) -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG0]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: scf.yield %[[ARG1]], %[[RETAIN]] -// CHECK: else -// CHECK: %[[RETAIN_0:.*]] = deallocation.retain(%[[ARG2]]) of() -// CHECK: scf.yield %[[ARG2]], %[[RETAIN_0]] -// CHECK: return %[[IF]]#0, %[[IF]]#1 - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>, %[[ARG2_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @ite(%[[ARG0_0]], %[[ARG1_0]], %[[ARG2_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @ite_select(%arg0: i1, %arg1: memref<1x2x3xf32>, - %arg2: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = arith.select %arg0, %arg1, %arg2 : memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: i1, %arg1: memref<1x2x3xf32>, %arg2: memref<1x2x3xf32>) - -> memref<1x2x3xf32> { - %0 = call @ite_select(%arg0, %arg1, %arg2) - : (i1, memref<1x2x3xf32>, memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @ite_select(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>, %[[ARG2:.*]]: memref<1x2x3xf32>) -// CHECK: %[[SELECT:.*]] = arith.select %[[ARG0]], %[[ARG1]], %[[ARG2]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[SELECT]]) of() -// CHECK: return %[[SELECT]], %[[RETAIN]] - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>, %[[ARG2_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @ite_select(%[[ARG0_0]], %[[ARG1_0]], %[[ARG2_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @may_reuse(%arg0: i1, %arg1: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = scf.if %arg0 -> (memref<1x2x3xf32>) { - scf.yield %arg1 : memref<1x2x3xf32> - } else { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - scf.yield %alloc : memref<1x2x3xf32> - } - return %0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: i1, %arg1: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = call @may_reuse(%arg0, %arg1) : (i1, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @may_reuse(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>) -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG0]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: scf.yield %[[ARG1]], %[[RETAIN]] -// CHECK: else -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: scf.yield %[[ALLOC]], %[[OWN]] -// CHECK: return %[[IF]]#0, %[[IF]]#1 - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @may_reuse(%[[ARG0_0]], %[[ARG1_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @insert(%arg0: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 7.000000e+00 : f32 - memref.store %cst, %arg0[%c0, %c1, %c1] : memref<1x2x3xf32> - return %arg0 : memref<1x2x3xf32> -} - -func.func @user(%arg0: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = call @insert(%arg0) : (memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @insert(%[[ARG0:.*]]: memref<1x2x3xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[CST:.*]] = arith.constant 7.0 -// CHECK: memref.store %[[CST]], %[[ARG0]][%[[C0]], %[[C1]], %[[C1]]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG0]]) of() -// CHECK: return %[[ARG0]], %[[RETAIN]] - -// CHECK: @user(%[[ARG0_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @insert(%[[ARG0_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// ----- - -func.func @ite_no_yielded_buffers(%pred: i1) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 7.000000e+00 : f32 - %outer_alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - scf.if %pred { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - memref.store %cst, %alloc[%c0, %c1, %c1] : memref<1x2x3xf32> - scf.yield - } else { - memref.store %cst, %outer_alloc[%c0, %c1, %c1] : memref<1x2x3xf32> - scf.yield - } - return -} - -func.func @user(%arg0: i1) { - call @ite_no_yielded_buffers(%arg0) : (i1) -> () - return -} - -// CHECK: @ite_no_yielded_buffers(%[[ARG0:.*]]: i1) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[CST:.*]] = arith.constant 7.0 -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: scf.if %[[ARG0]] -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK: memref.store %[[CST]], %[[ALLOC_0]][%[[C0]], %[[C1]], %[[C1]]] -// CHECK: deallocation.retain() of(%[[OWN_0]]) -// CHECK: else -// CHECK: memref.store %[[CST]], %[[ALLOC]][%[[C0]], %[[C1]], %[[C1]]] -// CHECK: deallocation.retain() of(%[[OWN]]) -// CHECK: return - -// CHECK: @user(%[[ARG0_0:.*]]: i1) -// CHECK: call @ite_no_yielded_buffers(%[[ARG0_0]]) -// CHECK: return - -// ----- - -func.func @may_reuse(%pred: i1, %arg: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %0 = scf.if %pred -> (memref<1x2x3xf32>) { - scf.yield %arg : memref<1x2x3xf32> - } else { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - scf.yield %alloc : memref<1x2x3xf32> - } - return %0 : memref<1x2x3xf32> -} - -func.func @user(%pred: i1, %arg: memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %may_escape_indirectly = memref.alloc() {alignment = 64 : i64} - : memref<1x2x3xf32> - %0 = call @may_reuse(%pred, %may_escape_indirectly) : (i1, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @may_reuse(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>) -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG0]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: scf.yield %[[ARG1]], %[[RETAIN]] -// CHECK: else -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: scf.yield %[[ALLOC]], %[[OWN]] -// CHECK: return %[[IF]]#0, %[[IF]]#1 - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK: %[[OWNERSHIP:.*]]:2 = call @may_reuse(%[[ARG0_0]], %[[ALLOC_0]]) -// CHECK: %[[RETAIN_0:.*]] = deallocation.retain(%[[OWNERSHIP]]#0) of(%[[OWN_0]], %[[OWNERSHIP]]#1) -// CHECK: return %[[OWNERSHIP]]#0, %[[RETAIN_0]] - -// ----- - -func.func @insert_may_reuse_and_forward(%arg0: i1, %arg1: memref<1x2x3xf32>) - -> (memref<1x2x3xf32>, memref<1x2x3xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 7.000000e+00 : f32 - %0 = scf.if %arg0 -> (memref<1x2x3xf32>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - memref.copy %arg1, %alloc : memref<1x2x3xf32> to memref<1x2x3xf32> - memref.store %cst, %alloc[%c0, %c1, %c1] : memref<1x2x3xf32> - scf.yield %alloc : memref<1x2x3xf32> - } else { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x3xf32> - memref.store %cst, %alloc[%c0, %c1, %c1] : memref<1x2x3xf32> - scf.yield %alloc : memref<1x2x3xf32> - } - return %0, %arg1 : memref<1x2x3xf32>, memref<1x2x3xf32> -} - -func.func @user(%arg0: i1, %arg1: memref<1x2x3xf32>) - -> (memref<1x2x3xf32>, memref<1x2x3xf32>) { - %5:2 = call @insert_may_reuse_and_forward(%arg0, %arg1) - : (i1, memref<1x2x3xf32>) -> (memref<1x2x3xf32>, memref<1x2x3xf32>) - return %5#0, %5#1 : memref<1x2x3xf32>, memref<1x2x3xf32> -} - -// CHECK: @insert_may_reuse_and_forward(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: memref<1x2x3xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[CST:.*]] = arith.constant 7.0 -// CHECK: %[[IF:.*]]:2 = scf.if %[[ARG0]] -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: memref.copy %[[ARG1]], %[[ALLOC]] -// CHECK: memref.store %[[CST]], %[[ALLOC]][%[[C0]], %[[C1]], %[[C1]]] -// CHECK: scf.yield %[[ALLOC]], %[[OWN]] -// CHECK: else -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK: memref.store %[[CST]], %[[ALLOC_0]][%[[C0]], %[[C1]], %[[C1]]] -// CHECK: scf.yield %[[ALLOC_0]], %[[OWN_0]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: return %[[IF]]#0, %[[ARG1]], %[[IF]]#1, %[[RETAIN]] - -// CHECK: @user(%[[ARG0_0:.*]]: i1, %[[ARG1_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[RESULT:.*]]:4 = call @insert_may_reuse_and_forward(%[[ARG0_0]], %[[ARG1_0]]) -// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#3 - -// ----- - -func.func @f(%a : memref<1x2x3xf32>, %b : memref<1x2x3xf32>, - %c : memref<1x2x3xf32>, %d : memref<1x2x3xf32>, %e : memref<1x2x3xf32>) - -> memref<1x2x3xf32> { - %0 = func.call @f(%a, %a, %b, %c, %d) : (memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - func.return %0 : memref<1x2x3xf32> -} - -func.func @user() -> memref<1x2x3xf32> { - %a = memref.alloc() : memref<1x2x3xf32> - %b = memref.alloc() : memref<1x2x3xf32> - %c = memref.alloc() : memref<1x2x3xf32> - %d = memref.alloc() : memref<1x2x3xf32> - %e = memref.alloc() : memref<1x2x3xf32> - %0 = func.call @f(%a, %b, %c, %d, %e) : (memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @f(%[[ARG0:.*]]: memref<1x2x3xf32>, %[[ARG1:.*]]: memref<1x2x3xf32>, %[[ARG2:.*]]: memref<1x2x3xf32>, %[[ARG3:.*]]: memref<1x2x3xf32>, %[[ARG4:.*]]: memref<1x2x3xf32>) -// CHECK: %[[OWNERSHIP:.*]]:2 = call @f(%[[ARG0]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 - -// CHECK: @user() -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc -// CHECK-NEXT: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK-NEXT: %[[ALLOC_0:.*]] = memref.alloc -// CHECK-NEXT: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK-NEXT: %[[ALLOC_1:.*]] = memref.alloc -// CHECK-NEXT: %[[OWN_1:.*]] = deallocation.own %[[ALLOC_1]] -// CHECK-NEXT: %[[ALLOC_2:.*]] = memref.alloc -// CHECK-NEXT: %[[OWN_2:.*]] = deallocation.own %[[ALLOC_2]] -// CHECK-NEXT: %[[ALLOC_3:.*]] = memref.alloc -// CHECK-NEXT: %[[OWN_3:.*]] = deallocation.own %[[ALLOC_3]] -// CHECK-NEXT: %[[OWNERSHIP_0:.*]]:2 = call @f(%[[ALLOC]], %[[ALLOC_0]], %[[ALLOC_1]], %[[ALLOC_2]], %[[ALLOC_3]]) -// CHECK-NEXT: %[[RETAIN:.*]] = deallocation.retain(%[[OWNERSHIP_0]]#0) of(%[[OWN]], %[[OWN_0]], %[[OWN_1]], %[[OWN_2]], %[[OWNERSHIP_0]]#1) -// CHECK-NEXT: deallocation.retain() of(%[[OWN_3]]) -// CHECK-NEXT: return %[[OWNERSHIP_0]]#0, %[[RETAIN]] - -// ----- - -func.func @terminating_f(%i : i32, %a : memref<1x2x3xf32>, - %b : memref<1x2x3xf32>, %c : memref<1x2x3xf32>, %d : memref<1x2x3xf32>, - %e : memref<1x2x3xf32>) -> memref<1x2x3xf32> { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %pred = arith.cmpi slt, %i, %c0 : i32 - %0 = scf.if %pred -> memref<1x2x3xf32> { - scf.yield %a : memref<1x2x3xf32> - } else { - %i_ = arith.subi %i, %c1 : i32 - %1 = func.call @terminating_f(%i_, %a, %a, %b, %c, %d) - : (i32, memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>) -> memref<1x2x3xf32> - scf.yield %1 : memref<1x2x3xf32> - } - func.return %0 : memref<1x2x3xf32> -} - -func.func @user() -> memref<1x2x3xf32> { - %c0 = arith.constant 0 : i32 - %a = memref.alloc() : memref<1x2x3xf32> - %b = memref.alloc() : memref<1x2x3xf32> - %c = memref.alloc() : memref<1x2x3xf32> - %d = memref.alloc() : memref<1x2x3xf32> - %e = memref.alloc() : memref<1x2x3xf32> - %0 = func.call @terminating_f(%c0, %a, %b, %c, %d, %e) - : (i32, memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>) -> memref<1x2x3xf32> - return %0 : memref<1x2x3xf32> -} - -// CHECK: @terminating_f(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: memref<1x2x3xf32>, %[[ARG2:.*]]: memref<1x2x3xf32>, %[[ARG3:.*]]: memref<1x2x3xf32>, %[[ARG4:.*]]: memref<1x2x3xf32>, %[[ARG5:.*]]: memref<1x2x3xf32>) -// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32 -// CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[ARG0]], %[[C0_I32]] -// CHECK: %[[IF:.*]]:2 = scf.if %[[CMPI]] -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: scf.yield %[[ARG1]], %[[RETAIN]] -// CHECK: else -// CHECK: %[[SUBI:.*]] = arith.subi %[[ARG0]], %[[C1_I32]] -// CHECK: %[[OWNERSHIP:.*]]:2 = func.call @terminating_f(%[[SUBI]], %[[ARG1]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -// CHECK: scf.yield %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1 -// CHECK: return %[[IF]]#0, %[[IF]]#1 - -// CHECK: @user() -// CHECK-DAG: %[[C0_I32_0:.*]] = arith.constant 0 : i32 -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK: %[[ALLOC_1:.*]] = memref.alloc -// CHECK: %[[OWN_1:.*]] = deallocation.own %[[ALLOC_1]] -// CHECK: %[[ALLOC_2:.*]] = memref.alloc -// CHECK: %[[OWN_2:.*]] = deallocation.own %[[ALLOC_2]] -// CHECK: %[[ALLOC_3:.*]] = memref.alloc -// CHECK: %[[OWN_3:.*]] = deallocation.own %[[ALLOC_3]] -// CHECK: %[[OWNERSHIP_0:.*]]:2 = call @terminating_f(%[[C0_I32_0]], %[[ALLOC]], %[[ALLOC_0]], %[[ALLOC_1]], %[[ALLOC_2]], %[[ALLOC_3]]) -// CHECK: %[[RETAIN_0:.*]] = deallocation.retain(%[[OWNERSHIP_0]]#0) of(%[[OWN]], %[[OWN_0]], %[[OWN_1]], %[[OWN_2]], %[[OWNERSHIP_0]]#1) -// CHECK: deallocation.retain() of(%[[OWN_3]]) -// CHECK: return %[[OWNERSHIP_0]]#0, %[[RETAIN_0]] - -// ----- - -func.func @id(%arg0 : memref<1x2x3xf32>, %arg1 : memref<1x2x3xf32>) - -> memref<1x2x3xf32> { - func.return %arg1 : memref<1x2x3xf32> -} - -func.func @user() -> (memref<1x2x3xf32>, memref<1x2x3xf32>) { - %alloc0 = memref.alloc() : memref<1x2x3xf32> - %alloc1 = memref.alloc() : memref<1x2x3xf32> - %alloc2 = memref.alloc() : memref<1x2x3xf32> - %0 = func.call @id(%alloc0, %alloc2) : (memref<1x2x3xf32>, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - %1 = func.call @id(%alloc1, %alloc2) : (memref<1x2x3xf32>, memref<1x2x3xf32>) - -> memref<1x2x3xf32> - func.return %0, %1 : memref<1x2x3xf32>, memref<1x2x3xf32> -} - -// CHECK: @id(%[[ARG0:.*]]: memref<1x2x3xf32>, %[[ARG1:.*]]: memref<1x2x3xf32>) -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: return %[[ARG1]], %[[RETAIN]] - -// CHECK: @user() -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] : memref<1x2x3xf32> -// CHECK: %[[ALLOC_1:.*]] = memref.alloc -// CHECK: %[[OWN_1:.*]] = deallocation.own %[[ALLOC_1]] : memref<1x2x3xf32> -// CHECK: %[[OWNERSHIP:.*]]:2 = call @id(%[[ALLOC]], %[[ALLOC_1]]) -// CHECK: %[[OWNERSHIP_0:.*]]:2 = call @id(%[[ALLOC_0]], %[[ALLOC_1]]) -// CHECK: %[[RETAIN_0:.*]]:2 = deallocation.retain(%[[OWNERSHIP]]#0, %[[OWNERSHIP_0]]#0) of(%[[OWN_1]], %[[OWNERSHIP]]#1, %[[OWNERSHIP_0]]#1) -// CHECK: deallocation.retain() of(%[[OWN]]) -// CHECK: deallocation.retain() of(%[[OWN_0]]) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP_0]]#0, %[[RETAIN_0]]#0, %[[RETAIN_0]]#1 - -// ----- - -func.func @forward(%arg0: memref<1x2x3xf32>, %arg1: memref<1x2x3xf32>, - %arg2: memref<1x2x3xf32>) -> (memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>) { - func.return %arg0, %arg1, %arg2 : memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32> -} - -func.func @replace(%arg0: memref<1x2x3xf32>, %arg1: memref<1x2x3xf32>, - %arg2: memref<1x2x3xf32>) -> (memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>) { - %alloc0 = memref.alloc() : memref<1x2x3xf32> - %alloc1 = memref.alloc() : memref<1x2x3xf32> - %alloc2 = memref.alloc() : memref<1x2x3xf32> - func.return %alloc0, %alloc1, %alloc2 - : memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32> -} - -func.func @user() -> (memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) { - %alloc0 = memref.alloc() : memref<1x2x3xf32> - %alloc1 = memref.alloc() : memref<1x2x3xf32> - %alloc2 = memref.alloc() : memref<1x2x3xf32> - %0:3 = func.call @forward(%alloc0, %alloc1, %alloc2) - : (memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - -> (memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - %1:3 = func.call @replace(%alloc0, %alloc1, %alloc2) - : (memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - -> (memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>) - func.return %0#0, %0#1, %0#2, %1#0, %1#1, %1#2 : memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32>, memref<1x2x3xf32>, - memref<1x2x3xf32>, memref<1x2x3xf32> -} - -// CHECK: @forward(%[[ARG0:.*]]: memref<1x2x3xf32>, %[[ARG1:.*]]: memref<1x2x3xf32>, %[[ARG2:.*]]: memref<1x2x3xf32>) -// CHECK: %[[RETAIN:.*]] = deallocation.retain(%[[ARG0]]) of() -// CHECK: %[[RETAIN_0:.*]] = deallocation.retain(%[[ARG1]]) of() -// CHECK: %[[RETAIN_1:.*]] = deallocation.retain(%[[ARG2]]) of() -// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[RETAIN]], %[[RETAIN_0]], %[[RETAIN_1]] - -// CHECK: @replace(%[[ARG0_0:.*]]: memref<1x2x3xf32>, %[[ARG1_0:.*]]: memref<1x2x3xf32>, %[[ARG2_0:.*]]: memref<1x2x3xf32>) -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: %[[OWN:.*]] = deallocation.own %[[ALLOC]] -// CHECK: %[[ALLOC_0:.*]] = memref.alloc -// CHECK: %[[OWN_0:.*]] = deallocation.own %[[ALLOC_0]] -// CHECK: %[[ALLOC_1:.*]] = memref.alloc -// CHECK: %[[OWN_1:.*]] = deallocation.own %[[ALLOC_1]] -// CHECK: return %[[ALLOC]], %[[ALLOC_0]], %[[ALLOC_1]], %[[OWN]], %[[OWN_0]], %[[OWN_1]] - -// CHECK: @user() -// CHECK: %[[ALLOC_2:.*]] = memref.alloc -// CHECK: %[[OWN_2:.*]] = deallocation.own %[[ALLOC_2]] -// CHECK: %[[ALLOC_0_0:.*]] = memref.alloc -// CHECK: %[[OWN_3:.*]] = deallocation.own %[[ALLOC_0_0]] -// CHECK: %[[ALLOC_1_0:.*]] = memref.alloc -// CHECK: %[[OWN_4:.*]] = deallocation.own %[[ALLOC_1_0]] -// CHECK: %[[OWNERSHIP:.*]]:6 = call @forward(%[[ALLOC_2]], %[[ALLOC_0_0]], %[[ALLOC_1_0]]) -// CHECK: %[[OWNERSHIP_0:.*]]:6 = call @replace(%[[ALLOC_2]], %[[ALLOC_0_0]], %[[ALLOC_1_0]]) -// CHECK: %[[RETAIN_2:.*]] = deallocation.retain(%[[OWNERSHIP]]#0) of(%[[OWN_2]], %[[OWNERSHIP]]#3) -// CHECK: %[[RETAIN_3:.*]] = deallocation.retain(%[[OWNERSHIP]]#1) of(%[[OWN_3]], %[[OWNERSHIP]]#4) -// CHECK: %[[RETAIN_4:.*]] = deallocation.retain(%[[OWNERSHIP]]#2) of(%[[OWN_4]], %[[OWNERSHIP]]#5) -// CHECK: return %[[OWNERSHIP]]#0, %[[OWNERSHIP]]#1, %[[OWNERSHIP]]#2, %[[OWNERSHIP_0]]#0, %[[OWNERSHIP_0]]#1, %[[OWNERSHIP_0]]#2, %[[RETAIN_2]], %[[RETAIN_3]], %[[RETAIN_4]], %[[OWNERSHIP_0]]#3, %[[OWNERSHIP_0]]#4, %[[OWNERSHIP_0]]#5 \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate_invalid.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate_invalid.mlir deleted file mode 100644 index 9604880ced0e86..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocate_invalid.mlir +++ /dev/null @@ -1,76 +0,0 @@ -// RUN: mlir-hlo-opt -allow-unregistered-dialect %s -split-input-file -hlo-deallocate -verify-diagnostics - -func.func @dealloc_invalid(%lb: index, %ub: index, %step: index) { - %alloc = memref.alloc() : memref - scf.for %i = %lb to %ub step %step { // expected-error {{can't implicitly capture across loop boundaries}} - memref.dealloc %alloc: memref - } - return -} - -// ----- - -func.func @realloc_no_else(%size: index, %cond: i1) { - %alloc = memref.alloc(%size) : memref - scf.if %cond { // expected-error {{cannot implicitly capture from an if without else}} - %realloc = memref.realloc %alloc(%size) : memref to memref - } - return -} - -// ----- - -func.func @realloc_not_yielded(%size: index, %cond: i1) { - %alloc = memref.alloc(%size) : memref - scf.if %cond { // expected-error {{released value not yielded on other branch}} - %realloc = memref.realloc %alloc(%size) : memref to memref - } else { - "test.dummy"() : () -> () - } - return -} - -// ----- - -func.func @realloc_arg(%arg: memref, %size: index) { - %realloc = memref.realloc %arg(%size) : memref to memref // expected-error {{unable to find ownership indicator for operand}} - return -} - -// ----- - -func.func @realloc_twice(%size: index) { // expected-error {{invalid realloc of memref}} - %alloc = memref.alloc(%size) : memref - %realloc0 = memref.realloc %alloc(%size) : memref to memref - %realloc1 = memref.realloc %alloc(%size) : memref to memref - return -} - -// ----- - -func.func @realloc_twice_in_if(%size: index, %cond: i1) { // expected-error {{invalid realloc of memref}} - %alloc = memref.alloc(%size) : memref - scf.if %cond -> memref { - %realloc = memref.realloc %alloc(%size) : memref to memref - scf.yield %realloc : memref - } else { - scf.yield %alloc : memref - } - scf.if %cond -> memref { - %realloc = memref.realloc %alloc(%size) : memref to memref - scf.yield %realloc : memref - } else { - scf.yield %alloc : memref - } - return -} - -// ----- - -func.func @cross_loop_boundary(%size: index, %lb: index, %ub: index, %step: index) { - %alloc = memref.alloc(%size) : memref - scf.for %i = %lb to %ub step %step { // expected-error {{can't implicitly capture across loop boundaries}} - memref.realloc %alloc(%size) : memref to memref - } - return -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_ops.mlir deleted file mode 100644 index 0dd5e47ad9c1b1..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_ops.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --verify-diagnostics | FileCheck %s - -// CHECK-LABEL: @retain -func.func @retain(%arg0: memref<2xf32>, %arg1: !deallocation.ownership, %arg2: !deallocation.ownership) - -> !deallocation.ownership { - %0 = deallocation.retain(%arg0) of(%arg1, %arg2) - : (memref<2xf32>, !deallocation.ownership, !deallocation.ownership) -> !deallocation.ownership - return %0 : !deallocation.ownership -} - -// CHECK-LABEL: @get_buffer -func.func @get_buffer(%arg0: memref<2xf32>) -> index { - %0 = deallocation.get_buffer %arg0 : memref<2xf32> - return %0 : index -} - -// CHECK-LABEL: @get_ownership_buffer -func.func @get_ownership_buffer(%arg0: !deallocation.ownership) -> index { - %0 = deallocation.get_buffer %arg0 : !deallocation.ownership - return %0 : index -} - -// CHECK-LABEL: @own -func.func @own(%arg0: memref<2xf32>) -> !deallocation.ownership { - %0 = deallocation.own %arg0 : memref<2xf32> - return %0 : !deallocation.ownership -} - -// CHECK-LABEL: @null -func.func @null() -> !deallocation.ownership { - %0 = deallocation.null - return %0 : !deallocation.ownership -} - -// CHECK-LABEL: @free -func.func @free(%arg0: !deallocation.ownership) { - deallocation.free %arg0 - return -} \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_simplification.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_simplification.mlir deleted file mode 100644 index ebefd3dce49da7..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_simplification.mlir +++ /dev/null @@ -1,181 +0,0 @@ -// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -hlo-deallocation-simplification | FileCheck %s - -func.func @retain_is_dealloc() { - %alloc = memref.alloc() : memref<2xf32> - %alloc_owned = deallocation.own %alloc : memref<2xf32> - "test.use"(%alloc) : (memref<2xf32>) -> () - deallocation.retain() of (%alloc_owned) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_dealloc -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: test.use -// CHECK-NEXT: memref.dealloc %[[ALLOC]] - -// ----- - -func.func @retain_of_nothing(%arg: memref<2xf32>) -> !deallocation.ownership { - %ret = deallocation.retain(%arg) of() : (memref<2xf32>) -> (!deallocation.ownership) - return %ret : !deallocation.ownership -} - -// CHECK-LABEL: @retain_of_nothing -// CHECK-SAME: (%[[ARG:.*]]: memref<2xf32> -// CHECK-NEXT: %[[NULL:.*]] = deallocation.null -// CHECK-NEXT: return %[[NULL]] - -// ----- - -func.func @retain_is_dealloc_for(%lb: index, %ub: index, %step: index) { - %alloc = memref.alloc() : memref<2xf32> - %alloc_owned = deallocation.own %alloc : memref<2xf32> - %for:2 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %alloc, %arg1 = %alloc_owned) - -> (memref<2xf32>, !deallocation.ownership) { - "some.use"(%arg0) : (memref<2xf32>) -> () - scf.yield %arg0, %arg1 : memref<2xf32>, !deallocation.ownership - } - deallocation.retain() of(%for#1) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_dealloc_for -// CHECK-NEXT: memref.alloc() -// CHECK-NEXT: deallocation.null -// CHECK-NEXT: %[[FOR:.*]]:2 = scf.for -// CHECK-NEXT: some.use -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: memref.dealloc %[[FOR]]#0 -// CHECK-NEXT: return - -// ----- - -func.func @retain_is_dealloc_reallocated(%lb: index, %ub: index, %step: index) { - %alloc = memref.alloc() : memref<2xf32> - %alloc_owned = deallocation.own %alloc : memref<2xf32> - %for:2 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %alloc, %arg1 = %alloc_owned) - -> (memref<2xf32>, !deallocation.ownership) { - "some.use"(%arg0) : (memref<2xf32>) -> () - deallocation.retain() of(%arg1) : (!deallocation.ownership) -> () - %alloc0 = memref.alloc() : memref<2xf32> - %alloc0_owned = deallocation.own %alloc0 : memref<2xf32> - scf.yield %alloc, %alloc0_owned : memref<2xf32>, !deallocation.ownership - } - deallocation.retain() of(%for#1) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_dealloc_reallocated -// CHECK-NEXT: memref.alloc -// CHECK-NEXT: deallocation.null -// CHECK-NEXT: %[[FOR:.*]]:2 = scf.for -// CHECK: memref.dealloc -// CHECK: } -// CHECK: memref.dealloc %[[FOR]] - -// ----- - -func.func @retain_is_not_dealloc_for( - %x: memref<2xf32>, %x_owned: !deallocation.ownership, - %lb: index, %ub: index, %step: index) { - %for:2 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %x, %arg1 = %x_owned) - -> (memref<2xf32>, !deallocation.ownership) { - "some.use"(%arg0) : (memref<2xf32>) -> () - deallocation.retain() of(%arg1) : (!deallocation.ownership) -> () - %alloc = memref.alloc() : memref<2xf32> - %alloc_owned = deallocation.own %alloc : memref<2xf32> - scf.yield %alloc, %alloc_owned : memref<2xf32>, !deallocation.ownership - } - deallocation.retain() of(%for#1) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_not_dealloc_for -// CHECK: %[[FOR:.*]]:2 = scf.for -// CHECK: deallocation.retain() of(%[[FOR]]#1) - -// ----- - -func.func @retain_is_dealloc_while() { - %a = memref.alloc() : memref<2xf32> - %a_owned = deallocation.own %a : memref<2xf32> - %while:2 = scf.while (%arg0 = %a, %arg1 = %a_owned) - : (memref<2xf32>, !deallocation.ownership) -> (memref<2xf32>, !deallocation.ownership) { - %0 = "test.make_condition"() : () -> i1 - scf.condition(%0) %arg0, %arg1 : memref<2xf32>, !deallocation.ownership - } do { - ^bb0(%arg0: memref<2xf32>, %arg1: !deallocation.ownership): - "some.use"(%arg0) : (memref<2xf32>) -> () - deallocation.retain() of(%arg1) : (!deallocation.ownership) -> () - %b = memref.alloc() : memref<2xf32> - %b_owned = deallocation.own %b : memref<2xf32> - scf.yield %b, %b_owned: memref<2xf32>, !deallocation.ownership - } - deallocation.retain() of (%while#1) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_dealloc_while -// CHECK: %[[WHILE:.*]]:2 = scf.while -// CHECK: memref.dealloc %[[WHILE]]#0 - -// ----- - -func.func @retain_is_dealloc_while_permute() { - %a = memref.alloc() : memref - %a_owned = deallocation.own %a : memref - %b = memref.alloc() : memref - %b_owned = deallocation.own %b : memref - %c = memref.alloc() : memref - %c_owned = deallocation.own %c : memref - %w:6 = scf.while (%arg0 = %a, %arg1 = %b, %arg2 = %c, - %arg3 = %a_owned, %arg4 = %b_owned, %arg5 = %c_owned) - : (memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership) -> - (memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership) { - %cond = "test.make_condition"() : () -> i1 - scf.condition(%cond) %arg2, %arg1, %arg0, %arg5, %arg4, %arg3 - : memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership - } do { - ^bb0(%arg0: memref, %arg1: memref, %arg2: memref, - %arg3: !deallocation.ownership, %arg4: !deallocation.ownership, %arg5: !deallocation.ownership): - scf.yield %arg1, %arg0, %arg2, %arg4, %arg3, %arg5 - : memref, memref, memref, !deallocation.ownership, !deallocation.ownership, !deallocation.ownership - } - "test.use"(%w#1) : (memref) -> () - deallocation.retain() of (%w#3) : (!deallocation.ownership) -> () - deallocation.retain() of (%w#4) : (!deallocation.ownership) -> () - deallocation.retain() of (%w#5) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_is_dealloc_while_permute -// CHECK: memref.alloc -// CHECK: memref.alloc -// CHECK: memref.alloc -// CHECK: %[[WHILE:.*]]:6 = scf.while -// CHECK: memref.dealloc %[[WHILE]] -// CHECK: memref.dealloc %[[WHILE]] -// CHECK: memref.dealloc %[[WHILE]] - -func.func @retain_of_null(%arg0: memref<4xi32>, %arg1: memref<4xi32>, - %arg2: index, %arg3: index, %arg4: index) { - %0 = deallocation.null - %2:4 = scf.for %arg5 = %arg2 to %arg3 step %arg4 - iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %0, %arg9 = %0) -> - (memref<4xi32>, memref<4xi32>, !deallocation.ownership, !deallocation.ownership) { - "test.use"(%arg6, %arg7) : (memref<4xi32>, memref<4xi32>) -> () - %3 = deallocation.retain(%arg6) of(%arg8) - : (memref<4xi32>, !deallocation.ownership) -> !deallocation.ownership - %4 = deallocation.retain(%arg7) of(%arg9) - : (memref<4xi32>, !deallocation.ownership) -> !deallocation.ownership - scf.yield %arg7, %arg6, %4, %3 - : memref<4xi32>, memref<4xi32>, !deallocation.ownership, !deallocation.ownership - } - deallocation.retain() of(%2#2) : (!deallocation.ownership) -> () - deallocation.retain() of(%2#3) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_of_null -// CHECK-NOT: deallocation.retain() diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_to_scf.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_to_scf.mlir deleted file mode 100644 index a181ef97c44788..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/deallocation_to_scf.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: mlir-hlo-opt %s -hlo-deallocation-to-scf | FileCheck %s - -func.func @retain_nothing(%arg0: !deallocation.ownership) { - deallocation.retain() of (%arg0) : (!deallocation.ownership) -> () - return -} - -// CHECK-LABEL: @retain_nothing -// CHECK-SAME: %[[ARG:.*]]: -// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[BUF:.*]] = deallocation.get_buffer %[[ARG]] -// CHECK-NEXT: %[[NONNULL:.*]] = arith.cmpi ne, %[[BUF]], %[[ZERO]] -// CHECK-NEXT: scf.if %[[NONNULL]] { -// CHECK-NEXT: deallocation.free %[[ARG]] -// CHECK-NEXT: } - -// ----- - -func.func @retain_something(%arg0: memref<2xf32>, %arg1: !deallocation.ownership) - -> !deallocation.ownership { - %ret = deallocation.retain(%arg0) of (%arg1) : (memref<2xf32>, !deallocation.ownership) - -> (!deallocation.ownership) - return %ret : !deallocation.ownership -} - -// CHECK-LABEL: @retain_something -// CHECK-SAME: %[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: -// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[BUF:.*]] = deallocation.get_buffer %[[ARG1]] -// CHECK-NEXT: %[[NULL:.*]] = deallocation.null -// CHECK-NEXT: %[[RETAINED_BUF:.*]] = deallocation.get_buffer %[[ARG0]] -// CHECK-NEXT: %[[SAME:.*]] = arith.cmpi eq, %[[RETAINED_BUF]], %[[BUF]] -// CHECK-NEXT: %[[RET:.*]]:3 = scf.if %[[SAME]] -// CHECK-NEXT: scf.yield %[[NULL]], %[[ZERO]], %[[ARG1]] -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[ARG1]], %[[BUF]], %[[NULL]] -// CHECK-NEXT: } -// CHECK-NEXT: %[[DEALLOC:.*]] = arith.cmpi ne, %[[RET]]#1, %[[ZERO]] -// CHECK-NEXT: scf.if %[[DEALLOC]] { -// CHECK-NEXT: deallocation.free %[[RET]]#0 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[RET]]#2 diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/split_alloc_tensors.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/split_alloc_tensors.mlir deleted file mode 100644 index 9718a69372d996..00000000000000 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/deallocation/split_alloc_tensors.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -hlo-split-alloc-tensors | FileCheck %s - -func.func @split() { - %alloc_tensor = bufferization.alloc_tensor() : tensor<2xf32> - %a = "some.op"(%alloc_tensor) : (tensor<2xf32>) -> (tensor<2xf32>) - %b = "some.op"(%a, %alloc_tensor) - : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>) - "some.use"(%b) : (tensor<2xf32>) -> () - %c = "some.op"(%alloc_tensor) : (tensor<2xf32>) -> (tensor<2xf32>) - return -} - -// CHECK-LABEL: @split -// CHECK-NEXT: alloc_tensor -// CHECK-NEXT: some.op -// CHECK-NEXT: alloc_tensor -// CHECK-NEXT: some.op -// CHECK-NEXT: some.use -// CHECK-NEXT: alloc_tensor -// CHECK-NEXT: some.op - -func.func @split_empty_region() { - %alloc_tensor = bufferization.alloc_tensor() : tensor<2xf32> - %cond = "test.cond"() : () -> (i1) - scf.if %cond { - %a = "some.op"(%alloc_tensor) : (tensor<2xf32>) -> (tensor<2xf32>) - } - // No else. - return -} - -// This is a regression test. Just check that this is processed successfully. -// CHECK-LABEL: @split_empty_region diff --git a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt index 0f20301d11c76d..a7a65587e11905 100644 --- a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -23,7 +23,6 @@ set(LIBS MLIROptLib AllMhloPasses - DeallocationDialect DeallocationPasses LmhloDialect LmhloGPUDialect diff --git a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc index db46cfdfd34f5f..742bf5707b8eaf 100644 --- a/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc +++ b/third_party/xla/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "deallocation/IR/deallocation_ops.h" #include "deallocation/transforms/passes.h" #include "lhlo/IR/lhlo_ops.h" #include "lhlo/transforms/passes.h" @@ -43,7 +42,6 @@ int main(int argc, char** argv) { registerAllExtensions(registry); mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); - registry.insert(); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index b338623b1135bc..bef78953286b93 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -15,8 +15,6 @@ limitations under the License. #include #include -#include "deallocation/IR/deallocation_ops.h" // IWYU pragma: keep -#include "deallocation/transforms/passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" @@ -107,8 +105,6 @@ class GenericHostToLLVMPass populateSCFToControlFlowConversionPatterns(patterns); populateComplexToLLVMConversionPatterns(typeConverter, patterns); populateMathToLibmConversionPatterns(patterns); - deallocation::populateDeallocationToLLVMConversionPatterns(typeConverter, - patterns); // Vector patterns. vector::populateVectorMaskMaterializationPatterns(patterns, true); diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.td b/third_party/xla/xla/mlir_hlo/transforms/passes.td index d9966b11f764db..f97752add62536 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.td @@ -140,7 +140,6 @@ def GenericHostToLLVMPass : Pass<"generic-host-to-llvm", "ModuleOp"> { "::mlir::arith::ArithDialect", "::mlir::cf::ControlFlowDialect", "::mlir::complex::ComplexDialect", - "::mlir::deallocation::DeallocationDialect", "::mlir::func::FuncDialect", "::mlir::math::MathDialect", "::mlir::memref::MemRefDialect", diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 5125dd46256764..ae2a2cd5672ae6 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -70,6 +70,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -254,6 +255,10 @@ void LoadMLIRDialects(mlir::MLIRContext& context) { xla::runtime::RuntimeDialect>(); mlir::registerBuiltinDialectTranslation(context); mlir::registerLLVMDialectTranslation(context); + + mlir::DialectRegistry registry; + mlir::memref::registerAllocationOpInterfaceExternalModels(registry); + context.appendDialectRegistry(registry); } xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( @@ -266,9 +271,6 @@ xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_n_dim(), xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_k_dim()}; } - options.experimental_deallocation = - xla::GetDebugOptionsFromFlags() - .xla_cpu_enable_experimental_deallocation(); options.enable_avx2 = [&] { // Derive whether this is an x86 CPU with AVX2 enabled. if (!target_triple.isX86()) return false; @@ -279,7 +281,6 @@ xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( options.cpu_name = cpu_name; if (xla::GetDebugOptionsFromFlags().xla_cpu_enable_mlir_fusion_outlining()) { options.enable_fusion_outlining = true; - options.experimental_deallocation = true; } return options; } diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index c6e7792e1cfe0a..caff8ee13e90f7 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project @@ -48,7 +49,6 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/runtime/transforms/compiler.h" -#include "xla/mlir_hlo/deallocation/transforms/passes.h" #include "xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" @@ -222,16 +222,9 @@ static Status CreateHloXlaPipeline( // bufferizing anything. pm.addPass(mlir::createCanonicalizerPass()); - if (options.experimental_deallocation) { - // Experimental deallocation needs input IR without any buffer reuse to - // work optimally. This pass ensures that's the case. - pm.addNestedPass(mlir::deallocation::createSplitAllocTensorsPass()); - } - if (options.sparse_bufferization) { // Convert Sparse tensors. - AddSparsificationPasses(pm, options.experimental_deallocation, - options.xla_cpu_sparse_cuda_threads); + AddSparsificationPasses(pm, false, options.xla_cpu_sparse_cuda_threads); } else { pm.addPass(mlir::hlo::createOneShotBufferizePass()); } @@ -258,28 +251,14 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass( out_params_options)); - if (options.experimental_deallocation) { - pm.addNestedPass( - mlir::deallocation::createXlaBufferArgRewritePass()); - pm.addPass(mlir::deallocation::createDeallocatePass()); - pm.addNestedPass( - mlir::deallocation::createDeallocationSimplificationPass()); - // Remove SCF iter args that became redundant after simplification. - pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::deallocation::createBufferReusePass()); - pm.addNestedPass( - mlir::deallocation::createDeallocationSimplificationPass()); - pm.addNestedPass(mlir::deallocation::createDeallocationToScfPass()); - } else { - pm.addNestedPass( - mlir::bufferization::createPromoteBuffersToStackPass(nullptr)); + pm.addNestedPass( + mlir::bufferization::createPromoteBuffersToStackPass(nullptr)); + pm.addNestedPass( + mlir::bufferization::createBufferDeallocationPass()); + pm.addPass(mlir::createBufferizationToMemRefPass()); + if (options.remove_copies_to_outparams) { pm.addNestedPass( - mlir::bufferization::createBufferDeallocationPass()); - pm.addPass(mlir::createBufferizationToMemRefPass()); - if (options.remove_copies_to_outparams) { - pm.addNestedPass( - xla::cpu::createRemoveCopiesToOutParamsPass()); - } + xla::cpu::createRemoveCopiesToOutParamsPass()); } // Specialize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat, @@ -315,6 +294,7 @@ void RegisterHloXlaRuntimePipelineDialects(mlir::DialectRegistry& dialects) { mlir::arith::registerBufferizableOpInterfaceExternalModels(dialects); mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( dialects); + mlir::memref::registerAllocationOpInterfaceExternalModels(dialects); mlir::linalg::registerBufferizableOpInterfaceExternalModels(dialects); mlir::linalg::registerTilingInterfaceExternalModels(dialects); mlir::mhlo::registerBufferizableOpInterfaceExternalModels(dialects); diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.h b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.h index b77436c16565d2..c2fa5f6f4b4d2e 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.h +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.h @@ -34,7 +34,6 @@ struct HloXlaRuntimePipelineOptions { bool enable_fusion_outlining = true; bool remove_copies_to_outparams = true; bool sparse_bufferization = true; - bool experimental_deallocation = false; bool enable_avx2 = true; // Accelerate sparse computations with CUDA threading. // This is an experimental feature, so off by default. From 47d7b236b796ab7dd0dbd41cba5c7e7dcebe44fa Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 28 Nov 2023 04:28:59 -0800 Subject: [PATCH 138/381] PR #7287: Fix TF32 in gemm calls Imported from GitHub PR https://github.com/openxla/xla/pull/7287 - This condition appears to have been accidentally inverted in https://github.com/tensorflow/tensorflow/commit/14ea9d18 - This was discovered due to large regressions in some models. cc @reedwm Copybara import of the project: -- 28c1798aee55090de9d5df64a3332a99ae3322b2 by Ben Barsdell : Fix TF32 in gemm calls - This condition appears to have been accidentally inverted in https://github.com/tensorflow/tensorflow/commit/14ea9d18 - This was discovered due to large regressions in some models. Merging this change closes #7287 PiperOrigin-RevId: 585927485 --- third_party/xla/xla/stream_executor/cuda/cuda_blas.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index 843b3470977b53..a03f609398c1fd 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -612,7 +612,7 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, #else if (dtype == blas::DataType::kFloat) { math_type = CUBLAS_TF32_TENSOR_OP_MATH; - if (numeric_options.allow_tf32) { + if (!numeric_options.allow_tf32) { math_type = CUBLAS_DEFAULT_MATH; } } From 4e413157ed1e54b414b04d25c43c137c7694d0d3 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 28 Nov 2023 04:49:47 -0800 Subject: [PATCH 139/381] [XLA:GPU] Add cache for fusion runtime. PiperOrigin-RevId: 585932142 --- .../gpu/model/gpu_performance_model.cc | 55 ++++++++++++++++++- .../service/gpu/model/gpu_performance_model.h | 21 +++++-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 0f759abd33e60e..7bd454134cebb4 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -303,6 +303,20 @@ std::optional GpuPerformanceModelCache::Get( return std::nullopt; } +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& producer, const HloInstruction& consumer) { + absl::MutexLock lock(&mutex_); + + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(producer)); + if (it != fusion_runtime_data_.end()) { + auto jt = it->second.find(HloInstructionAdaptor(consumer)); + if (jt != it->second.end()) { + return jt->second; + } + } + return std::nullopt; +} + void GpuPerformanceModelCache::Set(const HloInstruction& instruction, const EstimateRunTimeData& runtime_data) { absl::MutexLock lock(&mutex_); @@ -310,10 +324,33 @@ void GpuPerformanceModelCache::Set(const HloInstruction& instruction, instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; } +void GpuPerformanceModelCache::Set(const HloInstruction& producer, + const HloInstruction& consumer, + absl::Duration runtime) { + absl::MutexLock lock(&mutex_); + fusion_runtime_data_[HloInstructionAdaptor(producer)] + [HloInstructionAdaptor(consumer)] = runtime; +} + void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { absl::MutexLock lock(&mutex_); + HloInstructionAdaptor adaptor(instruction); - instruction_runtime_data_.erase(HloInstructionAdaptor(instruction)); + // Remove runtime data for the instruction. + instruction_runtime_data_.erase(adaptor); + + // Remove cache for all producer-consumer pairs where the instruction is + // producer. + fusion_runtime_data_.erase(adaptor); + + // Iterate through operands to find all producer-consumer pairs where + // instruction is consumer and remove them from cache. + for (auto* operand : instruction.operands()) { + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(*operand)); + if (it != fusion_runtime_data_.end()) { + it->second.erase(adaptor); + } + } } /*static*/ EstimateRunTimeData @@ -642,6 +679,15 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( kKernelLaunchOverhead * fused_consumers.size(); for (auto [idx, fused_consumer] : llvm::enumerate(fused_consumers)) { VLOG(8) << "Fused consumer: " << fused_consumer->name(); + + if (config.calculate_full_priority && config.gpu_performance_model_cache) { + if (auto fusion_runtime = config.gpu_performance_model_cache->Get( + *producer, *fused_consumer)) { + exec_time_fused += *fusion_runtime; + continue; + } + } + float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); @@ -668,10 +714,15 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( // With `calculate_full_priority`, consumer computation and full read time // is accounted in the priority. if (config.calculate_full_priority) { - exec_time_fused += EstimateRunTimeForFusion( + auto fusion_runtime = EstimateRunTimeForFusion( producer, fused_consumer, producer_runtime, consumer_runtimes[idx], launch_dimensions_fused, utilization_by_this_consumer, cost_analysis, analysis_fused, config); + exec_time_fused += fusion_runtime; + if (config.gpu_performance_model_cache) { + config.gpu_performance_model_cache->Set(*producer, *fused_consumer, + fusion_runtime); + } continue; } diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index 8962a5d6b6d176..9e5cfe49d58c2f 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -59,22 +59,35 @@ struct EstimateRunTimeData { class GpuPerformanceModelCache { public: - // Returns cached runtime data for the instruction. Returns nullopt if there - // is no data in cache. + // Returns cached runtime data for the instruction or producer-consumer pair. + // Returns nullopt if there is no data in cache. std::optional Get(const HloInstruction& instruction); + std::optional Get(const HloInstruction& producer, + const HloInstruction& consumer); - // Sets cache value for the instruction. + // Sets cache value for the instruction or producer-consumer pair. void Set(const HloInstruction& instruction, const EstimateRunTimeData& runtime_data); + void Set(const HloInstruction& producer, const HloInstruction& consumer, + absl::Duration runtime); - // Removes all cache entries for this instruction. + // Removes all cache entries for this instruction. The cache contains entries + // for individual instructions in instruction_runtime_data_ and for + // producer-consumer pairs in fusion_runtime_data_. void Invalidate(const HloInstruction& instruction); private: absl::Mutex mutex_; + // Stores unfused runtime data for individual instructions. absl::flat_hash_map instruction_runtime_data_; + + // Stores fused runtime data for producer-consumer pairs. + absl::flat_hash_map< + HloInstructionAdaptor, + absl::flat_hash_map> + fusion_runtime_data_; }; struct GpuPerformanceModelOptions { From bf02f8e3bf7fc57d33e64b2c448ad8e11305b5aa Mon Sep 17 00:00:00 2001 From: Jackson Stokes Date: Tue, 28 Nov 2023 05:45:04 -0800 Subject: [PATCH 140/381] [XLA:GPU] Fix crash in (Triton) tiling of reductions. Previously, this was avoided as modules with unsupported reduces would crash earlier on subsequent broadcast ops. With the addition of new triton softmax fusion patterns, it's possible that inputs will be reduced without subsequent broadcasts. PiperOrigin-RevId: 585945011 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 8 +++-- .../service/gpu/gemm_rewriter_triton_test.cc | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 0737a0a9ce717a..9413c839b16403 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -951,8 +951,12 @@ FusionContext::GetPropagatedDimOrdersForDimAlteringOp( } const auto* reduce = Cast(&hlo); dst_logical.resize(src_logical.size() + reduce->dimensions().size()); + if (reduce->dimensions().size() != 1) { return FusionDecision("Unsupported reduction."); + } else if (reduce->dimensions().front() != + reduce->operand(0)->shape().rank() - 1) { + return FusionDecision("Only row reductions are supported."); } for (int i = 0; i < dst_logical.size(); ++i) { if (i == reduce->dimensions().front()) { @@ -1713,9 +1717,7 @@ Status FusionContext::PropagateDimensionOrdersToParameters( DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, properties_); - if (std::holds_alternative(result)) { - LOG(FATAL) << std::get(result).Explain(); - } + TF_RET_CHECK(std::holds_alternative(result)); TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); iter_specs[hlo] = DimensionOrderToTensorIterationSpec(dim_orders_.at(hlo)); for (const HloInstruction* operand : hlo->operands()) { diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 9a9c0bef826420..8c17d7c63a9d69 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -796,6 +796,35 @@ ENTRY e { nullptr); } +TEST_F(TritonSoftmaxAnalysisTest, ReduceOfNonRowDimensionIsNotSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[8,4,127]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[4,127]{1,0} reduce(param_0, constant), dimensions={0}, to_apply=add +} + +ENTRY main { + param_0 = f32[8,4,127]{2,1,0} parameter(0) + ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + TEST_F(GemmRewriterTritonTest, HandleDotIfCublasRequiresPadding) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( From ac14d9178dec55afb52bdd5da29ceb90476b9f5a Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 28 Nov 2023 07:53:58 -0800 Subject: [PATCH 141/381] Wrap line to 80 characters --- tensorflow/python/ops/image_ops_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 6e19d47460ce0c..e476217347bed6 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2030,7 +2030,8 @@ def random_brightness(image, max_delta, seed=None): Args: image: An image or images to adjust. - max_delta: float, must be non-negative. This parameter controls the maximum relative change in brightness. + max_delta: float, must be non-negative. This parameter controls the maximum + relative change in brightness. seed: A Python integer. Used to create a random seed. See `tf.compat.v1.set_random_seed` for behavior. From af0eaf0a61de3a42b92a57761d3ef4c55bd02c88 Mon Sep 17 00:00:00 2001 From: Prateek Kumar Date: Tue, 28 Nov 2023 22:23:07 +0530 Subject: [PATCH 142/381] Fix duplicate checkpoint removal in Saver class --- tensorflow/python/training/saver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 92ec7cff402129..fd4243e4d3021e 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1029,9 +1029,10 @@ def _RecordLastCheckpoint(self, latest_save_path): if not self.saver_def.max_to_keep: return # Remove first from list if the same name was used before. - for p in self._last_checkpoints: + for p in self._last_checkpoints[:]: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) + # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) From 6880b6a06f88892c5a006250cb386cbacf3da410 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 08:49:23 -0800 Subject: [PATCH 143/381] Update XNNPACK and cpuinfo version This picks up, among other things, the fix for the invalid memcpy call in https://github.com/google/XNNPACK/commit/07e1a4a461a7ca4777b0dfa8081199174f57b697 The new XNNPACK requires a new cpuinfo, so update that too. PiperOrigin-RevId: 585992920 --- tensorflow/lite/tools/cmake/modules/cpuinfo.cmake | 4 ++-- tensorflow/lite/tools/cmake/modules/xnnpack.cmake | 2 +- tensorflow/workspace2.bzl | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake index 7866627555d030..d72fa2c18c07ca 100644 --- a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake +++ b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake @@ -22,8 +22,8 @@ include(OverridableFetchContent) OverridableFetchContent_Declare( cpuinfo GIT_REPOSITORY https://github.com/pytorch/cpuinfo - # Sync with tensorflow/third_party/cpuinfo/workspace.bzl - GIT_TAG 959002f82d7962a473d8bf301845f2af720e0aa4 + # Sync with tensorflow/workspace2.bzl + GIT_TAG ef634603954d88d2643d5809011288b890ac126e GIT_PROGRESS TRUE SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo" ) diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index a6b36451cb819b..436be3901c4865 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG c7e7cde37615a81a529c326aa278bfab4cd6fe5a + GIT_TAG 0cbbe74a16e6ca11acf8484ccac85f620336dea4 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 0cc7b5e1c5aae2..af03c5a8c5e6a2 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -150,9 +150,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "88e0158aff1e1498e34dfcaf08d948a73a3246a04fe96e548da71f6b9245a009", - strip_prefix = "XNNPACK-c7e7cde37615a81a529c326aa278bfab4cd6fe5a", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/c7e7cde37615a81a529c326aa278bfab4cd6fe5a.zip"), + sha256 = "ca829b6486d7dcc0a63eae9d5d5be21dcb542e6601af4cada17b9d5f7d5fafb7", + strip_prefix = "XNNPACK-0cbbe74a16e6ca11acf8484ccac85f620336dea4", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/0cbbe74a16e6ca11acf8484ccac85f620336dea4.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -172,9 +172,9 @@ def _tf_repositories(): tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-959002f82d7962a473d8bf301845f2af720e0aa4", - sha256 = "a0f53ccfb477c57753c595df02bf79ed67bf092fd9a5c61ec5b8992b81bc1e65", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip"), + strip_prefix = "cpuinfo-ef634603954d88d2643d5809011288b890ac126e", + sha256 = "e07512a11e1c71687359a133f49d60583d7465b737fe5dbe11f461c9aaa72a2b", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/ef634603954d88d2643d5809011288b890ac126e.zip"), ) tf_http_archive( From 017e96a3237ac33d9e11aeb52ce73bbe3fdef41e Mon Sep 17 00:00:00 2001 From: Chao Date: Tue, 28 Nov 2023 09:01:00 -0800 Subject: [PATCH 144/381] PR #7323: [ROCm] Restore private visibility of stream executor private targets Imported from GitHub PR https://github.com/openxla/xla/pull/7323 fixed rocm build due to https://github.com/openxla/xla/commit/33fc605a8a118368eaf8f748c8e90eeb18f2e0d3 @xla-rotation Copybara import of the project: -- ad859aa6fa0d44e2a7609eaee6bedbcd4d3968da by Chao Chen : remove command_buffer and kernel links in rocm build Merging this change closes #7323 PiperOrigin-RevId: 585997544 --- third_party/xla/xla/stream_executor/rocm/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 4a6487b9d47947..03fbbe60445eb3 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -107,8 +107,6 @@ cc_library( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", "//xla/stream_executor", - "//xla/stream_executor:command_buffer", - "//xla/stream_executor:kernel", "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_activation_header", "//xla/stream_executor/gpu:gpu_event", From 051bde107908fde1d429415b40445110b3b4acd3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 28 Nov 2023 09:03:49 -0800 Subject: [PATCH 145/381] [XLA:CPU] Refactor local collectives into a separate file behind an interface. Reimplement local collectives to utilize thread-parallelism, rather than having one thread do all the work. They are simpler this way! PiperOrigin-RevId: 585998597 --- .../xla/xla/service/collective_ops_utils.h | 88 +-- third_party/xla/xla/service/cpu/BUILD | 52 +- .../xla/service/cpu/collectives_interface.h | 81 +++ .../xla/xla/service/cpu/cpu_runtime.cc | 627 ++++-------------- .../xla/service/cpu/in_process_collectives.cc | 439 ++++++++++++ .../xla/service/cpu/in_process_collectives.h | 78 +++ 6 files changed, 772 insertions(+), 593 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/collectives_interface.h create mode 100644 third_party/xla/xla/service/cpu/in_process_collectives.cc create mode 100644 third_party/xla/xla/service/cpu/in_process_collectives.h diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index 6917e34adc2225..55af5f51f113b4 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -17,7 +17,9 @@ limitations under the License. #define XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_ #include +#include #include +#include #include #include @@ -265,59 +267,17 @@ void WaitAndLogIfStuck(tsl::BlockingCounter* counter, const DescFn& desc_fn) { // Participant data for each rendezvous. struct ParticipantData { - explicit ParticipantData(const RendezvousKey& rendezvous_key) - : rendezvous_key(rendezvous_key) {} + ParticipantData(const RendezvousKey& rendezvous_key, int local_rank) + : rendezvous_key(rendezvous_key), local_rank(local_rank) {} virtual ~ParticipantData() {} RendezvousKey rendezvous_key; + int local_rank; // Which of the local participants is this? virtual std::string ToString() const = 0; }; -// Encapsulates parameters to Rendezvous::SubmitParticipant. -struct AllReduceParticipantData : ParticipantData { - AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, - int64_t device_ordinal_p, se::Stream* stream_p) - : ParticipantData(rendezvous_key_p), - device_ordinal(device_ordinal_p), - stream(stream_p) {} - - // TODO(b/125951860): We should vet that we're buffer allocating such that - // source_buffer == destination_buffer if that avoids a NCCL copy (will depend - // on how well the NCCL in-place implementation performs vs the out-of-place - // implementation). - struct Buffer { - int64_t element_count; - se::DeviceMemoryBase source_data; - se::DeviceMemoryBase destination_data; - PrimitiveType primitive_type; - }; - int64_t device_ordinal; - se::Stream* stream; - std::vector buffers; - - ReductionKind reduction_kind; - - // For each local all-reduce participant a (global ID, local device ordinal) - // pair for the participant. Participants are in no particular order. - std::vector> local_devices; - - std::string ToString() const override { - std::vector buffer_strs; - buffer_strs.reserve(buffers.size()); - for (const Buffer& buffer : buffers) { - buffer_strs.push_back( - absl::StrFormat("{element_count=%d}", buffer.element_count)); - } - return absl::StrFormat( - "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, " - "device_ordinal=%d, stream=%p}", - absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(), - device_ordinal, stream); - } -}; - // The set of threads that want to do a collective op together all pick the same // Rendezvous object out of the global cache and call SubmitParticipant. // @@ -334,7 +294,8 @@ template RunCollectiveOp(const I& participant) = 0; - // Initialize the rendezvous by the first ("primary") thread which reaches the - // barrier. Returns whether this thread is primary. - bool InitializationBarrier() { - absl::MutexLock lock(&mu_); - if (!initialized_) { - initialized_ = true; - return true; - } - return false; - } + // Adding participants_ requires holding mu_. + // Not annotated with ABSL_GUARDED_BY(mu_) because we do not require the lock + // to be held during CollectiveOp(), since at that point all the data is known + // to be present due to the global barrier. + std::vector> participants_; + private: absl::Mutex mu_; - bool initialized_ ABSL_GUARDED_BY(mu_) = false; - - std::vector participants_ ABSL_GUARDED_BY(mu_); - - private: // Runs the all-reduce on the given thread. If successful, returns // - a handle to the clique that was used, so that the caller may keep the // clique alive if it chooses. @@ -396,18 +348,8 @@ class Rendezvous { SubmitParticipant(const I& participant) { { absl::MutexLock lock(&mu_); - CHECK(!initialized_); - - // Spot check for consistent replica counts among submitting threads. - if (!participants_.empty() && - participants_.back().rendezvous_key != participant.rendezvous_key) { - return InvalidArgument( - "Mismatch among all-reduce participants. Expected same " - "replica-count, element-count, and rendezvous-key but were %s and " - "%s", - participants_.back().ToString(), participant.ToString()); - } - participants_.push_back(participant); + CHECK(!participants_[participant.local_rank].has_value()); + participants_[participant.local_rank] = participant; } // Wait for all participants to arrive. diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 8d05c8ce160edb..ef51b94dcd80ad 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -897,28 +897,27 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":collectives_interface", + ":in_process_collectives", "//xla:executable_run_options", - "//xla:refcounting_hash_map", "//xla:shape_util", "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", + "//xla/service:global_device_id", "//xla/service:hlo_parser", - "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -1685,3 +1684,42 @@ cc_library( "//xla/service:symbol_repository", ], ) + +cc_library( + name = "collectives_interface", + hdrs = ["collectives_interface.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "in_process_collectives", + srcs = ["in_process_collectives.cc"], + hdrs = ["in_process_collectives.h"], + visibility = ["//visibility:public"], + deps = [ + ":collectives_interface", + "//xla:refcounting_hash_map", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h new file mode 100644 index 00000000000000..bd518db3a780bc --- /dev/null +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -0,0 +1,81 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ +#define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class CollectivesCommunicator { + public: + virtual ~CollectivesCommunicator() = default; + + // Performs an all-reduce. + virtual absl::Status AllReduce(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, + size_t num_elements, const void* input_buffer, + void* output_buffer, + absl::Duration timeout) = 0; + + // Performs a collective permute. + // Arguments: + // source_rank: the rank from which this rank should receive its data. + // Optional; if absent, then the output is filled with zeros. + // target_rank: the ranks to which this rank should send its data. + virtual absl::Status CollectivePermute(const RendezvousKey& key, + size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, + void* output_buffer, + absl::Duration timeout) = 0; + + // Performs an all-to-all. + // The all-to-all chunks are passed separately and do not have to be + // contiguous in memory. + virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffer, + absl::Span output_buffer, + absl::Duration timeout) = 0; +}; + +class CollectivesInterface { + public: + virtual ~CollectivesInterface() = default; + + // Builds a context for a collective group. + // Args: + // devices: the devices participating in this collective. + // rank: the rank of this process. + virtual absl::StatusOr> + GetCommunicator(absl::Span devices, int rank) = 0; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index ee0208644f9b4a..a4090d67fc8d9b 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -15,37 +15,38 @@ limitations under the License. #include "xla/service/cpu/cpu_runtime.h" -#include #include -#include +#include #include -#include -#include -#include +#include #include #include #include -#include +#include #include #include -#include "absl/base/dynamic_annotations.h" +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/layout_util.h" -#include "xla/primitive_util.h" -#include "xla/refcounting_hash_map.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/in_process_collectives.h" #include "xla/service/cpu/xfeed_manager.h" +#include "xla/service/global_device_id.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/profiler/lib/traceme.h" @@ -152,69 +153,6 @@ extern const char* const kOneDnnMatMulSymbolName = namespace { -struct CollectivePermuteParticipantData : ParticipantData { - CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, - int64_t device_ordinal_p, - se::Stream* stream_p) - : ParticipantData(rendezvous_key_p), - device_ordinal(device_ordinal_p), - stream(stream_p) {} - - int64_t device_ordinal; - se::Stream* stream; - int replica_id; - se::DeviceMemoryBase source_data; - se::DeviceMemoryBase destination_data; - int64_t byte_size; - std::vector replica_ids_to_copy_to; - - std::string ToString() const override { - return absl::StrFormat( - "CollectivePermuteParticipantData{replica_id=%d, " - "source_data=%p, destination_data=%p, byte_size=%d, " - "replica_ids_to_copy_to=[%s], device_ordinal=%d, stream=%p}", - replica_id, source_data.opaque(), destination_data.opaque(), byte_size, - absl::StrJoin(replica_ids_to_copy_to, ", "), device_ordinal, stream); - } -}; - -struct AllToAllParticipantData : ParticipantData { - AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, - int64_t device_ordinal_p, se::Stream* stream_p) - : ParticipantData(rendezvous_key_p), - device_ordinal(device_ordinal_p), - stream(stream_p) {} - - int64_t device_ordinal; - se::Stream* stream; - std::vector source_buffers; - std::vector destination_buffers; - GlobalDeviceId device_id; - - // Replica ids participating in AllToAll, concatenation happens in the order - // of appearance. - std::vector devices_to_copy_to; - - std::string ToString() const override { - auto addr_formatter = [](std::string* out, - const se::DeviceMemoryBase& mem) { - absl::StrAppend(out, absl::StrFormat("%p", mem.opaque())); - }; - auto device_formatter = [](std::string* out, const GlobalDeviceId& device) { - absl::StrAppend(out, device.value()); - }; - return absl::StrFormat( - "AllToAllParticipantData{replica_id=%d, " - "replica_ids_to_copy_to=[%s], source_buffers=[%s], " - "destination_buffers=[%s], device_ordinal=%d, stream=%p}", - device_id.value(), - absl::StrJoin(devices_to_copy_to, ", ", device_formatter), - absl::StrJoin(source_buffers, ", ", addr_formatter), - absl::StrJoin(destination_buffers, ", ", addr_formatter), - device_ordinal, stream); - } -}; - // Inverses the encoding of a Shape protobuf into an LLVM global variable. StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, int32_t size_bytes) { @@ -250,343 +188,6 @@ int GetDeviceOrdinal(const ExecutableRunOptions* run_options) { return run_options->stream()->parent()->device_ordinal(); } -class CpuAllToAllRendezvous - : public Rendezvous { - public: - explicit CpuAllToAllRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - StatusOr RunCollectiveOp( - const AllToAllParticipantData& /*participant*/) override { - bool is_primary = InitializationBarrier(); - - if (is_primary) { - absl::MutexLock lock(&mu_); - - CHECK(!participants_.empty()); - CHECK(!participants_[0].source_buffers.empty()); - int expected_buffer_size = participants_[0].source_buffers[0].size(); - - // Device id -> position in participants_. - absl::flat_hash_map device_map; - - for (int pos = 0; pos < participants_.size(); pos++) { - const AllToAllParticipantData& p = participants_[pos]; - CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size()); - CHECK_EQ(p.source_buffers.size(), participants_.size()); - for (int i = 0; i < p.source_buffers.size(); i++) { - CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size); - CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size); - } - device_map[p.device_id] = pos; - } - - const std::vector& devices_to_copy_to = - participants_[0].devices_to_copy_to; - - // Device id -> rank - absl::flat_hash_map device_ranks; - for (int rank = 0; rank < devices_to_copy_to.size(); ++rank) { - auto device_id = devices_to_copy_to[rank]; - device_ranks[device_id] = rank; - } - - for (const AllToAllParticipantData& sender : participants_) { - VLOG(3) << "Processing AllToAll participant: " << sender.ToString(); - - int rank = FindOrDie(device_ranks, sender.device_id); - - for (int i = 0; i < participants_.size(); ++i) { - auto device_id = devices_to_copy_to[i]; - int participant_num = FindOrDie(device_map, device_id); - AllToAllParticipantData& receiver = participants_[participant_num]; - - std::memcpy(receiver.destination_buffers[rank].opaque(), - sender.source_buffers[i].opaque(), expected_buffer_size); - } - } - } - return nullptr; - } -}; - -class CpuCollectivePermuteRendezvous - : public Rendezvous { - public: - explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - StatusOr RunCollectiveOp( - const CollectivePermuteParticipantData& /*participant*/) override { - bool primary = InitializationBarrier(); - - // Perform all copies from the primary thread. - if (primary) { - absl::MutexLock lock(&mu_); - - std::map replica_idx_to_participant_idx; - for (int p_idx = 0; p_idx < participants_.size(); p_idx++) { - replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx; - } - for (auto& p : participants_) { - for (int dest_replica : p.replica_ids_to_copy_to) { - auto& dest_p = participants_[FindOrDie(replica_idx_to_participant_idx, - dest_replica)]; - std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(), - p.byte_size); - - // Each replica may be copied into only once. - replica_idx_to_participant_idx.erase(dest_replica); - } - } - - // Zero out untouched participants. - for (auto& replica_p : replica_idx_to_participant_idx) { - auto& p = participants_[replica_p.second]; - std::memset(p.destination_data.opaque(), 0, p.byte_size); - } - } - return nullptr; - } -}; - -class CpuAllReduceRendezvous - : public Rendezvous { - public: - explicit CpuAllReduceRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - StatusOr RunCollectiveOp( - const AllReduceParticipantData& participant) override { - PrimitiveType datatype = participant.buffers.front().primitive_type; - bool primary = InitializationBarrier(); - - if (primary) { - switch (datatype) { - case S8: - DoAllReduce(participant); - break; - case PRED: - case U8: - DoAllReduce(participant); - break; - case S16: - DoAllReduce(participant); - break; - case U16: - DoAllReduce(participant); - break; - case S32: - DoAllReduce(participant); - break; - case U32: - DoAllReduce(participant); - break; - case S64: - DoAllReduce(participant); - break; - case U64: - DoAllReduce(participant); - break; - case F16: - DoAllReduce(participant); - break; - case F32: - DoAllReduce(participant); - break; - case F64: - DoAllReduce(participant); - break; - case C64: - DoAllReduce(participant); - break; - case C128: - DoAllReduce(participant); - break; - default: - LOG(FATAL) << "Unexpected datatype;"; - } - } - return nullptr; - } - - private: - template - void DoAllReduce(AllReduceParticipantData participant) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - absl::MutexLock lock(&mu_); - CHECK(!participants_.empty()); - ReductionKind reduction_kind = participant.reduction_kind; - for (const auto& p : participants_) { - CHECK(p.reduction_kind == reduction_kind); - } - int num_participants = participants_.size(); - - // participant_idx -> buffer_idx -> buffer. - std::vector>> input_buffers; - std::vector>> output_buffers; - input_buffers.reserve(num_participants); - output_buffers.reserve(num_participants); - const AllReduceParticipantData& first_participant = participants_.front(); - - int buffers_per_participant = first_participant.buffers.size(); - for (AllReduceParticipantData& p : participants_) { - CHECK_EQ(p.buffers.size(), buffers_per_participant); - - input_buffers.emplace_back(); - output_buffers.emplace_back(); - std::vector>& participant_input_buffers = - input_buffers.back(); - std::vector>& participant_output_buffers = - output_buffers.back(); - participant_input_buffers.reserve(p.buffers.size()); - participant_output_buffers.reserve(p.buffers.size()); - - for (int buffer_idx = 0; buffer_idx < buffers_per_participant; - buffer_idx++) { - auto& participant_buffer = p.buffers[buffer_idx]; - participant_input_buffers.emplace_back( - static_cast(participant_buffer.source_data.opaque()), - participant_buffer.element_count); - participant_output_buffers.emplace_back( - static_cast(participant_buffer.destination_data.opaque()), - participant_buffer.element_count); - CHECK_EQ(participant_buffer.element_count, - first_participant.buffers[buffer_idx].element_count); - } - } - - for (int buffer_idx = 0; buffer_idx < buffers_per_participant; - buffer_idx++) { - int element_count = first_participant.buffers[buffer_idx].element_count; - for (int idx = 0; idx < element_count; idx++) { - T out = GetInitialValue(reduction_kind); - for (int participant_idx = 0; participant_idx < participants_.size(); - participant_idx++) { - out = PerformReductionStep( - reduction_kind, out, - input_buffers[participant_idx][buffer_idx][idx]); - } - for (int participant_idx = 0; participant_idx < participants_.size(); - participant_idx++) { - output_buffers[participant_idx][buffer_idx][idx] = out; - } - } - } - } - - template - T GetInitialValue(ReductionKind reduction_kind) { - switch (reduction_kind) { - case ReductionKind::SUM: - return static_cast(0); - case ReductionKind::PRODUCT: - return static_cast(1); - case ReductionKind::MIN: - return std::numeric_limits::max(); - case ReductionKind::MAX: - return std::numeric_limits::min(); - } - } - - template - struct SumProductTypeForReductionStep { - using type = T; - }; - - template - struct SumProductTypeForReductionStep { - using type = typename std::make_unsigned_t; - }; - - template >::type* = nullptr> - T PerformReductionStep(ReductionKind reduction_kind, T a, T b) { - using SumProductType = typename SumProductTypeForReductionStep< - T, std::is_integral::value && std::is_signed::value>::type; - switch (reduction_kind) { - case ReductionKind::SUM: - return absl::bit_cast( - static_cast(absl::bit_cast(a) + - absl::bit_cast(b))); - case ReductionKind::PRODUCT: - return absl::bit_cast( - static_cast(absl::bit_cast(a) * - absl::bit_cast(b))); - case ReductionKind::MIN: - return std::min(a, b); - case ReductionKind::MAX: - return std::max(a, b); - } - } - - template >::type* = nullptr> - T PerformReductionStep(ReductionKind reduction_kind, T a, T b) { - using SumProductType = typename SumProductTypeForReductionStep< - T, std::is_integral::value && std::is_signed::value>::type; - switch (reduction_kind) { - case ReductionKind::SUM: - return absl::bit_cast( - static_cast(absl::bit_cast(a) + - absl::bit_cast(b))); - case ReductionKind::PRODUCT: - return absl::bit_cast( - static_cast(absl::bit_cast(a) * - absl::bit_cast(b))); - case ReductionKind::MIN: - case ReductionKind::MAX: - LOG(FATAL) << "min/max not valid for complex types"; - } - } -}; - -RefcountingHashMap& -GlobalAllReduceRendezvousMap() { - static auto& m = - *new RefcountingHashMap; - return m; -} - -RefcountingHashMap& -GlobalCollectivePermuteRendezvousMap() { - static auto& m = - *new RefcountingHashMap; - return m; -} - -RefcountingHashMap& -GlobalAllToAllRendezvousMap() { - static auto& m = - *new RefcountingHashMap; - return m; -} - -RendezvousKey GetRendezvousKey(const ExecutableRunOptions* run_options, - std::vector group, - int32_t channel_id_present, - std::optional use_global_device_ids, - int64_t op_id) { - const DeviceAssignment& device_assignment = *run_options->device_assignment(); - int device_ordinal = GetDeviceOrdinal(run_options); - RendezvousKey::CollectiveOpKind op_kind = channel_id_present - ? RendezvousKey::kCrossModule - : RendezvousKey::kCrossReplica; - std::vector participating_devices = - GetParticipatingDevices(GlobalDeviceId(device_ordinal), device_assignment, - group, - GetCollectiveOpGroupMode(channel_id_present != 0, - use_global_device_ids) - .value()) - .value(); - int num_local_participants = participating_devices.size(); - return RendezvousKey{run_options->run_id(), std::move(participating_devices), - num_local_participants, op_kind, op_id}; -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void* AcquireInfeedBufferForDequeueImpl(const ExecutableRunOptions* run_options, int32_t buffer_length, @@ -664,6 +265,55 @@ void ReleaseOutfeedBufferAfterPopulationImpl( std::move(shape)); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void ReplicaIdImpl(const ExecutableRunOptions* run_options, + void* output_buffer) { + int device_ordinal = GetDeviceOrdinal(run_options); + int32_t replica_id = run_options->device_assignment() + ->ReplicaIdForDevice(GlobalDeviceId(device_ordinal)) + .value(); + std::memcpy(output_buffer, &replica_id, 4); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void PartitionIdImpl(const ExecutableRunOptions* run_options, + void* output_buffer) { + int device_ordinal = GetDeviceOrdinal(run_options); + const DeviceAssignment::LogicalID logical_id = + run_options->device_assignment() + ->LogicalIdForDevice(GlobalDeviceId(device_ordinal)) + .value(); + std::memcpy(output_buffer, &logical_id.computation_id, 4); +} + +RendezvousKey GetRendezvousKey(const ExecutableRunOptions* run_options, + GlobalDeviceId device, + std::vector group, + int32_t channel_id_present, + std::optional use_global_device_ids, + int64_t op_id) { + const DeviceAssignment& device_assignment = *run_options->device_assignment(); + RendezvousKey::CollectiveOpKind op_kind = channel_id_present + ? RendezvousKey::kCrossModule + : RendezvousKey::kCrossReplica; + std::vector participating_devices = + GetParticipatingDevices(GlobalDeviceId(device), device_assignment, group, + GetCollectiveOpGroupMode(channel_id_present != 0, + use_global_device_ids) + .value()) + .value(); + int num_local_participants = participating_devices.size(); + return RendezvousKey{run_options->run_id(), std::move(participating_devices), + num_local_participants, op_kind, op_id}; +} + +CollectivesInterface* GetInProcessCollectivesImpl() { + static InProcessCollectives* c = new InProcessCollectives(); + return c; +} + +absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t channel_id_present, int64_t op_id, @@ -671,41 +321,28 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size, void** source_buffers, void** destination_buffers) { - int device_ordinal = GetDeviceOrdinal(run_options); - absl::string_view replica_groups_serialized( + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); RendezvousKey rendezvous_key = - GetRendezvousKey(run_options, group, channel_id_present, + GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - AllToAllParticipantData participant(rendezvous_key, device_ordinal, - run_options->stream()); - participant.device_id = GlobalDeviceId(device_ordinal); - participant.devices_to_copy_to = - GetParticipatingDevices( - GlobalDeviceId(device_ordinal), *run_options->device_assignment(), - group, - GetCollectiveOpGroupMode(channel_id_present != 0, - /*use_global_device_ids=*/std::nullopt) - .value()) - .value(); - for (int i = 0; i < num_buffers; i++) { - participant.source_buffers.emplace_back(source_buffers[i], buffer_size); - participant.destination_buffers.emplace_back(destination_buffers[i], - buffer_size); - } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant( - [&] { - return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent( - rendezvous_key, make_cpu_rendezvous); - }, - participant) - .status()); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->AllToAll( + rendezvous_key, buffer_size, + absl::Span(source_buffers, num_buffers), + absl::Span(destination_buffers, num_buffers), + DefaultCollectiveTimeout())); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -716,13 +353,14 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, int32_t reduction_kind, const void* shape_ptr, int32_t shape_length, int32_t num_buffers, void** input_buffers, void** output_buffers) { - int device_ordinal = GetDeviceOrdinal(run_options); - absl::string_view replica_groups_serialized( + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); - RendezvousKey rendezvous_key = GetRendezvousKey( - run_options, group, channel_id_present, use_global_device_ids, op_id); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + use_global_device_ids, op_id); auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; @@ -732,53 +370,21 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - AllReduceParticipantData participant(rendezvous_key, device_ordinal, - run_options->stream()); - participant.reduction_kind = static_cast(reduction_kind); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); for (int i = 0; i < num_buffers; i++) { Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i); - AllReduceParticipantData::Buffer buffer; - buffer.element_count = ShapeUtil::ElementsIn(subshape); - buffer.primitive_type = subshape.element_type(); - buffer.source_data = - se::DeviceMemoryBase(input_buffers[i], ShapeUtil::ByteSizeOf(subshape)); - buffer.destination_data = se::DeviceMemoryBase( - output_buffers[i], ShapeUtil::ByteSizeOf(subshape)); - participant.buffers.push_back(buffer); + TF_CHECK_OK(communicator->AllReduce( + rendezvous_key, static_cast(reduction_kind), + subshape.element_type(), ShapeUtil::ElementsIn(subshape), + input_buffers[i], output_buffers[i], DefaultCollectiveTimeout())); } - - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - - TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant( - [&] { - return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent( - rendezvous_key, make_cpu_rendezvous); - }, - participant) - .status()); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY -void ReplicaIdImpl(const ExecutableRunOptions* run_options, - void* output_buffer) { - int device_ordinal = GetDeviceOrdinal(run_options); - int32_t replica_id = run_options->device_assignment() - ->ReplicaIdForDevice(GlobalDeviceId(device_ordinal)) - .value(); - std::memcpy(output_buffer, &replica_id, 4); -} - -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY -void PartitionIdImpl(const ExecutableRunOptions* run_options, - void* output_buffer) { - int device_ordinal = GetDeviceOrdinal(run_options); - const DeviceAssignment::LogicalID logical_id = - run_options->device_assignment() - ->LogicalIdForDevice(GlobalDeviceId(device_ordinal)) - .value(); - std::memcpy(output_buffer, &logical_id.computation_id, 4); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -787,17 +393,16 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, int32_t byte_size, void* input_buffer, void* output_buffer, const void* source_target_pairs, int32_t source_target_pairs_size) { - int device_ordinal = GetDeviceOrdinal(run_options); - absl::string_view source_target_pairs_serialized( + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view source_target_pairs_serialized( static_cast(source_target_pairs), source_target_pairs_size); auto pairs = absl::StrSplit(source_target_pairs_serialized, ','); const DeviceAssignment::LogicalID logical_id = - run_options->device_assignment() - ->LogicalIdForDevice(GlobalDeviceId(device_ordinal)) - .value(); + run_options->device_assignment()->LogicalIdForDevice(device).value(); int32_t logical_device_id = channel_id_present ? logical_id.computation_id : logical_id.replica_id; + std::optional source_replica_id; std::vector copy_to; for (auto& p : pairs) { std::vector mapping = absl::StrSplit(p, '='); @@ -807,30 +412,26 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, if (from == logical_device_id) { copy_to.push_back(to); } + if (to == logical_device_id) { + CHECK(!source_replica_id.has_value()); + source_replica_id = from; + } } RendezvousKey rendezvous_key = - GetRendezvousKey(run_options, {}, channel_id_present, + GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal, - run_options->stream()); - participant.replica_id = logical_device_id; - participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size); - participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size); - participant.replica_ids_to_copy_to = copy_to; - participant.byte_size = byte_size; - - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - TF_CHECK_OK( - CpuCollectivePermuteRendezvous::SubmitParticipant( - [&] { - return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent( - rendezvous_key, make_cpu_rendezvous); - }, - participant) - .status()); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->CollectivePermute( + rendezvous_key, byte_size, source_replica_id, copy_to, input_buffer, + output_buffer, DefaultCollectiveTimeout())); } } // namespace } // namespace runtime diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc new file mode 100644 index 00000000000000..0de37e2f6017e4 --- /dev/null +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -0,0 +1,439 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/in_process_collectives.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/primitive_util.h" +#include "xla/refcounting_hash_map.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace cpu { +namespace runtime { +namespace { + +void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { + absl::StrAppend(out, device.value()); +} + +struct AllReduceParticipantData : ParticipantData { + explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + int64_t element_count; + const void* source_data; + void* destination_data; + PrimitiveType primitive_type; + + ReductionKind reduction_kind; + + std::string ToString() const override { + return absl::StrFormat( + "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " + "rendezvous_key=%s}", + local_rank, element_count, PrimitiveType_Name(primitive_type), + rendezvous_key.ToString()); + } +}; + +template +T GetInitialValue(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return static_cast(0); + case ReductionKind::PRODUCT: + return static_cast(1); + case ReductionKind::MIN: + return std::numeric_limits::max(); + case ReductionKind::MAX: + return std::numeric_limits::min(); + } +} + +template +void Reduce(absl::Span acc, absl::Span const> inputs) { + // TODO(penporn): make sure this gets vectorized. + if constexpr (reduction_kind == ReductionKind::SUM) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] += inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] *= inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::MIN) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::min(acc[i], inputs[j][i]); + } + } + } else if constexpr (reduction_kind == ReductionKind::MAX) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::max(acc[i], inputs[j][i]); + } + } + } else { + static_assert(false, "Unsupported reduction kind"); + } +} + +class CpuAllReduceRendezvous + : public Rendezvous { + public: + explicit CpuAllReduceRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllReduceParticipantData& me) override { + VLOG(3) << me.ToString(); + int64_t world_size = participants_.size(); + // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th + // chunk of the output. + int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); + + int64_t start_elem = me.local_rank * chunk_elems; + int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); + chunk_elems = std::max(int64_t{0}, end_elem - start_elem); + if (chunk_elems == 0) { + return nullptr; + } + + switch (me.primitive_type) { + case S8: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + break; + default: + return absl::UnimplementedError("Unexpected datatype"); + } + + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + for (const auto& p : participants_) { + if (p->local_rank != me.local_rank) { + std::memcpy( + reinterpret_cast(p->destination_data) + chunk_offset, + reinterpret_cast(me.destination_data) + chunk_offset, + chunk_bytes); + } + } + return nullptr; + } + + template + absl::Status DoAllReduce(const AllReduceParticipantData& me, + int64_t start_elem, int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(me.reduction_kind); + T* acc = reinterpret_cast(me.destination_data); + for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { + acc[i] = initial_value; + } + + absl::Span out_chunk = absl::MakeSpan( + reinterpret_cast(me.destination_data) + start_elem, num_elems); + std::vector> inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(absl::Span( + reinterpret_cast(p->source_data) + start_elem, num_elems)); + } + switch (me.reduction_kind) { + case ReductionKind::SUM: + Reduce(out_chunk, inputs); + break; + case ReductionKind::PRODUCT: + Reduce(out_chunk, inputs); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + Reduce(out_chunk, inputs); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + Reduce(out_chunk, inputs); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); + } +}; + +struct CollectivePermuteParticipantData : ParticipantData { + CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + const void* source_buffer; + void* destination_buffer; + size_t num_bytes; + + // From which rank is this participant receiving its data? Optional; if + // absent fill with zeros. + std::optional source_rank; + + std::string ToString() const override { + return absl::StrFormat( + "CollectivePermuteParticipantData{rank=%d, " + "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " + "source_replica_id=%d, " + "devices=[%s]}", + local_rank, source_buffer, destination_buffer, num_bytes, + source_rank.value_or(-1), + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); + } +}; + +class CpuCollectivePermuteRendezvous + : public Rendezvous { + public: + explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + + absl::StatusOr RunCollectiveOp( + const CollectivePermuteParticipantData& p) override { + VLOG(3) << p.ToString(); + if (p.source_rank) { + std::memcpy(p.destination_buffer, + participants_[*p.source_rank]->source_buffer, p.num_bytes); + } else { + std::memset(p.destination_buffer, 0, p.num_bytes); + } + return nullptr; + } +}; + +struct AllToAllParticipantData : ParticipantData { + AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + std::vector source_buffers; + std::vector destination_buffers; + size_t chunk_size; + + std::string ToString() const override { + auto addr_formatter = [](std::string* out, const void* mem) { + absl::StrAppend(out, absl::StrFormat("%p", mem)); + }; + return absl::StrFormat( + "AllToAllParticipantData{rank=%d, " + "devices=[%s], source_buffers=[%s], " + "destination_buffers=[%s], chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + absl::StrJoin(source_buffers, ", ", addr_formatter), + absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); + } +}; + +class CpuAllToAllRendezvous + : public Rendezvous { + public: + explicit CpuAllToAllRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const AllToAllParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + for (int i = 0; i < world_size; ++i) { + std::memcpy(participants_[i]->destination_buffers[p.local_rank], + p.source_buffers[i], p.chunk_size); + } + return nullptr; + } +}; + +} // namespace + +struct InProcessCollectivesState { + RefcountingHashMap + all_reduce_rendezvous_map; + RefcountingHashMap + collective_permute_rendezvous_map; + RefcountingHashMap + all_to_all_rendezvous_map; +}; + +InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( + InProcessCollectivesState* state, int rank, int size) + : state_(state), rank_(rank) {} +InProcessCollectivesCommunicator::~InProcessCollectivesCommunicator() = default; + +absl::Status InProcessCollectivesCommunicator::AllReduce( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* const input_buffer, void* const output_buffer, + absl::Duration timeout) { + AllReduceParticipantData participant(key, rank_); + participant.element_count = num_elements; + participant.primitive_type = element_type; + participant.source_data = input_buffer; + participant.destination_data = output_buffer; + participant.reduction_kind = reduction_kind; + + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + + return CpuAllReduceRendezvous::SubmitParticipant( + [&] { + return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCollectivesCommunicator::CollectivePermute( + const RendezvousKey& key, size_t num_bytes, std::optional source_rank, + absl::Span target_ranks, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + CollectivePermuteParticipantData participant(key, rank_); + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + participant.num_bytes = num_bytes; + participant.source_rank = source_rank; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuCollectivePermuteRendezvous::SubmitParticipant( + [&] { + return state_->collective_permute_rendezvous_map + .GetOrCreateIfAbsent(key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCollectivesCommunicator::AllToAll( + const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) { + AllToAllParticipantData participant(key, rank_); + TF_RET_CHECK(input_buffers.size() == output_buffers.size()); + participant.chunk_size = chunk_bytes; + participant.source_buffers.reserve(input_buffers.size()); + participant.destination_buffers.reserve(output_buffers.size()); + for (const void* input_buffer : input_buffers) { + participant.source_buffers.push_back(input_buffer); + } + for (void* output_buffer : output_buffers) { + participant.destination_buffers.push_back(output_buffer); + } + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllToAllRendezvous::SubmitParticipant( + [&] { + return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +InProcessCollectives::InProcessCollectives() + : state_(std::make_unique()) {} +InProcessCollectives::~InProcessCollectives() = default; + +absl::StatusOr> +InProcessCollectives::GetCommunicator(absl::Span devices, + int rank) { + // We don't care about devices here: we share rendezvous state globally. + return std::make_shared(state_.get(), rank, + devices.size()); +} + +} // namespace runtime +} // namespace cpu +} // namespace xla diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h new file mode 100644 index 00000000000000..fb25fd3528d606 --- /dev/null +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -0,0 +1,78 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ +#define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu::runtime { + +struct InProcessCollectivesState; + +class InProcessCollectivesCommunicator : public CollectivesCommunicator { + public: + InProcessCollectivesCommunicator(InProcessCollectivesState* state, int rank, + int size); + ~InProcessCollectivesCommunicator() override; + + absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, + absl::Duration timeout) override; + + private: + InProcessCollectivesState* state_; + int rank_; +}; + +class InProcessCollectives : public CollectivesInterface { + public: + InProcessCollectives(); + ~InProcessCollectives() override; + + // Thread-safe. + absl::StatusOr> GetCommunicator( + absl::Span devices, int rank) override; + + private: + std::unique_ptr state_; +}; + +} // namespace xla::cpu::runtime + +#endif // XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ From ee10e3d87270153265f78c061bcbdfc0a34f0201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 28 Nov 2023 09:27:21 -0800 Subject: [PATCH 146/381] [XLA:GPU][NFC] Split GemmRewriterTriton into 4 parts triton_support - basic Triton support checks triton_tiling_propagation - The code for propagating the tilings in a functional paradigm triton_fusion_analysis - FusionContext and TritonFusionAnalysis gemm_rewriter_triton - GemmRewriterTriton PiperOrigin-RevId: 586006558 --- third_party/xla/xla/service/gpu/BUILD | 103 +- .../xla/service/gpu/gemm_rewriter_triton.cc | 1638 +---------------- .../xla/service/gpu/gemm_rewriter_triton.h | 121 +- .../service/gpu/gemm_rewriter_triton_test.cc | 638 +------ .../xla/xla/service/gpu/ir_emitter_triton.cc | 3 +- .../xla/xla/service/gpu/ir_emitter_triton.h | 2 +- .../ir_emitter_triton_parametrized_test.cc | 2 +- .../xla/service/gpu/ir_emitter_triton_test.cc | 1 - .../service/gpu/softmax_rewriter_triton.cc | 2 +- .../xla/service/gpu/split_k_gemm_rewriter.cc | 4 +- .../service/gpu/split_k_gemm_rewriter_test.cc | 2 +- .../xla/service/gpu/triton_fusion_analysis.cc | 306 +++ .../xla/service/gpu/triton_fusion_analysis.h | 134 ++ .../gpu/triton_fusion_analysis_test.cc | 678 +++++++ .../xla/xla/service/gpu/triton_support.cc | 128 ++ .../xla/xla/service/gpu/triton_support.h | 51 + .../service/gpu/triton_tiling_propagation.cc | 1075 +++++++++++ .../service/gpu/triton_tiling_propagation.h | 241 +++ 18 files changed, 2752 insertions(+), 2377 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/triton_fusion_analysis.cc create mode 100644 third_party/xla/xla/service/gpu/triton_fusion_analysis.h create mode 100644 third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc create mode 100644 third_party/xla/xla/service/gpu/triton_support.cc create mode 100644 third_party/xla/xla/service/gpu/triton_support.h create mode 100644 third_party/xla/xla/service/gpu/triton_tiling_propagation.cc create mode 100644 third_party/xla/xla/service/gpu/triton_tiling_propagation.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index dde94e3aee7bbb..83ec740c3b5861 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -446,12 +446,13 @@ cc_library( hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]), visibility = ["//visibility:public"], deps = [ - ":gemm_rewriter_triton", ":hlo_traversal", ":ir_emission_utils", ":launch_dimensions", ":matmul_utils", ":target_util", + ":triton_fusion_analysis", + ":triton_tiling_propagation", "//xla:autotuning_proto_cc", "//xla:comparison_util", "//xla:literal", @@ -533,7 +534,6 @@ xla_test( tags = ["nomac"], deps = [ ":backend_configs_cc", - ":gemm_rewriter_triton", ":gpu_device_info_for_tests", ":ir_emission_utils", ":ir_emitter_triton", @@ -607,7 +607,7 @@ xla_test( shard_count = 10, tags = ["nomac"], deps = [ - ":gemm_rewriter_triton", + ":triton_support", "//xla:comparison_util", "//xla:error_spec", "//xla:xla_data_proto_cc", @@ -633,7 +633,6 @@ cc_library( ":autotuner_util", ":backend_configs_cc", ":buffer_comparator", - ":gemm_rewriter_triton", ":gemm_rewriter", ":gpu_float_support", ":gpu_fusible", @@ -1276,6 +1275,84 @@ cc_library( ]), ) +cc_library( + name = "triton_support", + srcs = ["triton_support.cc"], + hdrs = ["triton_support.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "triton_tiling_propagation", + srcs = ["triton_tiling_propagation.cc"], + hdrs = ["triton_tiling_propagation.h"], + visibility = ["//visibility:public"], + deps = [ + ":triton_support", + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:instruction_fusion", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "triton_fusion_analysis", + srcs = ["triton_fusion_analysis.cc"], + hdrs = ["triton_fusion_analysis.h"], + visibility = ["//visibility:public"], + deps = [ + ":matmul_utils", + ":triton_tiling_propagation", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:instruction_fusion", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "triton_fusion_analysis_test", + srcs = ["triton_fusion_analysis_test.cc"], + deps = [ + ":gemm_rewriter_triton", + ":triton_fusion_analysis", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "gemm_rewriter_triton", srcs = ["gemm_rewriter_triton.cc"], @@ -1286,27 +1363,24 @@ cc_library( ":cublas_padding_requirements", ":ir_emission_utils", ":matmul_utils", - "//xla:autotuning_proto_cc", - "//xla:permutation_util", + ":triton_fusion_analysis", + ":triton_support", + ":triton_tiling_propagation", "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_utils", @@ -1319,6 +1393,7 @@ xla_cc_test( deps = [ ":cublas_padding_requirements", ":gemm_rewriter_triton", + ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:statusor", "//xla:xla_data_proto_cc", @@ -1343,9 +1418,11 @@ cc_library( hdrs = ["split_k_gemm_rewriter.h"], visibility = ["//visibility:public"], deps = [ - ":gemm_rewriter_triton", ":ir_emission_utils", ":matmul_utils", + ":triton_fusion_analysis", + ":triton_support", + ":triton_tiling_propagation", "//xla:autotuning_proto_cc", "//xla:literal_util", "//xla:shape_util", @@ -1372,9 +1449,9 @@ xla_cc_test( name = "split_k_gemm_rewriter_test", srcs = ["split_k_gemm_rewriter_test.cc"], deps = [ - ":gemm_rewriter_triton", ":matmul_utils", ":split_k_gemm_rewriter", + ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -1404,8 +1481,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", - ":gemm_rewriter_triton", ":ir_emission_utils", + ":triton_support", "//xla:shape_util", "//xla:status", "//xla:status_macros", diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 9413c839b16403..fa20343a435827 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -15,52 +15,42 @@ limitations under the License. #include "xla/service/gpu/gemm_rewriter_triton.h" -#include #include #include -#include #include #include #include #include -#include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/autotuning.pb.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/layout.h" -#include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/instruction_fusion.h" -#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" -#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" @@ -68,1284 +58,11 @@ limitations under the License. namespace xla { namespace gpu { -bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { - VLOG(9) << this->ToString(); - VLOG(9) << other.ToString(); - auto it_this = dim_iteration_specs_.cbegin(); - while (it_this != dim_iteration_specs_.cend()) { - auto it_other = other.dim_iteration_specs_.find(it_this->first); - if (it_other == other.dim_iteration_specs_.cend()) { - return false; - } - if (it_this->second.size() != it_other->second.size()) { - return false; - } - for (int fragment = 0; fragment < it_this->second.size(); ++fragment) { - if (it_this->second[fragment] != it_other->second[fragment]) { - return false; - } - } - ++it_this; - } - return true; -} - -bool IsDistributiveOverAddition(const HloInstruction& hlo) { - // The list is most likely incomplete. - // For example division can be added too but only for operand #0. - if (hlo.opcode() == HloOpcode::kMultiply || - hlo.opcode() == HloOpcode::kNegate || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || - hlo.opcode() == HloOpcode::kTranspose || - hlo.opcode() == HloOpcode::kConvert || - hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kSlice) { - return true; - } - return false; -} - namespace { -FusionDecision IsConversionWorthFusing(const HloInstruction& input, - se::GpuComputeCapability gpu_version) { - // TODO(b/266862494): Can pick up almost any - // convert, but if it's reducing the data volume it should rather be fused - // to the output of the producer kernel. However not all operations support - // output fusion - then it should be fused here anyway! - if (ShapeUtil::ByteSizeOf(input.operand(0)->shape()) > - ShapeUtil::ByteSizeOf(input.shape())) { - return "Narrowing conversion."; - } - return FusionDecision{}; -} - -// Handles numbers of dimensions of an HLO instruction -// projected onto another one. -// Used to calculate cumulative index transformations done by non-elementwise -// instructions between source and target. -class DimensionOrder { - public: - // Softmax fusions have a fixed tiling scheme. These numbers are chosen to - // reflect that reductions in softmax fusions currently happen on the minor- - // most dimension (dimensions_minor(0)) and the rest (1+) is treated as a - // single non-tiled batch dimension. The numbers have to match those the - // emitter uses in the queries to the analysis. - static constexpr int kSoftmaxReductionDimension = 0; - static constexpr int kSoftmaxBatchDimension = 1; - - static DimensionOrder FromDotOperandOrOutput( - const HloInstruction& hlo, const int split_k_dimension_index = -1) { - DimensionOrder dim_order; - dim_order.tensor_fragments_order_.reserve(hlo.shape().rank()); - for (const int i : hlo.shape().layout().minor_to_major()) { - int target_dim_number = i; - if (i == split_k_dimension_index) { - CHECK(!dim_order.tensor_fragments_order_.empty()) - << "The split-K batch dimension has be preceded by the contracting " - "dimension it originates from by construction."; - target_dim_number = - dim_order.tensor_fragments_order_.back().dst_dim_number(); - } - dim_order.dim_fragments_orders_[target_dim_number].push_back( - dim_order.tensor_fragments_order_.size()); - dim_order.tensor_fragments_order_.push_back( - Fragment{target_dim_number, hlo.shape().dimensions(i)}); - } - return dim_order; - } - - static DimensionOrder FromSoftmaxRoot(const HloInstruction& hlo) { - DimensionOrder dim_order; - dim_order.tensor_fragments_order_.reserve(hlo.shape().rank()); - dim_order.dim_fragments_orders_[kSoftmaxReductionDimension].push_back( - dim_order.tensor_fragments_order_.size()); - dim_order.tensor_fragments_order_.push_back( - Fragment{kSoftmaxReductionDimension, hlo.shape().dimensions_minor(0)}); - for (int i = 1; i < hlo.shape().rank(); ++i) { - dim_order.dim_fragments_orders_[kSoftmaxBatchDimension].push_back( - dim_order.tensor_fragments_order_.size()); - dim_order.tensor_fragments_order_.push_back( - Fragment{kSoftmaxBatchDimension, hlo.shape().dimensions_minor(i)}); - } - return dim_order; - } - - // Description of a continuous fragment of one dimension of a tensor. - class Fragment { - public: - explicit Fragment(int dst_dim_number, int64_t size) - : dst_dim_number_(dst_dim_number), - size_(size), - slice_start_(0), - slice_limit_(size) {} - - std::string ToString() const { - return absl::StrCat(dst_dim_number_, ":", size_, ":", slice_start_, "-", - slice_limit_); - } - // Label carrying the dimension number of an defining operation. - int dst_dim_number() const { return dst_dim_number_; } - // Total number of elements in the fragment ignoring slicing. - int64_t full_size() const { return size_; } - // First used element. - int64_t slice_start() const { return slice_start_; } - // Last used element. - int64_t slice_limit() const { return slice_limit_; } - int64_t sliced_size() const { return slice_limit_ - slice_start_; } - bool is_sliced() const { return full_size() != sliced_size(); } - void set_slice(int64_t start, int64_t limit) { - slice_start_ = start; - slice_limit_ = limit; - } - void set_size(int64_t size) { size_ = size; } - - private: - const int dst_dim_number_; - int64_t size_; - int64_t slice_start_; - int64_t slice_limit_; - }; - using Fragments = std::vector; - using FragmentOrders = absl::flat_hash_map>; - - const Fragments& TensorFragmentsOrder() const { - return tensor_fragments_order_; - } - Fragments& TensorFragmentsOrder() { return tensor_fragments_order_; } - - const FragmentOrders& DimFragmentsOrders() const { - return dim_fragments_orders_; - } - FragmentOrders& DimFragmentsOrders() { return dim_fragments_orders_; } - - // Tells that two dimension orders describe the same tensor physical layout. - bool IsPhysicallyEquivalent(const DimensionOrder& other) const; - - std::string ToString() const { - std::string ret = absl::StrJoin(tensor_fragments_order_, " - ", - [](std::string* out, const Fragment& f) { - absl::StrAppend(out, f.ToString(), " "); - }); - absl::StrAppend(&ret, "|"); - for (const auto& [dim, fragments] : dim_fragments_orders_) { - absl::StrAppend(&ret, dim, ":", absl::StrJoin(fragments, ","), " "); - } - return ret; - } - - private: - // Sequence of all fragments of dimensions of tensor's shape - // in layout minor-to-major (physical) order. - Fragments tensor_fragments_order_; - // Iteration orders of fragments of each dimension of the defining operation - // (fragments can be physically unordered and disconnected within - // the shape due to reshapes and transposes). - FragmentOrders dim_fragments_orders_; -}; - -using DimIterationSpec = TensorIterationSpec::DimIterationSpec; -using Fragment = DimensionOrder::Fragment; -using Fragments = DimensionOrder::Fragments; -using FragmentOrders = DimensionOrder::FragmentOrders; -using DimOrderMap = absl::flat_hash_map; -using DimOrderMapOrError = std::variant; - -// This represents an invalid dimension index. -constexpr int kNoDimensionIndex = -1; -struct DotProperties { - const int noncontracting_dimension; - // Index of dot dimension that can be split. - // Currently typically LHS non-contracting one. - const int splittable_dimension_index; -}; -struct SoftmaxProperties { - const int softmax_reduction_dimension; - const int softmax_batch_dimension; -}; -// HeroProperties depend only on the hero op and they don't change as we -// change the fusion. -using HeroProperties = std::variant; - -// A special value for splittable_dimension_major_part_size. -constexpr int kNoSplitRequirement = 1; -struct DotRequirements { - explicit DotRequirements(int64_t splittable_dimension_major_part_size) - : splittable_dimension_major_part_size( - splittable_dimension_major_part_size) { - CHECK_GE(splittable_dimension_major_part_size, 1); - } - // If not kNoSplitRequirement, then the major part size of the splittable - // dimension must be the given value. - int64_t splittable_dimension_major_part_size; -}; -struct SoftmaxRequirements {}; -// Requirements can change depending on what we fuse. -using Requirements = std::variant; -using RequirementsOrError = std::variant; - -// The dimension orders and requirements resulting from propagating the -// dimension orders through an HLO. -struct DimOrdersAndReqs { - DimOrderMap dim_orders; - Requirements requirements; -}; -using DimOrdersAndReqsOrError = std::variant; - -using Int64OrError = std::variant; -Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { - if (a == b || b == kNoSplitRequirement) { - return a; - } - if (a == kNoSplitRequirement) { - return b; - } - return FusionDecision("Conflicting splits of splittable dimension"); -} - -RequirementsOrError CombineDotRequirements(DotRequirements a, - DotRequirements b) { - Int64OrError combined_size_req = - CombineSplitDimMajorPartSizeReqs(a.splittable_dimension_major_part_size, - b.splittable_dimension_major_part_size); - if (std::holds_alternative(combined_size_req)) { - return std::get(combined_size_req); - } - return DotRequirements(std::get(combined_size_req)); -} - -RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a, - SoftmaxRequirements b) { - // SoftmaxRequirements is an empty class for now. - return a; -} - -RequirementsOrError CombineRequirements(Requirements a, - RequirementsOrError b_or_error) { - if (std::holds_alternative(b_or_error)) { - return b_or_error; - } - const Requirements& b = std::get(b_or_error); - if (std::holds_alternative(b)) { - return CombineDotRequirements(std::get(a), - std::get(b)); - } - return CombineSoftmaxRequirements(std::get(a), - std::get(b)); -} - -TensorIterationSpec DimensionOrderToTensorIterationSpec( - const DimensionOrder& order) { - const Fragments& dim_fragments = order.TensorFragmentsOrder(); - TensorIterationSpec tensor_spec; - int64_t accumulated_stride = 1; - int last_dim = -1; - auto remove_last_fragment_if_degenerate = [&tensor_spec](const int dim_idx) { - if (dim_idx >= 0 && !tensor_spec[dim_idx].empty() && - tensor_spec[dim_idx].back().count == 1) { - tensor_spec[dim_idx].pop_back(); - } - }; - for (int dim_order_index = 0; dim_order_index < dim_fragments.size(); - ++dim_order_index) { - const DimensionOrder::Fragment& fragment = dim_fragments[dim_order_index]; - VLOG(6) << fragment.ToString(); - - DimIterationSpec& dim_spec = tensor_spec[fragment.dst_dim_number()]; - if (last_dim == fragment.dst_dim_number()) { - // Remove previous 1-sized subfragment if present. - if (!dim_spec.empty() && !dim_spec.back().subfragments.empty() && - dim_spec.back().subfragments.back() == 1) { - dim_spec.back().subfragments.pop_back(); - } - // Contiguous dimension, split only logically. Merge it back. - if (fragment.full_size() > 1) { - CHECK(!dim_spec.empty()); - CHECK(!dim_spec.back().is_sliced()) - << "Only the major-most fragment can have an offset."; - dim_spec.back().slice_start = - fragment.slice_start() * dim_spec.back().count; - dim_spec.back().slice_limit = - fragment.slice_limit() * dim_spec.back().count; - dim_spec.back().count *= fragment.full_size(); - dim_spec.back().subfragments.push_back(fragment.sliced_size()); - } - } else { - remove_last_fragment_if_degenerate(last_dim); - // Add part of the dimension. - dim_spec.push_back( - TensorIterationSpec::IterationSpecFragment{accumulated_stride, - fragment.full_size(), - fragment.slice_start(), - fragment.slice_limit(), - {fragment.sliced_size()}}); - } - - accumulated_stride *= fragment.full_size(); - last_dim = fragment.dst_dim_number(); - } - remove_last_fragment_if_degenerate(last_dim); - tensor_spec.RemoveEmptyDimensions(); - return tensor_spec; -} - -bool DimensionOrder::IsPhysicallyEquivalent(const DimensionOrder& other) const { - return DimensionOrderToTensorIterationSpec(*this) == - DimensionOrderToTensorIterationSpec(other); -} - -// Logical index of a dimension in `shape` labeled with `label` in the -// `dim_order` describing the shape. -std::optional LogicalIndexOfLabeledDimension( - const Shape& shape, const DimensionOrder& dim_order, const int label) { - auto fragment_it = dim_order.TensorFragmentsOrder().cbegin(); - for (int dim : shape.layout().minor_to_major()) { - const int64_t dim_size = shape.dimensions()[dim]; - int64_t fragments_size = 1; - while (fragments_size < dim_size) { - fragments_size *= fragment_it->full_size(); - if (fragment_it->dst_dim_number() == label) { - return dim; - } - ++fragment_it; - } - } - return std::nullopt; -} - -enum class TransformDirection { kInputToOutput, kOutputToInput }; - using OldToNewHloMap = absl::flat_hash_map; -class FusionContext { - FusionContext(HeroProperties properties, Requirements requirements) - : properties_(properties), requirements_(requirements) {} - - public: - // Create fusion context from a dot operand according to - // the currently supported configurations. - static FusionContext FromDotOperand(const HloInstruction& dot, - int operand_number, int split_k = 1); - - // Create fusion context from dot's output. - static FusionContext FromDotOutput( - const HloInstruction& dot, int split_k, - int64_t splittable_dimension_major_part_size); - - static FusionContext FromSoftmaxRoot(const HloInstruction&); - - // If possible, propagates `src_dim_order` (describing one side of `hlo`) to - // the other side and returns those dim orders. - static DimOrderMapOrError GetPropagatedDimOrders( - const HloInstruction& hlo, TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties); - - // If the dimension order is supported by the triton emitters, this returns - // which requirements does this order impose on the fusion. - // - // All subdimensions within a dimension have to be ordered. - static RequirementsOrError GetRequirementsIfSupportedOrder( - const DimensionOrder& order, const HeroProperties& properties); - // Apply GetRequirementsIfSupportedOrder() to all known - // dimension orders around `hlo` and combine the result. - static RequirementsOrError GetRequirementsIfSupportedOrders( - const HloInstruction& hlo, const DimOrderMap& dim_orders, - const HeroProperties& properties); - // If fusing the instruction is possible then it propagates - // the `src_dim_order` (describing one side of `hlo`) to the other side and - // returns those dim orders and the requirements that they impose on the - // fusion. - static DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( - const HloInstruction& hlo, const DimensionOrder& src_dim_order, - TransformDirection direction, const HeroProperties& properties); - // If fusing the instruction is possible *and profitable* then it propagates - // the `src_dim_order` (describing one side of `hlo`) to the other side and - // returns those dim orders and the requirements that they impose on the - // fusion. - // - // `src_operand_index` must be set iff `transform_direction` is - // kInputToOutput. - static DimOrdersAndReqsOrError - GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - const HloInstruction& hlo, TransformDirection transform_direction, - const std::optional& src_operand_index, - const DimensionOrder& src_dim_order, - const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties); - - // Add dimension orders from `update` to `dim_orders_` and update - // `requirements_` if all of them are compatible. - bool CombineDimOrdersAndReqs(const DimOrdersAndReqs& update); - // Fuse an instruction with all its fusible inputs. - // If an input is not fusible stop there and make a parameter of the new - // fusion, otherwise put it onto stack and check its own inputs first. - void TryToFuseWithInputsRecursively( - HloInstruction& root, se::GpuComputeCapability gpu_version, - OldToNewHloMap& old_to_new_map, - std::vector& fusion_inputs, - HloComputation::Builder& builder); - // Propagate dimension orders in consumer->producer direction starting at - // `origin` with output `origin_dim_order` till parameters of the computation. - // Store the found parameters and their iteration specs. - Status PropagateDimensionOrdersToParameters( - const HloInstruction& origin, ConstHloInstructionSet& parameters, - ConstHloInstructionMap& iter_specs); - - int64_t splittable_dimension_major_part_size() const { - CHECK(std::holds_alternative(requirements_)); - return std::get(requirements_) - .splittable_dimension_major_part_size; - } - const HeroProperties& hero_properties() const { return properties_; } - const DimOrderMap& dim_orders() const { return dim_orders_; } - - private: - static DimOrderMap GetPropagatedDimOrdersForElementwise( - const HloInstruction& hlo, TransformDirection direction, - const DimensionOrder& src_dim_order); - static DimOrderMapOrError GetPropagatedDimOrdersForBitcast( - const HloInstruction& hlo, TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties); - static DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( - const HloInstruction& hlo, TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties); - - const HeroProperties properties_; - Requirements requirements_; - DimOrderMap dim_orders_; -}; - -FusionContext FusionContext::FromDotOperand(const HloInstruction& dot, - const int operand_number, - const int split_k) { - // There can be either none or one split-K batch dimension. - const int num_split_k_batch_dims = split_k > 1; - int split_k_dimension_index = kNoDimensionIndex; - if (split_k > 1) { - split_k_dimension_index = - ContractingDimensionIndex(dot, operand_number) - 1; - } - int splittable_dimension_index = kNoDimensionIndex; - // LHS non-contracting dimension can be split if non-splitK batch is absent. - if (operand_number == 0 && - dot.dot_dimension_numbers().lhs_batch_dimensions_size() - - num_split_k_batch_dims == - 0) { - splittable_dimension_index = - NonContractingDimensionIndex(dot, operand_number); - } - FusionContext context( - DotProperties{ - static_cast(NonContractingDimensionIndex(dot, operand_number)), - splittable_dimension_index}, - DotRequirements(kNoSplitRequirement)); - context.dim_orders_[dot.operand(operand_number)] = - DimensionOrder::FromDotOperandOrOutput(*dot.operand(operand_number), - split_k_dimension_index); - return context; -} - -FusionContext FusionContext::FromDotOutput( - const HloInstruction& dot, const int split_k, - const int64_t splittable_dimension_major_part_size) { - // Allow non-contracting dimension originating from LHS to split if - // this dimension is split at the output at the same ratio as - // at the input. - int splittable_dimension_index = kNoDimensionIndex; - if (splittable_dimension_major_part_size > 1) { - // Split-K dimension is the first one in the output if present; - // LHS non-contracting follows (batch is absent in this case). - splittable_dimension_index = (split_k > 1) ? 1 : 0; - } - FusionContext context(DotProperties{/*noncontracting_dimension=*/-1, - splittable_dimension_index}, - DotRequirements(splittable_dimension_major_part_size)); - context.dim_orders_[&dot] = DimensionOrder::FromDotOperandOrOutput(dot); - return context; -} - -FusionContext FusionContext::FromSoftmaxRoot(const HloInstruction& root) { - FusionContext context( - SoftmaxProperties{DimensionOrder::kSoftmaxReductionDimension, - DimensionOrder::kSoftmaxBatchDimension}, - SoftmaxRequirements{}); - context.dim_orders_[&root] = DimensionOrder::FromSoftmaxRoot(root); - return context; -} - -/*static*/ RequirementsOrError FusionContext::GetRequirementsIfSupportedOrder( - const DimensionOrder& order, const HeroProperties& properties) { - VLOG(8) << order.ToString(); - int64_t split_dim_major_part = kNoSplitRequirement; - const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder(); - for (const auto& [dim_index, dim_fragments] : order.DimFragmentsOrders()) { - CHECK(!dim_fragments.empty()); - for (int i = 0; i < dim_fragments.size() - 1; ++i) { - if (tensor_dim_fragments[dim_fragments[i]].is_sliced()) { - return "Sliced non-major-most fragment."; - } - } - int group_counter = 0; - int last_seen_group_last_fragment_index = -1; - auto fragment_it = dim_fragments.cbegin(); - while (true) { - if (fragment_it == dim_fragments.cend()) { - break; - } - int64_t grouped_size = tensor_dim_fragments[*fragment_it].full_size(); - // Gather contiguous fragments: they have consecutive indices. - while ((fragment_it + 1) != dim_fragments.cend() && - *(fragment_it + 1) == *fragment_it + 1) { - ++fragment_it; - grouped_size *= tensor_dim_fragments[*fragment_it].full_size(); - } - // Ignore 1-sized groups of fragments. - if (grouped_size == 1) { - ++fragment_it; - continue; - } - - if (last_seen_group_last_fragment_index > *fragment_it) { - return "Transpose within a dimension."; - } - - ++group_counter; - if (group_counter > 1) { - if (!std::holds_alternative(properties)) { - return "Splitting a dimension is not supported for Softmax."; - } - // Only the dimension indicated by `splittable_dimension_index` (if any) - // can be split physically once by other dimensions. Other ones can be - // only split logically. - const int splittable_dimension_index = - std::get(properties).splittable_dimension_index; - if (dim_index == splittable_dimension_index) { - if (group_counter == 2) { - if (split_dim_major_part != kNoSplitRequirement && - split_dim_major_part != grouped_size) { - return "Conflicting splits of splittable dimension"; - } - split_dim_major_part = grouped_size; - } else if (group_counter > 2) { - return "2nd split of a splittable dimension."; - } - } else { - return "Unsupported split of a dimension."; - } - } - - last_seen_group_last_fragment_index = *fragment_it; - ++fragment_it; - } - } - - if (std::holds_alternative(properties)) { - return DotRequirements(split_dim_major_part); - } - return SoftmaxRequirements{}; -} - -/*static*/ RequirementsOrError FusionContext::GetRequirementsIfSupportedOrders( - const HloInstruction& hlo, const DimOrderMap& dim_orders, - const HeroProperties& properties) { - const Requirements empty_requirements = - std::holds_alternative(properties) - ? Requirements(DotRequirements(kNoSplitRequirement)) - : Requirements(SoftmaxRequirements{}); - auto get_requirements = - [&](const HloInstruction& instr) -> RequirementsOrError { - if (auto it = dim_orders.find(&instr); it != dim_orders.end()) { - return GetRequirementsIfSupportedOrder(it->second, properties); - } - return empty_requirements; - }; - - Requirements requirements = empty_requirements; - for (const HloInstruction* operand : hlo.operands()) { - RequirementsOrError requirements_or_error = - CombineRequirements(requirements, get_requirements(*operand)); - if (std::holds_alternative(requirements_or_error)) { - return requirements_or_error; - } - requirements = std::get(requirements_or_error); - } - - return CombineRequirements(requirements, get_requirements(hlo)); -} - -/*static*/ DimOrderMap FusionContext::GetPropagatedDimOrdersForElementwise( - const HloInstruction& hlo, TransformDirection direction, - const DimensionOrder& src_dim_order) { - if (direction == TransformDirection::kOutputToInput) { - DimOrderMap map; - for (const HloInstruction* operand : hlo.operands()) { - map.insert({operand, src_dim_order}); - } - return map; - } - - DimOrderMap map; - map.insert({&hlo, src_dim_order}); - // TODO(tdanyluk): For now, the "input to output" direction of this function - // also returns the dim orders for the operands, not just the output. This is - // needed to propagate the dim order of one input to the other(s) when fusing - // elementwise ops to the output. Perhaps we can separate the "input to - // output" and "output to input" directions of that in a later CL. - for (const HloInstruction* operand : hlo.operands()) { - map.insert({operand, src_dim_order}); - } - return map; -} - -const HloInstruction& GetSourceHlo(const HloInstruction& hlo, - TransformDirection direction) { - CHECK_GE(hlo.operand_count(), 1); - - if (direction == TransformDirection::kOutputToInput) { - return hlo; - } - return *hlo.operand(0); -} - -using ConstInstructionVector = absl::InlinedVector; -ConstInstructionVector GetDestHlos(const HloInstruction& hlo, - TransformDirection direction) { - if (direction == TransformDirection::kInputToOutput) { - return {&hlo}; - } - - ConstInstructionVector hlos; - hlos.reserve(hlo.operands().size()); - for (const HloInstruction* operand : hlo.operands()) { - hlos.push_back(operand); - } - return hlos; -} - -const HloInstruction& GetDestHlo(const HloInstruction& hlo, - TransformDirection direction) { - CHECK_EQ(hlo.operand_count(), 1); - - if (direction == TransformDirection::kInputToOutput) { - return hlo; - } - - return *hlo.operand(0); -} - -/*static*/ DimOrderMapOrError FusionContext::GetPropagatedDimOrdersForBitcast( - const HloInstruction& hlo, const TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties) { - const HloInstruction& dst = GetDestHlo(hlo, direction); - const Shape& dst_shape = dst.shape(); - const Fragments& src_fragments_order = src_dim_order.TensorFragmentsOrder(); - DimOrderMap dst_dim_orders; - DimensionOrder& dst_dim_order = - dst_dim_orders.insert({&dst, DimensionOrder()}).first->second; - Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); - // Size of not yet assigned part of current target dimension. - int64_t dst_remaining_size = 1; - // Track destination fragments created from a source one. - absl::flat_hash_map> src_to_dst; - // Iterate in parallel over source dimension order and target dimensions - // in minor_to_major order. Find groups of dimensions of equal size - // and project the source dimension order onto the destination. - auto dst_dim_it = dst_shape.layout().minor_to_major().cbegin(); - const auto dst_dim_end = dst_shape.layout().minor_to_major().cend(); - for (auto src_dim = src_fragments_order.cbegin(); - src_dim != src_fragments_order.cend(); ++src_dim) { - auto add_new_fragment = [&](const Fragment& fragment) { - dst_fragments_order.push_back(fragment); - src_to_dst[&*src_dim].push_back(dst_fragments_order.size() - 1); - }; - if (std::holds_alternative(properties) && - src_dim->dst_dim_number() == - std::get(properties).softmax_batch_dimension) { - // Special handling for softmax batch dimension: allow arbitrary reshapes - // on it because it's guaranteed by the construction of the fusion to have - // no physical alterations like transposes. - // Find a continuous group of fragments corresponding to this dimension in - // the source and assign the corresponding size in fragments of the - // destination ignoring the source ones. - dst_remaining_size = src_dim->full_size(); - while (src_dim + 1 != src_fragments_order.cend() && - (src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) { - ++src_dim; - dst_remaining_size *= src_dim->full_size(); - } - while (dst_remaining_size > 1) { - CHECK(dst_dim_it != dst_dim_end); - add_new_fragment(Fragment{src_dim->dst_dim_number(), - dst_shape.dimensions(*dst_dim_it)}); - dst_remaining_size /= dst_shape.dimensions(*dst_dim_it); - ++dst_dim_it; - } - continue; - } - if (dst_remaining_size >= src_dim->full_size()) { - if (dst_remaining_size % src_dim->full_size()) { - return "Unsupported bitcast"; - } - // Source dimension fragment completely fits into the destination one: - // just copy it as is. - add_new_fragment(*src_dim); - // Update the size of the remaining part of the destination that is - // carried over to next source dimensions. - dst_remaining_size /= src_dim->full_size(); - } else { - // Source is larger than destination. - // Assign further destination dimensions. - // Size of the not yet assigned part of the source dimension. - int64_t src_remaining_size = src_dim->full_size(); - // Handle dimension splits. - if (dst_remaining_size > 1) { - // If there is a remaining fragment of a previous destination dimension - // assign it first. - if (src_remaining_size % dst_remaining_size || (src_dim->is_sliced())) { - return "Unsupported bitcast"; - } - add_new_fragment( - Fragment{src_dim->dst_dim_number(), dst_remaining_size}); - // Update the size of the fragment remaining to assign. - src_remaining_size /= dst_remaining_size; - dst_remaining_size = 1; - } - while (src_remaining_size > 1) { - // Assign destination dimensions until the source remainder is covered. - CHECK(dst_dim_it != dst_dim_end); - int64_t dst_dim_size = dst_shape.dimensions(*dst_dim_it); - int64_t new_fragment_size = dst_dim_size; - if (dst_dim_size > src_remaining_size) { - // If adding the next destination dimension exceeds source fragment - // size assign the remainder of the source and carry over the - // remainder of the destination. - if (dst_dim_size % src_remaining_size) { - return "Unsupported bitcast"; - } - dst_remaining_size = dst_dim_size / src_remaining_size; - new_fragment_size = src_remaining_size; - } - if (src_dim->is_sliced()) { - return "Unsupported bitcast"; - } - add_new_fragment( - Fragment{src_dim->dst_dim_number(), new_fragment_size}); - src_remaining_size /= new_fragment_size; - ++dst_dim_it; - } - } - } - CHECK_EQ(dst_remaining_size, 1); - - // Handle remaining major dimensions of the destination. Call all degenerate - // ones subdimensions of the most-major non-degenerate one. Otherwise - // give up. - while (dst_dim_it != dst_dim_end) { - if (dst_shape.dimensions(*dst_dim_it) != 1) { - return "Unsupported bitcast"; - } - if (!dst_fragments_order.empty()) { - dst_fragments_order.push_back( - Fragment{dst_fragments_order.back().dst_dim_number(), 1}); - src_to_dst[&src_fragments_order.back()].push_back( - dst_fragments_order.size() - 1); - } - ++dst_dim_it; - } - - FragmentOrders& dst_dim_fragment_orders = dst_dim_order.DimFragmentsOrders(); - for (const auto& [dim_index, dim_sequence] : - src_dim_order.DimFragmentsOrders()) { - std::vector& dst = dst_dim_fragment_orders[dim_index]; - dst.reserve(dim_sequence.size()); - for (const int src : dim_sequence) { - std::copy(src_to_dst[&src_fragments_order[src]].cbegin(), - src_to_dst[&src_fragments_order[src]].cend(), - std::back_inserter(dst)); - } - } - - return dst_dim_orders; -} - -// Handle copy, transpose, broadcast or reduce. -// Common between them is that they alter the tensor dimensions or their order -// and the way to handle layouts. -/*static*/ DimOrderMapOrError -FusionContext::GetPropagatedDimOrdersForDimAlteringOp( - const HloInstruction& hlo, const TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties) { - // Temporary storage for new fragments local to this function. - // Please keep this as the first local variable of this function, with type - // std::list to make sure that all pointers to elements of this remain valid - // throughout the entire function. std::deque would also work but it is - // unnecessarily big for a typical size of 1. - std::list new_fragments; - - const HloInstruction& src = GetSourceHlo(hlo, direction); - // Note: copying instead of using a const reference because - // some operations (slice) will modify fragment properties in-place. - Fragments src_fragments_order = src_dim_order.TensorFragmentsOrder(); - if (hlo.opcode() == HloOpcode::kSlice && - ShapeUtil::IsEffectiveScalar(hlo.shape())) { - return FusionDecision("Slice to scalar is not implemented yet."); - } - // Every HLO dimension can correspond to a group of subdimensions in - // dim_order_. For the easier handling of permutations: group dim_order_ by - // dimension, apply permutations, then finally remove the grouping. - // Group subdimensions by iterating over them in the same order as over - // full dimensions and matching by total size. - std::vector> src_physical; - src_physical.reserve(src.shape().rank()); - auto src_fragment_it = src_fragments_order.begin(); - for (int64_t dim_index : src.shape().layout().minor_to_major()) { - const int64_t dim_size = src.shape().dimensions(dim_index); - int64_t subdim_size_accumulator = 1; - std::vector subdim_group; - do { - CHECK(src_fragment_it != src_fragments_order.end()); - subdim_size_accumulator *= src_fragment_it->full_size(); - subdim_group.push_back(&*src_fragment_it); - ++src_fragment_it; - } while (subdim_size_accumulator < dim_size); - CHECK_EQ(subdim_size_accumulator, dim_size); - src_physical.push_back(subdim_group); - } - - // Source physical -> source logical. - std::vector> src_logical; - src_logical.resize(src_physical.size()); - for (int i = 0; i < src_physical.size(); ++i) { - src_logical[src.shape().layout().minor_to_major(i)] = src_physical[i]; - } - - DimOrderMap dst_dim_orders; - for (const HloInstruction* dst : GetDestHlos(hlo, direction)) { - DimensionOrder& dst_dim_order = - dst_dim_orders.insert({dst, DimensionOrder()}).first->second; - // Source logical -> destination logical. - std::vector> dst_logical; - if (hlo.opcode() == HloOpcode::kTranspose) { - const auto* transpose = Cast(&hlo); - std::vector permutation(transpose->dimensions().cbegin(), - transpose->dimensions().cend()); - if (direction == TransformDirection::kInputToOutput) { - permutation = InversePermutation(permutation); - } - dst_logical.resize(permutation.size()); - for (int i = 0; i < permutation.size(); ++i) { - dst_logical[permutation[i]] = src_logical[i]; - } - } else if (hlo.opcode() == HloOpcode::kBroadcast) { - const auto* broadcast = Cast(&hlo); - dst_logical.resize(broadcast->dimensions().size()); - for (int i = 0; i < broadcast->dimensions().size(); ++i) { - dst_logical[i] = src_logical[broadcast->dimensions()[i]]; - } - } else if (hlo.opcode() == HloOpcode::kReduce) { - // Operand 1 (the neutral value) has to be a scalar. - if (dst != &hlo && hlo.operand_index(dst) == 1) { - continue; - } - const auto* reduce = Cast(&hlo); - dst_logical.resize(src_logical.size() + reduce->dimensions().size()); - - if (reduce->dimensions().size() != 1) { - return FusionDecision("Unsupported reduction."); - } else if (reduce->dimensions().front() != - reduce->operand(0)->shape().rank() - 1) { - return FusionDecision("Only row reductions are supported."); - } - for (int i = 0; i < dst_logical.size(); ++i) { - if (i == reduce->dimensions().front()) { - // This way to assign the reduction dimension will only work for - // softmax fusions with known patterns for now. Generally a reduction - // should create a new tiled dimension. - dst_logical[i] = {&new_fragments.emplace_back( - std::get(properties) - .softmax_reduction_dimension, - reduce->operand(0)->shape().dimensions(i))}; - } else { - dst_logical[i] = src_logical[i]; - } - } - } else if (hlo.opcode() == HloOpcode::kConcatenate) { - dst_logical.resize(src_logical.size()); - for (int i = 0; i < src_logical.size(); ++i) { - dst_logical[i] = src_logical[i]; - if (i == hlo.concatenate_dimension()) { - if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { - return FusionDecision("Unsupported concatenation."); - } - dst_logical[i][0]->set_size(dst->shape().dimensions(i)); - dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); - } - } - } else if (hlo.opcode() == HloOpcode::kCopy) { - // Copy preserves the logical shape, just permutes the layout. - CHECK(ShapeUtil::SameDimensions(src.shape(), dst->shape())); - dst_logical = src_logical; - } else if (hlo.opcode() == HloOpcode::kPad) { - // Operand 1 (the padding value) has to be a scalar. - if (dst != &hlo && hlo.operand_index(dst) == 1) { - continue; - } - const auto* pad = Cast(&hlo); - dst_logical.resize(src_logical.size()); - for (int i = 0; i < src_logical.size(); ++i) { - // This only handles the padding added by - // PadDotOperandsIfNeededForSplitK, which sets only edge_padding_high. - const int padding = - pad->padding_config().dimensions(i).edge_padding_high(); - CHECK_EQ(pad->padding_config().dimensions(i).edge_padding_low(), 0); - CHECK_EQ(pad->padding_config().dimensions(i).interior_padding(), 0); - if (padding == 0) { - dst_logical[i] = src_logical[i]; - } else { - // This case is executed for the contracting dimension when we run the - // TritonFusionAnalysis after the padding and the split-k transform - // are applied. - const std::vector& fragments = src_logical[i]; - // We must have 2 fragments at this point. - CHECK_EQ(fragments.size(), 2); - // The dst_dim_numbers must be the same for the 2 fragments of the - // contracting dimension after applying split-k. - CHECK_EQ(fragments[0]->dst_dim_number(), - fragments[1]->dst_dim_number()); - - new_fragments.emplace_back( - fragments[0]->dst_dim_number(), - fragments[0]->full_size() * fragments[1]->full_size() - padding); - dst_logical[i] = {&new_fragments.back()}; - } - } - } else if (hlo.opcode() == HloOpcode::kSlice) { - const auto slice = Cast(&hlo); - dst_logical.resize(src_logical.size()); - for (int dim = 0; dim < src_logical.size(); ++dim) { - dst_logical[dim] = src_logical[dim]; - if (slice->slice_limits(dim) - slice->slice_starts(dim) != - dst->shape().dimensions(dim)) { - if (dst_logical[dim].size() > 1) { - return FusionDecision("Slicing of fragmented dimension."); - } - auto fragment = dst_logical[dim].front(); - fragment->set_size(dst->shape().dimensions(dim)); - // Slicing of an already sliced dimension means adding offsets. - fragment->set_slice( - fragment->slice_start() + slice->slice_starts(dim), - fragment->slice_start() + slice->slice_starts(dim) + - fragment->sliced_size()); - } - } - } else { - return FusionDecision("Function called on a wrong instruction."); - } - // Destination logical -> destination physical and ungroup subdimensions. - // Map original fragments to the resulting ones to derive their new - // logical ordering within each dimension. - absl::flat_hash_map src_to_dst; - Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); - FragmentOrders& dst_dim_fragments_order = - dst_dim_order.DimFragmentsOrders(); - // Remember which dimensions are present before a broadcast; - // skip cases when already present dimension is being expanded. - absl::flat_hash_set dim_numbers_present_in_dst; - for (const int64_t dim_idx : dst->shape().layout().minor_to_major()) { - for (const Fragment* subdim : dst_logical[dim_idx]) { - dst_fragments_order.push_back(*subdim); - src_to_dst[subdim] = dst_fragments_order.size() - 1; - dim_numbers_present_in_dst.insert(subdim->dst_dim_number()); - } - } - for (const auto& [dim_index, dim_sequence] : - src_dim_order.DimFragmentsOrders()) { - for (const int fragment_number : dim_sequence) { - const auto it = src_to_dst.find(&src_fragments_order[fragment_number]); - if (it == src_to_dst.cend()) { - if (hlo.opcode() == HloOpcode::kBroadcast && - src_fragments_order[fragment_number].full_size() > 1 && - dim_numbers_present_in_dst.contains(dim_index)) { - return FusionDecision("Unsupported broadcast"); - } - continue; - } - dst_dim_fragments_order[dim_index].push_back(it->second); - } - } - } - return dst_dim_orders; -} - -// Infers DimensionOrders of all unknown sides (output, operands) -// of `hlo` from the known ones. -/*static*/ DimOrderMapOrError FusionContext::GetPropagatedDimOrders( - const HloInstruction& hlo, const TransformDirection direction, - const DimensionOrder& src_dim_order, const HeroProperties& properties) { - VLOG(7) << "Analyzing " << hlo.ToString(); - if (hlo.opcode() != HloOpcode::kParameter && - direction == TransformDirection::kOutputToInput && - absl::c_any_of(hlo.users(), [](const HloInstruction* user) { - return user->opcode() == HloOpcode::kConcatenate; - })) { - return "No fusion into concatenations"; - } - if (hlo.opcode() == HloOpcode::kParameter || - hlo_query::IsScalarConstant(&hlo)) { - CHECK(direction == TransformDirection::kOutputToInput); - return DimOrderMap{}; - } else if (hlo.opcode() == HloOpcode::kTranspose || - hlo.opcode() == HloOpcode::kCopy) { - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kBroadcast) { - if (direction != TransformDirection::kOutputToInput) { - return "Unsupported broadcast direction."; - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kReduce) { - if (!std::holds_alternative(properties)) { - return "Reductions are not supported in GEMM fusions yet."; - } - if (direction != TransformDirection::kOutputToInput) { - return "Unsupported direction of reduction."; - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kPad) { - if (direction != TransformDirection::kOutputToInput) { - return "Unsupported pad direction."; - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } else if (hlo.operand_count() > 0 && - IsTritonSupportedElementwise( - hlo.opcode(), hlo.operand(0)->shape().element_type())) { - return GetPropagatedDimOrdersForElementwise(hlo, direction, src_dim_order); - } else if (hlo.opcode() == HloOpcode::kBitcast) { - return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kSlice) { - if (direction != TransformDirection::kOutputToInput) { - return "Unsupported slice direction."; - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kReshape) { - if (!ShapeUtil::ReshapeIsBitcast(hlo.operand(0)->shape(), hlo.shape())) { - return "Non-bitcast reshape."; - } - return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, - properties); - } else if (hlo.opcode() == HloOpcode::kConcatenate && - direction == TransformDirection::kOutputToInput) { - if (!std::holds_alternative(properties)) { - return "Concatenations for now are only supported in GEMM fusions."; - } - auto dim = LogicalIndexOfLabeledDimension( - hlo.shape(), src_dim_order, - std::get(properties).noncontracting_dimension); - if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { - return "Unsupported concatenation."; - } - if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) { - return operand->user_count() > 1; - })) { - return FusionDecision( - "Concatenation has to be the only user of its inputs."); - } - if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { - // In the current simple implementation of concatenation the size of - // each of its inputs along the concatenated dimension has to be - // divisible by the tile size used for this dimension. Concatenations - // with any operand not divisible by kMinConcatFragmentSize will not - // be fused; tiling configurations with tile size for this dimension - // larger than kMinConcatFragmentSize will not be emitted. - constexpr int kMinConcatFragmentSize = 128; - return operand->shape().dimensions(hlo.concatenate_dimension()) % - kMinConcatFragmentSize != - 0; - })) { - return FusionDecision( - "One or more operands of concatenation can not be perfectly tiled."); - } - return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, - properties); - } - return "Unimplemented instruction."; -} - -// Difference of input and output data volumes of an instruction. -int64_t InputMinusOutputBytes(const HloInstruction& hlo) { - CHECK(!hlo.shape().IsTuple()); - int64_t input_size = 0; - for (const HloInstruction* operand : hlo.operands()) { - CHECK(!operand->shape().IsTuple()); - input_size += ShapeUtil::ByteSizeOf(operand->shape()); - } - return input_size - ShapeUtil::ByteSizeOf(hlo.shape()); -} - -// Tells if an instruction has no user into which it could be fused. -// More cases should be added here. -bool CanNotBeFusedIntoAUser(const HloInstruction& hlo) { - return hlo.IsRoot() || (hlo.user_count() == 1 && hlo.users()[0]->IsRoot() && - hlo.users()[0]->opcode() == HloOpcode::kTuple); -} - -// Let input and output data volumes of a fusion grow by small amounts. -constexpr int kIoToleranceBytes = 1024; - -// Tells that fusing an instruction as an input is efficient. -bool IsInputWorthFusing(const HloInstruction& hlo) { - if (InputMinusOutputBytes(hlo) <= kIoToleranceBytes) { - return true; - } - if (hlo.user_count() > 1) { - return false; - } - if (hlo.opcode() == HloOpcode::kSlice && - hlo_query::AllOperandsAreParametersOrConstants(hlo)) { - return true; - } - return hlo_query::AllOperandsAreParametersOrConstantsWithSingleUser(hlo); -} - -// Tells that fusing an instruction as an output is efficient. -bool IsOutputWorthFusing(const HloInstruction& hlo) { - return CanNotBeFusedIntoAUser(hlo) || - InputMinusOutputBytes(hlo) >= -kIoToleranceBytes; -} - -/*static*/ DimOrdersAndReqsOrError -FusionContext::GetPropagatedDimOrdersAndRequirements( - const HloInstruction& hlo, const DimensionOrder& src_dim_order, - TransformDirection direction, const HeroProperties& properties) { - DimOrderMapOrError propagated_dim_orders_or_error = - GetPropagatedDimOrders(hlo, direction, src_dim_order, properties); - if (std::holds_alternative(propagated_dim_orders_or_error)) { - return std::get(propagated_dim_orders_or_error); - } - DimOrderMap propagated_dim_orders = - std::move(std::get(propagated_dim_orders_or_error)); - RequirementsOrError requirements_or_error = - GetRequirementsIfSupportedOrders(hlo, propagated_dim_orders, properties); - if (std::holds_alternative(requirements_or_error)) { - return std::get(requirements_or_error); - } - return DimOrdersAndReqs{propagated_dim_orders, - std::get(requirements_or_error)}; -} - -/*static*/ DimOrdersAndReqsOrError -FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - const HloInstruction& hlo, TransformDirection transform_direction, - const std::optional& src_operand_index, - const DimensionOrder& src_dim_order, - const se::GpuComputeCapability& gpu_version, - const HeroProperties& properties) { - CHECK_EQ(transform_direction == TransformDirection::kInputToOutput, - src_operand_index.has_value()); - - if (hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { - return "Unsupported instruction."; - } - if (hlo.opcode() == HloOpcode::kReduce) { - return "Reductions are not fused yet."; - } - if (hlo.opcode() == HloOpcode::kPad) { - return "Pads are not fused yet."; - } - for (const HloInstruction* operand : hlo.operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { - return "Unsupported input data type."; - } - } - if (!IsTritonSupportedDataType(hlo.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; - } - DimOrdersAndReqsOrError result_or_error = - GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, - transform_direction, properties); - if (!std::holds_alternative(result_or_error)) { - return result_or_error; - } - DimOrdersAndReqs dim_orders_and_requirements = - std::move(std::get(result_or_error)); - int fusion_level = - hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); - if (!std::get(gpu_version) - .IsAtLeast(se::CudaComputeCapability::AMPERE)) { - fusion_level = std::min(fusion_level, 1); - } - if (transform_direction == TransformDirection::kOutputToInput) { - if (fusion_level < 2) { - if (hlo.opcode() == HloOpcode::kConvert) { - if (FusionDecision decision = IsConversionWorthFusing(hlo, gpu_version); - !decision) { - return decision; - } - } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { - return "Ignored elementwise operation"; - } - } else { - // Exception for binary elementwise operations: in most cases these are - // not trivial to fuse because they increase DRAM traffic but if one - // of the inputs is for example a broadcast that can be fused too it - // becomes worth fusing. Look ahead and analyze operands here. - bool accepted = false; - if (hlo.IsElementwise() && hlo.operand_count() == 2) { - for (const HloInstruction* operand : hlo.operands()) { - if (operand->opcode() == HloOpcode::kBroadcast && - (operand->operand(0)->opcode() == HloOpcode::kParameter || - operand->operand(0)->opcode() == HloOpcode::kConstant) && - std::holds_alternative( - GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *operand, TransformDirection::kOutputToInput, - /*src_operand_index=*/std::nullopt, - /*src_dim_order=*/ - dim_orders_and_requirements.dim_orders.at(operand), - gpu_version, properties))) { - accepted = true; - break; - } - } - } - if (!accepted && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; - } - } - } else { - if (fusion_level < 2) { - return "Skipping fusing outputs at low fusion levels."; - } - for (int i = 0; i < hlo.operand_count(); ++i) { - const HloInstruction* operand = hlo.operand(i); - // Skip source operand. - if (i == *src_operand_index) { - continue; - } - // Currently only broadcasts of scalar constants or parameters - // are accepted as other inputs of non-unary operations - // in the output fusion. - if (hlo_query::IsBroadcastOfScalarConstant(*operand) || - operand->opcode() == HloOpcode::kParameter) { - continue; - } - return "Has multiple inputs - not properly analyzed yet."; - } - if (!IsOutputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as output."; - } - } - return dim_orders_and_requirements; -} - // Gets the fused HLO corresponding to `hlo` or adds a new parameter if not // found. HloInstruction* GetFusedHloOrAddParameter( @@ -1419,31 +136,15 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { return hlo.operand_count() - 1; } -bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { - // First check that all updates to insert are compatible to avoid - // incomplete merges. - for (const auto& [key, value] : update.dim_orders) { - auto it = dim_orders_.find(key); - if (it != dim_orders_.cend() && !it->second.IsPhysicallyEquivalent(value)) { - return false; - } - } - - RequirementsOrError requirements_or_error = - CombineRequirements(requirements_, update.requirements); - if (std::holds_alternative(requirements_or_error)) { - return false; - } - - requirements_ = std::move(std::get(requirements_or_error)); - dim_orders_.insert(update.dim_orders.begin(), update.dim_orders.end()); - return true; -} - -void FusionContext::TryToFuseWithInputsRecursively( - HloInstruction& root, const se::GpuComputeCapability gpu_version, - OldToNewHloMap& old_to_new_map, std::vector& fusion_inputs, - HloComputation::Builder& builder) { +// Fuse an instruction with all its fusible inputs. +// If an input is not fusible stop there and make a parameter of the new +// fusion, otherwise put it onto stack and check its own inputs first. +void TryToFuseWithInputsRecursively(HloInstruction& root, + se::GpuComputeCapability gpu_version, + triton_fusion::FusionContext& context, + OldToNewHloMap& old_to_new_map, + std::vector& fusion_inputs, + HloComputation::Builder& builder) { // Instructions at the fusion edge that can either get fused too or // become parameters of the fusion. Used to track the number of parameters. absl::flat_hash_set inputs; @@ -1469,13 +170,14 @@ void FusionContext::TryToFuseWithInputsRecursively( continue; } num_requeued = 0; - const DimOrdersAndReqsOrError result = + const triton_fusion::DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *hlo, TransformDirection::kOutputToInput, - /*src_operand_index=*/std::nullopt, dim_orders_.at(hlo), - gpu_version, properties_); - if (!std::holds_alternative(result) || - !CombineDimOrdersAndReqs(std::get(result))) { + *hlo, triton_fusion::TransformDirection::kOutputToInput, + /*src_operand_index=*/std::nullopt, context.dim_orders().at(hlo), + gpu_version, context.hero_properties()); + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs( + std::get(result))) { continue; } if (hlo->opcode() != HloOpcode::kParameter) { @@ -1534,15 +236,15 @@ StatusOr FuseDot(HloInstruction& dot, // differently shaped tiles but may go through same HLO graph nodes. // Direct dot inputs have well defined dimension orders. - auto fuse_inputs = - [&](int operand_number, - OldToNewHloMap& old_to_new_map) -> StatusOr { + auto fuse_inputs = [&](int operand_number, OldToNewHloMap& old_to_new_map) + -> StatusOr { const int operand_count_before = fusion_inputs.size(); // Direct dot inputs have well defined dimension orders. - auto context = FusionContext::FromDotOperand(dot, operand_number); - context.TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), - gpu_version, old_to_new_map, - fusion_inputs, builder); + auto context = + triton_fusion::FusionContext::FromDotOperand(dot, operand_number); + TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), + gpu_version, context, old_to_new_map, + fusion_inputs, builder); const int new_parameters = fusion_inputs.size() - operand_count_before; TF_RET_CHECK(new_parameters <= TritonFusionAnalysis::kMaxParameterPerDotScope) @@ -1553,7 +255,7 @@ StatusOr FuseDot(HloInstruction& dot, // Original instruction -> fused one. Separate for each scope. OldToNewHloMap lhs_old_to_new_map; - TF_ASSIGN_OR_RETURN(const FusionContext lhs_context, + TF_ASSIGN_OR_RETURN(const triton_fusion::FusionContext lhs_context, fuse_inputs(0, lhs_old_to_new_map)); OldToNewHloMap rhs_old_to_new_map; @@ -1570,7 +272,7 @@ StatusOr FuseDot(HloInstruction& dot, // Fusion at dot's output. // These describe _outputs_ of corresponding HLOs. - auto context = FusionContext::FromDotOutput( + auto context = triton_fusion::FusionContext::FromDotOutput( dot, /*split_k=*/1, lhs_context.splittable_dimension_major_part_size()); HloInstruction* fusion_output = ˙ bool output_changed = true; @@ -1583,21 +285,22 @@ StatusOr FuseDot(HloInstruction& dot, if (!IsDistributiveOverAddition(*user)) { break; } - DimOrdersAndReqsOrError result = - FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *user, TransformDirection::kInputToOutput, + triton_fusion::DimOrdersAndReqsOrError result = + triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *user, triton_fusion::TransformDirection::kInputToOutput, user->operand_index(fusion_output), context.dim_orders().at(fusion_output), gpu_version, context.hero_properties()); - if (!std::holds_alternative(result) || - !context.CombineDimOrdersAndReqs(std::get(result))) { + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs( + std::get(result))) { break; } for (HloInstruction* operand : user->operands()) { if (!output_old_to_new_map.contains(operand)) { - context.TryToFuseWithInputsRecursively(*operand, gpu_version, - output_old_to_new_map, - fusion_inputs, builder); + TryToFuseWithInputsRecursively(*operand, gpu_version, context, + output_old_to_new_map, fusion_inputs, + builder); } } Fuse(*user, output_old_to_new_map, fusion_inputs, builder); @@ -1694,213 +397,8 @@ StatusOr RunOnComputation(HloComputation* computation, return visitor.changed(); } -Status FusionContext::PropagateDimensionOrdersToParameters( - const HloInstruction& origin, ConstHloInstructionSet& parameters, - ConstHloInstructionMap& iter_specs) { - absl::flat_hash_set visited; - std::queue to_process; - // Dimension orders describing outputs of corresponding instructions. - visited.insert(&origin); - to_process.push(&origin); - while (!to_process.empty()) { - const HloInstruction* hlo = to_process.front(); - to_process.pop(); - if (hlo->opcode() == HloOpcode::kParameter) { - // One parameter corresponds to one iteration spec in the results of the - // analysis. This describes well situations when a parameter has one or - // more elementwise users - they share the same tiling. Situations when - // one instruction is read differently by different users in the same - // scope of the dot are currently prevented during the fusion. - TF_RET_CHECK(parameters.insert(hlo).second); - VLOG(5) << hlo->ToString(); - } - DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( - *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, - properties_); - TF_RET_CHECK(std::holds_alternative(result)); - TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); - iter_specs[hlo] = DimensionOrderToTensorIterationSpec(dim_orders_.at(hlo)); - for (const HloInstruction* operand : hlo->operands()) { - if (!visited.insert(operand).second) { - continue; - } - if (operand->opcode() == HloOpcode::kDot) { - // Encountering the dot itself happens during the processing of the - // output fusion. The propagation should stop at it. - continue; - } - to_process.push(operand); - } - } - return OkStatus(); -} - } // namespace -// Data types that are supported by the Triton emitters. -bool IsTritonSupportedDataType(PrimitiveType type, - se::GpuComputeCapability gpu_version) { - auto cuda_compute_capability = - std::get(gpu_version); - switch (type) { - case PRED: - case S8: - case S16: - case S32: - case F16: - case F32: - return true; - case BF16: - return cuda_compute_capability.IsAtLeast( - stream_executor::CudaComputeCapability::AMPERE); - default: - return false; - } -} - -// BF16 is supported in a sense that all operations on it are implemented -// through F32 and converts have to be inserted into the HLO graph, but -// they can be missing during fusion. - -std::vector TritonSupportedUnaryElementwise( - PrimitiveType element_type) { - std::vector ret = {HloOpcode::kConvert}; - if (element_type == PrimitiveType::PRED) { - ret.push_back(HloOpcode::kNot); - return ret; - } - ret.push_back(HloOpcode::kAbs); - ret.push_back(HloOpcode::kNegate); - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, - HloOpcode::kExpm1, HloOpcode::kLog, - HloOpcode::kLog1p, HloOpcode::kRsqrt, - HloOpcode::kSin, HloOpcode::kSqrt, - HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh}, - std::back_inserter(ret)); - } - return ret; -} - -std::vector TritonSupportedBinaryElementwise( - PrimitiveType element_type) { - if (element_type == PrimitiveType::PRED) { - return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, - HloOpcode::kCompare}; - } - std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, - HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kMultiply, HloOpcode::kSubtract}; - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - ret.push_back(HloOpcode::kAtan2); - ret.push_back(HloOpcode::kDivide); - ret.push_back(HloOpcode::kPower); - } - return ret; -} - -std::vector TritonSupportedTernaryElementwise( - PrimitiveType element_type) { - return {HloOpcode::kSelect}; -} - -bool IsTritonSupportedElementwise(HloOpcode opcode, - PrimitiveType element_type) { - return absl::c_linear_search(TritonSupportedUnaryElementwise(element_type), - opcode) || - absl::c_linear_search(TritonSupportedBinaryElementwise(element_type), - opcode) || - absl::c_linear_search(TritonSupportedTernaryElementwise(element_type), - opcode); -} - -StatusOr TritonFusionAnalysis::Execute( - const HloComputation& computation, const int split_k) { - VLOG(5) << computation.ToString(HloPrintOptions::ShortParsable()); - TritonFusionAnalysis analysis; - const HloInstruction* dot = - hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot); - if (dot != nullptr) { - TF_RETURN_IF_ERROR(analysis.ExecuteForDotFusion(*dot, split_k)); - } else { - TF_RETURN_IF_ERROR( - analysis.ExecuteForSoftmaxFusion(*computation.root_instruction())); - } - return analysis; -} - -Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( - const HloInstruction& root) { - auto context = FusionContext::FromSoftmaxRoot(root); - // Softmax fusion uses one tiled scope. - TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( - root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); - iter_specs_[Scope::LHS] = {}; - iter_specs_[Scope::RHS] = {}; - return OkStatus(); -} - -Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, - const int split_k) { - int64_t lhs_nc_split_major_part_size = kNoSplitRequirement; - for (const Scope scope : {Scope::LHS, Scope::RHS}) { - const int operand_number = static_cast(scope); - auto context = FusionContext::FromDotOperand(dot, operand_number, split_k); - TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( - *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); - if (scope == Scope::LHS) { - lhs_nc_split_major_part_size = - context.splittable_dimension_major_part_size(); - } - } - - auto context = - FusionContext::FromDotOutput(dot, split_k, lhs_nc_split_major_part_size); - const HloInstruction* output = ˙ - // Currently supported is one fusion output and one path from dot to it. - // Propagate dimension order from dot to root. - while (!output->IsRoot()) { - TF_RET_CHECK(output->user_count() == 1); - const HloInstruction* input = output; - output = output->users()[0]; - DimOrdersAndReqsOrError result = - context.GetPropagatedDimOrdersAndRequirements( - *output, context.dim_orders().at(input), - TransformDirection::kInputToOutput, context.hero_properties()); - TF_RET_CHECK(std::holds_alternative(result)); - TF_RET_CHECK( - context.CombineDimOrdersAndReqs(std::get(result))); - } - TF_RET_CHECK(iter_specs_[Scope::OUTPUT] - .insert({output, DimensionOrderToTensorIterationSpec( - context.dim_orders().at(output))}) - .second); - if (output != &dot) { - // Propagate back to parameters of the output fusion. - TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( - *output, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); - } - return OkStatus(); -} - -const DimIterationSpec* TritonFusionAnalysis::IterSpec( - const TritonFusionAnalysis::Scope scope, const HloInstruction* hlo, - const int dimension) const { - auto hlo_spec = iter_specs_.at(scope).find(hlo); - if (hlo_spec != iter_specs_.at(scope).cend()) { - auto dim_spec = hlo_spec->second.Storage().find(dimension); - if (dim_spec != hlo_spec->second.Storage().cend()) { - return &dim_spec->second; - } - } - return nullptr; -} - FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, const se::GpuComputeCapability gpu_version) { if (dot.opcode() != HloOpcode::kDot || @@ -1992,67 +490,5 @@ StatusOr GemmRewriterTriton::Run( return changed; } -static std::string IterationSpecByInstructionMapToString( // NOLINT - const TritonFusionAnalysis::IterationSpecByInstructionMap& m) { - return absl::StrCat("IterSpec{", - absl::StrJoin(m, ", ", - [&](std::string* s, const auto& kv) { - absl::StrAppend(s, kv.first->name(), ": ", - kv.second.ToString()); - }), - "}"); -} - -static std::string ScopeToString(TritonFusionAnalysis::Scope s) { // NOLINT - switch (s) { - case TritonFusionAnalysis::Scope::LHS: - return "LHS"; - case TritonFusionAnalysis::Scope::RHS: - return "RHS"; - case TritonFusionAnalysis::Scope::OUTPUT: - return "OUTPUT"; - } -} - -std::string TensorIterationSpec::IterationSpecFragment::ToString() const { - return absl::StrCat("{stride=", stride, ", count=", count, - ", slice_start=", slice_start, ", subfragments=[", - absl::StrJoin(subfragments, ", "), "]}"); -} - -bool TensorIterationSpec::IterationSpecFragment::operator!=( - const IterationSpecFragment& other) const { - return stride != other.stride || count != other.count || - slice_start != other.slice_start || slice_limit != other.slice_limit; -} - -std::string TensorIterationSpec::ToString() const { - return absl::StrCat( - "{", - absl::StrJoin(dim_iteration_specs_, ", ", - [&](std::string* s, const auto& kv) { - absl::StrAppend( - s, kv.first, ": ", "[", - absl::StrJoin(kv.second, ", ", - [&](std::string* ss, const auto& v) { - absl::StrAppend(ss, v.ToString()); - }), - "]"); - }), - "}"); -} - -std::string TritonFusionAnalysis::ToString() const { - return absl::StrCat( - "TritonFusionAnalysis{\n", - absl::StrJoin(iter_specs_, ",\n", - [&](std::string* s, const auto& kv) { - absl::StrAppend( - s, ScopeToString(kv.first), ": ", - IterationSpecByInstructionMapToString(kv.second)); - }), - "\n}"); -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h index a95b2f02f920ea..34a77bf905bbbe 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h @@ -15,47 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ #define XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ -#include -#include -#include -#include +// This file contains the code for fusing dots and other operations into Triton +// GEMM fusions. -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/instruction_fusion.h" -#include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" -#include "xla/xla_data.pb.h" namespace xla { namespace gpu { -// Tells if f(a+b) == f(a) + f(b). -bool IsDistributiveOverAddition(const HloInstruction& hlo); - -// Allowlist of unary elementwise operations supported by Triton GEMM codegen. -std::vector TritonSupportedUnaryElementwise(PrimitiveType); - -// Allowlist of binary elementwise operations supported by Triton GEMM codegen. -std::vector TritonSupportedBinaryElementwise(PrimitiveType); - -// Allowlist of ternary elementwise operations supported by Triton GEMM codegen. -std::vector TritonSupportedTernaryElementwise(PrimitiveType); - -// Data types that are supported by the Triton emitters. -bool IsTritonSupportedDataType(PrimitiveType, se::GpuComputeCapability); - -// Checks elementwise operation against all supported by Triton GEMM codegen. -bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType); - // Filters GEMMs which can be handled using Triton. FusionDecision CanTritonHandleGEMM(const HloInstruction&, se::GpuComputeCapability gpu_version); @@ -64,96 +37,6 @@ FusionDecision CanTritonHandleGEMM(const HloInstruction&, bool ShouldTritonHandleGEMM(HloInstruction&, se::GpuComputeCapability gpu_version); -class TensorIterationSpec { - public: - // Description of basic iteration: `count` elements separated by `stride`. - struct IterationSpecFragment { - int64_t stride; - int64_t count; - int64_t slice_start; - int64_t slice_limit; - // Logical subfragments when this iteration is composed - // of several HLO dimensions. - std::vector subfragments; - - bool is_sliced() const { return count != slice_limit - slice_start; } - bool operator!=(const IterationSpecFragment& other) const; - std::string ToString() const; - }; - // Description of complex iteration over a sequence of several strides. - // Describes a logically contiguous dimension of a tensor physically - // separated into multiple fragments by other dimensions. - using DimIterationSpec = std::vector; - - using StorageType = absl::flat_hash_map; - const DimIterationSpec& operator[](const int dimension) const { - return dim_iteration_specs_.at(dimension); - } - DimIterationSpec& operator[](const int dimension) { - return dim_iteration_specs_[dimension]; - } - const StorageType& Storage() const { return dim_iteration_specs_; } - void RemoveEmptyDimensions() { - absl::erase_if(dim_iteration_specs_, - [](const auto& it) { return it.second.empty(); }); - } - - // Compares physical layouts of tensors ignoring subfragments of dimensions. - bool operator==(const TensorIterationSpec& other) const; - - std::string ToString() const; - - private: - StorageType dim_iteration_specs_; -}; - -// Analysis of tensor iteration orders within tiled fusions. -class TritonFusionAnalysis { - Status ExecuteForDotFusion(const HloInstruction& dot, int split_k); - Status ExecuteForSoftmaxFusion(const HloInstruction& root); - - public: - // Execute the analysis of a fusion computation. - // `split_k` indicates whether this operation was converted to the split-K - // form and tells the analysis how to interpret the batch dimensions. - static StatusOr Execute( - const HloComputation& computation, int split_k = 1); - - // A scope is an HLO graph that can be tiled efficiently using same or - // compatible tile shapes on all operations. GEMM fusion has 3 scopes - // defined by left operand, right operand and output. - enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; - - using IterationSpecByInstructionMap = - ConstHloInstructionMap; - using IterationSpecByInstructionByScopeMap = - std::map; - - // Every parameter requires a separate piece of shared memory for asynchronous - // loads. Multiple parameters are approximately equivalent to multiple - // pipeline stages. - // Note: this has been tuned specifically for GEMMs, where pipelining with - // more than 4 stages has been shown to rarely be practical. This limitation - // is not necessarily applicable to other operations. - static constexpr int kMaxParameterPerDotScope = 4; - - // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. - const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, - const HloInstruction*, - int dimension) const; - // Parameter HLO instructions used in a scope of `dot`. - const ConstHloInstructionSet& ScopeParameters(const Scope scope) const { - return parameters_.at(scope); - } - - std::string ToString() const; - - private: - IterationSpecByInstructionByScopeMap iter_specs_; - // HLO computation parameters per scope. - std::map parameters_; -}; - // Rewrite compatible dot() calls into custom calls with fused computations // that target Triton-based matmul emitter. class GemmRewriterTriton : public HloModulePass { diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 8c17d7c63a9d69..537fea463d09c0 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/statusor.h" @@ -188,643 +189,6 @@ ENTRY e { EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); } -using TritonDotAnalysisTest = HloTestBase; - -TEST_F(TritonDotAnalysisTest, NopBitcasts) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - param_0.1 = s8[48,4]{1,0} parameter(0) - bitcast.18 = s8[1,48,4]{2,1,0} bitcast(param_0.1) - bitcast.19 = s8[48,4]{1,0} bitcast(bitcast.18) - convert.4 = bf16[48,4]{1,0} convert(bitcast.19) - param_1.1 = bf16[4,3]{1,0} parameter(1) - ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[48,4]{1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - custom-call = bf16[48,3]{1,0} custom-call(p0, p1), - custom_call_target="__triton", - called_computations={triton_dot} - ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = module->entry_computation() - ->root_instruction() - ->operand(0) - ->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p0); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p1); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/4, /*count=*/48, /*slice_start=*/0, - /*slice_limit=*/48, ElementsAre(48)))); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, /*slice_start=*/0, - /*slice_limit=*/4, ElementsAre(4)))); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, /*slice_start=*/0, - /*slice_limit=*/4, ElementsAre(4)))); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, /*slice_start=*/0, - /*slice_limit=*/3, ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, Merge) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - param_0.1 = s8[1,8,6,4]{3,2,1,0} parameter(0) - bitcast.18 = s8[48,4]{1,0} bitcast(param_0.1) - convert.4 = bf16[48,4]{1,0} convert(bitcast.18) - param_1.1 = bf16[4,3]{1,0} parameter(1) - ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[1,8,6,4]{3,2,1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - custom-call = bf16[48,3]{1,0} custom-call(p0, p1), - custom_call_target="__triton", - called_computations={triton_dot} - ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = module->entry_computation() - ->root_instruction() - ->operand(0) - ->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p0); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p1); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/4, /*count=*/6 * 8, - /*slice_start=*/0, /*slice_limit=*/6 * 8, - /*subfragments=*/ElementsAre(6, 8)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, - /*slice_start=*/0, /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, Split) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - %parameter_1 = f32[24000,2]{1,0} parameter(1) - %convert.15 = f16[24000,2]{1,0} convert(%parameter_1) - %parameter_0 = f16[4]{0} parameter(0) - %bitcast.45 = f16[2,2]{1,0} bitcast(%parameter_0) - ROOT %dot.26 = f16[24000,2]{1,0} dot(%convert.15, %bitcast.45), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f16[4]{0} parameter(0) - p1 = f32[24000,2]{1,0} parameter(1) - ROOT r = f16[24000,2]{1,0} custom-call(p0, p1), - custom_call_target="__triton", - called_computations={triton_dot} -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p1); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p0); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/2, /*count=*/24000, - /*slice_start=*/0, /*slice_limit=*/24000, - /*subfragments=*/ElementsAre(24000)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/2, - /*slice_start=*/0, /*slice_limit=*/2, - /*subfragments=*/ElementsAre(2)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/2, /*count=*/2, - /*slice_start=*/0, /*slice_limit=*/2, - /*subfragments=*/ElementsAre(2)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/2, - /*slice_start=*/0, /*slice_limit=*/2, - /*subfragments=*/ElementsAre(2)))); -} - -TEST_F(TritonDotAnalysisTest, TransposeMerge) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - param_0.1 = s8[1,4,8,6]{3,2,1,0} parameter(0) - transpose.3 = s8[1,8,6,4]{3,2,1,0} transpose(param_0.1), dimensions={0,2,3,1} - bitcast.18 = s8[48,4]{1,0} bitcast(transpose.3) - convert.4 = bf16[48,4]{1,0} convert(bitcast.18) - param_1.1 = bf16[4,3]{1,0} parameter(1) - ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[1,4,8,6]{3,2,1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - custom-call = bf16[48,3]{1,0} custom-call(p0, p1), - custom_call_target="__triton", - called_computations={triton_dot} - ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = module->entry_computation() - ->root_instruction() - ->operand(0) - ->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p0); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p1); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 6, - /*slice_start=*/0, /*slice_limit=*/8 * 6, - /*subfragments=*/ElementsAre(6, 8)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/8 * 6, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, - /*slice_start=*/0, /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, CopyMerge) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - param_0.1 = s8[1,4,8,6]{3,2,1,0} parameter(0) - bitcast.99 = s8[1,8,6,4]{2,1,3,0} bitcast(param_0.1) - copy.3 = s8[1,8,6,4]{3,2,1,0} copy(bitcast.99) - bitcast.18 = s8[48,4]{1,0} bitcast(copy.3) - convert.4 = bf16[48,4]{1,0} convert(bitcast.18) - param_1.1 = bf16[4,3]{1,0} parameter(1) - ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = s8[1,4,8,6]{3,2,1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - custom-call = bf16[48,3]{1,0} custom-call(p0, p1), - custom_call_target="__triton", - called_computations={triton_dot} - ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = module->entry_computation() - ->root_instruction() - ->operand(0) - ->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p0); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p1); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 6, - /*slice_start=*/0, /*slice_limit=*/8 * 6, - /*subfragments=*/ElementsAre(6, 8)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/8 * 6, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, - /*slice_start=*/0, /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, TransposeMergeNCN) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - param_0.1 = bf16[3,4,8,1]{3,2,1,0} parameter(0) - transpose.3 = bf16[3,8,1,4]{3,2,1,0} transpose(param_0.1), dimensions={0,2,3,1} - bitcast.18 = bf16[24,4]{1,0} bitcast(transpose.3) - param_1.1 = bf16[4,3]{1,0} parameter(1) - ROOT dot = bf16[24,3]{1,0} dot(bitcast.18, param_1.1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = bf16[3,4,8,1]{3,2,1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - custom-call = bf16[24,3]{1,0} custom-call(p0, p1), - custom_call_target="__triton", called_computations={triton_dot} - ROOT bitcast.2 = bf16[3,8,1,3]{3,2,1,0} bitcast(custom-call) -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = module->entry_computation() - ->root_instruction() - ->operand(0) - ->called_computations()[0]; - const HloInstruction* p0 = dot_computation->parameter_instruction(0); - const HloInstruction* p1 = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), - p0); - EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), - p1); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8, - /*slice_start=*/0, /*slice_limit=*/8, - /*subfragments=*/ElementsAre(8)), - FieldsAre(/*stride=*/4 * 8, /*count=*/3, - /*slice_start=*/0, /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), - ElementsAre(FieldsAre(/*stride=*/8, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), - ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, - /*slice_start=*/0, /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, TransposeOutput) { - const std::string hlo_text = R"( -HloModule t - -triton_dot { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - dot = bf16[24,3]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - bc = bf16[12,2,3]{2,1,0} bitcast(dot) - ROOT t = bf16[3,12,2]{2,1,0} transpose(bc), dimensions={2,0,1} -} - -ENTRY e { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - ROOT r = bf16[3,12,2]{2,1,0} fusion(p0, p1), kind=kCustom, - calls=triton_dot -})"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const HloInstruction* dot_output = dot_computation->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, dot_output, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, /*slice_start=*/0, - /*slice_limit=*/24, - /*subfragments=*/ElementsAre(2, 12)))); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, dot_output, 1), - ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, /*slice_start=*/0, - /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, OutputParameterIsHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -triton_dot { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - dot = bf16[24,3]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - p2 = f16[3,24]{1,0} parameter(2) - p2t = f16[24,3]{1,0} transpose(p2), dimensions={1,0} - p2tc = bf16[24,3]{1,0} convert(p2t) - ROOT r = bf16[24,3]{1,0} divide(p2tc, dot) -} - -ENTRY e { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4,3]{1,0} parameter(1) - p2 = f16[3,24]{1,0} parameter(2) - ROOT r = bf16[24,3]{1,0} fusion(p0, p1, p2), kind=kCustom, - calls=triton_dot -})")); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const HloInstruction* output_param = - dot_computation->parameter_instruction(2); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ( - analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 0) - ->size(), - 1); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, /*slice_start=*/0, - /*slice_limit=*/24, - /*subfragments=*/ElementsAre(24)))); - EXPECT_EQ( - analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 1) - ->size(), - 1); - EXPECT_THAT( - *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 1), - ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, /*slice_start=*/0, - /*slice_limit=*/3, - /*subfragments=*/ElementsAre(3)))); -} - -TEST_F(TritonDotAnalysisTest, InputBroadcastFromScalarIsHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -triton_dot { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[] parameter(1) - p1b = bf16[4,3] broadcast(p1) - ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[] parameter(1) - ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, - calls=triton_dot -})")); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const HloInstruction* scalar = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, scalar, 0), - nullptr); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, scalar, 1), - nullptr); -} - -TEST_F(TritonDotAnalysisTest, InputBroadcastFromVectorIsHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -triton_dot { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4] parameter(1) - p1b = bf16[4,3] broadcast(p1), dimensions={0} - ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = bf16[24,4]{1,0} parameter(0) - p1 = bf16[4] parameter(1) - ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, - calls=triton_dot -})")); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const HloInstruction* vector = dot_computation->parameter_instruction(1); - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_EQ( - analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, vector, 0)->size(), - 1); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, vector, 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, - /*slice_start=*/0, /*slice_limit=*/4, - /*subfragments=*/ElementsAre(4)))); -} - -TEST_F(TritonDotAnalysisTest, OutputBroadcastIsNotAccepted) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -ENTRY e { - p0 = f16[2,35] parameter(0) - p0c = bf16[2,35] convert(p0) - p1 = bf16[35,2] parameter(1) - dot = bf16[2,2] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT bc = bf16[2,2,100] broadcast(dot), dimensions={0,1} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kBroadcast); -} - -TEST_F(TritonDotAnalysisTest, DegenerateSplitFragmentIsHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -triton_gemm_r { - Arg_0.1 = s8[30,913,8,21]{3,2,1,0} parameter(0) - bitcast.6 = s8[30,8,21,913]{2,1,3,0} bitcast(Arg_0.1) - copy.7 = s8[30,8,21,913]{3,2,1,0} copy(bitcast.6) - bitcast.8 = s8[5040,913]{1,0} bitcast(copy.7) - convert.9 = bf16[5040,913]{1,0} convert(bitcast.8) - bitcast.32 = bf16[58,913]{1,0} parameter(1) - dot.33 = bf16[5040,58]{1,0} dot(convert.9, bitcast.32), - lhs_contracting_dims={1}, rhs_contracting_dims={1} - bitcast.34 = bf16[30,8,21,58]{3,2,1,0} bitcast(dot.33) - copy.35 = bf16[30,8,21,58]{2,1,3,0} copy(bitcast.34) - ROOT bitcast.41 = bf16[30,1,58,8,21]{4,3,2,1,0} bitcast(copy.35) -} - -ENTRY e { - Arg_0.1 = s8[30,913,8,21]{3,2,1,0} parameter(0) - Arg_1.2 = bf16[58,913]{1,0} parameter(1) - ROOT r = bf16[30,1,58,8,21]{4,3,2,1,0} fusion(Arg_0.1, Arg_1.2), kind=kCustom, - calls=triton_gemm_r, - backend_config={kind: "__triton_gemm"} -})")); - const HloComputation* dot_computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*dot_computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - dot_computation->root_instruction(), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 21, - /*slice_start=*/0, /*slice_limit=*/8 * 21, - /*subfragments=*/ElementsAre(21, 8)), - FieldsAre(/*stride=*/8 * 21 * 58, /*count=*/30, - /*slice_start=*/0, /*slice_limit=*/30, - /*subfragments=*/ElementsAre(30)))); -} - -using TritonSoftmaxAnalysisTest = HloTestBase; - -TEST_F(TritonSoftmaxAnalysisTest, DegenerateBatchDimensionIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -max { - p1 = f32[] parameter(1) - p0 = f32[] parameter(0) - ROOT m = f32[] maximum(p0, p1) -} - -triton_softmax_computation { - p0 = f32[1,97]{1,0} parameter(0) - bitcast = f32[97]{0} bitcast(p0) - constant = f32[] constant(-inf) - reduce = f32[] reduce(bitcast, constant), dimensions={0}, to_apply=max - broadcast = f32[1,97]{1,0} broadcast(reduce), dimensions={} - ROOT subtract = f32[1,97]{1,0} subtract(p0, broadcast) -} - -ENTRY e { - p0 = f32[1,97]{1,0} parameter(0) - ROOT r = f32[1,97]{1,0} fusion(p0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} -})")); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/97, - /*slice_start=*/0, /*slice_limit=*/97, - /*subfragments=*/ElementsAre(97)))); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 1), - nullptr); -} - -TEST_F(TritonSoftmaxAnalysisTest, BroadcastIntoBatchDimensionIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -c { - p1 = f32[127]{0} parameter(0) - ROOT b = f32[125,127]{1,0} broadcast(p1), dimensions={1} -} - -ENTRY e { - p0 = f32[127]{0} parameter(0) - ROOT t = f32[125,127]{1,0} fusion(p0), kind=kCustom, calls=c -})")); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, - /*slice_start=*/0, /*slice_limit=*/127, - /*subfragments=*/ElementsAre(127)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 1), - ElementsAre(FieldsAre(/*stride=*/127, /*count=*/125, - /*slice_start=*/0, /*slice_limit=*/125, - /*subfragments=*/ElementsAre(125)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->parameter_instruction(0), 0), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, - /*slice_start=*/0, /*slice_limit=*/127, - /*subfragments=*/ElementsAre(127)))); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->parameter_instruction(0), 1), - nullptr); -} - -TEST_F(TritonSoftmaxAnalysisTest, ReduceOfNonRowDimensionIsNotSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t -add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) -} - -triton_softmax_computation { - param_0 = f32[8,4,127]{2,1,0} parameter(0) - constant = f32[] constant(0) - ROOT reduce = f32[4,127]{1,0} reduce(param_0, constant), dimensions={0}, to_apply=add -} - -ENTRY main { - param_0 = f32[8,4,127]{2,1,0} parameter(0) - ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, - calls=triton_softmax_computation, - backend_config={"kind":"__triton_softmax"} -})")); - - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - const auto analysis = TritonFusionAnalysis::Execute(*computation); - EXPECT_FALSE(analysis.ok()); -} - TEST_F(GemmRewriterTritonTest, HandleDotIfCublasRequiresPadding) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 7f5d4ff2b412f9..93f511506eea37 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -94,13 +94,14 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" #include "xla/service/dump.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/target_util.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/status.h" diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.h b/third_party/xla/xla/service/gpu/ir_emitter_triton.h index dbb7b160f06504..8a6e233c17f02b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.h @@ -26,10 +26,10 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 10a4e8681eeffc..848c774aa94382 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/triton_support.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 890ddb3ece7381..e0990e5ba2b14b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index eb5c93fc03766e..039bc7baf371df 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/triton_support.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 345b1b6e3b18ce..2da6897e22734a 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -39,9 +39,11 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/literal_util.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index da131e5d7c6ad6..40b2bde030dd39 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/hlo_verifier.h" #include "xla/service/layout_assignment.h" #include "xla/service/pattern_matcher.h" diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc new file mode 100644 index 00000000000000..8af0e9529d49cf --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_fusion_analysis.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/service/instruction_fusion.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { + +namespace triton_fusion { + +/*static*/ FusionContext FusionContext::FromDotOperand( + const HloInstruction& dot, const int operand_number, const int split_k) { + // There can be either none or one split-K batch dimension. + const int num_split_k_batch_dims = split_k > 1; + int split_k_dimension_index = kNoDimensionIndex; + if (split_k > 1) { + split_k_dimension_index = + ContractingDimensionIndex(dot, operand_number) - 1; + } + int splittable_dimension_index = kNoDimensionIndex; + // LHS non-contracting dimension can be split if non-splitK batch is absent. + if (operand_number == 0 && + dot.dot_dimension_numbers().lhs_batch_dimensions_size() - + num_split_k_batch_dims == + 0) { + splittable_dimension_index = + NonContractingDimensionIndex(dot, operand_number); + } + FusionContext context( + DotProperties{ + static_cast(NonContractingDimensionIndex(dot, operand_number)), + splittable_dimension_index}, + DotRequirements(kNoSplitRequirement)); + context.dim_orders_[dot.operand(operand_number)] = + DimensionOrder::FromDotOperandOrOutput(*dot.operand(operand_number), + split_k_dimension_index); + return context; +} + +/*static*/ FusionContext FusionContext::FromDotOutput( + const HloInstruction& dot, const int split_k, + const int64_t splittable_dimension_major_part_size) { + // Allow non-contracting dimension originating from LHS to split if + // this dimension is split at the output at the same ratio as + // at the input. + int splittable_dimension_index = kNoDimensionIndex; + if (splittable_dimension_major_part_size > 1) { + // Split-K dimension is the first one in the output if present; + // LHS non-contracting follows (batch is absent in this case). + splittable_dimension_index = (split_k > 1) ? 1 : 0; + } + FusionContext context(DotProperties{/*noncontracting_dimension=*/-1, + splittable_dimension_index}, + DotRequirements(splittable_dimension_major_part_size)); + context.dim_orders_[&dot] = DimensionOrder::FromDotOperandOrOutput(dot); + return context; +} + +/*static*/ FusionContext FusionContext::FromSoftmaxRoot( + const HloInstruction& root) { + FusionContext context( + SoftmaxProperties{DimensionOrder::kSoftmaxReductionDimension, + DimensionOrder::kSoftmaxBatchDimension}, + SoftmaxRequirements{}); + context.dim_orders_[&root] = DimensionOrder::FromSoftmaxRoot(root); + return context; +} + +namespace { +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape())) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} +} // namespace + +bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { + // First check that all updates to insert are compatible to avoid + // incomplete merges. + for (const auto& [key, value] : update.dim_orders) { + auto it = dim_orders_.find(key); + if (it != dim_orders_.cend() && !it->second.IsPhysicallyEquivalent(value)) { + return false; + } + } + + RequirementsOrError requirements_or_error = + triton_fusion::CombineRequirements(requirements_, update.requirements); + if (std::holds_alternative(requirements_or_error)) { + return false; + } + + requirements_ = std::move(std::get(requirements_or_error)); + dim_orders_.insert(update.dim_orders.begin(), update.dim_orders.end()); + return true; +} + +Status FusionContext::PropagateDimensionOrdersToParameters( + const HloInstruction& origin, ConstHloInstructionSet& parameters, + ConstHloInstructionMap& iter_specs) { + absl::flat_hash_set visited; + std::queue to_process; + // Dimension orders describing outputs of corresponding instructions. + visited.insert(&origin); + to_process.push(&origin); + while (!to_process.empty()) { + const HloInstruction* hlo = to_process.front(); + to_process.pop(); + if (hlo->opcode() == HloOpcode::kParameter) { + // One parameter corresponds to one iteration spec in the results of the + // analysis. This describes well situations when a parameter has one or + // more elementwise users - they share the same tiling. Situations when + // one instruction is read differently by different users in the same + // scope of the dot are currently prevented during the fusion. + TF_RET_CHECK(parameters.insert(hlo).second); + VLOG(5) << hlo->ToString(); + } + DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( + *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, + properties_); + TF_RET_CHECK(std::holds_alternative(result)); + TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); + iter_specs[hlo] = dim_orders_.at(hlo).ToTensorIterationSpec(); + for (const HloInstruction* operand : hlo->operands()) { + if (!visited.insert(operand).second) { + continue; + } + if (operand->opcode() == HloOpcode::kDot) { + // Encountering the dot itself happens during the processing of the + // output fusion. The propagation should stop at it. + continue; + } + to_process.push(operand); + } + } + return OkStatus(); +} + +} // namespace triton_fusion + +StatusOr TritonFusionAnalysis::Execute( + const HloComputation& computation, const int split_k) { + VLOG(5) << computation.ToString(HloPrintOptions::ShortParsable()); + TritonFusionAnalysis analysis; + const HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot); + if (dot != nullptr) { + TF_RETURN_IF_ERROR(analysis.ExecuteForDotFusion(*dot, split_k)); + } else { + TF_RETURN_IF_ERROR( + analysis.ExecuteForSoftmaxFusion(*computation.root_instruction())); + } + return analysis; +} + +Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( + const HloInstruction& root) { + auto context = triton_fusion::FusionContext::FromSoftmaxRoot(root); + // Softmax fusion uses one tiled scope. + TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( + root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); + iter_specs_[Scope::LHS] = {}; + iter_specs_[Scope::RHS] = {}; + return OkStatus(); +} + +Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, + const int split_k) { + int64_t lhs_nc_split_major_part_size = triton_fusion::kNoSplitRequirement; + for (const Scope scope : {Scope::LHS, Scope::RHS}) { + const int operand_number = static_cast(scope); + auto context = triton_fusion::FusionContext::FromDotOperand( + dot, operand_number, split_k); + TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( + *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); + if (scope == Scope::LHS) { + lhs_nc_split_major_part_size = + context.splittable_dimension_major_part_size(); + } + } + + auto context = triton_fusion::FusionContext::FromDotOutput( + dot, split_k, lhs_nc_split_major_part_size); + const HloInstruction* output = ˙ + // Currently supported is one fusion output and one path from dot to it. + // Propagate dimension order from dot to root. + while (!output->IsRoot()) { + TF_RET_CHECK(output->user_count() == 1); + const HloInstruction* input = output; + output = output->users()[0]; + triton_fusion::DimOrdersAndReqsOrError result = + triton_fusion::GetPropagatedDimOrdersAndRequirements( + *output, context.dim_orders().at(input), + triton_fusion::TransformDirection::kInputToOutput, + context.hero_properties()); + TF_RET_CHECK( + std::holds_alternative(result)); + TF_RET_CHECK(context.CombineDimOrdersAndReqs( + std::get(result))); + } + TF_RET_CHECK( + iter_specs_[Scope::OUTPUT] + .insert( + {output, context.dim_orders().at(output).ToTensorIterationSpec()}) + .second); + if (output != &dot) { + // Propagate back to parameters of the output fusion. + TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( + *output, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); + } + return OkStatus(); +} + +const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec( + const TritonFusionAnalysis::Scope scope, const HloInstruction* hlo, + const int dimension) const { + auto hlo_spec = iter_specs_.at(scope).find(hlo); + if (hlo_spec != iter_specs_.at(scope).cend()) { + auto dim_spec = hlo_spec->second.Storage().find(dimension); + if (dim_spec != hlo_spec->second.Storage().cend()) { + return &dim_spec->second; + } + } + return nullptr; +} + +namespace { +std::string IterationSpecByInstructionMapToString( + const TritonFusionAnalysis::IterationSpecByInstructionMap& m) { + return absl::StrCat("IterSpec{", + absl::StrJoin(m, ", ", + [&](std::string* s, const auto& kv) { + absl::StrAppend(s, kv.first->name(), ": ", + kv.second.ToString()); + }), + "}"); +} + +std::string ScopeToString(TritonFusionAnalysis::Scope s) { + switch (s) { + case TritonFusionAnalysis::Scope::LHS: + return "LHS"; + case TritonFusionAnalysis::Scope::RHS: + return "RHS"; + case TritonFusionAnalysis::Scope::OUTPUT: + return "OUTPUT"; + } +} +} // namespace + +std::string TritonFusionAnalysis::ToString() const { + return absl::StrCat( + "TritonFusionAnalysis{\n", + absl::StrJoin(iter_specs_, ",\n", + [&](std::string* s, const auto& kv) { + absl::StrAppend( + s, ScopeToString(kv.first), ": ", + IterationSpecByInstructionMapToString(kv.second)); + }), + "\n}"); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.h b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h new file mode 100644 index 00000000000000..4f8c225ab378d7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h @@ -0,0 +1,134 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_TRITON_FUSION_ANALYSIS_H_ +#define XLA_SERVICE_GPU_TRITON_FUSION_ANALYSIS_H_ + +// This file contains TritonFusionAnalysis and FusionContext. + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +// Analysis of tensor iteration orders within tiled fusions. +class TritonFusionAnalysis { + Status ExecuteForDotFusion(const HloInstruction& dot, int split_k); + Status ExecuteForSoftmaxFusion(const HloInstruction& root); + + public: + // Execute the analysis of a fusion computation. + // `split_k` indicates whether this operation was converted to the split-K + // form and tells the analysis how to interpret the batch dimensions. + static StatusOr Execute( + const HloComputation& computation, int split_k = 1); + + // A scope is an HLO graph that can be tiled efficiently using same or + // compatible tile shapes on all operations. GEMM fusion has 3 scopes + // defined by left operand, right operand and output. + enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + + using IterationSpecByInstructionMap = + ConstHloInstructionMap; + using IterationSpecByInstructionByScopeMap = + std::map; + + // Every parameter requires a separate piece of shared memory for asynchronous + // loads. Multiple parameters are approximately equivalent to multiple + // pipeline stages. + // Note: this has been tuned specifically for GEMMs, where pipelining with + // more than 4 stages has been shown to rarely be practical. This limitation + // is not necessarily applicable to other operations. + static constexpr int kMaxParameterPerDotScope = 4; + + // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. + const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, + const HloInstruction*, + int dimension) const; + // Parameter HLO instructions used in a scope of `dot`. + const ConstHloInstructionSet& ScopeParameters(const Scope scope) const { + return parameters_.at(scope); + } + + std::string ToString() const; + + private: + IterationSpecByInstructionByScopeMap iter_specs_; + // HLO computation parameters per scope. + std::map parameters_; +}; + +// The details of the Triton fusion / tiling propagation are in a separate +// namespace to avoid littering the xla::gpu namespace. +namespace triton_fusion { +class FusionContext { + FusionContext(HeroProperties properties, Requirements requirements) + : properties_(properties), requirements_(requirements) {} + + public: + // Create fusion context from a dot operand according to + // the currently supported configurations. + static FusionContext FromDotOperand(const HloInstruction& dot, + int operand_number, int split_k = 1); + + // Create fusion context from dot's output. + static FusionContext FromDotOutput( + const HloInstruction& dot, int split_k, + int64_t splittable_dimension_major_part_size); + + static FusionContext FromSoftmaxRoot(const HloInstruction&); + + // Add dimension orders from `update` to `dim_orders_` and update + // `requirements_` if all of them are compatible. + bool CombineDimOrdersAndReqs(const DimOrdersAndReqs& update); + + // Propagate dimension orders in consumer->producer direction starting at + // `origin` with output `origin_dim_order` till parameters of the + // computation. Store the found parameters and their iteration specs. + Status PropagateDimensionOrdersToParameters( + const HloInstruction& origin, ConstHloInstructionSet& parameters, + ConstHloInstructionMap& iter_specs); + + int64_t splittable_dimension_major_part_size() const { + CHECK(std::holds_alternative(requirements_)); + return std::get(requirements_) + .splittable_dimension_major_part_size; + } + const HeroProperties& hero_properties() const { return properties_; } + const DimOrderMap& dim_orders() const { return dim_orders_; } + + private: + const HeroProperties properties_; + Requirements requirements_; + DimOrderMap dim_orders_; +}; + +} // namespace triton_fusion + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRITON_FUSION_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc new file mode 100644 index 00000000000000..d2f67580234f60 --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -0,0 +1,678 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_fusion_analysis.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/gemm_rewriter_triton.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; +using ::testing::FieldsAre; + +using TritonDotAnalysisTest = HloTestBase; + +TEST_F(TritonDotAnalysisTest, NopBitcasts) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + param_0.1 = s8[48,4]{1,0} parameter(0) + bitcast.18 = s8[1,48,4]{2,1,0} bitcast(param_0.1) + bitcast.19 = s8[48,4]{1,0} bitcast(bitcast.18) + convert.4 = bf16[48,4]{1,0} convert(bitcast.19) + param_1.1 = bf16[4,3]{1,0} parameter(1) + ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[48,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + custom-call = bf16[48,3]{1,0} custom-call(p0, p1), + custom_call_target="__triton", + called_computations={triton_dot} + ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = module->entry_computation() + ->root_instruction() + ->operand(0) + ->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/4, /*count=*/48, /*slice_start=*/0, + /*slice_limit=*/48, ElementsAre(48)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, /*slice_start=*/0, + /*slice_limit=*/4, ElementsAre(4)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, /*slice_start=*/0, + /*slice_limit=*/4, ElementsAre(4)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, /*slice_start=*/0, + /*slice_limit=*/3, ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, Merge) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + param_0.1 = s8[1,8,6,4]{3,2,1,0} parameter(0) + bitcast.18 = s8[48,4]{1,0} bitcast(param_0.1) + convert.4 = bf16[48,4]{1,0} convert(bitcast.18) + param_1.1 = bf16[4,3]{1,0} parameter(1) + ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[1,8,6,4]{3,2,1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + custom-call = bf16[48,3]{1,0} custom-call(p0, p1), + custom_call_target="__triton", + called_computations={triton_dot} + ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = module->entry_computation() + ->root_instruction() + ->operand(0) + ->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/4, /*count=*/6 * 8, + /*slice_start=*/0, /*slice_limit=*/6 * 8, + /*subfragments=*/ElementsAre(6, 8)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, + /*slice_start=*/0, /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, Split) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + %parameter_1 = f32[24000,2]{1,0} parameter(1) + %convert.15 = f16[24000,2]{1,0} convert(%parameter_1) + %parameter_0 = f16[4]{0} parameter(0) + %bitcast.45 = f16[2,2]{1,0} bitcast(%parameter_0) + ROOT %dot.26 = f16[24000,2]{1,0} dot(%convert.15, %bitcast.45), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[4]{0} parameter(0) + p1 = f32[24000,2]{1,0} parameter(1) + ROOT r = f16[24000,2]{1,0} custom-call(p0, p1), + custom_call_target="__triton", + called_computations={triton_dot} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p1); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p0); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/2, /*count=*/24000, + /*slice_start=*/0, /*slice_limit=*/24000, + /*subfragments=*/ElementsAre(24000)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/2, + /*slice_start=*/0, /*slice_limit=*/2, + /*subfragments=*/ElementsAre(2)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/2, /*count=*/2, + /*slice_start=*/0, /*slice_limit=*/2, + /*subfragments=*/ElementsAre(2)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/2, + /*slice_start=*/0, /*slice_limit=*/2, + /*subfragments=*/ElementsAre(2)))); +} + +TEST_F(TritonDotAnalysisTest, TransposeMerge) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + param_0.1 = s8[1,4,8,6]{3,2,1,0} parameter(0) + transpose.3 = s8[1,8,6,4]{3,2,1,0} transpose(param_0.1), dimensions={0,2,3,1} + bitcast.18 = s8[48,4]{1,0} bitcast(transpose.3) + convert.4 = bf16[48,4]{1,0} convert(bitcast.18) + param_1.1 = bf16[4,3]{1,0} parameter(1) + ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[1,4,8,6]{3,2,1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + custom-call = bf16[48,3]{1,0} custom-call(p0, p1), + custom_call_target="__triton", + called_computations={triton_dot} + ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = module->entry_computation() + ->root_instruction() + ->operand(0) + ->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 6, + /*slice_start=*/0, /*slice_limit=*/8 * 6, + /*subfragments=*/ElementsAre(6, 8)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/8 * 6, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, + /*slice_start=*/0, /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, CopyMerge) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + param_0.1 = s8[1,4,8,6]{3,2,1,0} parameter(0) + bitcast.99 = s8[1,8,6,4]{2,1,3,0} bitcast(param_0.1) + copy.3 = s8[1,8,6,4]{3,2,1,0} copy(bitcast.99) + bitcast.18 = s8[48,4]{1,0} bitcast(copy.3) + convert.4 = bf16[48,4]{1,0} convert(bitcast.18) + param_1.1 = bf16[4,3]{1,0} parameter(1) + ROOT dot = bf16[48,3]{1,0} dot(convert.4, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[1,4,8,6]{3,2,1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + custom-call = bf16[48,3]{1,0} custom-call(p0, p1), + custom_call_target="__triton", + called_computations={triton_dot} + ROOT bitcast.2 = bf16[1,8,6,3]{3,2,1,0} bitcast(custom-call) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = module->entry_computation() + ->root_instruction() + ->operand(0) + ->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 6, + /*slice_start=*/0, /*slice_limit=*/8 * 6, + /*subfragments=*/ElementsAre(6, 8)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/8 * 6, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, + /*slice_start=*/0, /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, TransposeMergeNCN) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + param_0.1 = bf16[3,4,8,1]{3,2,1,0} parameter(0) + transpose.3 = bf16[3,8,1,4]{3,2,1,0} transpose(param_0.1), dimensions={0,2,3,1} + bitcast.18 = bf16[24,4]{1,0} bitcast(transpose.3) + param_1.1 = bf16[4,3]{1,0} parameter(1) + ROOT dot = bf16[24,3]{1,0} dot(bitcast.18, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[3,4,8,1]{3,2,1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + custom-call = bf16[24,3]{1,0} custom-call(p0, p1), + custom_call_target="__triton", called_computations={triton_dot} + ROOT bitcast.2 = bf16[3,8,1,3]{3,2,1,0} bitcast(custom-call) +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = module->entry_computation() + ->root_instruction() + ->operand(0) + ->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8, + /*slice_start=*/0, /*slice_limit=*/8, + /*subfragments=*/ElementsAre(8)), + FieldsAre(/*stride=*/4 * 8, /*count=*/3, + /*slice_start=*/0, /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/8, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/3, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/3, + /*slice_start=*/0, /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, TransposeOutput) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + bc = bf16[12,2,3]{2,1,0} bitcast(dot) + ROOT t = bf16[3,12,2]{2,1,0} transpose(bc), dimensions={2,0,1} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + ROOT r = bf16[3,12,2]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* dot_output = dot_computation->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, dot_output, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, /*slice_start=*/0, + /*slice_limit=*/24, + /*subfragments=*/ElementsAre(2, 12)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, dot_output, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, /*slice_start=*/0, + /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, OutputParameterIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + dot = bf16[24,3]{1,0} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = f16[3,24]{1,0} parameter(2) + p2t = f16[24,3]{1,0} transpose(p2), dimensions={1,0} + p2tc = bf16[24,3]{1,0} convert(p2t) + ROOT r = bf16[24,3]{1,0} divide(p2tc, dot) +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4,3]{1,0} parameter(1) + p2 = f16[3,24]{1,0} parameter(2) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1, p2), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* output_param = + dot_computation->parameter_instruction(2); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ( + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 0) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, /*slice_start=*/0, + /*slice_limit=*/24, + /*subfragments=*/ElementsAre(24)))); + EXPECT_EQ( + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 1) + ->size(), + 1); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, output_param, 1), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/3, /*slice_start=*/0, + /*slice_limit=*/3, + /*subfragments=*/ElementsAre(3)))); +} + +TEST_F(TritonDotAnalysisTest, InputBroadcastFromScalarIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + p1b = bf16[4,3] broadcast(p1) + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* scalar = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, scalar, 0), + nullptr); + EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, scalar, 1), + nullptr); +} + +TEST_F(TritonDotAnalysisTest, InputBroadcastFromVectorIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +triton_dot { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + p1b = bf16[4,3] broadcast(p1), dimensions={0} + ROOT dot = bf16[24,3]{1,0} dot(p0, p1b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = bf16[24,4]{1,0} parameter(0) + p1 = bf16[4] parameter(1) + ROOT r = bf16[24,3]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* vector = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ( + analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, vector, 0)->size(), + 1); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, vector, 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/4, + /*slice_start=*/0, /*slice_limit=*/4, + /*subfragments=*/ElementsAre(4)))); +} + +TEST_F(TritonDotAnalysisTest, OutputBroadcastIsNotAccepted) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f16[2,35] parameter(0) + p0c = bf16[2,35] convert(p0) + p1 = bf16[35,2] parameter(1) + dot = bf16[2,2] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT bc = bf16[2,2,100] broadcast(dot), dimensions={0,1} +})")); + EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kBroadcast); +} + +TEST_F(TritonDotAnalysisTest, DegenerateSplitFragmentIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm_r { + Arg_0.1 = s8[30,913,8,21]{3,2,1,0} parameter(0) + bitcast.6 = s8[30,8,21,913]{2,1,3,0} bitcast(Arg_0.1) + copy.7 = s8[30,8,21,913]{3,2,1,0} copy(bitcast.6) + bitcast.8 = s8[5040,913]{1,0} bitcast(copy.7) + convert.9 = bf16[5040,913]{1,0} convert(bitcast.8) + bitcast.32 = bf16[58,913]{1,0} parameter(1) + dot.33 = bf16[5040,58]{1,0} dot(convert.9, bitcast.32), + lhs_contracting_dims={1}, rhs_contracting_dims={1} + bitcast.34 = bf16[30,8,21,58]{3,2,1,0} bitcast(dot.33) + copy.35 = bf16[30,8,21,58]{2,1,3,0} copy(bitcast.34) + ROOT bitcast.41 = bf16[30,1,58,8,21]{4,3,2,1,0} bitcast(copy.35) +} + +ENTRY e { + Arg_0.1 = s8[30,913,8,21]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[58,913]{1,0} parameter(1) + ROOT r = bf16[30,1,58,8,21]{4,3,2,1,0} fusion(Arg_0.1, Arg_1.2), kind=kCustom, + calls=triton_gemm_r, + backend_config={kind: "__triton_gemm"} +})")); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + dot_computation->root_instruction(), 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/8 * 21, + /*slice_start=*/0, /*slice_limit=*/8 * 21, + /*subfragments=*/ElementsAre(21, 8)), + FieldsAre(/*stride=*/8 * 21 * 58, /*count=*/30, + /*slice_start=*/0, /*slice_limit=*/30, + /*subfragments=*/ElementsAre(30)))); +} + +using TritonSoftmaxAnalysisTest = HloTestBase; + +TEST_F(TritonSoftmaxAnalysisTest, DegenerateBatchDimensionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +max { + p1 = f32[] parameter(1) + p0 = f32[] parameter(0) + ROOT m = f32[] maximum(p0, p1) +} + +triton_softmax_computation { + p0 = f32[1,97]{1,0} parameter(0) + bitcast = f32[97]{0} bitcast(p0) + constant = f32[] constant(-inf) + reduce = f32[] reduce(bitcast, constant), dimensions={0}, to_apply=max + broadcast = f32[1,97]{1,0} broadcast(reduce), dimensions={} + ROOT subtract = f32[1,97]{1,0} subtract(p0, broadcast) +} + +ENTRY e { + p0 = f32[1,97]{1,0} parameter(0) + ROOT r = f32[1,97]{1,0} fusion(p0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->root_instruction(), 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/97, + /*slice_start=*/0, /*slice_limit=*/97, + /*subfragments=*/ElementsAre(97)))); + EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->root_instruction(), 1), + nullptr); +} + +TEST_F(TritonSoftmaxAnalysisTest, BroadcastIntoBatchDimensionIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +c { + p1 = f32[127]{0} parameter(0) + ROOT b = f32[125,127]{1,0} broadcast(p1), dimensions={1} +} + +ENTRY e { + p0 = f32[127]{0} parameter(0) + ROOT t = f32[125,127]{1,0} fusion(p0), kind=kCustom, calls=c +})")); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->root_instruction(), 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, + /*slice_start=*/0, /*slice_limit=*/127, + /*subfragments=*/ElementsAre(127)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->root_instruction(), 1), + ElementsAre(FieldsAre(/*stride=*/127, /*count=*/125, + /*slice_start=*/0, /*slice_limit=*/125, + /*subfragments=*/ElementsAre(125)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->parameter_instruction(0), 0), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127, + /*slice_start=*/0, /*slice_limit=*/127, + /*subfragments=*/ElementsAre(127)))); + EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->parameter_instruction(0), 1), + nullptr); +} + +TEST_F(TritonSoftmaxAnalysisTest, ReduceOfNonRowDimensionIsNotSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[8,4,127]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[4,127]{1,0} reduce(param_0, constant), dimensions={0}, to_apply=add +} + +ENTRY main { + param_0 = f32[8,4,127]{2,1,0} parameter(0) + ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc new file mode 100644 index 00000000000000..1f7845ab79bb49 --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -0,0 +1,128 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_support.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +bool IsDistributiveOverAddition(const HloInstruction& hlo) { + // The list is most likely incomplete. + // For example division can be added too but only for operand #0. + if (hlo.opcode() == HloOpcode::kMultiply || + hlo.opcode() == HloOpcode::kNegate || + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || + hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kConvert || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kSlice) { + return true; + } + return false; +} + +// Data types that are supported by the Triton emitters. +// +// BF16 is supported in a sense that all operations on it are implemented +// through F32 and converts have to be inserted into the HLO graph, but +// they can be missing during fusion. +bool IsTritonSupportedDataType(PrimitiveType type, + se::GpuComputeCapability gpu_version) { + auto cuda_compute_capability = + std::get(gpu_version); + switch (type) { + case PRED: + case S8: + case S16: + case S32: + case F16: + case F32: + return true; + case BF16: + return cuda_compute_capability.IsAtLeast( + stream_executor::CudaComputeCapability::AMPERE); + default: + return false; + } +} + +std::vector TritonSupportedUnaryElementwise( + PrimitiveType element_type) { + std::vector ret = {HloOpcode::kConvert}; + if (element_type == PrimitiveType::PRED) { + ret.push_back(HloOpcode::kNot); + return ret; + } + ret.push_back(HloOpcode::kAbs); + ret.push_back(HloOpcode::kNegate); + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, + HloOpcode::kSin, HloOpcode::kSqrt, + HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh}, + std::back_inserter(ret)); + } + return ret; +} + +std::vector TritonSupportedBinaryElementwise( + PrimitiveType element_type) { + if (element_type == PrimitiveType::PRED) { + return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, + HloOpcode::kCompare}; + } + std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, + HloOpcode::kMaximum, HloOpcode::kMinimum, + HloOpcode::kMultiply, HloOpcode::kSubtract}; + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + ret.push_back(HloOpcode::kAtan2); + ret.push_back(HloOpcode::kDivide); + ret.push_back(HloOpcode::kPower); + } + return ret; +} + +std::vector TritonSupportedTernaryElementwise( + PrimitiveType element_type) { + return {HloOpcode::kSelect}; +} + +bool IsTritonSupportedElementwise(HloOpcode opcode, + PrimitiveType element_type) { + return absl::c_linear_search(TritonSupportedUnaryElementwise(element_type), + opcode) || + absl::c_linear_search(TritonSupportedBinaryElementwise(element_type), + opcode) || + absl::c_linear_search(TritonSupportedTernaryElementwise(element_type), + opcode); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_support.h b/third_party/xla/xla/service/gpu/triton_support.h new file mode 100644 index 00000000000000..e95f196fdc6a5f --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_support.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_TRITON_SUPPORT_H_ +#define XLA_SERVICE_GPU_TRITON_SUPPORT_H_ + +// This file is the home of the basic Triton support checks which are used by +// multiple other components. + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" +namespace xla { +namespace gpu { + +// Tells if f(a+b) == f(a) + f(b). +bool IsDistributiveOverAddition(const HloInstruction& hlo); + +// Allowlist of unary elementwise operations supported by Triton GEMM codegen. +std::vector TritonSupportedUnaryElementwise(PrimitiveType); + +// Allowlist of binary elementwise operations supported by Triton GEMM codegen. +std::vector TritonSupportedBinaryElementwise(PrimitiveType); + +// Allowlist of ternary elementwise operations supported by Triton GEMM codegen. +std::vector TritonSupportedTernaryElementwise(PrimitiveType); + +// Data types that are supported by the Triton emitters. +bool IsTritonSupportedDataType(PrimitiveType, se::GpuComputeCapability); + +// Checks elementwise operation against all supported by Triton GEMM codegen. +bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRITON_SUPPORT_H_ diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc new file mode 100644 index 00000000000000..f64042649ec032 --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -0,0 +1,1075 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_tiling_propagation.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/permutation_util.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/service/instruction_fusion.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { + VLOG(9) << this->ToString(); + VLOG(9) << other.ToString(); + auto it_this = dim_iteration_specs_.cbegin(); + while (it_this != dim_iteration_specs_.cend()) { + auto it_other = other.dim_iteration_specs_.find(it_this->first); + if (it_other == other.dim_iteration_specs_.cend()) { + return false; + } + if (it_this->second.size() != it_other->second.size()) { + return false; + } + for (int fragment = 0; fragment < it_this->second.size(); ++fragment) { + if (it_this->second[fragment] != it_other->second[fragment]) { + return false; + } + } + ++it_this; + } + return true; +} + +std::string TensorIterationSpec::IterationSpecFragment::ToString() const { + return absl::StrCat("{stride=", stride, ", count=", count, + ", slice_start=", slice_start, ", subfragments=[", + absl::StrJoin(subfragments, ", "), "]}"); +} + +bool TensorIterationSpec::IterationSpecFragment::operator!=( + const IterationSpecFragment& other) const { + return stride != other.stride || count != other.count || + slice_start != other.slice_start || slice_limit != other.slice_limit; +} + +std::string TensorIterationSpec::ToString() const { + return absl::StrCat( + "{", + absl::StrJoin(dim_iteration_specs_, ", ", + [&](std::string* s, const auto& kv) { + absl::StrAppend( + s, kv.first, ": ", "[", + absl::StrJoin(kv.second, ", ", + [&](std::string* ss, const auto& v) { + absl::StrAppend(ss, v.ToString()); + }), + "]"); + }), + "}"); +} + +namespace triton_fusion { + +using Fragment = DimensionOrder::Fragment; +using Fragments = DimensionOrder::Fragments; +using FragmentOrders = DimensionOrder::FragmentOrders; + +/*static*/ DimensionOrder DimensionOrder::FromDotOperandOrOutput( + const HloInstruction& hlo, const int split_k_dimension_index) { + DimensionOrder dim_order; + dim_order.tensor_fragments_order_.reserve(hlo.shape().rank()); + for (const int i : hlo.shape().layout().minor_to_major()) { + int target_dim_number = i; + if (i == split_k_dimension_index) { + CHECK(!dim_order.tensor_fragments_order_.empty()) + << "The split-K batch dimension has be preceded by the contracting " + "dimension it originates from by construction."; + target_dim_number = + dim_order.tensor_fragments_order_.back().dst_dim_number(); + } + dim_order.dim_fragments_orders_[target_dim_number].push_back( + dim_order.tensor_fragments_order_.size()); + dim_order.tensor_fragments_order_.push_back( + Fragment{target_dim_number, hlo.shape().dimensions(i)}); + } + return dim_order; +} + +/*static*/ DimensionOrder DimensionOrder::FromSoftmaxRoot( + const HloInstruction& hlo) { + DimensionOrder dim_order; + dim_order.tensor_fragments_order_.reserve(hlo.shape().rank()); + dim_order.dim_fragments_orders_[kSoftmaxReductionDimension].push_back( + dim_order.tensor_fragments_order_.size()); + dim_order.tensor_fragments_order_.push_back( + Fragment{kSoftmaxReductionDimension, hlo.shape().dimensions_minor(0)}); + for (int i = 1; i < hlo.shape().rank(); ++i) { + dim_order.dim_fragments_orders_[kSoftmaxBatchDimension].push_back( + dim_order.tensor_fragments_order_.size()); + dim_order.tensor_fragments_order_.push_back( + Fragment{kSoftmaxBatchDimension, hlo.shape().dimensions_minor(i)}); + } + return dim_order; +} + +std::string DimensionOrder::Fragment::ToString() const { + return absl::StrCat(dst_dim_number_, ":", size_, ":", slice_start_, "-", + slice_limit_); +} + +std::string DimensionOrder::ToString() const { + std::string ret = absl::StrJoin(tensor_fragments_order_, " - ", + [](std::string* out, const Fragment& f) { + absl::StrAppend(out, f.ToString(), " "); + }); + absl::StrAppend(&ret, "|"); + for (const auto& [dim, fragments] : dim_fragments_orders_) { + absl::StrAppend(&ret, dim, ":", absl::StrJoin(fragments, ","), " "); + } + return ret; +} + +TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { + const Fragments& dim_fragments = TensorFragmentsOrder(); + TensorIterationSpec tensor_spec; + int64_t accumulated_stride = 1; + int last_dim = -1; + auto remove_last_fragment_if_degenerate = [&tensor_spec](const int dim_idx) { + if (dim_idx >= 0 && !tensor_spec[dim_idx].empty() && + tensor_spec[dim_idx].back().count == 1) { + tensor_spec[dim_idx].pop_back(); + } + }; + for (int dim_order_index = 0; dim_order_index < dim_fragments.size(); + ++dim_order_index) { + const DimensionOrder::Fragment& fragment = dim_fragments[dim_order_index]; + VLOG(6) << fragment.ToString(); + + TensorIterationSpec::DimIterationSpec& dim_spec = + tensor_spec[fragment.dst_dim_number()]; + if (last_dim == fragment.dst_dim_number()) { + // Remove previous 1-sized subfragment if present. + if (!dim_spec.empty() && !dim_spec.back().subfragments.empty() && + dim_spec.back().subfragments.back() == 1) { + dim_spec.back().subfragments.pop_back(); + } + // Contiguous dimension, split only logically. Merge it back. + if (fragment.full_size() > 1) { + CHECK(!dim_spec.empty()); + CHECK(!dim_spec.back().is_sliced()) + << "Only the major-most fragment can have an offset."; + dim_spec.back().slice_start = + fragment.slice_start() * dim_spec.back().count; + dim_spec.back().slice_limit = + fragment.slice_limit() * dim_spec.back().count; + dim_spec.back().count *= fragment.full_size(); + dim_spec.back().subfragments.push_back(fragment.sliced_size()); + } + } else { + remove_last_fragment_if_degenerate(last_dim); + // Add part of the dimension. + dim_spec.push_back( + TensorIterationSpec::IterationSpecFragment{accumulated_stride, + fragment.full_size(), + fragment.slice_start(), + fragment.slice_limit(), + {fragment.sliced_size()}}); + } + + accumulated_stride *= fragment.full_size(); + last_dim = fragment.dst_dim_number(); + } + remove_last_fragment_if_degenerate(last_dim); + tensor_spec.RemoveEmptyDimensions(); + return tensor_spec; +} + +namespace { +// Logical index of a dimension in `shape` labeled with `label` in the +// `dim_order` describing the shape. +std::optional LogicalIndexOfLabeledDimension( + const Shape& shape, const DimensionOrder& dim_order, const int label) { + auto fragment_it = dim_order.TensorFragmentsOrder().cbegin(); + for (int dim : shape.layout().minor_to_major()) { + const int64_t dim_size = shape.dimensions()[dim]; + int64_t fragments_size = 1; + while (fragments_size < dim_size) { + fragments_size *= fragment_it->full_size(); + if (fragment_it->dst_dim_number() == label) { + return dim; + } + ++fragment_it; + } + } + return std::nullopt; +} + +using Int64OrError = std::variant; +Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { + if (a == b || b == kNoSplitRequirement) { + return a; + } + if (a == kNoSplitRequirement) { + return b; + } + return FusionDecision("Conflicting splits of splittable dimension"); +} + +RequirementsOrError CombineDotRequirements(DotRequirements a, + DotRequirements b) { + Int64OrError combined_size_req = + CombineSplitDimMajorPartSizeReqs(a.splittable_dimension_major_part_size, + b.splittable_dimension_major_part_size); + if (std::holds_alternative(combined_size_req)) { + return std::get(combined_size_req); + } + return DotRequirements(std::get(combined_size_req)); +} + +RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a, + SoftmaxRequirements b) { + // SoftmaxRequirements is an empty class for now. + return a; +} +} // namespace + +RequirementsOrError CombineRequirements(Requirements a, + RequirementsOrError b_or_error) { + if (std::holds_alternative(b_or_error)) { + return b_or_error; + } + const Requirements& b = std::get(b_or_error); + if (std::holds_alternative(b)) { + return CombineDotRequirements(std::get(a), + std::get(b)); + } + return CombineSoftmaxRequirements(std::get(a), + std::get(b)); +} + +namespace { + +// If the dimension order is supported by the triton emitters, this returns +// which requirements does this order impose on the fusion. +// +// All subdimensions within a dimension have to be ordered. +RequirementsOrError GetRequirementsIfSupportedOrder( + const DimensionOrder& order, const HeroProperties& properties) { + VLOG(8) << order.ToString(); + int64_t split_dim_major_part = kNoSplitRequirement; + const Fragments& tensor_dim_fragments = order.TensorFragmentsOrder(); + for (const auto& [dim_index, dim_fragments] : order.DimFragmentsOrders()) { + CHECK(!dim_fragments.empty()); + for (int i = 0; i < dim_fragments.size() - 1; ++i) { + if (tensor_dim_fragments[dim_fragments[i]].is_sliced()) { + return "Sliced non-major-most fragment."; + } + } + int group_counter = 0; + int last_seen_group_last_fragment_index = -1; + auto fragment_it = dim_fragments.cbegin(); + while (true) { + if (fragment_it == dim_fragments.cend()) { + break; + } + int64_t grouped_size = tensor_dim_fragments[*fragment_it].full_size(); + // Gather contiguous fragments: they have consecutive indices. + while ((fragment_it + 1) != dim_fragments.cend() && + *(fragment_it + 1) == *fragment_it + 1) { + ++fragment_it; + grouped_size *= tensor_dim_fragments[*fragment_it].full_size(); + } + // Ignore 1-sized groups of fragments. + if (grouped_size == 1) { + ++fragment_it; + continue; + } + + if (last_seen_group_last_fragment_index > *fragment_it) { + return "Transpose within a dimension."; + } + + ++group_counter; + if (group_counter > 1) { + if (!std::holds_alternative(properties)) { + return "Splitting a dimension is not supported for Softmax."; + } + // Only the dimension indicated by `splittable_dimension_index` (if any) + // can be split physically once by other dimensions. Other ones can be + // only split logically. + const int splittable_dimension_index = + std::get(properties).splittable_dimension_index; + if (dim_index == splittable_dimension_index) { + if (group_counter == 2) { + if (split_dim_major_part != kNoSplitRequirement && + split_dim_major_part != grouped_size) { + return "Conflicting splits of splittable dimension"; + } + split_dim_major_part = grouped_size; + } else if (group_counter > 2) { + return "2nd split of a splittable dimension."; + } + } else { + return "Unsupported split of a dimension."; + } + } + + last_seen_group_last_fragment_index = *fragment_it; + ++fragment_it; + } + } + + if (std::holds_alternative(properties)) { + return DotRequirements(split_dim_major_part); + } + return SoftmaxRequirements{}; +} + +// Apply GetRequirementsIfSupportedOrder() to all known +// dimension orders around `hlo` and combine the result. +RequirementsOrError GetRequirementsIfSupportedOrders( + const HloInstruction& hlo, const DimOrderMap& dim_orders, + const HeroProperties& properties) { + const Requirements empty_requirements = + std::holds_alternative(properties) + ? Requirements(DotRequirements(kNoSplitRequirement)) + : Requirements(SoftmaxRequirements{}); + auto get_requirements = + [&](const HloInstruction& instr) -> RequirementsOrError { + if (auto it = dim_orders.find(&instr); it != dim_orders.end()) { + return GetRequirementsIfSupportedOrder(it->second, properties); + } + return empty_requirements; + }; + + Requirements requirements = empty_requirements; + for (const HloInstruction* operand : hlo.operands()) { + RequirementsOrError requirements_or_error = + CombineRequirements(requirements, get_requirements(*operand)); + if (std::holds_alternative(requirements_or_error)) { + return requirements_or_error; + } + requirements = std::get(requirements_or_error); + } + + return CombineRequirements(requirements, get_requirements(hlo)); +} + +DimOrderMap GetPropagatedDimOrdersForElementwise( + const HloInstruction& hlo, TransformDirection direction, + const DimensionOrder& src_dim_order) { + if (direction == TransformDirection::kOutputToInput) { + DimOrderMap map; + for (const HloInstruction* operand : hlo.operands()) { + map.insert({operand, src_dim_order}); + } + return map; + } + + DimOrderMap map; + map.insert({&hlo, src_dim_order}); + // TODO(tdanyluk): For now, the "input to output" direction of this function + // also returns the dim orders for the operands, not just the output. This is + // needed to propagate the dim order of one input to the other(s) when fusing + // elementwise ops to the output. Perhaps we can separate the "input to + // output" and "output to input" directions of that in a later CL. + for (const HloInstruction* operand : hlo.operands()) { + map.insert({operand, src_dim_order}); + } + return map; +} + +const HloInstruction& GetSourceHlo(const HloInstruction& hlo, + TransformDirection direction) { + CHECK_GE(hlo.operand_count(), 1); + + if (direction == TransformDirection::kOutputToInput) { + return hlo; + } + return *hlo.operand(0); +} + +using ConstInstructionVector = absl::InlinedVector; +ConstInstructionVector GetDestHlos(const HloInstruction& hlo, + TransformDirection direction) { + if (direction == TransformDirection::kInputToOutput) { + return {&hlo}; + } + + ConstInstructionVector hlos; + hlos.reserve(hlo.operands().size()); + for (const HloInstruction* operand : hlo.operands()) { + hlos.push_back(operand); + } + return hlos; +} + +const HloInstruction& GetDestHlo(const HloInstruction& hlo, + TransformDirection direction) { + CHECK_EQ(hlo.operand_count(), 1); + + if (direction == TransformDirection::kInputToOutput) { + return hlo; + } + + return *hlo.operand(0); +} + +DimOrderMapOrError GetPropagatedDimOrdersForBitcast( + const HloInstruction& hlo, const TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties) { + const HloInstruction& dst = GetDestHlo(hlo, direction); + const Shape& dst_shape = dst.shape(); + const Fragments& src_fragments_order = src_dim_order.TensorFragmentsOrder(); + DimOrderMap dst_dim_orders; + DimensionOrder& dst_dim_order = + dst_dim_orders.insert({&dst, DimensionOrder()}).first->second; + Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); + // Size of not yet assigned part of current target dimension. + int64_t dst_remaining_size = 1; + // Track destination fragments created from a source one. + absl::flat_hash_map> src_to_dst; + // Iterate in parallel over source dimension order and target dimensions + // in minor_to_major order. Find groups of dimensions of equal size + // and project the source dimension order onto the destination. + auto dst_dim_it = dst_shape.layout().minor_to_major().cbegin(); + const auto dst_dim_end = dst_shape.layout().minor_to_major().cend(); + for (auto src_dim = src_fragments_order.cbegin(); + src_dim != src_fragments_order.cend(); ++src_dim) { + auto add_new_fragment = [&](const Fragment& fragment) { + dst_fragments_order.push_back(fragment); + src_to_dst[&*src_dim].push_back(dst_fragments_order.size() - 1); + }; + if (std::holds_alternative(properties) && + src_dim->dst_dim_number() == + std::get(properties).softmax_batch_dimension) { + // Special handling for softmax batch dimension: allow arbitrary reshapes + // on it because it's guaranteed by the construction of the fusion to have + // no physical alterations like transposes. + // Find a continuous group of fragments corresponding to this dimension in + // the source and assign the corresponding size in fragments of the + // destination ignoring the source ones. + dst_remaining_size = src_dim->full_size(); + while (src_dim + 1 != src_fragments_order.cend() && + (src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) { + ++src_dim; + dst_remaining_size *= src_dim->full_size(); + } + while (dst_remaining_size > 1) { + CHECK(dst_dim_it != dst_dim_end); + add_new_fragment(Fragment{src_dim->dst_dim_number(), + dst_shape.dimensions(*dst_dim_it)}); + dst_remaining_size /= dst_shape.dimensions(*dst_dim_it); + ++dst_dim_it; + } + continue; + } + if (dst_remaining_size >= src_dim->full_size()) { + if (dst_remaining_size % src_dim->full_size()) { + return "Unsupported bitcast"; + } + // Source dimension fragment completely fits into the destination one: + // just copy it as is. + add_new_fragment(*src_dim); + // Update the size of the remaining part of the destination that is + // carried over to next source dimensions. + dst_remaining_size /= src_dim->full_size(); + } else { + // Source is larger than destination. + // Assign further destination dimensions. + // Size of the not yet assigned part of the source dimension. + int64_t src_remaining_size = src_dim->full_size(); + // Handle dimension splits. + if (dst_remaining_size > 1) { + // If there is a remaining fragment of a previous destination dimension + // assign it first. + if (src_remaining_size % dst_remaining_size || (src_dim->is_sliced())) { + return "Unsupported bitcast"; + } + add_new_fragment( + Fragment{src_dim->dst_dim_number(), dst_remaining_size}); + // Update the size of the fragment remaining to assign. + src_remaining_size /= dst_remaining_size; + dst_remaining_size = 1; + } + while (src_remaining_size > 1) { + // Assign destination dimensions until the source remainder is covered. + CHECK(dst_dim_it != dst_dim_end); + int64_t dst_dim_size = dst_shape.dimensions(*dst_dim_it); + int64_t new_fragment_size = dst_dim_size; + if (dst_dim_size > src_remaining_size) { + // If adding the next destination dimension exceeds source fragment + // size assign the remainder of the source and carry over the + // remainder of the destination. + if (dst_dim_size % src_remaining_size) { + return "Unsupported bitcast"; + } + dst_remaining_size = dst_dim_size / src_remaining_size; + new_fragment_size = src_remaining_size; + } + if (src_dim->is_sliced()) { + return "Unsupported bitcast"; + } + add_new_fragment( + Fragment{src_dim->dst_dim_number(), new_fragment_size}); + src_remaining_size /= new_fragment_size; + ++dst_dim_it; + } + } + } + CHECK_EQ(dst_remaining_size, 1); + + // Handle remaining major dimensions of the destination. Call all degenerate + // ones subdimensions of the most-major non-degenerate one. Otherwise + // give up. + while (dst_dim_it != dst_dim_end) { + if (dst_shape.dimensions(*dst_dim_it) != 1) { + return "Unsupported bitcast"; + } + if (!dst_fragments_order.empty()) { + dst_fragments_order.push_back( + Fragment{dst_fragments_order.back().dst_dim_number(), 1}); + src_to_dst[&src_fragments_order.back()].push_back( + dst_fragments_order.size() - 1); + } + ++dst_dim_it; + } + + FragmentOrders& dst_dim_fragment_orders = dst_dim_order.DimFragmentsOrders(); + for (const auto& [dim_index, dim_sequence] : + src_dim_order.DimFragmentsOrders()) { + std::vector& dst = dst_dim_fragment_orders[dim_index]; + dst.reserve(dim_sequence.size()); + for (const int src : dim_sequence) { + std::copy(src_to_dst[&src_fragments_order[src]].cbegin(), + src_to_dst[&src_fragments_order[src]].cend(), + std::back_inserter(dst)); + } + } + + return dst_dim_orders; +} + +// Handle copy, transpose, broadcast or reduce. +// Common between them is that they alter the tensor dimensions or their order +// and the way to handle layouts. +DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( + const HloInstruction& hlo, const TransformDirection direction, + const DimensionOrder& src_dim_order, const HeroProperties& properties) { + // Temporary storage for new fragments local to this function. + // Please keep this as the first local variable of this function, with type + // std::list to make sure that all pointers to elements of this remain valid + // throughout the entire function. std::deque would also work but it is + // unnecessarily big for a typical size of 1. + std::list new_fragments; + + const HloInstruction& src = GetSourceHlo(hlo, direction); + // Note: copying instead of using a const reference because + // some operations (slice) will modify fragment properties in-place. + Fragments src_fragments_order = src_dim_order.TensorFragmentsOrder(); + if (hlo.opcode() == HloOpcode::kSlice && + ShapeUtil::IsEffectiveScalar(hlo.shape())) { + return FusionDecision("Slice to scalar is not implemented yet."); + } + // Every HLO dimension can correspond to a group of subdimensions in + // dim_order_. For the easier handling of permutations: group dim_order_ by + // dimension, apply permutations, then finally remove the grouping. + // Group subdimensions by iterating over them in the same order as over + // full dimensions and matching by total size. + std::vector> src_physical; + src_physical.reserve(src.shape().rank()); + auto src_fragment_it = src_fragments_order.begin(); + for (int64_t dim_index : src.shape().layout().minor_to_major()) { + const int64_t dim_size = src.shape().dimensions(dim_index); + int64_t subdim_size_accumulator = 1; + std::vector subdim_group; + do { + CHECK(src_fragment_it != src_fragments_order.end()); + subdim_size_accumulator *= src_fragment_it->full_size(); + subdim_group.push_back(&*src_fragment_it); + ++src_fragment_it; + } while (subdim_size_accumulator < dim_size); + CHECK_EQ(subdim_size_accumulator, dim_size); + src_physical.push_back(subdim_group); + } + + // Source physical -> source logical. + std::vector> src_logical; + src_logical.resize(src_physical.size()); + for (int i = 0; i < src_physical.size(); ++i) { + src_logical[src.shape().layout().minor_to_major(i)] = src_physical[i]; + } + + DimOrderMap dst_dim_orders; + for (const HloInstruction* dst : GetDestHlos(hlo, direction)) { + DimensionOrder& dst_dim_order = + dst_dim_orders.insert({dst, DimensionOrder()}).first->second; + // Source logical -> destination logical. + std::vector> dst_logical; + if (hlo.opcode() == HloOpcode::kTranspose) { + const auto* transpose = Cast(&hlo); + std::vector permutation(transpose->dimensions().cbegin(), + transpose->dimensions().cend()); + if (direction == TransformDirection::kInputToOutput) { + permutation = InversePermutation(permutation); + } + dst_logical.resize(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + dst_logical[permutation[i]] = src_logical[i]; + } + } else if (hlo.opcode() == HloOpcode::kBroadcast) { + const auto* broadcast = Cast(&hlo); + dst_logical.resize(broadcast->dimensions().size()); + for (int i = 0; i < broadcast->dimensions().size(); ++i) { + dst_logical[i] = src_logical[broadcast->dimensions()[i]]; + } + } else if (hlo.opcode() == HloOpcode::kReduce) { + // Operand 1 (the neutral value) has to be a scalar. + if (dst != &hlo && hlo.operand_index(dst) == 1) { + continue; + } + const auto* reduce = Cast(&hlo); + dst_logical.resize(src_logical.size() + reduce->dimensions().size()); + + if (reduce->dimensions().size() != 1) { + return FusionDecision("Unsupported reduction."); + } else if (reduce->dimensions().front() != + reduce->operand(0)->shape().rank() - 1) { + return FusionDecision("Only row reductions are supported."); + } + for (int i = 0; i < dst_logical.size(); ++i) { + if (i == reduce->dimensions().front()) { + // This way to assign the reduction dimension will only work for + // softmax fusions with known patterns for now. Generally a reduction + // should create a new tiled dimension. + dst_logical[i] = {&new_fragments.emplace_back( + std::get(properties) + .softmax_reduction_dimension, + reduce->operand(0)->shape().dimensions(i))}; + } else { + dst_logical[i] = src_logical[i]; + } + } + } else if (hlo.opcode() == HloOpcode::kConcatenate) { + dst_logical.resize(src_logical.size()); + for (int i = 0; i < src_logical.size(); ++i) { + dst_logical[i] = src_logical[i]; + if (i == hlo.concatenate_dimension()) { + if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { + return FusionDecision("Unsupported concatenation."); + } + dst_logical[i][0]->set_size(dst->shape().dimensions(i)); + dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); + } + } + } else if (hlo.opcode() == HloOpcode::kCopy) { + // Copy preserves the logical shape, just permutes the layout. + CHECK(ShapeUtil::SameDimensions(src.shape(), dst->shape())); + dst_logical = src_logical; + } else if (hlo.opcode() == HloOpcode::kPad) { + // Operand 1 (the padding value) has to be a scalar. + if (dst != &hlo && hlo.operand_index(dst) == 1) { + continue; + } + const auto* pad = Cast(&hlo); + dst_logical.resize(src_logical.size()); + for (int i = 0; i < src_logical.size(); ++i) { + // This only handles the padding added by + // PadDotOperandsIfNeededForSplitK, which sets only edge_padding_high. + const int padding = + pad->padding_config().dimensions(i).edge_padding_high(); + CHECK_EQ(pad->padding_config().dimensions(i).edge_padding_low(), 0); + CHECK_EQ(pad->padding_config().dimensions(i).interior_padding(), 0); + if (padding == 0) { + dst_logical[i] = src_logical[i]; + } else { + // This case is executed for the contracting dimension when we run the + // TritonFusionAnalysis after the padding and the split-k transform + // are applied. + const std::vector& fragments = src_logical[i]; + // We must have 2 fragments at this point. + CHECK_EQ(fragments.size(), 2); + // The dst_dim_numbers must be the same for the 2 fragments of the + // contracting dimension after applying split-k. + CHECK_EQ(fragments[0]->dst_dim_number(), + fragments[1]->dst_dim_number()); + + new_fragments.emplace_back( + fragments[0]->dst_dim_number(), + fragments[0]->full_size() * fragments[1]->full_size() - padding); + dst_logical[i] = {&new_fragments.back()}; + } + } + } else if (hlo.opcode() == HloOpcode::kSlice) { + const auto slice = Cast(&hlo); + dst_logical.resize(src_logical.size()); + for (int dim = 0; dim < src_logical.size(); ++dim) { + dst_logical[dim] = src_logical[dim]; + if (slice->slice_limits(dim) - slice->slice_starts(dim) != + dst->shape().dimensions(dim)) { + if (dst_logical[dim].size() > 1) { + return FusionDecision("Slicing of fragmented dimension."); + } + auto fragment = dst_logical[dim].front(); + fragment->set_size(dst->shape().dimensions(dim)); + // Slicing of an already sliced dimension means adding offsets. + fragment->set_slice( + fragment->slice_start() + slice->slice_starts(dim), + fragment->slice_start() + slice->slice_starts(dim) + + fragment->sliced_size()); + } + } + } else { + return FusionDecision("Function called on a wrong instruction."); + } + // Destination logical -> destination physical and ungroup subdimensions. + // Map original fragments to the resulting ones to derive their new + // logical ordering within each dimension. + absl::flat_hash_map src_to_dst; + Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); + FragmentOrders& dst_dim_fragments_order = + dst_dim_order.DimFragmentsOrders(); + // Remember which dimensions are present before a broadcast; + // skip cases when already present dimension is being expanded. + absl::flat_hash_set dim_numbers_present_in_dst; + for (const int64_t dim_idx : dst->shape().layout().minor_to_major()) { + for (const Fragment* subdim : dst_logical[dim_idx]) { + dst_fragments_order.push_back(*subdim); + src_to_dst[subdim] = dst_fragments_order.size() - 1; + dim_numbers_present_in_dst.insert(subdim->dst_dim_number()); + } + } + for (const auto& [dim_index, dim_sequence] : + src_dim_order.DimFragmentsOrders()) { + for (const int fragment_number : dim_sequence) { + const auto it = src_to_dst.find(&src_fragments_order[fragment_number]); + if (it == src_to_dst.cend()) { + if (hlo.opcode() == HloOpcode::kBroadcast && + src_fragments_order[fragment_number].full_size() > 1 && + dim_numbers_present_in_dst.contains(dim_index)) { + return FusionDecision("Unsupported broadcast"); + } + continue; + } + dst_dim_fragments_order[dim_index].push_back(it->second); + } + } + } + return dst_dim_orders; +} + +// If possible, propagates `src_dim_order` (describing one side of `hlo`) to +// the other side and returns those dim orders. +DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, + const TransformDirection direction, + const DimensionOrder& src_dim_order, + const HeroProperties& properties) { + VLOG(7) << "Analyzing " << hlo.ToString(); + if (hlo.opcode() != HloOpcode::kParameter && + direction == TransformDirection::kOutputToInput && + absl::c_any_of(hlo.users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kConcatenate; + })) { + return "No fusion into concatenations"; + } + if (hlo.opcode() == HloOpcode::kParameter || + hlo_query::IsScalarConstant(&hlo)) { + CHECK(direction == TransformDirection::kOutputToInput); + return DimOrderMap{}; + } else if (hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kCopy) { + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kBroadcast) { + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported broadcast direction."; + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kReduce) { + if (!std::holds_alternative(properties)) { + return "Reductions are not supported in GEMM fusions yet."; + } + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported direction of reduction."; + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kPad) { + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported pad direction."; + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.operand_count() > 0 && + IsTritonSupportedElementwise( + hlo.opcode(), hlo.operand(0)->shape().element_type())) { + return GetPropagatedDimOrdersForElementwise(hlo, direction, src_dim_order); + } else if (hlo.opcode() == HloOpcode::kBitcast) { + return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kSlice) { + if (direction != TransformDirection::kOutputToInput) { + return "Unsupported slice direction."; + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kReshape) { + if (!ShapeUtil::ReshapeIsBitcast(hlo.operand(0)->shape(), hlo.shape())) { + return "Non-bitcast reshape."; + } + return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, + properties); + } else if (hlo.opcode() == HloOpcode::kConcatenate && + direction == TransformDirection::kOutputToInput) { + if (!std::holds_alternative(properties)) { + return "Concatenations for now are only supported in GEMM fusions."; + } + auto dim = LogicalIndexOfLabeledDimension( + hlo.shape(), src_dim_order, + std::get(properties).noncontracting_dimension); + if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { + return "Unsupported concatenation."; + } + if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) { + return operand->user_count() > 1; + })) { + return FusionDecision( + "Concatenation has to be the only user of its inputs."); + } + if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { + // In the current simple implementation of concatenation the size of + // each of its inputs along the concatenated dimension has to be + // divisible by the tile size used for this dimension. Concatenations + // with any operand not divisible by kMinConcatFragmentSize will not + // be fused; tiling configurations with tile size for this dimension + // larger than kMinConcatFragmentSize will not be emitted. + constexpr int kMinConcatFragmentSize = 128; + return operand->shape().dimensions(hlo.concatenate_dimension()) % + kMinConcatFragmentSize != + 0; + })) { + return FusionDecision( + "One or more operands of concatenation can not be perfectly tiled."); + } + return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, + properties); + } + return "Unimplemented instruction."; +} + +// Difference of input and output data volumes of an instruction. +int64_t InputMinusOutputBytes(const HloInstruction& hlo) { + CHECK(!hlo.shape().IsTuple()); + int64_t input_size = 0; + for (const HloInstruction* operand : hlo.operands()) { + CHECK(!operand->shape().IsTuple()); + input_size += ShapeUtil::ByteSizeOf(operand->shape()); + } + return input_size - ShapeUtil::ByteSizeOf(hlo.shape()); +} + +// Tells if an instruction has no user into which it could be fused. +// More cases should be added here. +bool CanNotBeFusedIntoAUser(const HloInstruction& hlo) { + return hlo.IsRoot() || (hlo.user_count() == 1 && hlo.users()[0]->IsRoot() && + hlo.users()[0]->opcode() == HloOpcode::kTuple); +} + +// Let input and output data volumes of a fusion grow by small amounts. +constexpr int kIoToleranceBytes = 1024; + +// Tells that fusing an instruction as an input is efficient. +bool IsInputWorthFusing(const HloInstruction& hlo) { + if (InputMinusOutputBytes(hlo) <= kIoToleranceBytes) { + return true; + } + if (hlo.user_count() > 1) { + return false; + } + if (hlo.opcode() == HloOpcode::kSlice && + hlo_query::AllOperandsAreParametersOrConstants(hlo)) { + return true; + } + return hlo_query::AllOperandsAreParametersOrConstantsWithSingleUser(hlo); +} + +// Tells that fusing an instruction as an output is efficient. +bool IsOutputWorthFusing(const HloInstruction& hlo) { + return CanNotBeFusedIntoAUser(hlo) || + InputMinusOutputBytes(hlo) >= -kIoToleranceBytes; +} + +FusionDecision IsConversionWorthFusing(const HloInstruction& input, + se::GpuComputeCapability gpu_version) { + // TODO(b/266862494): Can pick up almost any + // convert, but if it's reducing the data volume it should rather be fused + // to the output of the producer kernel. However not all operations support + // output fusion - then it should be fused here anyway! + if (ShapeUtil::ByteSizeOf(input.operand(0)->shape()) > + ShapeUtil::ByteSizeOf(input.shape())) { + return "Narrowing conversion."; + } + return FusionDecision{}; +} + +} // namespace + +DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( + const HloInstruction& hlo, const DimensionOrder& src_dim_order, + TransformDirection direction, const HeroProperties& properties) { + DimOrderMapOrError propagated_dim_orders_or_error = + GetPropagatedDimOrders(hlo, direction, src_dim_order, properties); + if (std::holds_alternative(propagated_dim_orders_or_error)) { + return std::get(propagated_dim_orders_or_error); + } + DimOrderMap propagated_dim_orders = + std::move(std::get(propagated_dim_orders_or_error)); + RequirementsOrError requirements_or_error = + GetRequirementsIfSupportedOrders(hlo, propagated_dim_orders, properties); + if (std::holds_alternative(requirements_or_error)) { + return std::get(requirements_or_error); + } + return DimOrdersAndReqs{propagated_dim_orders, + std::get(requirements_or_error)}; +} + +DimOrdersAndReqsOrError +GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + const HloInstruction& hlo, TransformDirection transform_direction, + const std::optional& src_operand_index, + const DimensionOrder& src_dim_order, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties) { + CHECK_EQ(transform_direction == TransformDirection::kInputToOutput, + src_operand_index.has_value()); + + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return "Unsupported instruction."; + } + if (hlo.opcode() == HloOpcode::kReduce) { + return "Reductions are not fused yet."; + } + if (hlo.opcode() == HloOpcode::kPad) { + return "Pads are not fused yet."; + } + for (const HloInstruction* operand : hlo.operands()) { + if (!IsTritonSupportedDataType(operand->shape().element_type(), + gpu_version)) { + return "Unsupported input data type."; + } + } + if (!IsTritonSupportedDataType(hlo.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + DimOrdersAndReqsOrError result_or_error = + GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, + transform_direction, properties); + if (!std::holds_alternative(result_or_error)) { + return result_or_error; + } + DimOrdersAndReqs dim_orders_and_requirements = + std::move(std::get(result_or_error)); + int fusion_level = + hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); + if (!std::get(gpu_version) + .IsAtLeast(se::CudaComputeCapability::AMPERE)) { + fusion_level = std::min(fusion_level, 1); + } + if (transform_direction == TransformDirection::kOutputToInput) { + if (fusion_level < 2) { + if (hlo.opcode() == HloOpcode::kConvert) { + if (FusionDecision decision = IsConversionWorthFusing(hlo, gpu_version); + !decision) { + return decision; + } + } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { + return "Ignored elementwise operation"; + } + } else { + // Exception for binary elementwise operations: in most cases these are + // not trivial to fuse because they increase DRAM traffic but if one + // of the inputs is for example a broadcast that can be fused too it + // becomes worth fusing. Look ahead and analyze operands here. + bool accepted = false; + if (hlo.IsElementwise() && hlo.operand_count() == 2) { + for (const HloInstruction* operand : hlo.operands()) { + if (operand->opcode() == HloOpcode::kBroadcast && + (operand->operand(0)->opcode() == HloOpcode::kParameter || + operand->operand(0)->opcode() == HloOpcode::kConstant) && + std::holds_alternative( + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *operand, TransformDirection::kOutputToInput, + /*src_operand_index=*/std::nullopt, + /*src_dim_order=*/ + dim_orders_and_requirements.dim_orders.at(operand), + gpu_version, properties))) { + accepted = true; + break; + } + } + } + if (!accepted && !IsInputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as input."; + } + } + } else { + if (fusion_level < 2) { + return "Skipping fusing outputs at low fusion levels."; + } + for (int i = 0; i < hlo.operand_count(); ++i) { + const HloInstruction* operand = hlo.operand(i); + // Skip source operand. + if (i == *src_operand_index) { + continue; + } + // Currently only broadcasts of scalar constants or parameters + // are accepted as other inputs of non-unary operations + // in the output fusion. + if (hlo_query::IsBroadcastOfScalarConstant(*operand) || + operand->opcode() == HloOpcode::kParameter) { + continue; + } + return "Has multiple inputs - not properly analyzed yet."; + } + if (!IsOutputWorthFusing(hlo)) { + return "Not obviously profitable to fuse as output."; + } + } + return dim_orders_and_requirements; +} + +} // namespace triton_fusion +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h new file mode 100644 index 00000000000000..1887b962a90e6c --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h @@ -0,0 +1,241 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRITON_TILING_PROPAGATION_H_ +#define XLA_SERVICE_GPU_TRITON_TILING_PROPAGATION_H_ + +// This file contains the logic of the Triton Tiling Propagation in a functional +// paradigm. Stateful operations belong in triton_fusion_analysis. + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/instruction_fusion.h" +#include "xla/stream_executor/device_description.h" +namespace xla { +namespace gpu { + +class TensorIterationSpec { + public: + // Description of basic iteration: `count` elements separated by `stride`. + struct IterationSpecFragment { + int64_t stride; + int64_t count; + int64_t slice_start; + int64_t slice_limit; + // Logical subfragments when this iteration is composed + // of several HLO dimensions. + std::vector subfragments; + + bool is_sliced() const { return count != slice_limit - slice_start; } + bool operator!=(const IterationSpecFragment& other) const; + std::string ToString() const; + }; + // Description of complex iteration over a sequence of several strides. + // Describes a logically contiguous dimension of a tensor physically + // separated into multiple fragments by other dimensions. + using DimIterationSpec = std::vector; + + using StorageType = absl::flat_hash_map; + const DimIterationSpec& operator[](const int dimension) const { + return dim_iteration_specs_.at(dimension); + } + DimIterationSpec& operator[](const int dimension) { + return dim_iteration_specs_[dimension]; + } + const StorageType& Storage() const { return dim_iteration_specs_; } + void RemoveEmptyDimensions() { + absl::erase_if(dim_iteration_specs_, + [](const auto& it) { return it.second.empty(); }); + } + + // Compares physical layouts of tensors ignoring subfragments of dimensions. + bool operator==(const TensorIterationSpec& other) const; + + std::string ToString() const; + + private: + StorageType dim_iteration_specs_; +}; + +// The details of the Triton fusion / tiling propagation are in a separate +// namespace to avoid littering the xla::gpu namespace. +namespace triton_fusion { + +// Handles numbers of dimensions of an HLO instruction +// projected onto another one. +// Used to calculate cumulative index transformations done by non-elementwise +// instructions between source and target. +class DimensionOrder { + public: + // Softmax fusions have a fixed tiling scheme. These numbers are chosen to + // reflect that reductions in softmax fusions currently happen on the minor- + // most dimension (dimensions_minor(0)) and the rest (1+) is treated as a + // single non-tiled batch dimension. The numbers have to match those the + // emitter uses in the queries to the analysis. + static constexpr int kSoftmaxReductionDimension = 0; + static constexpr int kSoftmaxBatchDimension = 1; + + static DimensionOrder FromDotOperandOrOutput( + const HloInstruction& hlo, int split_k_dimension_index = -1); + + static DimensionOrder FromSoftmaxRoot(const HloInstruction& hlo); + + // Description of a continuous fragment of one dimension of a tensor. + class Fragment { + public: + explicit Fragment(int dst_dim_number, int64_t size) + : dst_dim_number_(dst_dim_number), + size_(size), + slice_start_(0), + slice_limit_(size) {} + + std::string ToString() const; + + // Label carrying the dimension number of an defining operation. + int dst_dim_number() const { return dst_dim_number_; } + // Total number of elements in the fragment ignoring slicing. + int64_t full_size() const { return size_; } + // First used element. + int64_t slice_start() const { return slice_start_; } + // Last used element. + int64_t slice_limit() const { return slice_limit_; } + int64_t sliced_size() const { return slice_limit_ - slice_start_; } + bool is_sliced() const { return full_size() != sliced_size(); } + void set_slice(int64_t start, int64_t limit) { + slice_start_ = start; + slice_limit_ = limit; + } + void set_size(int64_t size) { size_ = size; } + + private: + const int dst_dim_number_; + int64_t size_; + int64_t slice_start_; + int64_t slice_limit_; + }; + using Fragments = std::vector; + using FragmentOrders = absl::flat_hash_map>; + + const Fragments& TensorFragmentsOrder() const { + return tensor_fragments_order_; + } + Fragments& TensorFragmentsOrder() { return tensor_fragments_order_; } + + const FragmentOrders& DimFragmentsOrders() const { + return dim_fragments_orders_; + } + FragmentOrders& DimFragmentsOrders() { return dim_fragments_orders_; } + + std::string ToString() const; + + TensorIterationSpec ToTensorIterationSpec() const; + + // Tells that two dimension orders describe the same tensor physical layout. + bool IsPhysicallyEquivalent(const DimensionOrder& other) const { + return ToTensorIterationSpec() == other.ToTensorIterationSpec(); + } + + private: + // Sequence of all fragments of dimensions of tensor's shape + // in layout minor-to-major (physical) order. + Fragments tensor_fragments_order_; + // Iteration orders of fragments of each dimension of the defining operation + // (fragments can be physically unordered and disconnected within + // the shape due to reshapes and transposes). + FragmentOrders dim_fragments_orders_; +}; + +// This represents an invalid dimension index. +inline constexpr int kNoDimensionIndex = -1; +struct DotProperties { + const int noncontracting_dimension; + // Index of dot dimension that can be split. + // Currently typically LHS non-contracting one. + const int splittable_dimension_index; +}; +struct SoftmaxProperties { + const int softmax_reduction_dimension; + const int softmax_batch_dimension; +}; +// HeroProperties depend only on the hero op and they don't change as we +// change the fusion. +using HeroProperties = std::variant; + +// A special value for splittable_dimension_major_part_size. +inline constexpr int kNoSplitRequirement = 1; +struct DotRequirements { + explicit DotRequirements(int64_t splittable_dimension_major_part_size) + : splittable_dimension_major_part_size( + splittable_dimension_major_part_size) { + CHECK_GE(splittable_dimension_major_part_size, 1); + } + // If not kNoSplitRequirement, then the major part size of the splittable + // dimension must be the given value. + int64_t splittable_dimension_major_part_size; +}; +struct SoftmaxRequirements {}; +// Requirements can change depending on what we fuse. +using Requirements = std::variant; +using RequirementsOrError = std::variant; + +RequirementsOrError CombineRequirements(Requirements a, + RequirementsOrError b_or_error); + +enum class TransformDirection { kInputToOutput, kOutputToInput }; +using DimOrderMap = absl::flat_hash_map; +using DimOrderMapOrError = std::variant; + +// The dimension orders and requirements resulting from propagating the +// dimension orders through an HLO. +struct DimOrdersAndReqs { + DimOrderMap dim_orders; + Requirements requirements; +}; +using DimOrdersAndReqsOrError = std::variant; + +// If fusing the instruction is possible then it propagates +// the `src_dim_order` (describing one side of `hlo`) to the other side and +// returns those dim orders and the requirements that they impose on the +// fusion. +DimOrdersAndReqsOrError GetPropagatedDimOrdersAndRequirements( + const HloInstruction& hlo, const DimensionOrder& src_dim_order, + TransformDirection direction, const HeroProperties& properties); +// If fusing the instruction is possible *and profitable* then it propagates +// the `src_dim_order` (describing one side of `hlo`) to the other side and +// returns those dim orders and the requirements that they impose on the +// fusion. +// +// `src_operand_index` must be set iff `transform_direction` is +// kInputToOutput. +DimOrdersAndReqsOrError +GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + const HloInstruction& hlo, TransformDirection transform_direction, + const std::optional& src_operand_index, + const DimensionOrder& src_dim_order, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties); + +} // namespace triton_fusion +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRITON_TILING_PROPAGATION_H_ From 247f77e44784f9254d2bd304372a71510a0849fc Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Tue, 28 Nov 2023 09:35:06 -0800 Subject: [PATCH 147/381] Add type annotations to test_util.py. Adds parameter and return type annotations to the majority of public functions and methods in the `test_util` module. This includes annotations for methods on `TensorFlowTestCase` which return values, but omits the assertion methods. If adding types is currently infeasible (due to complexity of the signature, limitations of the supported versions of python, type checker limitations, etc.), then this change simply does not add those annotations. PiperOrigin-RevId: 586008562 --- tensorflow/python/framework/BUILD | 2 +- tensorflow/python/framework/test_util.py | 554 ++++++++++++++--------- 2 files changed, 336 insertions(+), 220 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 0d4369274e02ee..959cb5920475ba 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -2063,7 +2063,7 @@ py_strict_library( deps = [], ) -py_strict_library( +pytype_strict_library( name = "test_lib", srcs = ["test_util.py"], srcs_version = "PY3", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index a4085a0f2b7571..4c982e87873f09 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -18,7 +18,7 @@ import collections from collections import OrderedDict -from collections.abc import Callable, Iterator +from collections.abc import Iterable, Iterator, Callable, Collection, Sequence import contextlib import functools import gc @@ -30,16 +30,21 @@ import tempfile import threading import time -from typing import Any, cast, Optional, overload, TypeVar, Union +from typing import Any, cast, Union, Optional, overload, TypeVar import unittest from absl.testing import parameterized import numpy as np from google.protobuf import descriptor_pool +from google.protobuf import message from google.protobuf import text_format from tensorflow.core.config import flags from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import tensor_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import pywrap_sanitizers from tensorflow.python import tf2 @@ -98,17 +103,20 @@ from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export + +_F = TypeVar("_F", bound=Callable[..., Any]) +_T = TypeVar("_T") _TC = TypeVar("_TC", bound=type["TensorFlowTestCase"]) -_R = TypeVar("_R") # If the below import is made available through the BUILD rule, then this # function is overridden and will instead return True and cause Tensorflow # graphs to be compiled with XLA. -def is_xla_enabled(): +def is_xla_enabled() -> bool: return False +# pytype: disable=import-error try: from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import except Exception: # pylint: disable=broad-except @@ -117,7 +125,7 @@ def is_xla_enabled(): # Uses the same mechanism as above to selectively enable/disable MLIR # compilation. -def is_mlir_bridge_enabled(): +def is_mlir_bridge_enabled() -> Optional[bool]: return None @@ -128,36 +136,39 @@ def is_mlir_bridge_enabled(): from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import except ImportError: pass +# pytype: enable=import-error -def is_asan_enabled(): +def is_asan_enabled() -> bool: """Check if ASAN is enabled.""" return pywrap_sanitizers.is_asan_enabled() -def is_msan_enabled(): +def is_msan_enabled() -> bool: """Check if MSAN is enabled.""" return pywrap_sanitizers.is_msan_enabled() -def is_tsan_enabled(): +def is_tsan_enabled() -> bool: """Check if TSAN is enabled.""" return pywrap_sanitizers.is_tsan_enabled() -def is_ubsan_enabled(): +def is_ubsan_enabled() -> bool: """Check if UBSAN is enabled.""" return pywrap_sanitizers.is_ubsan_enabled() -def _get_object_count_by_type(exclude=()): +def _get_object_count_by_type( + exclude: Iterable[Any] = (), +) -> collections.Counter[str]: return ( collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - collections.Counter([type(obj).__name__ for obj in exclude])) @tf_export("test.gpu_device_name") -def gpu_device_name(): +def gpu_device_name() -> str: """Returns the name of a GPU device if available or a empty string. This method should only be used in tests written with `tf.test.TestCase`. @@ -178,7 +189,9 @@ def gpu_device_name(): return "" -def assert_ops_in_graph(expected_ops, graph): +def assert_ops_in_graph( + expected_ops: dict[str, str], graph: ops.Graph +) -> dict[str, node_def_pb2.NodeDef]: """Assert all expected operations are found. Args: @@ -191,8 +204,8 @@ def assert_ops_in_graph(expected_ops, graph): Raises: ValueError: If the expected ops are not present in the graph. """ - actual_ops = {} - gd = graph.as_graph_def() + actual_ops: dict[str, node_def_pb2.NodeDef] = {} + gd = cast(graph_pb2.GraphDef, graph.as_graph_def()) for node in gd.node: if node.name in expected_ops: if expected_ops[node.name] != node.op: @@ -206,7 +219,9 @@ def assert_ops_in_graph(expected_ops, graph): @tf_export("test.assert_equal_graph_def", v1=[]) -def assert_equal_graph_def_v2(expected, actual): +def assert_equal_graph_def_v2( + expected: graph_pb2.GraphDef, actual: graph_pb2.GraphDef +) -> None: """Asserts that two `GraphDef`s are (mostly) the same. Compares two `GraphDef` protos for equality, ignoring versions and ordering of @@ -227,8 +242,12 @@ def assert_equal_graph_def_v2(expected, actual): @tf_export(v1=["test.assert_equal_graph_def"]) -def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, - hash_table_shared_name=False): +def assert_equal_graph_def_v1( + actual: graph_pb2.GraphDef, + expected: graph_pb2.GraphDef, + checkpoint_v2: bool = False, + hash_table_shared_name: bool = False +) -> None: """Asserts that two `GraphDef`s are (mostly) the same. Compares two `GraphDef` protos for equality, ignoring versions and ordering of @@ -251,8 +270,12 @@ def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, hash_table_shared_name) -def assert_equal_graph_def(actual, expected, checkpoint_v2=False, - hash_table_shared_name=False): +def assert_equal_graph_def( + actual: graph_pb2.GraphDef, + expected: graph_pb2.GraphDef, + checkpoint_v2: bool = False, + hash_table_shared_name: bool = False +)-> None: if not isinstance(actual, graph_pb2.GraphDef): raise TypeError("Expected tf.GraphDef for actual, got %s" % type(actual).__name__) @@ -274,7 +297,11 @@ def assert_equal_graph_def(actual, expected, checkpoint_v2=False, raise AssertionError(compat.as_str(diff)) -def assert_meta_graph_protos_equal(tester, a, b): +def assert_meta_graph_protos_equal( + tester: "TensorFlowTestCase", + a: meta_graph_pb2.MetaGraphDef, + b: meta_graph_pb2.MetaGraphDef, +) -> None: """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" # Carefully check the collection_defs tester.assertEqual(set(a.collection_def), set(b.collection_def)) @@ -282,7 +309,7 @@ def assert_meta_graph_protos_equal(tester, a, b): for k in collection_keys: a_value = a.collection_def[k] b_value = b.collection_def[k] - proto_type = ops.get_collection_proto_type(k) + proto_type = cast(type[message.Message], ops.get_collection_proto_type(k)) if proto_type: a_proto = proto_type() b_proto = proto_type() @@ -318,11 +345,12 @@ def assert_meta_graph_protos_equal(tester, a, b): _SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" -def _strip_checkpoint_v2_randomized(graph_def): +def _strip_checkpoint_v2_randomized(graph_def: graph_pb2.GraphDef) -> None: for node in graph_def.node: - delete_keys = [] + delete_keys: list[str] = [] for attr_key in node.attr: - attr_tensor_value = node.attr[attr_key].tensor + attr_tensor_value = cast( + tensor_pb2.TensorProto, node.attr[attr_key].tensor) if attr_tensor_value and len(attr_tensor_value.string_val) == 1: attr_tensor_string_value = attr_tensor_value.string_val[0] if (attr_tensor_string_value and @@ -336,9 +364,9 @@ def _strip_checkpoint_v2_randomized(graph_def): _TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" -def _strip_hash_table_shared_name(graph_def): +def _strip_hash_table_shared_name(graph_def: graph_pb2.GraphDef) -> None: for node in graph_def.node: - delete_keys = [] + delete_keys: list[str] = [] if node.op == "HashTableV2" and "shared_name" in node.attr: if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), node.attr["shared_name"].s): @@ -347,35 +375,37 @@ def _strip_hash_table_shared_name(graph_def): del node.attr[attr_key] -def IsGoogleCudaEnabled(): +def IsGoogleCudaEnabled() -> bool: return _pywrap_util_port.IsGoogleCudaEnabled() -def IsBuiltWithROCm(): +def IsBuiltWithROCm() -> bool: return _pywrap_util_port.IsBuiltWithROCm() -def IsBuiltWithXLA(): +def IsBuiltWithXLA() -> bool: return _pywrap_util_port.IsBuiltWithXLA() -def IsBuiltWithNvcc(): +def IsBuiltWithNvcc() -> bool: return _pywrap_util_port.IsBuiltWithNvcc() -def GpuSupportsHalfMatMulAndConv(): +def GpuSupportsHalfMatMulAndConv() -> bool: return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() -def IsMklEnabled(): +def IsMklEnabled() -> bool: return _pywrap_util_port.IsMklEnabled() -def InstallStackTraceHandler(): +def InstallStackTraceHandler() -> None: _pywrap_stacktrace_handler.InstallStacktraceHandler() -def NHWCToNCHW(input_tensor): +def NHWCToNCHW( + input_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Converts the input from the NHWC format to NCHW. Args: @@ -394,7 +424,9 @@ def NHWCToNCHW(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -def NHWCToNCHW_VECT_C(input_shape_or_tensor): +def NHWCToNCHW_VECT_C( + input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]] +)-> Union[tensor_lib.Tensor, list[int]]: """Transforms the input from the NHWC layout to NCHW_VECT_C layout. Note: Does not include quantization or type conversion steps, which should @@ -412,7 +444,7 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): """ permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) - temp_shape = ( + temp_shape: list[int] = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) if temp_shape[-1] % 4 != 0: @@ -429,7 +461,9 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): return [temp_shape[a] for a in permutation] -def NCHW_VECT_CToNHWC(input_shape_or_tensor): +def NCHW_VECT_CToNHWC( + input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Transforms the input from the NCHW_VECT_C layout to NHWC layout. Note: Does not include de-quantization or type conversion steps, which should @@ -446,7 +480,7 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): """ permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) - input_shape = ( + input_shape: list[int] = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) if input_shape[-1] != 4: @@ -461,7 +495,9 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): return nhwc_shape -def NCHWToNHWC(input_tensor): +def NCHWToNHWC( + input_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Converts the input from the NCHW format to NHWC. Args: @@ -480,7 +516,7 @@ def NCHWToNHWC(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -def skip_if(condition): +def skip_if(condition: Union[Callable[[], bool], bool]) -> Callable[[_F], _F]: """Skips the decorated function if condition is or evaluates to True. Args: @@ -491,7 +527,7 @@ def skip_if(condition): The wrapped function """ - def real_skip_if(fn): + def real_skip_if(fn: _F) -> _F: def wrapper(*args, **kwargs): if callable(condition): @@ -507,7 +543,11 @@ def wrapper(*args, **kwargs): @contextlib.contextmanager -def skip_if_error(test_obj, error_type, messages=None): +def skip_if_error( + test_obj: unittest.TestCase, + error_type: type[Exception], + messages: Union[str, list[str], None] = None +) -> Iterator[None]: """Context manager to skip cases not considered failures by the tests. Note that this does not work if used in setUpClass/tearDownClass. @@ -538,17 +578,17 @@ def skip_if_error(test_obj, error_type, messages=None): raise -def enable_c_shapes(fn): +def enable_c_shapes(fn: _F) -> _F: """No-op. TODO(b/74620627): Remove this.""" return fn -def with_c_shapes(cls): +def with_c_shapes(cls: type[_T]) -> type[_T]: """No-op. TODO(b/74620627): Remove this.""" return cls -def enable_control_flow_v2(fn): +def enable_control_flow_v2(fn: _F) -> _F: """Decorator for enabling CondV2 and WhileV2 on a test. Note this enables using CondV2 and WhileV2 after running the test class's @@ -575,7 +615,7 @@ def wrapper(*args, **kwargs): return wrapper -def with_control_flow_v2(cls): +def with_control_flow_v2(cls: _TC) -> _TC: """Adds methods that call original methods with WhileV2 and CondV2 enabled. Note this enables CondV2 and WhileV2 in new methods after running the test @@ -630,7 +670,7 @@ def testDisabledForV2(self): return cls -def disable_control_flow_v2(unused_msg): +def disable_control_flow_v2(unused_msg: str) -> Callable[[_F], _F]: """Decorator for a function in a with_control_flow_v2 enabled test class. Blocks the function from being run with v2 control flow ops. @@ -642,14 +682,14 @@ def disable_control_flow_v2(unused_msg): The wrapped function with _disable_control_flow_v2 attr set to True. """ - def wrapper(func): + def wrapper(func: _F) -> _F: func._disable_control_flow_v2 = True return func return wrapper -def enable_output_all_intermediates(fn): +def enable_output_all_intermediates(fn: _F) -> _F: """Force-enable outputing all intermediates from functional control flow ops. Args: @@ -789,7 +829,7 @@ def decorator(self: "TensorFlowTestCase", *args, **kwargs) -> None: return wrap_f -def assert_no_new_tensors(f): +def assert_no_new_tensors(f: _F) -> _F: """Decorator for asserting that no new Tensors persist after a test. Mainly useful for checking that code using the Python C API has correctly @@ -808,10 +848,10 @@ def assert_no_new_tensors(f): The decorated test case. """ - def decorator(*args, **kwargs): + def decorator(self: "TensorFlowTestCase", **kwargs): """Finds existing Tensors, runs the test, checks for new Tensors.""" - def _is_tensorflow_object(obj): + def _is_tensorflow_object(obj) -> bool: try: return isinstance(obj, (tensor_lib.Tensor, variables.Variable, @@ -822,7 +862,7 @@ def _is_tensorflow_object(obj): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - outside_executed_eagerly = context.executing_eagerly() + outside_executed_eagerly = cast(bool, context.executing_eagerly()) # Run the test in a new graph so that collections get cleared when it's # done, but inherit the graph key so optimizers behave. outside_graph_key = ops.get_default_graph()._graph_key @@ -830,9 +870,9 @@ def _is_tensorflow_object(obj): ops.get_default_graph()._graph_key = outside_graph_key if outside_executed_eagerly: with context.eager_mode(): - result = f(*args, **kwargs) + result = f(self, **kwargs) else: - result = f(*args, **kwargs) + result = f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. context.context()._clear_caches() # pylint: disable=protected-access @@ -851,9 +891,9 @@ def _is_tensorflow_object(obj): return tf_decorator.make_decorator(f, decorator) -def _find_reference_cycle(objects, idx): +def _find_reference_cycle(objects: Sequence[Any], idx: int) -> bool: - def get_ignore_reason(obj, denylist): + def get_ignore_reason(obj: Any, denylist: Collection[Any]) -> Optional[str]: """Tests whether an object should be omitted from the dependency graph.""" if len(denylist) > 100: return "" @@ -870,7 +910,9 @@ def get_ignore_reason(obj, denylist): # Note: this function is meant to help with diagnostics. Its output is purely # a human-readable representation, so you may freely modify it to suit your # needs. - def describe(obj, denylist, leaves_only=False): + def describe( + obj: Any, denylist: Collection[Any], leaves_only: bool = False, + ) -> str: """Returns a custom human-readable summary of obj. Args: @@ -902,7 +944,12 @@ def describe(obj, denylist, leaves_only=False): else: return "{}, {}".format(type(obj), id(obj)) - def build_ref_graph(obj, graph, reprs, denylist): + def build_ref_graph( + obj: Any, + graph: dict[int, list[int]], + reprs: dict[int, str], + denylist: tuple[Any, ...], + ) -> None: """Builds a reference graph as -> . Args: @@ -928,7 +975,12 @@ def build_ref_graph(obj, graph, reprs, denylist): build_ref_graph(r, graph, reprs, denylist) reprs[r_id] = describe(r, denylist) - def find_cycle(el, graph, reprs, path): + def find_cycle( + el: int, + graph: dict[int, list[int]], + reprs: dict[int, str], + path: tuple[int, ...], + ) -> Optional[bool]: """Finds and prints a single cycle in the dependency graph.""" if el not in graph: return @@ -944,8 +996,8 @@ def find_cycle(el, graph, reprs, path): return False obj = objects[idx] - graph = {} # referrer ID -> object ID - reprs = {} # object ID -> description + graph: dict[int, list[int]] = {} # referrer ID -> object ID + reprs: dict[int, str] = {} # object ID -> description build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, describe, build_ref_graph, find_cycle)) for k in graph: @@ -954,7 +1006,7 @@ def find_cycle(el, graph, reprs, path): return False -def assert_no_garbage_created(f): +def assert_no_garbage_created(f: _F) -> _F: """Test method decorator to assert that no garbage has been created. Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters @@ -970,7 +1022,7 @@ def assert_no_garbage_created(f): # FIXME(power) -- Update documentation, we no longer care if garbage is # created, we only want to verify we don't have memory leaks. - def decorator(self, **kwargs): + def decorator(self: "TensorFlowTestCase", **kwargs): """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" gc.disable() previous_debug_flags = gc.get_debug() @@ -996,7 +1048,7 @@ def decorator(self, **kwargs): logging.error("Object %d of %d", i, len(gc.garbage) - previous_garbage) - def _safe_object_str(obj): + def _safe_object_str(obj) -> str: return "<%s %d>" % (obj.__class__.__name__, id(obj)) logging.error(" Object type: %s", _safe_object_str(obj)) @@ -1034,7 +1086,7 @@ def _safe_object_str(obj): return decorator -def _combine_named_parameters(**kwargs): +def _combine_named_parameters(**kwargs) -> list[OrderedDict[str, Any]]: """Generate combinations based on its keyword arguments. Two sets of returned combinations can be concatenated using +. Their product @@ -1050,7 +1102,7 @@ def _combine_named_parameters(**kwargs): corresponding keyword argument values. """ sort_by_key = lambda k: k[0] - combinations = [] + combinations: list[list[tuple[str, Any]]] = [] for key, values in sorted(kwargs.items(), key=sort_by_key): if not isinstance(values, list): values = [values] @@ -1059,7 +1111,9 @@ def _combine_named_parameters(**kwargs): return [OrderedDict(result) for result in itertools.product(*combinations)] -def generate_combinations_with_testcase_name(**kwargs): +def generate_combinations_with_testcase_name( + **kwargs, +) -> list[OrderedDict[str, Any]]: """Generate combinations based on its keyword arguments using combine(). This function calls combine() and appends a testcase name to the list of @@ -1076,7 +1130,7 @@ def generate_combinations_with_testcase_name(**kwargs): corresponding keyword argument values. """ combinations = _combine_named_parameters(**kwargs) - named_combinations = [] + named_combinations: list[OrderedDict[str, Any]] = [] for combination in combinations: assert isinstance(combination, OrderedDict) name = "".join([ @@ -1092,7 +1146,7 @@ def generate_combinations_with_testcase_name(**kwargs): return named_combinations -def run_all_in_graph_and_eager_modes(cls): +def run_all_in_graph_and_eager_modes(cls: _TC) -> _TC: """Execute all test methods in the given class with and without eager.""" base_decorator = run_in_graph_and_eager_modes for name in dir(cls): @@ -1108,7 +1162,7 @@ def run_all_in_graph_and_eager_modes(cls): return cls -def run_class_in_v1_v2(cls): +def run_class_in_v1_v2(cls: _TC) -> _TC: """Execute all test methods in a given class in v1 and v2 modes.""" base_decorator = run_in_v1_v2 for name in dir(cls): @@ -1127,7 +1181,7 @@ def run_class_in_v1_v2(cls): return cls -def enable_nested_function_shape_inference(fn): +def enable_nested_function_shape_inference(fn: _F) -> _F: """Decorator for enabling nested_function_shape_inference on a test. This function returns a decorator intended to be applied to test methods in @@ -1164,7 +1218,7 @@ def wrapper(*args, **kwargs): return wrapper -def enable_quantized_dtypes_training(fn): +def enable_quantized_dtypes_training(fn: _F) -> _F: """Decorator for enabling quantized_dtypes_training on a test. This function returns a decorator intended to be applied to test methods in @@ -1201,7 +1255,7 @@ def wrapper(*args, **kwargs): return wrapper -def enable_eager_op_as_function(fn): +def enable_eager_op_as_function(fn: _F) -> _F: """Returns the same fn. This will be removed once all usages are removed. Args: @@ -1217,8 +1271,27 @@ def wrapper(*args, **kwargs): return wrapper +@overload +def with_eager_op_as_function( + cls: type[_T], + only_as_function: bool = False, +) -> type[_T]: + ... + + +@overload +def with_eager_op_as_function( + cls: None = None, + only_as_function: bool = False, +) -> Callable[[type[_T]], type[_T]]: + ... + + @tf_export("test.with_eager_op_as_function") -def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disable=unused-argument +def with_eager_op_as_function( + cls: Optional[type[_T]] = None, + only_as_function: bool = False, # pylint: disable=unused-argument +) -> Union[Callable[[type[_T]], type[_T]], type[_T]]: """Returns the same class. This will be removed once all usages are removed. Args: @@ -1229,16 +1302,16 @@ def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disa cls """ - def decorator(cls): + def decorator(cls: type[_T]) -> type[_T]: return cls if cls is not None: return decorator(cls) - return decorator + return decorator # pytype: disable=bad-return-type -def enable_graph_building_optimization(fn): +def enable_graph_building_optimization(fn: _F) -> _F: """Decorator for enabling graph_building_optimization on a test. This function returns a decorator intended to be applied to test methods in @@ -1315,7 +1388,7 @@ def testBarWithGraphBuildingOptimization(self): return cls -def disable_eager_op_as_function(unused_msg): +def disable_eager_op_as_function(unused_msg: str) -> Callable[[_F], _F]: """Decorator for a function in a with_eager_op_as_function enabled test class. Blocks the function from being run with eager_op_as_function enabled. @@ -1329,9 +1402,7 @@ def disable_eager_op_as_function(unused_msg): return _disable_test(execute_func=False) -def set_xla_env_flag( - flag: str = "", -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def set_xla_env_flag(flag: str = "") -> Callable[[_F], _F]: """Decorator for setting XLA_FLAGS prior to running a test. This function returns a decorator intended to be applied to test methods in @@ -1355,10 +1426,10 @@ def testFoo(self): function. """ - def decorator(f: Callable[..., _R]) -> Callable[..., _R]: + def decorator(f: _F) -> _F: @functools.wraps(f) - def decorated(*args, **kwargs) -> _R: + def decorated(*args, **kwargs): original_xla_flags = os.environ.get("XLA_FLAGS") new_xla_flags = flag if original_xla_flags: @@ -1430,12 +1501,12 @@ def function_in_eager(): return decorated -def run_in_async_and_sync_mode(f): +def run_in_async_and_sync_mode(f: _F) -> _F: """Execute the test in async mode and sync mode.""" @parameterized.named_parameters([("Async", True), ("", False)]) @functools.wraps(f) - def decorator(self, async_mode, *args, **kwargs): + def decorator(self: "TensorFlowTestCase", async_mode: bool, *args, **kwargs): if async_mode: with context.execution_mode(context.ASYNC): f(self, *args, **kwargs) @@ -1445,10 +1516,35 @@ def decorator(self, async_mode, *args, **kwargs): return decorator -def run_in_graph_and_eager_modes(func=None, - config=None, - use_gpu=True, - assert_no_eager_garbage=False): +@overload +def run_in_graph_and_eager_modes( + func: Callable[..., Any], + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Callable[..., None]: + ... + + +@overload +def run_in_graph_and_eager_modes( + func: None = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: + ... + + +def run_in_graph_and_eager_modes( + func: Optional[Callable[..., Any]] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Union[ + Callable[[Callable[..., Any]], Callable[..., None]], + Callable[..., None], +]: """Execute the decorated test with and without enabling eager execution. This function returns a decorator intended to be applied to test methods in @@ -1506,13 +1602,13 @@ def test_foo(self): eager execution enabled. """ - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: if tf_inspect.isclass(f): raise ValueError( "`run_in_graph_and_eager_modes` only supports test methods. " "Did you mean to use `run_all_in_graph_and_eager_modes`?") - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> None: logging.info("Running %s in GRAPH mode.", f.__name__) try: with context.graph_mode(), self.subTest("graph_mode"): @@ -1531,7 +1627,7 @@ def decorated(self, *args, **kwargs): except unittest.case.SkipTest: pass - def run_eagerly(self, **kwargs): + def run_eagerly(self: "TensorFlowTestCase", **kwargs) -> None: logging.info("Running %s in EAGER mode.", f.__name__) if not use_gpu: with ops.device("/device:CPU:0"): @@ -1608,7 +1704,7 @@ def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> None: except unittest.case.SkipTest: pass - def run_v2(self, **kwargs): + def run_v2(self: "TensorFlowTestCase", **kwargs) -> None: logging.info("Running %s in V2 mode.", f.__name__) if device_to_use: with ops.device(device_to_use): @@ -1639,14 +1735,14 @@ def run_v2(self, **kwargs): return decorator -def py_func_if_in_function(f): +def py_func_if_in_function(f: _F) -> _F: def decorated(*args, **kwds): if not ops.inside_function(): return f(*args, **kwds) - tensor_args = [] - tensor_indices = [] + tensor_args: list[Union[tensor_lib.Tensor, variables.Variable]] = [] + tensor_indices: list[int] = [] for i, arg in enumerate(args): if isinstance(arg, (tensor_lib.Tensor, variables.Variable)): tensor_args.append(arg) @@ -1663,7 +1759,7 @@ def inner_f(*inner_tensor_args): return tf_decorator.make_decorator(f, decorated) -def also_run_as_tf_function(f): +def also_run_as_tf_function(f: Callable[..., Any]) -> Callable[..., None]: """Runs the decorated test twice--once as is, once inside a tf.function. This allows you to run a test both in eager execution and inside a @@ -1683,9 +1779,9 @@ def also_run_as_tf_function(f): tf.function. """ - def decorated(*args, **kwds): + def decorated(*args, **kwds) -> None: - def bound_f(): + def bound_f() -> None: f(*args, **kwds) with context.eager_mode(): @@ -1699,7 +1795,7 @@ def bound_f(): @overload -def deprecated_graph_mode_only(func: Callable[..., _R]) -> Callable[..., _R]: +def deprecated_graph_mode_only(func: _F) -> _F: ... @@ -1708,9 +1804,7 @@ def deprecated_graph_mode_only(func: _TC) -> Optional[_TC]: ... -def deprecated_graph_mode_only( - func: Union[_TC, Callable[..., _R]], -) -> Union[_TC, Callable[..., _R]]: +def deprecated_graph_mode_only(func: Union[_TC, _F]) -> Union[_TC, _F]: """Execute the decorated test in graph mode. This is a decorator intended to be applied to tests that are not compatible @@ -1757,7 +1851,7 @@ def decorated(*args, **kwargs): run_deprecated_v1 = deprecated_graph_mode_only -def run_all_in_deprecated_graph_mode_only(cls): +def run_all_in_deprecated_graph_mode_only(cls: _TC) -> _TC: """Execute all tests in a class in graph mode.""" base_decorator = deprecated_graph_mode_only for name in dir(cls): @@ -1843,7 +1937,7 @@ def run_v2_only(func=None, reason=None): return _run_vn_only(func=func, v2=True, reason=reason) -def run_gpu_only(func: Callable[..., _R]) -> Callable[..., _R]: +def run_gpu_only(func: _F) -> _F: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence @@ -1859,7 +1953,7 @@ def run_gpu_only(func: Callable[..., _R]) -> Callable[..., _R]: if tf_inspect.isclass(func): raise ValueError("`run_gpu_only` only supports test methods.") - def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: + def decorated(self: "TensorFlowTestCase", *args, **kwargs): if not is_gpu_available(): self.skipTest("Test requires GPU") @@ -1868,7 +1962,7 @@ def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: return decorated -def run_cuda_only(func: Callable[..., _R]) -> Callable[..., _R]: +def run_cuda_only(func: _F) -> _F: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence @@ -1884,7 +1978,7 @@ def run_cuda_only(func: Callable[..., _R]) -> Callable[..., _R]: if tf_inspect.isclass(func): raise ValueError("`run_cuda_only` only supports test methods.") - def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: + def decorated(self: "TensorFlowTestCase", *args, **kwargs): if not is_gpu_available(cuda_only=True): self.skipTest("Test requires CUDA GPU") @@ -1893,7 +1987,7 @@ def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: return decorated -def run_gpu_or_tpu(func: Callable[..., _R]) -> Callable[..., _R]: +def run_gpu_or_tpu(func: _F) -> _F: """Execute the decorated test only if a physical GPU or TPU is available. This function is intended to be applied to tests that require the presence @@ -1912,7 +2006,7 @@ def run_gpu_or_tpu(func: Callable[..., _R]) -> Callable[..., _R]: if tf_inspect.isclass(func): raise ValueError("`run_gpu_or_tpu` only supports test methods.") - def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: + def decorated(self: "TensorFlowTestCase", *args, **kwargs): if config.list_physical_devices("GPU"): return func(self, "GPU", *args, **kwargs) @@ -1924,7 +2018,9 @@ def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: return decorated -def with_forward_compatibility_horizons(*horizons): +def with_forward_compatibility_horizons( + *horizons: Optional[tuple[int, int, int]] +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Executes the decorated test with the specified forward-compat horizons. Args: @@ -1942,7 +2038,7 @@ def with_forward_compatibility_horizons(*horizons): (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): raise ValueError("Bad horizon value: %r" % horizon) - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: if tf_inspect.isclass(f): raise ValueError("`with_forward_compatibility_horizons` only " "supports test methods.") @@ -1962,7 +2058,10 @@ def decorated(*args, **kwargs): @deprecation.deprecated(None, "Use `tf.config.list_physical_devices('GPU')` instead.") @tf_export("test.is_gpu_available") -def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): +def is_gpu_available( + cuda_only: bool = False, + min_cuda_compute_capability: Optional[tuple[int, int]] = None, +) -> bool: """Returns whether TensorFlow can access a GPU. Warning: if a non-GPU version of the package is installed, the function would @@ -2018,7 +2117,7 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): @contextlib.contextmanager -def device(use_gpu): +def device(use_gpu: bool) -> Iterator[None]: """Uses gpu when requested and available.""" if use_gpu and is_gpu_available(): dev = "/device:GPU:0" @@ -2029,28 +2128,28 @@ def device(use_gpu): @contextlib.contextmanager -def use_gpu(): +def use_gpu() -> Iterator[None]: """Uses gpu when requested and available.""" with device(use_gpu=True): yield @contextlib.contextmanager -def force_gpu(): +def force_gpu() -> Iterator[None]: """Force the gpu to be used.""" with ops.device("/device:GPU:0"): yield @contextlib.contextmanager -def force_cpu(): +def force_cpu() -> Iterator[None]: """Force the cpu to be used.""" with ops.device("/device:CPU:0"): yield @contextlib.contextmanager -def deterministic_ops(): +def deterministic_ops() -> Iterator[None]: """Enables deterministic ops.""" try: config.enable_op_determinism() @@ -2062,10 +2161,10 @@ def deterministic_ops(): class CapturedWrites: """A utility class to load the captured writes made to a stream.""" - def __init__(self, capture_location): + def __init__(self, capture_location: str): self.capture_location = capture_location - def contents(self): + def contents(self) -> str: """Get the captured writes as a single string.""" with open(self.capture_location) as tmp_file: output_data = "".join(tmp_file.readlines()) @@ -2144,7 +2243,7 @@ def run(self, *args, **kwargs): raise -def disable_cudnn_autotune(func: Callable[..., _R]) -> Callable[..., _R]: +def disable_cudnn_autotune(func: _F) -> _F: """Disable autotuning during the call to this function. Some tests want to base assertions on a graph being isomorphic with a copy. @@ -2157,7 +2256,7 @@ def disable_cudnn_autotune(func: Callable[..., _R]) -> Callable[..., _R]: Decorated function. """ - def decorated(*args, **kwargs) -> _R: + def decorated(*args, **kwargs): original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" original_xla_flags = os.environ.get("XLA_FLAGS") @@ -2183,17 +2282,13 @@ def decorated(*args, **kwargs) -> _R: # The description is just for documentation purposes. -def enable_tf_xla_constant_folding( - description: str, -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def enable_tf_xla_constant_folding(description: str) -> Callable[[_F], _F]: if not isinstance(description, str): raise ValueError("'description' should be string, got {}".format( type(description))) - def enable_tf_xla_constant_folding_impl( - func: Callable[..., _R], - ) -> Callable[..., _R]: + def enable_tf_xla_constant_folding_impl(func: _F) -> _F: """Enable constant folding during the call to this function. Some tests fail without constant folding. @@ -2205,7 +2300,7 @@ def enable_tf_xla_constant_folding_impl( Decorated function. """ - def decorated(*args, **kwargs) -> _R: + def decorated(*args, **kwargs): original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) result = func(*args, **kwargs) @@ -2218,13 +2313,11 @@ def decorated(*args, **kwargs) -> _R: # Updates test function by selectively disabling it. -def _disable_test( - execute_func: bool, -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def _disable_test(execute_func: bool) -> Callable[[_F], _F]: - def disable_test_impl(func: Callable[..., _R]) -> Callable[..., _R]: + def disable_test_impl(func: _F) -> _F: - def decorated(*args, **kwargs) -> _R: + def decorated(*args, **kwargs): if execute_func: return func(*args, **kwargs) @@ -2234,54 +2327,42 @@ def decorated(*args, **kwargs) -> _R: # The description is just for documentation purposes. -def disable_xla( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_xla(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if xla is not enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_mlir_bridge( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_mlir_bridge(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if MLIR bridge is not enabled.""" execute_func = not is_mlir_bridge_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_asan( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_asan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if ASAN is not enabled.""" execute_func = not is_asan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_msan( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_msan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if MSAN is not enabled.""" execute_func = not is_msan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_tsan( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_tsan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if TSAN is not enabled.""" execute_func = not is_tsan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_ubsan( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def disable_ubsan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if UBSAN is not enabled.""" execute_func = not is_ubsan_enabled() return _disable_test(execute_func) @@ -2290,14 +2371,9 @@ def disable_ubsan( # The description is just for documentation purposes. def disable_tfrt( unused_description: str, # pylint: disable=unused-argument -) -> Callable[ - [Union[_TC, Callable[..., _R]]], - Union[_TC, Callable[..., _R], None] -]: +) -> Callable[[Union[_TC, _F]], Union[_TC, _F, None]]: - def disable_tfrt_impl( - cls_or_func: Union[_TC, Callable[..., _R]] - ) -> Union[_TC, Callable[..., _R], None]: + def disable_tfrt_impl(cls_or_func: Union[_TC, _F]) -> Union[_TC, _F, None]: """Execute the test only if tfrt is not enabled.""" if tf_inspect.isclass(cls_or_func): @@ -2306,8 +2382,8 @@ def disable_tfrt_impl( else: return cast(_TC, cls_or_func) else: - func = cast(Callable[..., _R], cls_or_func) - def decorated(*args, **kwargs) -> _R: + func = cast(Callable[..., Any], cls_or_func) + def decorated(*args, **kwargs): if tfrt_utils.enabled(): return else: @@ -2318,7 +2394,9 @@ def decorated(*args, **kwargs) -> _R: return disable_tfrt_impl -def for_all_test_methods(decorator, *args, **kwargs): +def for_all_test_methods( + decorator: Callable[..., Any], *args, **kwargs, +) -> Callable[[_TC], _TC]: """Generate class-level decorator from given method-level decorator. It is expected for the given decorator to take some arguments and return @@ -2333,7 +2411,7 @@ def for_all_test_methods(decorator, *args, **kwargs): decorator. """ - def all_test_methods_impl(cls): + def all_test_methods_impl(cls: _TC) -> _TC: """Apply decorator to all test methods in class.""" for name in dir(cls): value = getattr(cls, name) @@ -2346,23 +2424,19 @@ def all_test_methods_impl(cls): # The description is just for documentation purposes. -def no_xla_auto_jit( - description: str, # pylint: disable=unused-argument -) -> Callable[[Callable[..., _R]], Callable[..., _R]]: +def no_xla_auto_jit(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """This test is not intended to be run with XLA auto jit enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def xla_allow_fallback( - description: str, # pylint: disable=unused-argument -): +def xla_allow_fallback(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument - def xla_allow_fallback_impl(func: Callable[..., _R]) -> Callable[..., _R]: + def xla_allow_fallback_impl(func: _F) -> _F: """Allow fallback to TF even though testing xla.""" - def decorated(*args, **kwargs) -> _R: + def decorated(*args, **kwargs): if is_xla_enabled(): # Update the global XLABuildOpsPassFlags to enable lazy compilation, # which allows the compiler to fall back to TF classic. Remember the @@ -2380,7 +2454,9 @@ def decorated(*args, **kwargs) -> _R: # The description is just for documentation purposes. -def run_without_tensor_float_32(description): # pylint: disable=unused-argument +def run_without_tensor_float_32( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Execute test with TensorFloat-32 disabled. While almost every real-world deep learning model runs fine with @@ -2396,7 +2472,7 @@ def run_without_tensor_float_32(description): # pylint: disable=unused-argument Decorator which runs a test with TensorFloat-32 disabled. """ - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: @functools.wraps(f) def decorated(*args, **kwargs): @@ -2413,7 +2489,7 @@ def decorated(*args, **kwargs): # The description is just for documentation purposes. -def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument +def run_all_without_tensor_float_32(description: str) -> Callable[[_TC], _TC]: # pylint: disable=unused-argument """Execute all tests in a class with TensorFloat-32 disabled.""" return for_all_test_methods(run_without_tensor_float_32, description) @@ -2555,7 +2631,7 @@ def _ClearCachedSession(self): self._cached_session.close() self._cached_session = None - def get_temp_dir(self): + def get_temp_dir(self) -> str: """Returns a unique temporary directory for the test to use. If you call this method multiple times during in a test, it will return the @@ -2784,7 +2860,11 @@ def evaluate( # pylint: disable=redefined-outer-name @contextlib.contextmanager def session( - self, graph=None, config=None, use_gpu=True, force_gpu=False + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, ) -> Iterator[s.Session]: """A context manager for a TensorFlow Session for use in executing tests. @@ -2829,11 +2909,13 @@ def testMyOperator(self): yield sess @contextlib.contextmanager - def cached_session(self, - graph=None, - config=None, - use_gpu=True, - force_gpu=False) -> Iterator[s.Session]: + def cached_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, + ) -> Iterator[s.Session]: """Returns a TensorFlow Session for use in executing tests. This method behaves differently than self.session(): for performance reasons @@ -2883,11 +2965,13 @@ def testMyOperator(self): @contextlib.contextmanager @deprecation.deprecated(None, "Use `self.session()` or " "`self.cached_session()` instead.") - def test_session(self, - graph=None, - config=None, - use_gpu=True, - force_gpu=False): + def test_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, + ) -> Iterator[s.Session]: """Use cached_session instead.""" if self.id().endswith(".test_session"): self.skipTest( @@ -2917,7 +3001,13 @@ class _CheckedThread(object): method. """ - def __init__(self, testcase, target, args=None, kwargs=None): + def __init__( + self, + testcase: "TensorFlowTestCase", + target: Callable[..., Any], + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ): """Constructs a new instance of _CheckedThread. Args: @@ -2929,21 +3019,21 @@ def __init__(self, testcase, target, args=None, kwargs=None): """ self._testcase = testcase self._target = target - self._args = () if args is None else args - self._kwargs = {} if kwargs is None else kwargs + self._args: tuple[Any, ...] = () if args is None else args + self._kwargs: dict[str, Any] = {} if kwargs is None else kwargs self._thread = threading.Thread(target=self._protected_run) self._exception = None self._is_thread_joined = False - def _protected_run(self): + def _protected_run(self) -> None: """Target for the wrapper thread. Sets self._exception on failure.""" try: self._target(*self._args, **self._kwargs) except Exception as e: # pylint: disable=broad-except self._exception = e - def start(self): + def start(self) -> None: """Starts the thread's activity. This must be called at most once per _CheckedThread object. It arranges @@ -2951,7 +3041,7 @@ def start(self): """ self._thread.start() - def join(self): + def join(self) -> None: """Blocks until the thread terminates. Raises: @@ -2963,7 +3053,7 @@ def join(self): if self._exception is not None: self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) - def is_alive(self): + def is_alive(self) -> bool: """Returns whether the thread is alive. This method returns True just before the run() method starts @@ -2974,7 +3064,7 @@ def is_alive(self): """ return self._thread.is_alive() - def check_termination(self): + def check_termination(self) -> None: """Returns whether the checked thread was properly used and did terminate. Every checked thread should be "join"ed after starting, and before the @@ -2996,7 +3086,12 @@ def check_termination(self): else: self._testcase.fail("A checked thread was not joined.") - def checkedThread(self, target, args=None, kwargs=None): + def checkedThread( + self, + target: Callable[..., Any], + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> _CheckedThread: """Returns a Thread wrapper that asserts 'target' completes successfully. This method should be used to create all threads in test cases, as @@ -3618,8 +3713,13 @@ def assertRaisesWithPredicateMatch(self, exception_type, else: def predicate(e): - err_str = e.message if isinstance(e, errors.OpError) else str(e) - op = e.op if isinstance(e, errors.OpError) else None + if isinstance(e, errors.OpError): + e = cast(errors.OpError, e) + err_str = cast(str, e.message) + op = e.op + else: + err_str = str(e) + op = None while op is not None: err_str += "\nCaused by: " + op.name op = op._original_op # pylint: disable=protected-access @@ -3718,7 +3818,8 @@ def assertDictEqual(self, a, b, msg=None): def _GetPyList(self, a): """Converts `a` to a nested python list.""" if isinstance(a, ragged_tensor.RaggedTensor): - return self.evaluate(a).to_list() + a = cast(ragged_tensor_value.RaggedTensorValue, self.evaluate(a)) + return a.to_list() elif isinstance(a, tensor_lib.Tensor): a = self.evaluate(a) return a.tolist() if isinstance(a, np.ndarray) else a @@ -3772,7 +3873,9 @@ def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): # pylint: enable=invalid-name @contextlib.contextmanager - def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + def _constrain_devices_and_set_default( + self, sess: s.Session, use_gpu: bool, force_gpu: bool, + ) -> Iterator[s.Session]: """Set the session and its graph to global default and constrain devices.""" if context.executing_eagerly(): yield None @@ -3792,10 +3895,17 @@ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): with sess.graph.device("/device:CPU:0"): yield sess - def _create_session(self, graph, config, force_gpu): + def _create_session( + self, + graph: Optional[ops.Graph], + config: Optional[config_pb2.ConfigProto], + force_gpu: bool, + ) -> s.Session: """See session() for details.""" - def prepare_config(config): + def prepare_config( + config: Optional[config_pb2.ConfigProto], + ) -> config_pb2.ConfigProto: """Returns a config for sessions. Args: @@ -3831,11 +3941,13 @@ def prepare_config(config): return ErrorLoggingSession(graph=graph, config=prepare_config(config)) - def _get_cached_session(self, - graph=None, - config=None, - force_gpu=False, - crash_if_inconsistent_args=True): + def _get_cached_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + force_gpu: bool = False, + crash_if_inconsistent_args: bool = True, + ) -> s.Session: """See cached_session() for documentation.""" if self._cached_session is None: sess = self._create_session( @@ -3866,7 +3978,7 @@ def _get_cached_session(self, return self._cached_session -ASSIGNED_PORTS = set() +ASSIGNED_PORTS: set[int] = set() lock = threading.Lock() @@ -3889,11 +4001,13 @@ def pick_unused_port(): @tf_export("test.create_local_cluster") -def create_local_cluster(num_workers, - num_ps, - protocol="grpc", - worker_config=None, - ps_config=None): +def create_local_cluster( + num_workers: int, + num_ps: int, + protocol: str = "grpc", + worker_config: Optional[config_pb2.ConfigProto] = None, + ps_config: Optional[config_pb2.ConfigProto] = None, +) -> tuple[list[server_lib.Server], list[server_lib.Server]]: """Create and start local servers and return the associated `Server` objects. "PS" stands for "parameter server": a task responsible for storing and @@ -3976,7 +4090,9 @@ def create_local_cluster(num_workers, return workers, ps_servers -def get_node_def_from_graph(node_name, graph_def): +def get_node_def_from_graph( + node_name: str, graph_def: graph_pb2.GraphDef, +) -> Optional[node_def_pb2.NodeDef]: """Returns the `NodeDef` instance for given node name in the graph def. This method explores only the NodeDefs in `graph_def.node`. @@ -3994,7 +4110,7 @@ def get_node_def_from_graph(node_name, graph_def): return None -def set_producer_version(graph, producer_version): +def set_producer_version(graph: ops.Graph, producer_version: int) -> None: """Sets graph.graph_def_versions.producer to `producer_version`.""" # The C API doesn't expose altering GraphDefVersions. We can indirectly set # it via import_graph_def though. @@ -4059,7 +4175,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @contextlib.contextmanager -def run_functions_eagerly(run_eagerly): +def run_functions_eagerly(run_eagerly: bool) -> Iterator[None]: """Runs functions eagerly if `run_eagerly` is true. WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode @@ -4104,17 +4220,17 @@ def __init__(self, name, label): self.label = label self.Reset() - def Reset(self): + def Reset(self) -> None: self.last_value = _test_metrics_util.test_counter_value( self.name, self.label) - def Get(self): + def Get(self) -> int: value = _test_metrics_util.test_counter_value(self.name, self.label) return value - self.last_value @tf_export("test.experimental.sync_devices") -def sync_devices(): +def sync_devices() -> None: """Synchronizes all devices. By default, GPUs run asynchronously. This means that when you run an op on the From 8407c142c99a2b4b82a7452044f62406dd4b352c Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 28 Nov 2023 10:10:50 -0800 Subject: [PATCH 148/381] Failed compatibility tests were disabled for older version plugins and we can remove the compatibility support here. PiperOrigin-RevId: 586020786 --- .../xla/xla/python/pjrt_ifrt/pjrt_array.cc | 57 ++++++------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 92c079f56e0928..048003004c3ec8 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/log/check.h" -#include "absl/status/status.h" #include "absl/strings/str_join.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -387,55 +386,35 @@ StatusOr> PjRtArray::Reshard( "first fetched to the host and then sent to the destination " "device."); } - if (new_sharding_has_memory_kind && memories_supported && - semantics == ArrayCopySemantics::kDonateInput && !memory_kind_equal) { - return Unimplemented( - "Donation across different memory kinds is not implemented."); - } - // Try using `PjRtBuffer::CopyToMemorySpace` instead of + // Use `PjRtBuffer::CopyToMemorySpace` instead of // `PjRtBuffer::CopyToDevice` when memories are supported. Because the // semantics of the latter one is to copy to the default memory space of // the device. - std::unique_ptr copied_buffer; if (new_sharding_has_memory_kind && memories_supported) { TF_ASSIGN_OR_RETURN( auto memory_space, GetMemorySpaceFromMemoryKind(new_sharding->devices()[i], canonicalized_sharding_memory_kind)); - StatusOr> copied_buffer_using_memory_space = - pjrt_buffers_[i]->CopyToMemorySpace(memory_space); - if (copied_buffer_using_memory_space.ok()) { - copied_buffer = std::move(*copied_buffer_using_memory_space); - } else if (!absl::IsUnimplemented( - copied_buffer_using_memory_space.status())) { - return copied_buffer_using_memory_space.status(); - } else { - // Returns unimplemented if the sharding's memory space isn't the - // device's default memory space. Otherwise continue on to the - // CopyToDevice fallback. - // TODO(b/307743645): clean up this branch when memory space is better - // supported. - TF_ASSIGN_OR_RETURN( - PjRtMemorySpace * default_memory_space, - new_sharding->devices()[i]->default_memory_space()); - if (canonicalized_sharding_memory_kind.memory_kind() != - default_memory_space->memory_space_kind()) { - return copied_buffer_using_memory_space.status(); + TF_ASSIGN_OR_RETURN(std::unique_ptr copied_buffer, + pjrt_buffers_[i]->CopyToMemorySpace(memory_space)); + if (semantics == ArrayCopySemantics::kDonateInput) { + if (!memory_kind_equal) { + return Unimplemented( + "Donation across different memory kinds is not implemented."); } + pjrt_buffers_[i] = nullptr; } + buffers.push_back(std::shared_ptr(copied_buffer.release())); + } else { + // Use `PjRtBuffer::CopyToDevice` when memories are not supported. + TF_ASSIGN_OR_RETURN( + std::unique_ptr copied_buffer, + pjrt_buffers_[i]->CopyToDevice(new_sharding->devices()[i])); + if (semantics == ArrayCopySemantics::kDonateInput) { + pjrt_buffers_[i] = nullptr; + } + buffers.push_back(std::shared_ptr(copied_buffer.release())); } - // Fallback to `PjRtBuffer::CopyToDevice` if (1) memories are not - // supported or (2) `PjRtBuffer::CopyToMemorySpace` returns unimplemented - // and canonicalized_sharding_memory_kind is the same as the - // default_memory_space of `new_sharding->devices()[i]`. - if (copied_buffer == nullptr) { - TF_ASSIGN_OR_RETURN(copied_buffer, pjrt_buffers_[i]->CopyToDevice( - new_sharding->devices()[i])); - } - if (semantics == ArrayCopySemantics::kDonateInput) { - pjrt_buffers_[i] = nullptr; - } - buffers.push_back(std::shared_ptr(copied_buffer.release())); } } return PjRtArray::Create(client_, dtype_, shape_, std::move(new_sharding), From 8d84d04265b96110213cedd09e8ef39278ff063c Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 28 Nov 2023 10:25:59 -0800 Subject: [PATCH 149/381] `third_party/gpus` changes from PR #7277 that were missed https://github.com/openxla/xla/pull/7277/files PiperOrigin-RevId: 586026438 --- .../clang/bin/crosstool_wrapper_driver_rocm.tpl | 4 ++-- third_party/gpus/rocm_configure.bzl | 10 ++++++---- .../clang/bin/crosstool_wrapper_driver_rocm.tpl | 4 ++-- .../tsl/third_party/gpus/rocm_configure.bzl | 10 ++++++---- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 8fb22313010a45..77ec948af32c6e 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -86,8 +86,8 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - #if args.fno_canonical_system_headers: - # opts += ' -fno-canonical-system-headers' + if args.fno_canonical_system_headers: + opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 90464b07264101..520c9bce6c5265 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -198,6 +198,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -345,14 +347,14 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths = [ (name, _rocm_lib_paths(repository_ctx, name, path)) for name, path in [ - ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), + ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), ("hipsparse", rocm_config.rocm_toolkit_path), - ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), + ("roctracer64", rocm_config.rocm_toolkit_path), ("rocsolver", rocm_config.rocm_toolkit_path), ] ] @@ -694,7 +696,7 @@ def _create_local_rocm_repository(repository_ctx): rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", - "-D__HIP_PLATFORM_HCC__", + "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", ]) @@ -729,7 +731,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hipcc_env}": _hipcc_env(repository_ctx), "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", + "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 8fb22313010a45..77ec948af32c6e 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -86,8 +86,8 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - #if args.fno_canonical_system_headers: - # opts += ' -fno-canonical-system-headers' + if args.fno_canonical_system_headers: + opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index 0bbbc09832db13..5c1195bada43f8 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -198,6 +198,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -345,14 +347,14 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths = [ (name, _rocm_lib_paths(repository_ctx, name, path)) for name, path in [ - ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), + ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), ("hipsparse", rocm_config.rocm_toolkit_path), - ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), + ("roctracer64", rocm_config.rocm_toolkit_path), ("rocsolver", rocm_config.rocm_toolkit_path), ] ] @@ -694,7 +696,7 @@ def _create_local_rocm_repository(repository_ctx): rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", - "-D__HIP_PLATFORM_HCC__", + "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", ]) @@ -729,7 +731,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hipcc_env}": _hipcc_env(repository_ctx), "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", + "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), From 1436d4d6ea8a6c2c588adfaf965458944d06f411 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Tue, 28 Nov 2023 10:30:51 -0800 Subject: [PATCH 150/381] Rewrite debugger functionalities in c++. Factored out a common pattern of mutating `NodeDef`s by iterating all node defs in a `GraphDef` into a templated function and applied it for both `enable_dump_tensor` and `change_dump_tensor_file_name`. PiperOrigin-RevId: 586028409 --- .../mlir/quantization/stablehlo/cc/BUILD | 21 +++++++ .../quantization/stablehlo/cc/graph_def.h | 46 ++++++++++++++ .../stablehlo/cc/graph_def_test.cc | 62 +++++++++++++++++++ .../mlir/quantization/stablehlo/python/BUILD | 3 + .../stablehlo/python/pywrap_quantization.cc | 32 ++++++++-- .../mlir/quantization/tensorflow/python/BUILD | 2 + .../tensorflow/python/py_function_lib.h | 27 -------- .../tensorflow/python/py_function_lib.py | 55 ---------------- .../tensorflow/python/pywrap_function_lib.cc | 17 +---- .../tensorflow/python/pywrap_function_lib.pyi | 10 --- .../python/pywrap_quantize_model.cc | 27 ++++++-- 11 files changed, 184 insertions(+), 118 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index f63ffe76399657..511757314e2da7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -52,3 +52,24 @@ tf_cc_test( "@local_tsl//tsl/platform:types", ], ) + +cc_library( + name = "graph_def", + srcs = [], + hdrs = ["graph_def.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "graph_def_test", + srcs = ["graph_def_test.cc"], + deps = [ + ":graph_def", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h new file mode 100644 index 00000000000000..5796b18e65d632 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ + +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace stablehlo::quantization { + +// Mutates all `NodeDef`s in `graph_def` by applying `func`. It modifies the +// top-level `NodeDef`s as well as all `NodeDef`s in the function library. +// `func` should accept a `NodeDef` reference. +template >> +void MutateNodeDefs(tensorflow::GraphDef& graph_def, FuncT&& func) { + for (tensorflow::NodeDef& node_def : *graph_def.mutable_node()) { + func(node_def); + } + + for (tensorflow::FunctionDef& function_def : + *graph_def.mutable_library()->mutable_function()) { + for (tensorflow::NodeDef& node_def : *function_def.mutable_node_def()) { + func(node_def); + } + } +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc new file mode 100644 index 00000000000000..58796acc4231bf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" + +#include +#include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; + +TEST(GraphDefTest, MutateNodeDefsMutatesTopLevelNodeDefs) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + node { name: "foo" } + )pb", + &graph_def)); + MutateNodeDefs(graph_def, + [](NodeDef& node_def) { node_def.set_name("bar"); }); + + ASSERT_THAT(graph_def.node(), SizeIs(1)); + EXPECT_THAT(graph_def.node()[0].name(), StrEq("bar")); +} + +TEST(GraphDefTest, MutateNodeDefsMutatesFunctionNodeDefs) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + library { function { node_def { name: "foo" } } } + )pb", + &graph_def)); + + MutateNodeDefs(graph_def, + [](NodeDef& node_def) { node_def.set_name("bar"); }); + + ASSERT_THAT(graph_def.library().function(), SizeIs(1)); + ASSERT_THAT(graph_def.library().function()[0].node_def(), SizeIs(1)); + EXPECT_THAT(graph_def.library().function()[0].node_def()[0].name(), + StrEq("bar")); +} + +} // namespace +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 4ef4ad4ce1a720..b8a559b51786e1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -85,13 +85,16 @@ tf_python_pybind_extension( srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:env", "@pybind11", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index ce66277ddd0e5d..ca4518462779e5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/string_view.h" #include "pybind11/detail/common.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 @@ -25,19 +26,27 @@ limitations under the License. #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tsl/platform/env.h" namespace py = pybind11; namespace { +using ::stablehlo::quantization::MutateNodeDefs; using ::stablehlo::quantization::io::CreateTmpDir; +using ::tensorflow::FunctionDef; +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; @@ -62,8 +71,16 @@ ExportedModel EnableDebugging( const std::unordered_set& tags, const absl::flat_hash_map& signature_def_map) { ExportedModel debugger_enabled_exported_model = exported_model; - *debugger_enabled_exported_model.mutable_graph_def() = - py_function_library.EnableDumpTensor(exported_model.graph_def()); + + // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by + // default to avoid logging data during calibration. + MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), + [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["enabled"].set_b(true); + } + }); + if (debugger_options.debugger_type() == DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. @@ -74,9 +91,14 @@ ExportedModel EnableDebugging( debugger_enabled_exported_model, src_saved_model_path, tags, signature_def_map); - *debugger_enabled_exported_model.mutable_graph_def() = - py_function_library.ChangeDumpTensorFileName( - debugger_enabled_exported_model.graph_def()); + // Update the `DumpTensor` ops' file name in `graph_def`. + MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), + [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["file_name"].set_s( + "quantized_tensor_data.pb"); + } + }); } return debugger_enabled_exported_model; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index b9cc102aa33b1b..d3c86b151e7912 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -203,9 +203,11 @@ tf_python_pybind_extension( deps = [ ":py_function_lib", ":type_casters", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/python/lib/core:pybind11_lib", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h index 96d937a7e59a9c..eb064f24a42e48 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -103,33 +103,6 @@ class PyFunctionLibrary { // pywrap_function_lib.pyi:run_calibration, // py_function_lib.py:run_calibration, // ) - - // Enables the `DumpTensor` ops in `graph_def`. This is done by updating the - // `enabled` attribute of `DumpTensor` ops to true. Returns the updated - // `GraphDef`. - // - // If the function signature changes, likely its corresponding .pyi type - // hinting and definition should also change. - // LINT.IfChange - virtual GraphDef EnableDumpTensor(const GraphDef& graph_def) const = 0; - // LINT.ThenChange( - // pywrap_function_lib.pyi:enable_dump_tensor, - // py_function_lib.py:enable_dump_tensor, - // ) - - // Updates the `DumpTensor` ops' file name in `graph_def`. Sets the - // `file_name` attribute to `quantized_tensor_data.pb`. Returns the updated - // `GraphDef`. - // - // If the function signature changes, likely its corresponding .pyi type - // hinting and definition should also change. - // LINT.IfChange - virtual GraphDef ChangeDumpTensorFileName( - const GraphDef& graph_def) const = 0; - // LINT.ThenChange( - // pywrap_function_lib.pyi:change_dump_tensor_file_name, - // py_function_lib.py:change_dump_tensor_file_name, - // ) }; } // namespace tensorflow::quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index c55f11c98b23b7..6c787a58357600 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -697,58 +697,3 @@ def run_calibration( _add_calibration_statistics(exported_model.graph_def, calibration_options) return exported_model.SerializeToString() - - # TODO: b/312371048 - Rewrite this in c++. - # LINT.IfChange(enable_dump_tensor) - def enable_dump_tensor(self, graph_def_serialized: bytes) -> bytes: - """Enable DumpTensor in the graph def. - - DumpTensor is disabled by default to avoid logging data during calibration. - This function is called after calibration to enable DumpTensor. - - Args: - graph_def_serialized: Serialized `GraphDef` to enable DumpTensor - - Returns: - Updated serialized GraphDef where DumpTensors are enabled. - """ - # LINT.ThenChange(py_function_lib.h:enable_dump_tensor) - graph_def = graph_pb2.GraphDef.FromString(graph_def_serialized) - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'DumpTensor': - continue - - node_def.attr['enabled'].b = True - - return graph_def.SerializeToString() - - # TODO: b/312371048 - Rewrite this in c++. - # LINT.IfChange(change_dump_tensor_file_name) - def change_dump_tensor_file_name(self, graph_def_serialized: bytes) -> bytes: - # LINT.ThenChange(py_function_lib.h:change_dump_tensor_file_name) - """Change file_name used by DumpTensor to quantized_tensor_data.pb. - - In whole model verify, DumpTensor in unquantized model uses file_name - unquantized_tensor_data.pb. - After unquantized dump model is created, this function allows quantized dump - model to use quantized_tensor_data.pb as file_name. - - Args: - graph_def_serialized: Serialized `GraphDef` to change file_name of - DumpTensor - - Returns: - Serialized GraphDef with updated file names for DumpTensors. - """ - graph_def = graph_pb2.GraphDef.FromString(graph_def_serialized) - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'DumpTensor': - continue - - node_def.attr['file_name'].s = 'quantized_tensor_data.pb'.encode( - 'utf-8' - ) - - return graph_def.SerializeToString() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index f850b7effbe0fd..928bcd3690c307 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -79,16 +79,6 @@ class PyFunctionLibraryTrampoline : public PyFunctionLibrary { signature_keys, tags, exported_model, calibration_options, force_graph_mode_calibration, representative_dataset); } - - GraphDef EnableDumpTensor(const GraphDef& graph_def) const override { - PYBIND11_OVERRIDE_PURE(GraphDef, PyFunctionLibrary, enable_dump_tensor, - graph_def); - } - - GraphDef ChangeDumpTensorFileName(const GraphDef& graph_def) const override { - PYBIND11_OVERRIDE_PURE(GraphDef, PyFunctionLibrary, - change_dump_tensor_file_name, graph_def); - } }; } // namespace @@ -110,10 +100,5 @@ PYBIND11_MODULE(pywrap_function_lib, m) { py::arg("tags"), py::arg("exported_model_serialized"), py::arg("calibration_options_serialized"), py::arg("force_graph_mode_calibration"), - py::arg("representative_dataset")) - .def("enable_dump_tensor", &PyFunctionLibrary::EnableDumpTensor, - py::arg("graph_def_serialized")) - .def("change_dump_tensor_file_name", - &PyFunctionLibrary::ChangeDumpTensorFileName, - py::arg("graph_def_serialized")); + py::arg("representative_dataset")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi index d464670c5593f0..8d94442f5fcf0d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi @@ -45,13 +45,3 @@ class PyFunctionLibrary: representative_dataset: Any, ) -> bytes: ... # LINT.ThenChange() - - # LINT.IfChange(enable_dump_tensor) - def enable_dump_tensor(self, graph_def_serialized: bytes) -> bytes: ... - # LINT.ThenChange() - - # LINT.IfChange(change_dump_tensor_file_name) - def change_dump_tensor_file_name( - self, graph_def_serialized: bytes - ) -> bytes: ... - # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index bf56d8607d9058..5f9c9254fe1940 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -29,18 +29,22 @@ limitations under the License. #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" namespace { +using ::stablehlo::quantization::MutateNodeDefs; using ::stablehlo::quantization::io::CreateTmpDir; +using ::tensorflow::NodeDef; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; @@ -69,8 +73,16 @@ ExportedModel EnableDebugging( const std::unordered_set& tags, const absl::flat_hash_map& signature_def_map) { ExportedModel debugger_enabled_exported_model = exported_model; - *debugger_enabled_exported_model.mutable_graph_def() = - py_function_library.EnableDumpTensor(exported_model.graph_def()); + + // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by + // default to avoid logging data during calibration. + MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), + [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["enabled"].set_b(true); + } + }); + if (debugger_options.debugger_type() == DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. @@ -81,9 +93,14 @@ ExportedModel EnableDebugging( debugger_enabled_exported_model, src_saved_model_path, tags, signature_def_map); - *debugger_enabled_exported_model.mutable_graph_def() = - py_function_library.ChangeDumpTensorFileName( - debugger_enabled_exported_model.graph_def()); + // Update the `DumpTensor` ops' file name in `graph_def`. + MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), + [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["file_name"].set_s( + "quantized_tensor_data.pb"); + } + }); } return debugger_enabled_exported_model; From 44ac2acf95b96335a0be305ae1735fa4cb0b1591 Mon Sep 17 00:00:00 2001 From: Yaning Liang Date: Tue, 28 Nov 2023 12:27:32 -0800 Subject: [PATCH 151/381] Add missing #includes that define symbols referenced by simple_*delegate* Also remove some unused #includes from the .cc files. Also use "= default" syntax for destructor. PiperOrigin-RevId: 586069299 --- tensorflow/lite/delegates/utils/BUILD | 5 +---- tensorflow/lite/delegates/utils/simple_delegate.cc | 6 +----- tensorflow/lite/delegates/utils/simple_delegate.h | 5 +---- tensorflow/lite/delegates/utils/simple_delegate_test.cc | 3 +-- .../lite/delegates/utils/simple_opaque_delegate.cc | 7 +++---- tensorflow/lite/delegates/utils/simple_opaque_delegate.h | 3 --- .../lite/delegates/utils/simple_opaque_delegate_test.cc | 9 --------- 7 files changed, 7 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/delegates/utils/BUILD b/tensorflow/lite/delegates/utils/BUILD index 9bb7079a8a9386..924eb3cb0c02a4 100644 --- a/tensorflow/lite/delegates/utils/BUILD +++ b/tensorflow/lite/delegates/utils/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,7 +19,6 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/lite:array", "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/c:common", @@ -53,7 +52,6 @@ cc_library_with_tflite( "//tensorflow/lite/c:common", ], deps = [ - "//tensorflow/lite:array", "//tensorflow/lite:builtin_ops", "//tensorflow/lite:minimal_logging", "//tensorflow/lite:util", @@ -77,7 +75,6 @@ cc_test( data = [":c_api_test_builtin_op_models"], deps = [ ":simple_opaque_delegate", - "//tensorflow/lite:framework_stable", "//tensorflow/lite/c:c_api", "//tensorflow/lite/c:c_api_experimental", "//tensorflow/lite/c:c_api_types", diff --git a/tensorflow/lite/delegates/utils/simple_delegate.cc b/tensorflow/lite/delegates/utils/simple_delegate.cc index 6b0401ba0385d9..7b3e9647a052a7 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate.cc +++ b/tensorflow/lite/delegates/utils/simple_delegate.cc @@ -14,20 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_delegate.h" -#include -#include - #include #include #include #include -#include "tensorflow/lite/array.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" namespace tflite { diff --git a/tensorflow/lite/delegates/utils/simple_delegate.h b/tensorflow/lite/delegates/utils/simple_delegate.h index 8b65fa892d4004..0fb67abc5f7621 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate.h +++ b/tensorflow/lite/delegates/utils/simple_delegate.h @@ -29,10 +29,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ -#include - #include -#include #include "tensorflow/lite/core/c/common.h" @@ -45,7 +42,7 @@ using TfLiteDelegateUniquePtr = // Each instance represents a single part of the graph (subgraph). class SimpleDelegateKernelInterface { public: - virtual ~SimpleDelegateKernelInterface() = default; + virtual ~SimpleDelegateKernelInterface() {} // Initializes a delegated subgraph. // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace diff --git a/tensorflow/lite/delegates/utils/simple_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_delegate_test.cc index fc589f1fb2dfc7..f2d9d352551601 100644 --- a/tensorflow/lite/delegates/utils/simple_delegate_test.cc +++ b/tensorflow/lite/delegates/utils/simple_delegate_test.cc @@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include #include +#include #include #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc index 8d0b7aba519e45..db22cbe891f7c3 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc @@ -14,19 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include -#include - +#include #include +#include #include -#include "tensorflow/lite/array.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace { diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate.h b/tensorflow/lite/delegates/utils/simple_opaque_delegate.h index b6f8e91946ee58..817186c01a1cc3 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate.h +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate.h @@ -31,10 +31,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_OPAQUE_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_OPAQUE_DELEGATE_H_ -#include - #include -#include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc index 953f348d734b41..dd6155b4b52667 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate_test.cc @@ -14,29 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include -#include -#include - #include #include #include #include #include -#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/delegate_test_util.h" #include "tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" -#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model_builder.h" namespace tflite { From 3c9304accf67e1b4856c0af44eef249a4258ef41 Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 28 Nov 2023 14:03:13 -0800 Subject: [PATCH 152/381] Make the allocated memory array size private, only allow querying the data size. This is an implementation detail. PiperOrigin-RevId: 586097271 --- tensorflow/lite/simple_memory_arena.cc | 22 +++++++++++++--------- tensorflow/lite/simple_memory_arena.h | 17 ++++------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index a61f68c7cdb97e..d2885f4bc8cd84 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -39,6 +39,10 @@ T AlignTo(size_t alignment, T offset) { : offset + (alignment - offset % alignment); } +size_t RequiredAllocationSize(size_t data_array_size, size_t alignment) { + return data_array_size + alignment - 1; +} + } // namespace namespace tflite { @@ -48,16 +52,17 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { // Skip reallocation when resizing down. return false; } - const size_t new_allocation_size = RequiredAllocationSize(new_size); #ifdef TF_LITE_TENSORFLOW_PROFILER PauseHeapMonitoring(/*pause=*/true); OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), - new_allocation_size); + new_size); #endif + const size_t new_allocation_size = + RequiredAllocationSize(new_size, alignment_); char* new_buffer = reinterpret_cast(std::malloc(new_allocation_size)); #if defined(__clang__) #if __has_feature(memory_sanitizer) - memset(new_buffer, 0, new_allocation_size); + std::memset(new_buffer, 0, new_allocation_size); #endif #endif char* new_aligned_ptr = reinterpret_cast( @@ -73,8 +78,7 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { #ifdef TF_LITE_TENSORFLOW_PROFILER if (data_size_ > 0) { OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - RequiredAllocationSize(data_size_)); + reinterpret_cast(this), data_size_); } #endif data_size_ = new_size; @@ -90,7 +94,7 @@ void ResizableAlignedBuffer::Release() { } #ifdef TF_LITE_TENSORFLOW_PROFILER OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - RequiredAllocationSize(data_size_)); + data_size_); #endif std::free(buffer_); buffer_ = nullptr; @@ -209,8 +213,8 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - TF_LITE_ENSURE( - context, underlying_buffer_.GetDataSize() >= (alloc.offset + alloc.size)); + TF_LITE_ENSURE(context, + underlying_buffer_.GetSize() >= (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { @@ -240,7 +244,7 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_.GetDataSize(), + tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_.GetSize(), active_allocs_); } diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index c72b3595919a88..05a26ccc20d27d 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -76,12 +76,9 @@ class ResizableAlignedBuffer { // Pointer to the data array. char* GetPtr() const { return aligned_ptr_; } - // Size of the data array (NOT of the allocation). - size_t GetDataSize() const { return data_size_; } - // Size of the allocation (NOT of the data array). - size_t GetAllocationSize() const { - return RequiredAllocationSize(data_size_); - } + // Size of the data array. Note: the allocated memory block might be larger + // due to excess alignment requirements. + size_t GetSize() const { return data_size_; } // Alignment of the data array. size_t GetAlignment() const { return alignment_; } @@ -91,10 +88,6 @@ class ResizableAlignedBuffer { ResizableAlignedBuffer(ResizableAlignedBuffer&&) = delete; ResizableAlignedBuffer& operator=(ResizableAlignedBuffer&&) = delete; - size_t RequiredAllocationSize(size_t data_array_size) const { - return data_array_size + alignment_ - 1; - } - char* buffer_; size_t data_size_; size_t alignment_; @@ -161,9 +154,7 @@ class SimpleMemoryArena { // again until Commit() is called & tensor allocations are resolved. TfLiteStatus ReleaseBuffer(); - size_t GetBufferSize() const { - return underlying_buffer_.GetAllocationSize(); - } + size_t GetBufferSize() const { return underlying_buffer_.GetSize(); } std::intptr_t BasePointer() const { return reinterpret_cast(underlying_buffer_.GetPtr()); From ddc414002cdb2345399dc63d23ce00b0258a99fa Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Tue, 28 Nov 2023 14:37:04 -0800 Subject: [PATCH 153/381] Remove a dangling TODO comment. There is apparently no feasible way of resolving the TODO comment. PiperOrigin-RevId: 586106981 --- .../mlir/quantization/stablehlo/passes/quantization_pattern.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h index 133bafab2978fe..3922374e402353 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h @@ -76,8 +76,6 @@ bool IsOpQuantizableStableHlo(Operation* op); // // Implementation of this pattern is mostly copied from QuantizationPattern in // third_party/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h. -// TODO - b/310545259 : Split declarations and implementations of -// StableHloQuantizationPattern. template class StableHloQuantizationPattern : public RewritePattern { From 37f97080ae958920809d3a13c45c89b0469dbc56 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 28 Nov 2023 16:03:46 -0800 Subject: [PATCH 154/381] [xla:runtime] NFC: Remove runner library PiperOrigin-RevId: 586130880 --- third_party/xla/xla/runtime/runner/BUILD | 91 ----- third_party/xla/xla/runtime/runner/runner.cc | 342 ------------------ third_party/xla/xla/runtime/runner/runner.h | 56 --- .../xla/xla/runtime/runner/runner.proto | 59 --- third_party/xla/xla/runtime/runner/runner.py | 149 -------- .../xla/xla/runtime/runner/testlib_runner.cc | 37 -- .../xla/runtime/runner/testlib_runner_test.py | 85 ----- 7 files changed, 819 deletions(-) delete mode 100644 third_party/xla/xla/runtime/runner/BUILD delete mode 100644 third_party/xla/xla/runtime/runner/runner.cc delete mode 100644 third_party/xla/xla/runtime/runner/runner.h delete mode 100644 third_party/xla/xla/runtime/runner/runner.proto delete mode 100644 third_party/xla/xla/runtime/runner/runner.py delete mode 100644 third_party/xla/xla/runtime/runner/testlib_runner.cc delete mode 100644 third_party/xla/xla/runtime/runner/testlib_runner_test.py diff --git a/third_party/xla/xla/runtime/runner/BUILD b/third_party/xla/xla/runtime/runner/BUILD deleted file mode 100644 index ff3276c1d65901..00000000000000 --- a/third_party/xla/xla/runtime/runner/BUILD +++ /dev/null @@ -1,91 +0,0 @@ -load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") # maybe @unused in OSS -load("//xla:xla.bzl", "xla_py_proto_library") -load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") - -package( - default_visibility = ["//visibility:public"], - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -cc_library( - name = "runner_lib", - srcs = ["runner.cc"], - hdrs = ["runner.h"], - visibility = ["//visibility:public"], - deps = [ - ":runner_proto_cc", - "//xla/runtime:arguments", - "//xla/runtime:executable", - "//xla/runtime:jit_executable", - "//xla/runtime:logical_result", - "//xla/runtime:results", - "//xla/runtime:types", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/util:command_line_flags", - ], -) - -tf_proto_library( - name = "runner_proto", - srcs = ["runner.proto"], - cc_api_version = 2, - protodeps = ["//xla:xla_data_proto"], - visibility = ["//visibility:public"], -) - -xla_py_proto_library( - name = "runner_pb2", - api_version = 2, - visibility = ["//visibility:public"], - deps = [":runner_proto"], -) - -xla_py_proto_library( - name = "xla_data_pb2", - api_version = 2, - visibility = ["//visibility:public"], - deps = ["//xla:xla_data_proto"], -) - -py_strict_library( - name = "runner", - testonly = True, - srcs = ["runner.py"], - deps = [ - ":runner_proto_py", - "//third_party/py/numpy", - "//xla:xla_data_proto_py", - ], -) - -# copybara:uncomment_begin(b/254857628) -# py_strict_test( -# name = "testlib_runner_test", -# size = "small", -# srcs = ["testlib_runner_test.py"], -# data = [":testlib_runner"], -# python_version = "PY3", -# srcs_version = "PY3", -# deps = [ -# ":runner", -# "//third_party/py/numpy", -# "@absl_py//absl/testing:absltest", -# ], -# ) -# -# cc_binary( -# name = "testlib_runner", -# testonly = True, -# srcs = ["testlib_runner.cc"], -# deps = [ -# ":runner_lib", -# "//xla/mlir/runtime/transforms/tests:testlib_pipeline", -# ], -# ) -# copybara:uncomment_end diff --git a/third_party/xla/xla/runtime/runner/runner.cc b/third_party/xla/xla/runtime/runner/runner.cc deleted file mode 100644 index 41afe3d146f321..00000000000000 --- a/third_party/xla/xla/runtime/runner/runner.cc +++ /dev/null @@ -1,342 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/runtime/runner/runner.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "xla/runtime/arguments.h" -#include "xla/runtime/logical_result.h" -#include "xla/runtime/results.h" -#include "xla/runtime/runner/runner.pb.h" -#include "xla/runtime/types.h" -#include "tsl/platform/env.h" -#include "tsl/platform/init_main.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/protobuf.h" - -namespace xla { -namespace runtime { - -using absl::InternalError; -using absl::InvalidArgumentError; -using absl::StrFormat; - -using tsl::Env; -using tsl::ReadBinaryProto; -using tsl::ReadFileToString; -using tsl::ReadTextProto; -using tsl::WriteBinaryProto; -using tsl::WriteTextProto; - -using RunnerArgs = Arguments; - -void AppendRunnerFlags(std::vector* flag_list, RunnerFlags* flags) { - flag_list->emplace_back("function", &flags->function, "Test function name."); - - flag_list->emplace_back("module", &flags->module_path, "Path to MLIR input."); - - flag_list->emplace_back( - "arguments", &flags->arguments_path, - "Path to arguments file. If the file ends in '.pbtxt' it is expected to " - "be in the human-readable proto text format, otherwise it is expected " - "to be in the proto binary format."); - - flag_list->emplace_back( - "results", &flags->results_path, - "Path to results file. The runner tool will serialize results into a " - " proto message and write it to this file path."); -} -//===----------------------------------------------------------------------===// - -AsyncTaskRunner* NoAsyncTaskRunner() { - return reinterpret_cast(0xDEADBEEF); -} - -//===----------------------------------------------------------------------===// -// Helper functions to Read/Write protobuf messages. -//===----------------------------------------------------------------------===// - -template -static absl::Status ReadProtoFile(Env* env, const std::string& fname, - T* proto) { - if (absl::EndsWith(fname, ".pbtxt")) { - return ReadTextProto(env, fname, proto); - } else { - return ReadBinaryProto(env, fname, proto); - } -} - -template -static absl::Status WriteProtoFile(Env* env, const std::string& fname, - T& proto) { - if (absl::EndsWith(fname, ".pbtxt")) { - return WriteTextProto(env, fname, proto); - } else { - return WriteBinaryProto(env, fname, proto); - } -} - -//===----------------------------------------------------------------------===// -// Convert ArgumentsProto message to Xla runtime arguments. -//===----------------------------------------------------------------------===// - -static absl::Status ConvertScalar(const ScalarProto& scalar, RunnerArgs& args) { - switch (scalar.value_case()) { - case ScalarProto::ValueCase::kI32: - args.emplace_back(scalar.i32()); - break; - case ScalarProto::ValueCase::kI64: - args.emplace_back(scalar.i64()); - break; - default: - return InvalidArgumentError( - StrFormat("unsupported scalar argument: %s", scalar.DebugString())); - } - return absl::OkStatus(); -} - -static absl::Status ConvertTensor(const TensorProto& tensor, RunnerArgs& args) { - args.emplace_back( - tensor.dtype(), - static_cast(const_cast(&tensor.contents())), - /*offset=*/0, tensor.sizes(), tensor.strides()); - return absl::OkStatus(); -} - -// Converts arguments protobuf message into Xla runtime arguments. -static absl::Status ConvertArgs(ArgumentsProto& proto, RunnerArgs& args) { - for (auto& arg : proto.arguments()) { - switch (arg.argument_case()) { - // Convert `ScalarProto` -> `ScalarArg`. - case ArgumentProto::ArgumentCase::kScalar: - if (auto st = ConvertScalar(arg.scalar(), args); !st.ok()) return st; - break; - // Convert `TensorProto` -> `MemrefDesc`. - case ArgumentProto::ArgumentCase::kTensor: - if (auto st = ConvertTensor(arg.tensor(), args); !st.ok()) return st; - break; - // Unsupported argument type. - default: - return InvalidArgumentError( - StrFormat("unsupported argument: %s", arg.DebugString())); - } - } - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Convert returned results to ResultsProto message. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): Implement error propagation through the results proto. -static void CheckNoError(const absl::Status& status) { - LOG(FATAL) << "Unexpected call to `ReturnError`"; -} - -// Converts results returned from compiled Xla executable to results proto. -struct ReturnResults { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* ret) const { - // We rely on the fact that result converter handles results from left to - // right and we can push new results to the back of the list. - auto* result = proto->add_results(); - - // Return scalar result as `ScalarProto`. - auto* scalar = llvm::dyn_cast(type); - switch (scalar ? scalar->type() : PrimitiveType::PRIMITIVE_TYPE_INVALID) { - case PrimitiveType::S32: - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(ret, sizeof(int32_t)); - result->mutable_scalar()->set_i32(*reinterpret_cast(ret)); - return success(); - default: - break; - } - - // Assuming result cannot be processed as Scalar, try `TensorProto` - auto* memref = llvm::dyn_cast(runtime_type); - if (memref) { - auto desc = ConvertReturnedMemref(*this, memref, ret); - if (failed(desc)) return failure(); - - char* data = static_cast(desc->data()); - int64_t size_in_bytes = primitive_util::ByteWidth(desc->dtype()); - - TensorProto* tensor_proto = result->mutable_tensor(); - for (int64_t size : desc->sizes()) { - size_in_bytes *= size; - tensor_proto->add_sizes(size); - } - - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(data, size_in_bytes); - tensor_proto->set_contents(std::string(data, size_in_bytes)); - tensor_proto->set_dtype(desc->dtype()); - - std::free(desc->data()); - return success(); - } - - return failure(); - } - - MemrefDesc operator()(PrimitiveType element_type, void* base_ptr, - void* data_ptr, int64_t offset, - absl::Span sizes, - absl::Span strides) const { - return MemrefDesc(element_type, base_ptr, offset, sizes, strides); - } - - ResultsProto* proto = nullptr; -}; - -// Converts arguments protobuf message into Xla runtime arguments. -static absl::Status WriteInoutResults(ArgumentsProto& proto, RunnerArgs& args, - ResultsProto* results) { - for (int i = 0; i < proto.arguments().size(); ++i) { - ArgumentProto arg = proto.arguments().Get(i); - switch (arg.argument_case()) { - case ArgumentProto::ArgumentCase::kScalar: - continue; - case ArgumentProto::ArgumentCase::kTensor: - if (arg.tensor().inout()) { - auto* result = results->add_results(); - TensorProto* tensor_proto = result->mutable_tensor(); - - auto* memref = llvm::cast(&args[i]); - - char* sv = static_cast(memref->data()); - int64_t size_in_bytes = primitive_util::ByteWidth(memref->dtype()); - - for (int64_t size : memref->sizes()) { - size_in_bytes *= size; - tensor_proto->add_sizes(size); - } - - tensor_proto->set_contents(std::string(sv, size_in_bytes)); - tensor_proto->set_dtype(memref->dtype()); - } - break; - // Unsupported argument type. - default: - return InvalidArgumentError( - StrFormat("unsupported argument: %s", arg.DebugString())); - } - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// - -absl::Status Execute(RunnerFlags flags, - const JitExecutable::Options& compile_opts, - const Executable::ExecuteOpts& execute_opts) { - LOG(INFO) << "Executing runner tool:\n" - << " - module: " << flags.module_path << "\n" - << " - arguments: " << flags.arguments_path << "\n" - << " - results: " << flags.results_path; - - Env* env = Env::Default(); - - // Read MLIR module from the input file. - std::string module; - if (auto st = ReadFileToString(env, flags.module_path, &module); !st.ok()) { - return InternalError( - StrFormat("failed to read module input from %s, error: %s", - flags.module_path, st.message())); - } - - // Read arguments from the input file. - ArgumentsProto args_proto; - if (auto read = ReadProtoFile(env, flags.arguments_path, &args_proto); - !read.ok()) { - return InternalError( - StrFormat("failed to read arguments input from %s, error %s", - flags.arguments_path, read.message())); - } - - // Convert arguments proto message to the Xla runtime arguments. - RunnerArgs args(args_proto.arguments_size()); - if (auto converted = ConvertArgs(args_proto, args); !converted.ok()) - return converted; - - // Instantiate JitExecutable from the input module. - absl::StatusOr jit_executable = - JitExecutable::Instantiate(module, compile_opts, {flags.function}); - if (!jit_executable.ok()) return jit_executable.status(); - - // TODO(ezhulenev): Add support for specializing to arguments shapes/values. - AsyncValuePtr executable = jit_executable->DefaultExecutable(); - if (executable.IsError()) return executable.GetError(); - - // Convert returned results to results proto. - ResultsProto results_proto; - ResultConverterSet converter(CheckNoError, ReturnResults{&results_proto}); - - // Execute and convert results to proto message. - if (auto executed = executable->Execute(args, converter, execute_opts); - !executed.ok()) - return executed.status(); - - if (auto inout = WriteInoutResults(args_proto, args, &results_proto); - !inout.ok()) - return inout; - - // Write results proto to the requested file location. - if (auto wrote = WriteProtoFile(env, flags.results_path, results_proto); - !wrote.ok()) - return InternalError( - StrFormat("failed to write results proto to %s, error %s", - flags.results_path, wrote.message())); - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Compose Xla Runtime Runner into `main` function. -//===----------------------------------------------------------------------===// - -int Main(int argc, char** argv, const JitExecutable::Options& compile_opts, - const Executable::ExecuteOpts& execute_opts) { - xla::runtime::RunnerFlags flags; - - std::vector flag_list; - xla::runtime::AppendRunnerFlags(&flag_list, &flags); - - if (auto parsed = tsl::Flags::Parse(&argc, argv, flag_list); !parsed) { - std::cerr << "Failed to parse runner flags"; - return 1; - } - - tsl::port::InitMain(argv[0], &argc, &argv); - - if (auto executed = Execute(flags, compile_opts, execute_opts); - !executed.ok()) { - std::cerr << "Failed to execute runner tool: " << executed.message(); - return 1; - } - - return 0; -} - -} // namespace runtime -} // namespace xla diff --git a/third_party/xla/xla/runtime/runner/runner.h b/third_party/xla/xla/runtime/runner/runner.h deleted file mode 100644 index 20344553ded097..00000000000000 --- a/third_party/xla/xla/runtime/runner/runner.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_RUNTIME_RUNNER_RUNNER_H_ -#define XLA_RUNTIME_RUNNER_RUNNER_H_ - -#include -#include - -#include "absl/status/status.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" -#include "tsl/util/command_line_flags.h" - -namespace xla { -namespace runtime { - -struct RunnerFlags { - std::string function; - std::string module_path; - std::string arguments_path; - std::string results_path; -}; - -void AppendRunnerFlags(std::vector* flag_list, RunnerFlags* flags); - -// Fake AsyncTaskRunner for programs that do not plan to execute any async work. -AsyncTaskRunner* NoAsyncTaskRunner(); - -// Compiles and executes the MLIR input program defined by `flags` using -// user-provided compilation and execution options. -absl::Status Execute(RunnerFlags flags, - const JitExecutable::Options& compile_opts, - const Executable::ExecuteOpts& execute_opts); - -// A wrapper around `Execute` that does argument parsing and binary -// initialization. Can be used as a main function in user-defined tools. -int Main(int argc, char** argv, const JitExecutable::Options& compile_opts, - const Executable::ExecuteOpts& execute_opts); - -} // namespace runtime -} // namespace xla - -#endif // XLA_RUNTIME_RUNNER_RUNNER_H_ diff --git a/third_party/xla/xla/runtime/runner/runner.proto b/third_party/xla/xla/runtime/runner/runner.proto deleted file mode 100644 index e2e6948c8eb9a4..00000000000000 --- a/third_party/xla/xla/runtime/runner/runner.proto +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package xla; - -import "xla/xla_data.proto"; - -message ScalarProto { - oneof value { - int32 i32 = 1; - int64 i64 = 2; - } -} - -message TensorProto { - PrimitiveType dtype = 1; - int64 offset = 2; - repeated int64 sizes = 3; - repeated int64 strides = 4; - bool inout = 5; - - bytes contents = 6; -} - -message ArgumentProto { - oneof argument { - ScalarProto scalar = 1; - TensorProto tensor = 2; - } -} - -message ResultProto { - oneof result { - ScalarProto scalar = 1; - TensorProto tensor = 2; - } -} - -message ArgumentsProto { - repeated ArgumentProto arguments = 1; -} - -message ResultsProto { - repeated ResultProto results = 1; -} diff --git a/third_party/xla/xla/runtime/runner/runner.py b/third_party/xla/xla/runtime/runner/runner.py deleted file mode 100644 index 5f443e8a04ab93..00000000000000 --- a/third_party/xla/xla/runtime/runner/runner.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python helper for running Xla runtime runner tools.""" - -import os -import subprocess -import tempfile -from typing import Any, Sequence - -import numpy as np - -from local_xla.xla import xla_data_pb2 -from xla.runtime.runner import runner_pb2 - -PrimitiveType = xla_data_pb2.PrimitiveType - -XLA_ELEMENT_TYPE_TO_DTYPE = { - PrimitiveType.PRED: np.dtype("bool"), - PrimitiveType.S8: np.dtype("int8"), - PrimitiveType.S16: np.dtype("int16"), - PrimitiveType.S32: np.dtype("int32"), - PrimitiveType.S64: np.dtype("int64"), - PrimitiveType.U8: np.dtype("uint8"), - PrimitiveType.U16: np.dtype("uint16"), - PrimitiveType.U32: np.dtype("uint32"), - PrimitiveType.U64: np.dtype("uint64"), - PrimitiveType.F16: np.dtype("float16"), - PrimitiveType.F32: np.dtype("float32"), - PrimitiveType.F64: np.dtype("float64"), - PrimitiveType.C64: np.dtype("complex64"), - PrimitiveType.C128: np.dtype("complex128"), - PrimitiveType.TUPLE: np.dtype(np.object_), - PrimitiveType.TOKEN: np.dtype(np.object_), -} - -# Note the conversion on the key. Numpy has a known issue wherein dtype hashing -# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, -# when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} - - -class Runner: - """Python helper for running Xla runtime runner tools.""" - - def __init__(self, runner: str): - self.runner = runner - - def execute(self, - module: str, - function: str, - arguments: Sequence[Any], - inout: Sequence[int] = None) -> Sequence[Any]: - """Executes `module` with user-provided arguments.""" - temp = tempfile.mkdtemp() - - # Write input mlir module to a file. - module_file = os.path.join(temp, "module.mlir") - with open(module_file, "w") as f: - f.write(module) - - inout = set(inout or []) - - # Pack arguments into a proto message. - args_proto = runner_pb2.ArgumentsProto() - for i, arg in enumerate(arguments): - if isinstance(arg, int): - args_proto.arguments.append( - runner_pb2.ArgumentProto(scalar=runner_pb2.ScalarProto(i32=arg))) - if i in inout: - raise RuntimeError(f"inout param {i} cannot be of type ScalarArg") - continue - elif isinstance(arg, np.ndarray): - element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(arg.dtype)] - args_proto.arguments.append( - runner_pb2.ArgumentProto( - tensor=runner_pb2.TensorProto( - dtype=element_type, - sizes=arg.shape, - strides=arg.strides, - inout=(i in inout), - contents=arg.tobytes()))) - - continue - - raise TypeError("Unsupported argument type") - - # Serialize argument proto message to a file. - arguments_file = os.path.join(temp, "arguments.pb") - with open(arguments_file, "wb") as f: - f.write(args_proto.SerializeToString()) - - # Expected results file path. - results_file = os.path.join(temp, "results.pb") - - # Execute the runner tool. - runner_cmd = [ - self.runner, "--logtostderr", f"--function={function}", - f"--module={module_file}", f"--arguments={arguments_file}", - f"--results={results_file}" - ] - result = subprocess.run(runner_cmd, capture_output=False, check=False) - - if result.returncode != 0: - err = result.stderr.decode("utf-8") - raise RuntimeError(f"failed to execute runner tool: {err}") - - # Read returned results. - with open(results_file, "rb") as f: - results_proto = runner_pb2.ResultsProto.FromString(f.read()) - - # Convert results from proto back to python objects. - results = [] - - for res in results_proto.results: - # Convert ScalarProto to scalar object - if res.HasField("scalar"): - scalar = res.scalar - - if hasattr(scalar, "i32"): - results.append(scalar.i32) - continue - if hasattr(scalar, "i64"): - results.append(scalar.i64) - continue - - # Convert TensorProto to numpy array - elif res.HasField("tensor"): - tensor = res.tensor - dtype = XLA_ELEMENT_TYPE_TO_DTYPE[tensor.dtype] - result_array = np.frombuffer(tensor.contents, dtype=dtype) - results.append(result_array) - continue - - raise ValueError(f"Unknown result {res}") - - return results diff --git a/third_party/xla/xla/runtime/runner/testlib_runner.cc b/third_party/xla/xla/runtime/runner/testlib_runner.cc deleted file mode 100644 index 4b66a8c426b79c..00000000000000 --- a/third_party/xla/xla/runtime/runner/testlib_runner.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/runtime/transforms/tests/testlib_pipeline.h" -#include "xla/runtime/runner/runner.h" - -using namespace xla::runtime; // NOLINT - -static JitExecutable::Options CompileOpts() { - JitExecutable::Options opts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.register_dialects = RegisterXlaRuntimeTestlibDialects; - opts.compiler.create_compilation_pipeline = CreateXlaRuntimeTestlibPipeline; - return opts; -} - -static Executable::ExecuteOpts ExecuteOpts() { - Executable::ExecuteOpts opts; - opts.async_task_runner = xla::runtime::NoAsyncTaskRunner(); - return opts; -} - -int main(int argc, char** argv) { - return xla::runtime::Main(argc, argv, CompileOpts(), ExecuteOpts()); -} diff --git a/third_party/xla/xla/runtime/runner/testlib_runner_test.py b/third_party/xla/xla/runtime/runner/testlib_runner_test.py deleted file mode 100644 index 70a36e4dbff0ba..00000000000000 --- a/third_party/xla/xla/runtime/runner/testlib_runner_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for testlib_runner.""" - -import pathlib - -from absl.testing import absltest -import numpy as np - -from xla.runtime.runner import runner - -# We assume that the testlib runner is defined in the same project as this test. -r = runner.Runner(f'{pathlib.Path(__file__).parent.resolve()}/testlib_runner') - - -class TestlibRunnerTest(absltest.TestCase): - - def testScalarAdd(self): - module = """ - func.func @add(%arg0: i32) -> i32 { - %0 = arith.constant 42 : i32 - %1 = arith.addi %arg0, %0 : i32 - return %1 : i32 - }""" - - [res] = r.execute(module, 'add', [42]) - self.assertEqual(res, 84) - - def testTensorAdd(self): - module = """ - func.func @addtensor(%arg0: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 3 : index - %step = arith.constant 1 : index - - scf.for %i = %c0 to %c1 step %step { - %0 = arith.constant 42.0 : f32 - %1 = memref.load %arg0[%i] : memref - %2 = arith.addf %0, %1 : f32 - memref.store %2, %arg0[%i] : memref - } - - func.return - }""" - - arg = np.array([1.0, 2.0, 3.0], dtype=np.float32) - [res] = r.execute(module, 'addtensor', [arg], inout=[0]) - self.assertTrue( - np.array_equal(res, np.array([43.0, 44.0, 45.0], dtype=np.float32))) - - def testTensorReturn(self): - module = """ - func.func @returntensor(%arg0: memref) -> memref<4xf32> { - %out = memref.alloc() : memref<4xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 4 : index - %step = arith.constant 1 : index - - scf.for %i = %c0 to %c1 step %step { - %0 = memref.load %arg0[%i] : memref - memref.store %0, %out[%i] : memref<4xf32> - } - - return %out : memref<4xf32> - }""" - - arg = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) - [res] = r.execute(module, 'returntensor', [arg]) - - self.assertTrue( - np.array_equal(res, np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))) - -if __name__ == '__main__': - absltest.main() From af3d28d4bd35533b0c57a93a732319789afaaf1e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 28 Nov 2023 17:01:31 -0800 Subject: [PATCH 155/381] Enable Linux Arm64 GitHub presubmit This adds the configs needed for us to be able to run a Linux Arm64 GitHub presubmit on incoming PRs. It runs tests by cross-compiling test binaries on remote Linux x86 VMs using RBE and then executing the built test binaries on the host Arm64 VM. On average, this presubmit should take about ~30 mins and is ~83% faster than the current GitHub Linux Arm64 presubmit (https://github.com/tensorflow/tensorflow/actions/workflows/arm-ci.yml). I have changed the name of the cross-compile env file to add the Python version it runs and to be consistent with other env names. PiperOrigin-RevId: 586144808 --- ...oss_compile => continuous_linux_arm64_cpu_py311_cross_compile} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ci/official/envs/{continuous_linux_arm64_cpu_cross_compile => continuous_linux_arm64_cpu_py311_cross_compile} (100%) diff --git a/ci/official/envs/continuous_linux_arm64_cpu_cross_compile b/ci/official/envs/continuous_linux_arm64_cpu_py311_cross_compile similarity index 100% rename from ci/official/envs/continuous_linux_arm64_cpu_cross_compile rename to ci/official/envs/continuous_linux_arm64_cpu_py311_cross_compile From 87959ee5747214979bb08fd765c477b65781b902 Mon Sep 17 00:00:00 2001 From: shawnwang18 <35983922+shawnwang18@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:34:02 -0800 Subject: [PATCH 156/381] PR #7136: [XLA:GPU] Add `Allocate` command to command buffer Imported from GitHub PR https://github.com/openxla/xla/pull/7136 This PR add the `Allocate` command to command buffer. The `Allocate` command is constructed with the pointer to `BufferAllocation`. The allocation will be performed when the command is recorded, the allocated address will be tracked by command buffer runtime through allocation index. For the consumer commands who want to access the allocated buffer, the record parameter buffer address should be provided as se::DeviceMemoryBase with special address (LAZY_ALLOCATE_ADDRESS_MARKER) and non-zero size, and it can be created with API se::DeviceMemory<>::MakeLazyAllocAddressFromByteSize(byte_length); Below is an example how to construct command sequences that access buffers allocated inside command buffer: ``` BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; commands.Emplace(&alloc_b); commands.Emplace(slice_b, slice_a, byte_length); commands.Emplace(slice_c, slice_b, byte_length); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); // Prepare arguments: a=42, b=0 se::DeviceMemory a = executor->AllocateArray(length, 0); stream.ThenMemset32(&a, 42, byte_length); se::DeviceMemory b = se::DeviceMemory::MakeLazyAllocAddressFromByteSize(byte_length); se::DeviceMemory c = executor->AllocateArray(length, 0); BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator()); ServiceExecutableRunOptions run_options; Thunk::ExecuteParams params(run_options, allocations, &stream, {}); // Execute command buffer thunk and verify that it copied the memory. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); ``` For CUDA implementation, the command has no update parameters, which means that when the command is added to command buffer, the address range allocated for this command is fixed across command buffer launches. The `Allocation` command is only implemented for CUDA platform Copybara import of the project: -- d2cdd0423fe5947e06d8d7b8d5192a8845b2beae by Shawn Wang : Add Allocate command to command buffer Merging this change closes #7136 PiperOrigin-RevId: 586150993 --- .../xla/xla/backends/interpreter/executor.h | 6 ++ .../xla/xla/service/gpu/buffer_allocations.cc | 23 +++++ .../xla/xla/service/gpu/buffer_allocations.h | 14 +++ .../xla/xla/service/gpu/gpu_executable.cc | 5 +- .../gpu/runtime3/command_buffer_cmd.cc | 36 ++++++-- .../service/gpu/runtime3/command_buffer_cmd.h | 23 +++++ .../gpu/runtime3/command_buffer_thunk.cc | 10 +++ .../gpu/runtime3/command_buffer_thunk.h | 6 ++ .../gpu/runtime3/command_buffer_thunk_test.cc | 57 ++++++++++++ .../xla/xla/stream_executor/command_buffer.cc | 9 ++ .../xla/xla/stream_executor/command_buffer.h | 12 +++ .../xla/stream_executor/cuda/cuda_driver.cc | 86 +++++++++++++++++++ .../xla/xla/stream_executor/device_memory.h | 26 +++++- .../stream_executor/gpu/gpu_command_buffer.cc | 50 +++++++++++ .../stream_executor/gpu/gpu_command_buffer.h | 11 ++- .../xla/xla/stream_executor/gpu/gpu_driver.h | 40 +++++++++ .../xla/stream_executor/gpu/gpu_executor.h | 2 + .../stream_executor_internal.h | 10 +++ 18 files changed, 413 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index a2c6fc13b360f9..683d83a15d58e0 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -48,9 +48,11 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { XlaInterpreterExecutor() = default; tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { + device_ordinal_ = device_ordinal; return ::tsl::OkStatus(); } + int device_ordinal() const override { return device_ordinal_; }; tsl::Status GetKernel(const MultiKernelLoaderSpec &spec, Kernel *kernel) override { return tsl::errors::Unimplemented("Not Implemented"); @@ -182,6 +184,10 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } private: + // The device ordinal value that this executor was initialized with; recorded + // for use in getting device metadata. Immutable post-initialization. + int device_ordinal_; + DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); tsl::StatusOr AllocateOutputBuffer(const xla::Shape &shape); diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.cc b/third_party/xla/xla/service/gpu/buffer_allocations.cc index 0b1b879d28707f..a1b8a7214f62b9 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/buffer_allocations.cc @@ -79,5 +79,28 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( buffer_slice.size()); } +se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice, + const se::CommandBuffer* command_buffer) const { + se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); + CHECK_LE(buffer_slice.offset(), base.size()); + CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); + + if (base.is_external_allocation_marker()) { + auto cmd_buffer_base = command_buffer->GetAllocationAddress( + buffer_slice.allocation()->index()); + CHECK(cmd_buffer_base.ok()) + << "Get allocation address from command_buffer failed"; + CHECK(!cmd_buffer_base.value().is_null()) + << "Allocation is not yet allocated by command buffer for slice: " + << buffer_slice.ToString(); + return cmd_buffer_base.value(); + } + + return se::DeviceMemoryBase( + static_cast(base.opaque()) + buffer_slice.offset(), + buffer_slice.size()); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.h b/third_party/xla/xla/service/gpu/buffer_allocations.h index 04e76c5157ee13..e805a66053c0cb 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.h +++ b/third_party/xla/xla/service/gpu/buffer_allocations.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/statusor.h" +#include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" @@ -68,6 +69,15 @@ class BufferAllocations { se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; + // For buffers that are lazily allocated through command buffer, this is + // indicated by specifying a special buffer address + // (LAZY_ALLOCATE_ADDRESS_MARKER), the real buffer address is tracked in + // CommandBuffer, this API will fetches the real address from CommandBuffer + // runtime. + se::DeviceMemoryBase GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice, + const se::CommandBuffer* command_buffer) const; + // Tears down all buffers allocated by this object that are not in // `live_addresses`. Status TearDown(const std::set& live_addresses, @@ -89,6 +99,10 @@ class BufferAllocations { // An array of device pointers that stores the address of each buffer // indexed by Index. Each element can point to a temporary buffer, an // input buffer, or nullptr if no buffer is needed for that Index. + + // a nullptr buffer with non-zero size buffer is assumed to be lazily + // allocated buffer, and will be allocated through command buffer Allocate + // command during runtime. std::vector buffers_; int device_ordinal_; se::DeviceMemoryAllocator* memory_allocator_; diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 1f9903cb3f6b94..05283bbdc2b8a5 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -936,9 +936,8 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) { GpuExecutable::GpuExecutable( std::shared_ptr hlo_module, std::string asm_text, std::vector binary, std::vector constants, - se::GpuComputeCapability gpu_version, - absl::string_view module_name, Shape xla_output_shape, - std::vector allocations, + se::GpuComputeCapability gpu_version, absl::string_view module_name, + Shape xla_output_shape, std::vector allocations, absl::flat_hash_map output_info, std::unique_ptr gpu_runtime_executable) : Executable(std::move(hlo_module)), diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 97c6e20b551ede..63841534daeed4 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -130,7 +130,8 @@ Status LaunchCmd::Record(const RecordParams& params, absl::InlinedVector buffers; for (const BufferAllocation::Slice& arg : args_) { - se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); + se::DeviceMemoryBase buf = + params.buffer_allocations->GetDeviceAddress(arg, command_buffer); VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); buffers.push_back(buf); } @@ -164,8 +165,10 @@ Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, se::CommandBuffer* command_buffer) { VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << ", src=" << src_ << ", num_bytes=" << num_bytes_; - se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dst_); - se::DeviceMemoryBase src = params.buffer_allocations->GetDeviceAddress(src_); + se::DeviceMemoryBase dst = + params.buffer_allocations->GetDeviceAddress(dst_, command_buffer); + se::DeviceMemoryBase src = + params.buffer_allocations->GetDeviceAddress(src_, command_buffer); return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); } @@ -204,6 +207,24 @@ CommandBufferCmd::Slices IfCmd::slices() { return {slices.begin(), slices.end()}; } +//===----------------------------------------------------------------------===// +// AllocateCmd +//===----------------------------------------------------------------------===// + +AllocateCmd::AllocateCmd(BufferAllocation* allocation) + : allocation_(allocation) {} + +Status AllocateCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + // Memory allocation address is returned on graph creation, and there is no + // update operation + VLOG(5) << "AllocationCmd: index=" << allocation_->index(); + return command_buffer->Allocate(se::CommandBuffer::AllocIndexSize{ + allocation_->index(), static_cast(allocation_->size())}); +} + +CommandBufferCmd::Slices AllocateCmd::slices() { return {}; } + //===----------------------------------------------------------------------===// // GemmCmd //===----------------------------------------------------------------------===// @@ -235,12 +256,11 @@ Status GemmCmd::Record(const RecordParams& params, se::DeviceMemoryBase workspace(nullptr, 0); se::DeviceMemoryBase lhs = - params.buffer_allocations->GetDeviceAddress(lhs_buffer_); + params.buffer_allocations->GetDeviceAddress(lhs_buffer_, command_buffer); se::DeviceMemoryBase rhs = - params.buffer_allocations->GetDeviceAddress(rhs_buffer_); - se::DeviceMemoryBase out = - params.buffer_allocations->GetDeviceAddress(output_buffer_); - + params.buffer_allocations->GetDeviceAddress(rhs_buffer_, command_buffer); + se::DeviceMemoryBase out = params.buffer_allocations->GetDeviceAddress( + output_buffer_, command_buffer); TF_ASSIGN_OR_RETURN( auto nested_buffer, se::CommandBuffer::Trace(params.executor, [&](se::Stream* stream) { diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 8477c093c6e6cb..ab894bb7952747 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -53,6 +53,11 @@ class CommandBufferCmd { // buffer. For example when we emit command buffer cmd sequence from an HLO // module, we only know the buffer slices required for HLO operations, but the // concrete device pointers become available only at run time. + // + // For allocations that performed through command buffer Allocate command, the + // target addresses are tracked by command buffer runtime. To record command + // that consumes buffers allocated inside command buffer, user should specify + // the target address as se::DeviceMemoryBase{nullptr, size}. struct RecordParams { se::StreamExecutor* executor; const BufferAllocations* buffer_allocations; @@ -205,6 +210,24 @@ class IfCmd : public CommandBufferCmd { CommandBufferCmdSequence then_cmds_; }; +//===----------------------------------------------------------------------===// +// AllocateCmd +//===----------------------------------------------------------------------===// + +class AllocateCmd : public CommandBufferCmd { + public: + explicit AllocateCmd(BufferAllocation* allocation); + + // After calling this function, the allocated memory address is updated to + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation* allocation_; +}; + //===----------------------------------------------------------------------===// // GemmCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index 9dca8e0be18fb7..b621a3f9c0f38d 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/runtime3/command_buffer_thunk.h" +#include #include #include "absl/synchronization/mutex.h" @@ -104,4 +105,13 @@ CommandBufferThunk::GetOrCreateCommandBuffer(se::StreamExecutor* executor) { return &emplaced.first->second; } +StatusOr CommandBufferThunk::GetLazyAllocationAddress( + const ExecuteParams& params, int64_t index) { + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, + GetOrCreateCommandBuffer(executor)); + absl::MutexLock lock(&cmd_buffer->mutex); + return cmd_buffer->command_buffer.GetAllocationAddress(index); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h index dada9c27fbb2ec..47d7298545a1ff 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h @@ -73,6 +73,12 @@ class CommandBufferThunk : public Thunk { StatusOr GetOrCreateCommandBuffer( se::StreamExecutor* executor); + // Return the allocation address that was lazilly allocated inside command + // buffer. This API is required when the buffers are allocated inside command + // buffer but will be consumed by non-command buffer operations. + StatusOr GetLazyAllocationAddress( + const ExecuteParams& params, int64_t index); + // Command sequence that initializes command buffers on each executor. CommandBufferCmdSequence commands_; diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 7bd5e1b1aa4395..66a9c17d92259e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -104,6 +104,63 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { ASSERT_EQ(dst, std::vector(4, 42)); } +// This test does the following operations: +// 1. Allocates memory region "a" and "c" outside command buffer. +// 2. Allocates memory region "b" inside command buffer. +// 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer. +// 4. MemCopyDEviceToDevice from "b" to "c" inside command buffer. +// 5. Verify that region "c" has the same content as "a". +TEST(CommandBufferThunkTest, MemallocCmd) { + se::StreamExecutor* executor = CudaExecutor(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + // Prepare arguments: + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(&alloc_b); + commands.Emplace(slice_b, slice_a, byte_length); + commands.Emplace(slice_c, slice_b, byte_length); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + stream.ThenMemset32(&a, 42, byte_length); + + se::DeviceMemory b = + se::DeviceMemory::MakeExternalAllocationFromByteSize( + byte_length); + se::DeviceMemory c = executor->AllocateArray(length, 0); + BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator()); + + ServiceExecutableRunOptions run_options; + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), allocations.GetMutableDeviceAddress(2), + byte_length); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + TEST(CommandBufferThunkTest, LaunchCmd) { se::StreamExecutor* executor = CudaExecutor(); diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 28fc7533fe2a49..0e901d98312e37 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -116,6 +116,15 @@ tsl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return implementation_->MemcpyDeviceToDevice(dst, src, size); } +tsl::Status CommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { + return implementation_->Allocate(alloc); +} + +tsl::StatusOr CommandBuffer::GetAllocationAddress( + int64_t index) const { + return implementation_->GetAllocationAddress(index); +} + tsl::Status CommandBuffer::If(StreamExecutor* executor, DeviceMemory pred, Builder then_builder) { return implementation_->If(executor, pred, std::move(then_builder)); diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 0a69a463167e30..4935fc964576ff 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -53,6 +53,10 @@ class CommandBuffer { public: // Builder constructs nested command buffers owned by a parent command buffer. using Builder = std::function; + struct AllocIndexSize { + int64_t index; + uint64_t size; + }; ~CommandBuffer(); CommandBuffer(CommandBuffer&&); @@ -168,6 +172,14 @@ class CommandBuffer { //--------------------------------------------------------------------------// + // Adds a device memory allocation command to the command buffer, allocated + // address is tracked by command buffer runtime. + tsl::Status Allocate(AllocIndexSize alloc); + + // Get the device address for allocations previously allocated through + // Allocate command. + tsl::StatusOr GetAllocationAddress(int64_t index) const; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. tsl::Status Finalize(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index e9effc95c21025..c1bc32abdcb812 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -882,6 +883,91 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, return ::tsl::OkStatus(); } +static CUmemAccess_flags ToCudaMemAccessFlags( + GpuDriver::MemAccessFlags access_flags) { + switch (access_flags) { + case GpuDriver::MemAccessFlags::kNone: + return CU_MEM_ACCESS_FLAGS_PROT_NONE; + case GpuDriver::MemAccessFlags::kRead: + return CU_MEM_ACCESS_FLAGS_PROT_READ; + case GpuDriver::MemAccessFlags::kReadWrite: + return CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } +} + +static CUmemLocationType ToCudaLocationType( + GpuDriver::MemLocationType location_type) { + switch (location_type) { + case GpuDriver::MemLocationType::kInvalid: + return CU_MEM_LOCATION_TYPE_INVALID; + case GpuDriver::MemLocationType::kDevice: + return CU_MEM_LOCATION_TYPE_DEVICE; + case GpuDriver::MemLocationType::kHost: + return CU_MEM_LOCATION_TYPE_HOST; + case GpuDriver::MemLocationType::kHostNuma: + return CU_MEM_LOCATION_TYPE_HOST_NUMA; + case GpuDriver::MemLocationType::kHostNumaCurrent: + return CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT; + } +} + +static CUmemAllocationType ToCudaAllocationType( + GpuDriver::MemAllocationType alocation_type) { + switch (alocation_type) { + case GpuDriver::MemAllocationType::kInvalid: + return CU_MEM_ALLOCATION_TYPE_INVALID; + case GpuDriver::MemAllocationType::kPinned: + return CU_MEM_ALLOCATION_TYPE_PINNED; + } +} + +/*static*/ tsl::Status GpuDriver::GraphAddMemAllocNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, + GpuDriver::MemAccessFlags access_flags, + GpuDriver::MemLocationType location_type, int device_id, + GpuDriver::MemAllocationType allocation_type, uint64_t size, + CUdeviceptr* d_ptr, uint64_t max_pool_size) { + CUDA_MEM_ALLOC_NODE_PARAMS params; + memset(¶ms, 0, sizeof(params)); + + CUmemLocation mem_location; + mem_location.id = device_id; + mem_location.type = ToCudaLocationType(location_type); + + CUmemAccessDesc mem_desc; + mem_desc.flags = ToCudaMemAccessFlags(access_flags); + mem_desc.location = mem_location; + + CUmemPoolProps mem_pool_props; + mem_pool_props.allocType = ToCudaAllocationType(allocation_type); + mem_pool_props.handleTypes = CU_MEM_HANDLE_TYPE_NONE; + mem_pool_props.location = mem_location; + mem_pool_props.maxSize = max_pool_size; + + params.accessDescCount = 1; + params.bytesize = size; + params.accessDescs = &mem_desc; + params.poolProps = mem_pool_props; + + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddMemAllocNode(node, graph, deps.data(), deps.size(), ¶ms), + "Failed to add memory allocation node to a CUDA graph"); + + VLOG(2) << "Add MemAllocNode to a graph " << graph << " size " << size + << " address " << reinterpret_cast(params.dptr); + + *d_ptr = params.dptr; + return ::tsl::OkStatus(); +} + +/*static*/ tsl::StatusOr> +GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { + CUDA_MEM_ALLOC_NODE_PARAMS params; + RETURN_IF_CUDA_RES_ERROR(cuGraphMemAllocNodeGetParams(node, ¶ms), + "Failed to get memory allocation node parameter"); + return std::pair{params.dptr, params.bytesize}; +} + /* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index f69fe8148c67a0..bd8957ab14b648 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -26,6 +26,8 @@ limitations under the License. #include +#include + #include "xla/stream_executor/platform/port.h" namespace stream_executor { @@ -33,6 +35,11 @@ namespace stream_executor { class DeviceMemoryAllocator; class StreamExecutor; +// This special address is used to indicate that the allocation is not ready +// when constructing DeviceMemory object, and will be lazily allocated by +// an external allocator (e.g. command buffer for GPU backend). +inline constexpr uintptr_t kExternalAllocationMarker = 0xDEADBEEF; + // void*-analogous device memory allocation. For the typed variation, see // DeviceMemory. // @@ -54,6 +61,11 @@ class DeviceMemoryBase { // Returns whether the backing memory is the null pointer. // A `== nullptr` convenience method is also provided. bool is_null() const { return opaque_ == nullptr; } + + bool is_external_allocation_marker() const { + return reinterpret_cast(opaque_) == kExternalAllocationMarker; + } + bool operator==(std::nullptr_t other) const { return is_null(); } bool operator!=(std::nullptr_t other) const { return !is_null(); } @@ -96,7 +108,13 @@ class DeviceMemoryBase { } private: - void *opaque_; // Platform-dependent value representing allocated memory. + // Platform-dependent value representing allocated memory. + // + // User may also constructs the object with `kExternalAllocationMarker` + // address and non-zero size, which indicates the case that buffer is + // allocated externally (for Gpu backends we use it to allocate memory via + // command buffer APIs). + void *opaque_; uint64_t size_; // Size in bytes of this allocation. uint64_t payload_ = 0; // Payload data associated with this allocation. }; @@ -136,6 +154,12 @@ class DeviceMemory final : public DeviceMemoryBase { return DeviceMemory(opaque, bytes); } + static DeviceMemory MakeExternalAllocationFromByteSize( + uint64_t bytes) { + return DeviceMemory( + reinterpret_cast(kExternalAllocationMarker), bytes); + } + // Resets the DeviceMemory data, in MakeFromByteSize fashion. // This simply clobbers the prior values. void ResetFromByteSize(void *opaque, uint64_t bytes) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 7737c73aa29058..7d978364f64944 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -273,6 +273,56 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } +tsl::Status GpuCommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + if (state_ == State::kCreate) { + Dependencies deps = GetDependencies(); + GpuGraphNodeHandle* node = &nodes_.emplace_back(); + + GpuDevicePtr ptr; + TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemAllocNode( + node, graph_, absl::MakeSpan(deps), + GpuDriver::MemAccessFlags::kReadWrite, + GpuDriver::MemLocationType::kDevice, parent_->device_ordinal(), + GpuDriver::MemAllocationType::kPinned, alloc.size, &ptr)); + // For CUDA impl, VA range is reserved when adding memory allocation node. + CHECK(ptr) << "CUDA graph memory allocation node returned nullptr"; + + VLOG(2) << "Setting device memory base with opaque pointer " + << reinterpret_cast(ptr) + << " device ordinal: " << parent_->device_ordinal(); + allocations_map_[alloc.index] = + DeviceMemoryBase{reinterpret_cast(ptr), alloc.size}; + return tsl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Memory allocation node implemented through CUDA graph does not allocate + // new memory region on update, just return the memory region allocated + // during the create step. + TF_ASSIGN_OR_RETURN( + AllocationResult params, + GpuDriver::GraphGetMemAllocNodeParams(nodes_[update_state_.node_idx])); + update_state_.node_idx++; + allocations_map_[alloc.index] = + DeviceMemoryBase{reinterpret_cast(params.first), params.second}; + return tsl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +tsl::StatusOr GpuCommandBuffer::GetAllocationAddress( + int64_t index) const { + if (allocations_map_.contains(index)) { + return allocations_map_.at(index); + } else { + return absl::InternalError( + absl::StrCat("Allocation is not yet allocated: ", index)); + } +} + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 4ac4329e53df02..e817c9fd07728d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -23,7 +23,6 @@ limitations under the License. #include #include -#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" @@ -60,6 +59,11 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) override; + tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) override; + + tsl::StatusOr GetAllocationAddress( + int64_t index) const override; + tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) override; @@ -179,6 +183,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { std::vector command_buffers; }; + using AllocationResult = std::pair; + tsl::StatusOr> CreateConditionalHandles(size_t num_handles); @@ -212,6 +218,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // error. tsl::Status CheckPrimary(); + // Keep tracks of allocations that is performed by allocation command. + absl::flat_hash_map allocations_map_; + // Returns OK status if the number of command buffers is equal to the expected // one, otherwise returns internal error. tsl::Status CheckNumCommandBuffers( diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 109a64ace754fa..9db00ac9f54b40 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "absl/types/span.h" @@ -491,6 +492,45 @@ class GpuDriver { unsigned int block_dim_z, unsigned int shared_mem_bytes, void** kernel_params, void** extra); + // Memory protection flags for mappings. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gfba87b8c4a8cd091554d8e2c3fc9b40a + enum class MemAccessFlags { + kNone, + kRead, + kReadWrite, + }; + + // Specifies the type of memory location + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g75cfd5b9fa5c1c6ee2be2547bfbe882e + enum class MemLocationType { + kInvalid, + kDevice, + kHost, + kHostNuma, + kHostNumaCurrent, + }; + + // The memory allocation type + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g7ed3482e0df8712d79a99bcb3bc4a95b + enum class MemAllocationType { + kInvalid, + kPinned, + }; + + // Creates a memory allocation node and adds it to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g73a351cb71b2945a0bcb913a93f69ec9 + static tsl::Status GraphAddMemAllocNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, MemAccessFlags access_flags, + MemLocationType location_type, int device_id, + MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, + uint64_t max_pool_size = 0); + + // Fetch memory allocation node's allocated address; + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gee2c7d66d3d96b1470c1d1a769f250a2 + static tsl::StatusOr> + GraphGetMemAllocNodeParams(GpuGraphNodeHandle node); + // Creates a memcpy node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g674da6ab54a677f13e0e0e8206ff5073 static tsl::Status GraphAddMemcpyD2DNode(GpuContext* context, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 6bc5ee304498f7..1055c167c9d946 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -101,6 +101,8 @@ class GpuExecutor : public internal::StreamExecutorInterface { tsl::Status Init(int device_ordinal, DeviceOptions device_options) override; + int device_ordinal() const override { return device_ordinal_; }; + tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, Kernel* kernel) override; diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index e21adebc37665e..de0b82e9d5e1c8 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -141,6 +141,14 @@ class CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) = 0; + // Adds a device memory allocation node to the command buffer. + virtual tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) = 0; + + // Get the device address for allocations performed through command buffer + // Allocate command. + virtual tsl::StatusOr GetAllocationAddress( + int64_t index) const = 0; + // For all conditional command APIs defined below, nested command buffers // constructed for conditional branches owned by *this and should never be // finalized or updated inside builders. @@ -280,6 +288,8 @@ class StreamExecutorInterface { return std::nullopt; } + virtual int device_ordinal() const { return -1; } + virtual tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, Kernel* kernel) { return absl::UnimplementedError("Not Implemented"); From 14f8066bcae22834baab7daff706f3a9ad66b33a Mon Sep 17 00:00:00 2001 From: Carlos Guia Date: Tue, 28 Nov 2023 18:26:13 -0800 Subject: [PATCH 157/381] Don't upload macOS Arm64 build artifacts. These are not ready yet. PiperOrigin-RevId: 586160312 --- ci/official/envs/nightly_macos_arm64_py310 | 2 -- ci/official/envs/nightly_macos_arm64_py311 | 2 -- ci/official/envs/nightly_macos_arm64_py312 | 2 -- ci/official/envs/nightly_macos_arm64_py39 | 2 -- 4 files changed, 8 deletions(-) diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310 index 4c950b626e1daa..2c902aef38764a 100644 --- a/ci/official/envs/nightly_macos_arm64_py310 +++ b/ci/official/envs/nightly_macos_arm64_py310 @@ -5,8 +5,6 @@ TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311 index 6d7699640758a3..61eb97792ce620 100644 --- a/ci/official/envs/nightly_macos_arm64_py311 +++ b/ci/official/envs/nightly_macos_arm64_py311 @@ -5,8 +5,6 @@ TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312 index 6a530907d13347..41371ad4befbef 100644 --- a/ci/official/envs/nightly_macos_arm64_py312 +++ b/ci/official/envs/nightly_macos_arm64_py312 @@ -5,8 +5,6 @@ TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 -TFCI_UPLOAD_WHL_GCS_ENABLE= -TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39 index f969f20d782ee7..8f3c172065b974 100644 --- a/ci/official/envs/nightly_macos_arm64_py39 +++ b/ci/official/envs/nightly_macos_arm64_py39 @@ -7,8 +7,6 @@ TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos/arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION From 4f1d4bc65b67bb3a10dd39ba4557de15a5b9cb03 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 18:46:20 -0800 Subject: [PATCH 158/381] Update XLA GPU config with NVCC compiler. PiperOrigin-RevId: 586163360 --- .bazelrc | 29 ++---------------------- third_party/xla/.bazelrc | 29 ++---------------------- third_party/xla/third_party/tsl/.bazelrc | 29 ++---------------------- 3 files changed, 6 insertions(+), 81 deletions(-) diff --git a/.bazelrc b/.bazelrc index 4ebb3f0ebb2e26..42330a369c9f6c 100644 --- a/.bazelrc +++ b/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 4ebb3f0ebb2e26..42330a369c9f6c 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 4ebb3f0ebb2e26..42330a369c9f6c 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base From 8509ee8d8018d021520072aa014ef5d0dd6c5076 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 18:48:02 -0800 Subject: [PATCH 159/381] For convolutions that can be interpreted as dots, rely on DotHandler to generate sharding strategies. For those that cannot be, we rely on pre-existing convolution handling code. PiperOrigin-RevId: 586163568 --- .../auto_sharding_dot_handler.cc | 172 ++++++++++++++---- 1 file changed, 133 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 159e3de1625a86..d574a68758142c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -33,7 +33,9 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" @@ -244,6 +246,7 @@ class DotHandler : public HandlerBase { const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph), + is_dot_(true), space_base_dim_( ins->dot_dimension_numbers().lhs_batch_dimensions_size()), lhs_con_dims_( @@ -258,6 +261,38 @@ class DotHandler : public HandlerBase { CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size()); } + DotHandler( + std::unique_ptr& strategy_group, StrategyMap& strategy_map, + const HloConvolutionInstruction* ins, + const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const CallGraph& call_graph) + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, + option, call_graph), + is_dot_(false), + space_base_dim_(-1) { + CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); + + for (auto dim_idx : conv_as_dot_dims.batch_dims) { + if (dim_idx.lhs >= 0) lhs_batch_dims_.Add(dim_idx.lhs); + if (dim_idx.rhs >= 0) rhs_batch_dims_.Add(dim_idx.rhs); + } + + for (auto dim_idx : conv_as_dot_dims.contracting_dims) { + if (dim_idx.lhs >= 0) lhs_con_dims_.Add(dim_idx.lhs); + if (dim_idx.rhs >= 0) rhs_con_dims_.Add(dim_idx.rhs); + } + + for (auto dim_idx : conv_as_dot_dims.lhs_non_contracting_dims) { + if (dim_idx.lhs >= 0) lhs_space_dims_.Add(dim_idx.lhs); + } + + for (auto dim_idx : conv_as_dot_dims.rhs_non_contracting_dims) { + if (dim_idx.rhs >= 0) rhs_space_dims_.Add(dim_idx.rhs); + } + } + void SplitLhsSpaceRhsSpace() { auto func = [this](const Enumeration& e) { const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}}; @@ -265,10 +300,14 @@ class DotHandler : public HandlerBase { std::string name = absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = DimMap{ - {space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = + DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_space_dims_.size(), rhs_space_dims_.size()); @@ -280,9 +319,11 @@ class DotHandler : public HandlerBase { {lhs_space_dims_[e.j], e.mesh_dims[1]}}; std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = - DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + e.j, e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + e.j, e.mesh_dims[1]}}; + } MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_); }; EnumerateHalf(func, lhs_space_dims_.size(), lhs_space_dims_.size()); @@ -294,11 +335,16 @@ class DotHandler : public HandlerBase { {rhs_space_dims_[e.j], e.mesh_dims[1]}}; std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = DimMap{ - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = + DimMap{{space_base_dim_ + + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[0]}, + {space_base_dim_ + + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + } MaybeAppend(name, {}, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func, rhs_space_dims_.size(), rhs_space_dims_.size()); @@ -315,8 +361,10 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, {lhs_con_dims_[e.j], e.mesh_dims[1]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap out_dim_map = - DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; + } auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); @@ -337,9 +385,13 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, {rhs_con_dims_[e.j], e.mesh_dims[0]}}; const DimMap lhs_dim_map = {{lhs_con_dims_[e.j], e.mesh_dims[0]}}; - const DimMap out_dim_map = DimMap{ - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = + DimMap{{space_base_dim_ + + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + } auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); @@ -359,7 +411,10 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_batch_dims_[e.i], e.j}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.i], e.j}}; std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", e.i, e.j); - const DimMap out_dim_map = DimMap{{e.i, e.j}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.i, e.j}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_batch_dims_.size(), device_mesh_.num_dimensions()); @@ -377,8 +432,10 @@ class DotHandler : public HandlerBase { {rhs_batch_dims_[1], e.mesh_dims[1]}}; std::string name = absl::StrFormat("Sb = Sb x Sb @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = - DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func, lhs_batch_dims_.size(), lhs_batch_dims_.size()); @@ -395,8 +452,11 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[1]}, {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap out_dim_map = DimMap{ - {e.j, e.mesh_dims[0]}, {space_base_dim_ + e.i, e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.j, e.mesh_dims[0]}, + {space_base_dim_ + e.i, e.mesh_dims[1]}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_space_dims_.size(), lhs_batch_dims_.size()); @@ -413,10 +473,14 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap lhs_dim_map = {{lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap out_dim_map = { - {e.j, e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = + DimMap{{e.j, e.mesh_dims[0]}, + {space_base_dim_ + + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, rhs_space_dims_.size(), lhs_batch_dims_.size()); @@ -434,7 +498,10 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + } auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); @@ -459,7 +526,10 @@ class DotHandler : public HandlerBase { {lhs_con_dims_[e.j], e.mesh_dims[1]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}, {rhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap out_dim_map = DimMap{}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{}; + } auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0], @@ -480,7 +550,10 @@ class DotHandler : public HandlerBase { e.mesh_dims[0], e.mesh_dims[0]); const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}}; - const DimMap out_dim_map = DimMap{}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{}; + } double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); @@ -511,7 +584,10 @@ class DotHandler : public HandlerBase { continue; } std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); - const DimMap out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; + } MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); } @@ -529,7 +605,10 @@ class DotHandler : public HandlerBase { } std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", mesh_dim, mesh_dim); - const DimMap out_dim_map = DimMap{}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{}; + } auto communication_cost_fn = [this, mesh_dim]( const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); @@ -551,7 +630,10 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_batch_dims_[i], mesh_dim}}; std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); - const DimMap out_dim_map = DimMap{{i, mesh_dim}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{i, mesh_dim}}; + } MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_); } @@ -649,12 +731,13 @@ class DotHandler : public HandlerBase { } // Dimension information + bool is_dot_; int64_t space_base_dim_; tsl::protobuf::RepeatedField lhs_space_dims_, rhs_space_dims_; - const tsl::protobuf::RepeatedField& lhs_con_dims_; - const tsl::protobuf::RepeatedField& rhs_con_dims_; - const tsl::protobuf::RepeatedField& lhs_batch_dims_; - const tsl::protobuf::RepeatedField& rhs_batch_dims_; + tsl::protobuf::RepeatedField lhs_con_dims_; + tsl::protobuf::RepeatedField rhs_con_dims_; + tsl::protobuf::RepeatedField lhs_batch_dims_; + tsl::protobuf::RepeatedField rhs_batch_dims_; }; // Register strategies for dot instructions. @@ -862,9 +945,20 @@ Status HandleConv(std::unique_ptr& strategy_group, strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph); - TF_RETURN_IF_ERROR(handler.RegisterStrategies()); + auto conv_as_dot_dims = + dot_as_convolution_util::ParseConvolutionDimsInfo(ins); + if (conv_as_dot_dims.conv_spatial_dims.empty()) { + DotHandler handler(strategy_group, strategy_map, + Cast(ins), conv_as_dot_dims, + cluster_env, batch_map, option, call_graph); + TF_RETURN_IF_ERROR(handler.RegisterStrategies()); + + } else { + ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, + batch_map, option, call_graph); + TF_RETURN_IF_ERROR(handler.RegisterStrategies()); + } + return OkStatus(); } From 192718f5ebe6d35e040bcd61f29bc67a19d2a190 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 28 Nov 2023 18:53:23 -0800 Subject: [PATCH 160/381] [stream_executor] Record cond_builder before evaluating loop condition in a While command PiperOrigin-RevId: 586164354 --- .../xla/xla/stream_executor/command_buffer.cc | 19 ++++++++++++++-- .../xla/xla/stream_executor/command_buffer.h | 22 +++++++++++++++---- .../cuda/cuda_command_buffer_test.cc | 8 +++---- .../stream_executor/gpu/gpu_command_buffer.cc | 7 +++--- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 0e901d98312e37..0082b44e9e0b79 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -37,6 +37,11 @@ CommandBuffer::~CommandBuffer() = default; CommandBuffer::CommandBuffer(CommandBuffer&&) = default; CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; +void CommandBuffer::Deleter::operator()( + internal::CommandBufferInterface* impl) { + if (owned) delete impl; +} + /*static*/ tsl::StatusOr CommandBuffer::Create( StreamExecutor* executor, Mode mode) { TF_ASSIGN_OR_RETURN( @@ -91,14 +96,24 @@ internal::CommandBufferInterface* CommandBuffer::implementation() { return implementation_.get(); } -/*static*/ CommandBuffer CommandBuffer::Wrap( +/*static*/ CommandBuffer CommandBuffer::Create( std::unique_ptr implementation) { return CommandBuffer(std::move(implementation)); } +/*static*/ tsl::Status CommandBuffer::Build( + internal::CommandBufferInterface* implementation, + const CommandBuffer::Builder& builder) { + CommandBuffer command_buffer(implementation); + return builder(&command_buffer); +} + CommandBuffer::CommandBuffer( std::unique_ptr implementation) - : implementation_(std::move(implementation)) {} + : implementation_(implementation.release(), {/*owned=*/true}) {} + +CommandBuffer::CommandBuffer(internal::CommandBufferInterface* implementation) + : implementation_(implementation, {/*owned=*/false}) {} tsl::Status CommandBuffer::Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 4935fc964576ff..fde601a7d52df9 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -210,16 +210,30 @@ class CommandBuffer { const internal::CommandBufferInterface* implementation() const; internal::CommandBufferInterface* implementation(); - // Wraps platform-specific command buffer implementation into a top-level - // StreamExecutor command buffer. - static CommandBuffer Wrap( + // Creates a command buffer from a platform-specific command buffer + // implementation. + static CommandBuffer Create( std::unique_ptr implementation); + // An adaptor for a command buffer builder that records commands into the + // platform-specific implementation + static tsl::Status Build(internal::CommandBufferInterface* implementation, + const CommandBuffer::Builder& builder); + private: explicit CommandBuffer( std::unique_ptr implementation); - std::unique_ptr implementation_; + explicit CommandBuffer(internal::CommandBufferInterface* implementation); + + // A custom deleter to be able to construct command buffer that doesn't own + // underlying implementation (behaves like std::weak_ptr for implementation). + struct Deleter { + void operator()(internal::CommandBufferInterface*); + bool owned = true; + }; + + std::unique_ptr implementation_; CommandBuffer(const CommandBuffer&) = delete; void operator=(const CommandBuffer&) = delete; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 3f9b94a0e39ae2..bfa1052fb72f98 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -595,15 +595,15 @@ TEST(CudaCommandBufferTest, ConditionalWhile) { int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; - // Prepare arguments: a=1, b=0, loop_index=1, pred=true + // Prepare arguments: a=1, b=0, loop_index=0, pred=false DeviceMemory pred = executor->AllocateArray(1, 0); DeviceMemory loop_index = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); - static constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&loop_index, 1, sizeof(int32_t)); + static constexpr bool kFalse = false; + stream.ThenMemcpy(&pred, &kFalse, 1); + stream.ThenMemset32(&loop_index, 0, sizeof(int32_t)); stream.ThenMemset32(&a, 1, byte_length); stream.ThenMemZero(&b, byte_length); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 7d978364f64944..70ebf6b2d89ed6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -388,7 +388,7 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( auto command_buffer_impl = parent_->GetCommandBufferImplementation( nested, graphs[i], is_owned_graph); - auto command_buffer = CommandBuffer::Wrap(std::move(command_buffer_impl)); + auto command_buffer = CommandBuffer::Create(std::move(command_buffer_impl)); TF_RETURN_IF_ERROR(builders[i](&command_buffer, handles[i])); TF_RETURN_IF_ERROR(command_buffer.Finalize()); @@ -617,9 +617,8 @@ tsl::Status GpuCommandBuffer::While(StreamExecutor* executor, TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_while_condition)); } - // TODO(ezhulenev): We assume that `pred` already has a value that decides if - // we should go into the first loop iteration. Instead we should run - // `cond_builder` to update primary command buffer. + // Record condition commands into the parent command buffer. + TF_RETURN_IF_ERROR(CommandBuffer::Build(this, cond_builder)); auto set_cond_fn = [&](absl::Span handles) { return Launch(set_while_condition, ThreadDim(), BlockDim(), handles[0], From b67b97e58a920e40dd00efec8a65ca3308fec862 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 19:05:13 -0800 Subject: [PATCH 161/381] Implemented QuantizeConvolutionOpPattern and integration test, integration test is skipped for now because the full support for convolution is not implemented. Refactored the target op quantization pattern matching to compatible with dot-like ops. PiperOrigin-RevId: 586166124 --- .../quantization/stablehlo/passes/passes.h | 2 +- .../passes/quantize_composite_functions.cc | 208 ++++++++++-------- .../mlir/quantization/stablehlo/python/BUILD | 1 + .../integration_test/quantize_model_test.py | 95 ++++++++ .../quantize_model_test_base.py | 75 +++++++ .../tests/quantize_composite_functions.mlir | 80 ++++++- 6 files changed, 367 insertions(+), 94 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index 0b05069b265989..f6e7fa8f1e4a26 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -39,7 +39,7 @@ std::unique_ptr> CreateQuantizeWeightPass( // Creates an instance of the StableHLO dialect PrepareQuantize pass without any // arguments. Preset method of SRQ is set to the quantization option by default. std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization = true, int bit_width = 8); + bool enable_per_channel_quantization = false, int bit_width = 8); // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 5d391b3e858c16..dd558a08bc642c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -63,6 +63,7 @@ namespace { using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; using ::mlir::stablehlo::DynamicBroadcastInDimOp; using ::mlir::stablehlo::UniformQuantizeOp; @@ -106,11 +107,11 @@ bool HasFusibleQuantizationPattern(Operation& op) { // Returns dynamically broadcasted user op of an input op. Returns null if // the op is used multiple times or the user op is not dynamically broadcasted. // Dynamic shapes usually has the following pattern. In the example below, -// the input operand would be stablehlo.dot_general op, and return value would +// the input operand would be stablehlo.gemm_style op, and return value would // be stablehlo.add op. // // ``` -// %2 = stablehlo.dot_general(%0, %1) +// %2 = stablehlo.gemm_style(%0, %1) // %3 = shape.shape_of %2 // %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 // %5 = stablehlo.add %2, %4 @@ -175,8 +176,6 @@ func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, const auto entry_function_symbol_ref = xla_call_module_op->getAttrOfType(kEntryFuncAttrName); - // Don't match if there are no DotGeneralOp. - // if (target_func_op.getOps().empty()) return {}; return dyn_cast_or_null( symbol_table.lookup(entry_function_symbol_ref.getValue())); } @@ -238,102 +237,126 @@ class EntryFuncBodyQuantizationPattern { PatternRewriter& rewriter) const = 0; }; +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + // function must have input, filter, and optionally bias. + auto& operations = entry_func_op.getBody().front().getOperations(); + if (operations.size() != 2 && operations.size() != 3) { + return failure(); + } + if (!isa(operations.front())) { + return failure(); + } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { + LLVM_DEBUG(llvm::dbgs() + << "Currently gemm style ops quantization only supports static " + " shapes.\n"); + return failure(); + } else if (!isa( + operations.front().getResult(0).getType())) { + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + // Update the output type of the gemm_style op. + GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + const double filter_scale = + getElementTypeOrSelf(filter_type).cast().getScale(); + const double result_scale = input_scale * filter_scale; + + // Define the intermediate output type, which is an i32 quantized type. + // This is intermediate because the final output type of the entry_func_op + // should be an i8 quantized type. + const UniformQuantizedType gemm_style_quantized_element_type = + CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), + *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + Value gemm_style_op_result = gemm_style_op->getResult(0); + auto gemm_style_op_result_type = + gemm_style_op_result.getType().cast(); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + const TensorType new_gemm_style_op_result_type = + gemm_style_op_result_type.cloneWith(gemm_style_shape, + gemm_style_quantized_element_type); + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation& next_op = *gemm_style_op->getNextNode(); + // If an op is used multiple times, do not apply quantization of fused + // patterns to prevent removal of dependee ops. + const bool should_quantize_without_fusion = + HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && + !gemm_style_op->hasOneUse(); + + // TODO: b/307620428 - Add support for dynamic shapes. + if (should_quantize_without_fusion || !isa(next_op)) { + // no bias + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + return; + } + // bias fusion + Value bias_op = next_op.getOperand(1); + Value add_op_result = next_op.getResult(0); + const auto add_op_result_type = + add_op_result.getType().cast(); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, gemm_style_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + // Quantizes the entry function's body containing a `DotGeneralOp`. class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeDotGeneralOpPattern(MLIRContext& ctx) : ctx_(&ctx) {} + explicit QuantizeDotGeneralOpPattern() = default; - // Match for all dot_general op and check for possible fusions. LogicalResult match(func::FuncOp entry_func_op) const override { - // function must have input, filter, and optionally bias. - auto& operations = entry_func_op.getBody().front().getOperations(); - if (operations.size() != 2 && operations.size() != 3) { - return failure(); - } - if (!isa(operations.front())) { - return failure(); - } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { - LLVM_DEBUG(llvm::dbgs() - << "Currently dot_general quantization only supports static " - " shapes.\n"); - return failure(); - } - return success(); + return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { - // Update the output type of the dot_general op. - DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); - - const Type input_type = entry_func_op.getArgumentTypes()[0]; - const Type filter_type = entry_func_op.getArgumentTypes()[1]; - const Type func_result_type = entry_func_op.getResultTypes()[0]; - - const double input_scale = getElementTypeOrSelf(input_type) - .cast() - .getScale(); - const double filter_scale = getElementTypeOrSelf(filter_type) - .cast() - .getScale(); - const double result_scale = input_scale * filter_scale; - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType dot_general_quantized_element_type = - CreateI32F32UniformQuantizedType(dot_general_op->getLoc(), *ctx_, - result_scale, - /*zero_point=*/0); - - Value dot_general_op_result = dot_general_op->getResult(0); - auto dot_general_op_result_type = - dot_general_op_result.getType().cast(); - const ArrayRef dot_general_shape = - dot_general_op_result_type.getShape(); - - const TensorType new_dot_general_op_result_type = - dot_general_op_result_type.cloneWith( - dot_general_shape, dot_general_quantized_element_type); - dot_general_op_result.setType(new_dot_general_op_result_type); - - rewriter.setInsertionPointAfter(dot_general_op); - - Operation& next_op = *dot_general_op->getNextNode(); - - // If an op is used multiple times, do not apply quantization of fused - // patterns to prevent removal of dependee ops. - const bool should_quantize_without_fusion = - HasFusibleQuantizationPattern(*dot_general_op.getOperation()) && - !dot_general_op->hasOneUse(); - - // TODO: b/307620428 - Add support for dynamic shapes. - if (should_quantize_without_fusion || !isa(next_op)) { - // no bias - CreateAndReturnUniformQuantizeOp(rewriter, *dot_general_op, entry_func_op, - func_result_type); - return; - } - // bias fusion - Value bias_op = next_op.getOperand(1); - Value add_op_result = next_op.getResult(0); - const auto add_op_result_type = - add_op_result.getType().cast(); - const ArrayRef add_op_shape = add_op_result_type.getShape(); - // For quantized bias add case, lhs, rhs, and result have the same types. - const TensorType new_add_op_result_type = add_op_result_type.cloneWith( - add_op_shape, dot_general_quantized_element_type); - add_op_result.setType(new_add_op_result_type); - - AddOp bias_add_op = rewriter.create(dot_general_op->getLoc(), - dot_general_op, bias_op); - - CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, - func_result_type); + RewriteGemmStyleOp(entry_func_op, rewriter); } +}; - private: - MLIRContext* ctx_ = nullptr; +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } }; // Converts `entry_func_op` to be quantized according to the respective @@ -408,14 +431,14 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { return failure(); } - return FuncBodyRewritePatternT(*getContext()).match(entry_func_op); + return FuncBodyRewritePatternT().match(entry_func_op); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, PatternRewriter& rewriter) const override { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(*getContext())); + FuncBodyRewritePatternT()); } }; @@ -444,7 +467,8 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // TODO - b/307839649: Move this as a separate pass. RewritePatternSet patterns(&ctx); - patterns.add>(ctx); + patterns.add, + XlaCallModuleOpToCallOp>(ctx); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index b8a559b51786e1..24a8248e750c56 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -55,6 +55,7 @@ pytype_strict_library( "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:save", "//tensorflow/python/types:core", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index e92b996a1c5110..3ae86aa83339cd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -14,6 +14,7 @@ # ============================================================================== import itertools from typing import Optional, Sequence +import unittest from absl.testing import parameterized import numpy as np @@ -218,6 +219,100 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # TODO: b/309674337 - Fix the large numerical errors. self.assertAllClose(new_outputs, expected_outputs, rtol=0.3) + @parameterized.named_parameters( + { + 'testcase_name': 'none', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.STABLEHLO, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + ) + @test_util.run_in_graph_and_eager_modes + @unittest.skip('b/307620966: e2e support for conv is under development.') + def test_conv_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + input_shape_dynamic: bool, + enable_per_channel_quantization: bool, + dilations: Sequence[int] = None, + ): + input_shape = (None, None, None, 3) if input_shape_dynamic else (1, 3, 4, 3) + filter_shape = (2, 3, 3, 2) + strides = (1, 1, 1, 1) + model = self._create_conv2d_model( + input_shape, + filter_shape, + self._input_saved_model_path, + has_bias, + has_batch_norm, + activation_fn, + strides, + dilations, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=1224) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + tags = {tag_constants.SERVING} + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 + ), + tags=tags, + signature_keys=['serving_default'], + op_set=target_opset, + representative_datasets={ + 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( + tfrecord_file_path=dataset_path + ) + }, + enable_per_channel_quantization=enable_per_channel_quantization, + ) + + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.conv2d(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol value is + # arbitrary. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.02) + def test_when_preset_not_srq_raise_error(self): self._create_matmul_model( input_shape=(1, 1024), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index 9600754c67e7d1..cc35e66ce641ee 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -26,6 +26,7 @@ from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.types import core @@ -227,3 +228,77 @@ def matmul_and_same_scale( ), ) return model + + def _create_conv2d_model( + self, + input_shape: Sequence[int], + filter_shape: Sequence[int], + saved_model_path: str, + has_bias: bool = False, + has_batch_norm: bool = False, + activation_fn: Optional[ops.Operation] = None, + strides: Sequence[int] = (1, 1, 1, 1), + dilations: Sequence[int] = (1, 1, 1, 1), + padding: str = 'SAME', + ) -> module.Module: + class ConvModel(module.Module): + """A simple model with a single conv2d, bias and relu.""" + + def __init__(self): + self.out_channel_size = filter_shape[-1] + + # This ensures filters will have different value range per out channel + self.filters = np.stack( + [ + np.random.uniform( + low=-(i + 1), high=(i + 1), size=filter_shape[:-1] + ).astype('f4') + for i in range(self.out_channel_size) + ], + axis=-1, + ) + + self.bias = np.random.uniform( + low=0, high=10, size=(self.out_channel_size) + ).astype('f4') + + @def_function.function + def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a 2D convolution operation. + + Args: + input_tensor: Input tensor to perform convolution on. + + Returns: + A map of: output key -> output result. + """ + scale = [1.0] * self.out_channel_size + offset = [0.5] * self.out_channel_size + mean, variance = scale, offset + out = nn_ops.conv2d( + input_tensor, + self.filters, + strides=strides, + dilations=dilations, + padding=padding, + data_format='NHWC', + name='sample/conv', + ) + if has_batch_norm: + # Fusing is supported for non-training case. + out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( + out, scale, offset, mean, variance, is_training=False + ) + return {'output': out} + + model = ConvModel() + saved_model_save.save( + model, + saved_model_path, + signatures=model.conv2d.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir index b82d8177c68537..d38083150ba8be 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir @@ -32,7 +32,7 @@ module attributes {tf_saved_model.semantics} { // i8 quantized tensor. // CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> } @@ -153,3 +153,81 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] // CHECK: return %[[DOT_GENERAL]] } + +// ----- + +// Test basic convolution is quantized. + +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @quantize_convolution(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_convolution_fn, _original_entry_function = "composite_convolution_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-LABEL: func.func private @quantize_convolution +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_convolution_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_convolution_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_convolution_fn(%[[ARG_2:.*]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.*]] = stablehlo.convolution(%[[ARG_2]], %[[ARG_3]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused bias pattern is properly quantized. + +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @quantize_convolution_with_bias(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3x4x2xf32>} : () -> tensor<1x3x4x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_convolution_with_bias_fn, _original_entry_function = "composite_convolution_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } + +// CHECK-LABEL: func.func private @quantize_convolution_with_bias +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_convolution_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK: func.func private @quantized_convolution_with_bias_fn(%[[ARG_2:.*]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_4:.*]]: tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_convolution_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } +// CHECK: %[[CONVOLUTION_0:.*]] = stablehlo.convolution(%[[ARG_2]], %[[ARG_3]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[CONVOLUTION_0]], %[[ARG_4]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} From fc8a13b4b4131f2e21cdc51dab531132382ce9b3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 28 Nov 2023 19:10:56 -0800 Subject: [PATCH 162/381] [stream_executor] Add Memset command to CommandBuffer PiperOrigin-RevId: 586166961 --- .../xla/xla/stream_executor/command_buffer.cc | 10 +- .../xla/xla/stream_executor/command_buffer.h | 9 +- .../cuda/cuda_command_buffer_test.cc | 61 ++++++++++-- .../xla/stream_executor/cuda/cuda_driver.cc | 92 +++++++++++++++++++ .../stream_executor/gpu/gpu_command_buffer.cc | 35 ++++++- .../stream_executor/gpu/gpu_command_buffer.h | 6 +- .../xla/xla/stream_executor/gpu/gpu_driver.h | 15 +++ .../stream_executor_internal.h | 5 + 8 files changed, 215 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 0082b44e9e0b79..734a3faeb6f946 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" +#include #include #include #include @@ -131,6 +132,11 @@ tsl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return implementation_->MemcpyDeviceToDevice(dst, src, size); } +tsl::Status CommandBuffer::Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements) { + return implementation_->Memset(dst, bit_pattern, num_elements); +} + tsl::Status CommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { return implementation_->Allocate(alloc); } @@ -159,9 +165,9 @@ tsl::Status CommandBuffer::Case(StreamExecutor* executor, } tsl::Status CommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, + DeviceMemory loop_counter, Builder body_builder) { - return implementation_->For(executor, num_iteration, loop_index, + return implementation_->For(executor, num_iteration, loop_counter, std::move(body_builder)); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index fde601a7d52df9..54349127dcfeed 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_COMMAND_BUFFER_H_ #define XLA_STREAM_EXECUTOR_COMMAND_BUFFER_H_ +#include #include #include #include +#include #include #include "absl/functional/any_invocable.h" @@ -127,6 +129,11 @@ class CommandBuffer { tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size); + // Adds a memset node to the command buffer. + using BitPattern = std::variant; + tsl::Status Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements); + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// @@ -153,7 +160,7 @@ class CommandBuffer { // Adds a conditional operation that will execute a command buffer constructed // by the `body_builder` exactly `num_iteration` times. tsl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, Builder body_builder); + DeviceMemory loop_counter, Builder body_builder); // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index bfa1052fb72f98..191c21f552c9e3 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -220,6 +220,48 @@ TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { ASSERT_EQ(dst, expected); } +TEST(CudaCommandBufferTest, Memset) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + DeviceMemory a = executor->AllocateArray(length, 0); + + // Create a command buffer with a single memset command. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{42}, length)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `a` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), a, byte_length); + + std::vector expected = {42, 42, 42, 42}; + ASSERT_EQ(dst, expected); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer.Update()); + TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{43}, length)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), a, byte_length); + + expected = {43, 43, 43, 43}; + ASSERT_EQ(dst, expected); +} + TEST(CudaCommandBufferTest, ConditionalIf) { Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); if (!CommandBuffer::SupportsConditionalCommands(platform)) { @@ -534,12 +576,13 @@ TEST(CudaCommandBufferTest, ConditionalFor) { int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; - // Prepare arguments: a=1, b=0, loop_index=0 - DeviceMemory loop_index = executor->AllocateArray(1, 0); + // Prepare arguments: a=1, b=0, loop_counter=100 + DeviceMemory loop_counter = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); - stream.ThenMemset32(&loop_index, 0, sizeof(int32_t)); + // Set loop counter to 100 to check that command buffer resets it. + stream.ThenMemset32(&loop_counter, 100, sizeof(int32_t)); stream.ThenMemset32(&a, 1, byte_length); stream.ThenMemZero(&b, byte_length); @@ -552,7 +595,7 @@ TEST(CudaCommandBufferTest, ConditionalFor) { // Create a command buffer with a single conditional operation. auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.For(executor, num_iters, loop_index, body_builder)); + TF_ASSERT_OK(cmd_buffer.For(executor, num_iters, loop_counter, body_builder)); TF_ASSERT_OK(cmd_buffer.Finalize()); TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); @@ -595,23 +638,23 @@ TEST(CudaCommandBufferTest, ConditionalWhile) { int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; - // Prepare arguments: a=1, b=0, loop_index=0, pred=false + // Prepare arguments: a=1, b=0, loop_counter=0, pred=false DeviceMemory pred = executor->AllocateArray(1, 0); - DeviceMemory loop_index = executor->AllocateArray(1, 0); + DeviceMemory loop_counter = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); DeviceMemory b = executor->AllocateArray(length, 0); static constexpr bool kFalse = false; stream.ThenMemcpy(&pred, &kFalse, 1); - stream.ThenMemset32(&loop_index, 0, sizeof(int32_t)); + stream.ThenMemset32(&loop_counter, 0, sizeof(int32_t)); stream.ThenMemset32(&a, 1, byte_length); stream.ThenMemZero(&b, byte_length); int32_t num_iters = 10; - // Loop cond: loop_index++ < num_iters; + // Loop cond: loop_counter++ < num_iters; CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { - return cond_cmd->Launch(inc_and_cmp, ThreadDim(), BlockDim(), loop_index, + return cond_cmd->Launch(inc_and_cmp, ThreadDim(), BlockDim(), loop_counter, pred, num_iters); }; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index c1bc32abdcb812..959625859650ea 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -996,6 +996,98 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { return ::tsl::OkStatus(); } +namespace { + +struct BitPatternToString { + std::string operator()(uint8_t pattern) { + return absl::StrCat("u8:", pattern); + } + std::string operator()(uint16_t pattern) { + return absl::StrCat("u16:", pattern); + } + std::string operator()(uint32_t pattern) { + return absl::StrCat("u32:", pattern); + } +}; + +// Broadcasts a pattern value of 1/2/4 bytes to a 4 byte value. +struct BitPatternToValue { + std::pair operator()(uint8_t pattern) { + unsigned value = pattern; + return {(value << 24) | (value << 16) | (value << 8) | value, + /*element_size=*/1}; + } + std::pair operator()(uint16_t pattern) { + unsigned value = pattern; + return {(value << 16) | value, /*element_size=*/2}; + } + std::pair operator()(uint32_t pattern) { + return {pattern, /*element_size=*/4}; + } +}; + +} // namespace + +/* static */ tsl::Status GpuDriver::GraphAddMemsetNode( + GpuContext* context, CUgraphNode* node, GpuGraphHandle graph, + absl::Span deps, CUdeviceptr dst, + std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Add memset node to a graph " << graph + << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context() << "; deps: " << deps.size(); + + CUDA_MEMSET_NODE_PARAMS params; + memset(¶ms, 0, sizeof(params)); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + params.dst = dst; + params.elementSize = element_size; + params.height = 1; + params.pitch = 0; // unused if height is 1 + params.value = value; + params.width = num_elements; + + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddMemsetNode(node, graph, deps.data(), deps.size(), ¶ms, + context->context()), + "Failed to add memset node to a CUDA graph"); + + return ::tsl::OkStatus(); +} + +/* static */ tsl::Status GpuDriver::GraphExecMemsetNodeSetParams( + GpuContext* context, CUgraphExec exec, CUgraphNode node, CUdeviceptr dst, + std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Set memset node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context(); + + CUDA_MEMSET_NODE_PARAMS params; + memset(¶ms, 0, sizeof(params)); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + params.dst = dst; + params.elementSize = element_size; + params.height = 1; + params.pitch = 0; // unused if height is 1 + params.value = value; + params.width = num_elements; + + RETURN_IF_CUDA_RES_ERROR( + cuGraphExecMemsetNodeSetParams(exec, node, ¶ms, context->context()), + "Failed to set memset node params"); + + return ::tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::GraphAddChildNode( CUgraphNode* node, CUgraph graph, absl::Span deps, CUgraph child) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 70ebf6b2d89ed6..d4564f9ea3bb20 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -273,6 +273,29 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } +tsl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, + CommandBuffer::BitPattern bit_pattern, + size_t num_elements) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + if (state_ == State::kCreate) { + Dependencies deps = GetDependencies(); + GpuGraphNodeHandle* node = &nodes_.emplace_back(); + return GpuDriver::GraphAddMemsetNode( + parent_->gpu_context(), node, graph_, absl::MakeSpan(deps), + AsDevicePtr(*dst), bit_pattern, num_elements); + } + + if (state_ == State::kUpdate) { + GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + return GpuDriver::GraphExecMemsetNodeSetParams( + parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), bit_pattern, + num_elements); + } + + return UnsupportedStateError(state_); +} + tsl::Status GpuCommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { TF_RETURN_IF_ERROR(CheckNotFinalized()); @@ -418,6 +441,8 @@ tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( tsl::Status GpuCommandBuffer::CreateConditionalCommand( ConditionType type, SetConditionFn set_condition, absl::Span builders) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + // Every conditional command buffer is controlled by its own handle. size_t num_handles = builders.size(); @@ -564,7 +589,7 @@ tsl::Status GpuCommandBuffer::Case( tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, + DeviceMemory loop_counter, CommandBuffer::Builder body_builder) { DCHECK(executor->implementation() == parent_); @@ -579,12 +604,12 @@ tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_for_condition)); } - // TODO(ezhulenev): We currently assume that `loop_index` initialized to - // zero, instead we should explicitly add a memset to clear it. + // Reset loop counter to zero. + TF_RETURN_IF_ERROR(Memset(&loop_counter, uint32_t{0}, 1)); auto set_cond_fn = [&](absl::Span handles) { return Launch(set_for_condition, ThreadDim(), BlockDim(), handles[0], - loop_index, num_iteration); + loop_counter, num_iteration); }; auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { @@ -592,7 +617,7 @@ tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, // Decide if we want to continue loop iteration. return body->Launch(set_for_condition, ThreadDim(), BlockDim(), handle, - loop_index, num_iteration); + loop_counter, num_iteration); }; std::array builders = {std::move(body)}; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index e817c9fd07728d..3a5269f283b81d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -59,6 +59,10 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) override; + tsl::Status Memset(DeviceMemoryBase* dst, + CommandBuffer::BitPattern bit_pattern, + size_t num_elements) override; + tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) override; tsl::StatusOr GetAllocationAddress( @@ -75,7 +79,7 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { std::vector branches) override; tsl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, + DeviceMemory loop_counter, CommandBuffer::Builder body_builder) override; tsl::Status While(StreamExecutor* executor, DeviceMemory pred, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 9db00ac9f54b40..be4c94a95f54f7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -540,6 +540,21 @@ class GpuDriver { GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size); + // Creates a memset node and adds it to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g89dc8fc3743392777c0daa2c4aca40d3 + static tsl::Status GraphAddMemsetNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr dst, + std::variant bit_pattern, + uint64_t num_elements); + + // Sets the parameters for a memset node in the given graph exec. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g5df5be09a0b7b3513e740ebbbcd59739 + static tsl::Status GraphExecMemsetNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr dst, std::variant bit_pattern, + uint64_t num_elements); + // Creates a child graph node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gde52afbcf91a8c79d4d7efbe0e3b6844 static tsl::Status GraphAddChildNode(GpuGraphNodeHandle* node, diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index de0b82e9d5e1c8..62289f7cedacf0 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -141,6 +141,11 @@ class CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) = 0; + // Adds a memset node to the command buffer. + virtual tsl::Status Memset(DeviceMemoryBase* dst, + CommandBuffer::BitPattern bit_pattern, + size_t num_elements) = 0; + // Adds a device memory allocation node to the command buffer. virtual tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) = 0; From 16f2028b003878185b8585fdde83edcc4f1e3b6c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 28 Nov 2023 19:23:20 -0800 Subject: [PATCH 163/381] [stream_executor] Add support for updating Memcpy command parameters PiperOrigin-RevId: 586169003 --- .../cuda/cuda_command_buffer_test.cc | 47 +++++++++++++++++++ .../xla/stream_executor/cuda/cuda_driver.cc | 26 ++++++++++ .../stream_executor/gpu/gpu_command_buffer.cc | 8 +++- .../xla/xla/stream_executor/gpu/gpu_driver.h | 6 +++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 191c21f552c9e3..1464df79e8e6d1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -220,6 +220,53 @@ TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { ASSERT_EQ(dst, expected); } +TEST(CudaCommandBufferTest, MemcpyDeviceToDevice) { + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=uninitialized + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&a, 42, byte_length); + + // Create a command buffer with a single a to b memcpy command. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&b, a, byte_length)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), a, byte_length); + + std::vector expected = {42, 42, 42, 42}; + ASSERT_EQ(dst, expected); + + // Update command buffer to swap the memcpy direction. + TF_ASSERT_OK(cmd_buffer.Update()); + TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&a, b, byte_length)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + // Clear destination to test that command buffer actually copied memory. + stream.ThenMemset32(&a, 0, byte_length); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `a` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), a, byte_length); + ASSERT_EQ(dst, expected); +} + TEST(CudaCommandBufferTest, Memset) { Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); StreamExecutor* executor = platform->ExecutorForDevice(0).value(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 959625859650ea..469043ba02b80b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -996,6 +996,32 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { return ::tsl::OkStatus(); } +/* static */ tsl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { + VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(gpu_dst) + << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size + << "; context: " << context->context(); + + CUDA_MEMCPY3D params; + memset(¶ms, 0, sizeof(params)); + + params.srcMemoryType = CU_MEMORYTYPE_DEVICE; + params.srcDevice = gpu_src; + params.dstMemoryType = CU_MEMORYTYPE_DEVICE; + params.dstDevice = gpu_dst; + params.WidthInBytes = size; + params.Height = 1; + params.Depth = 1; + + RETURN_IF_CUDA_RES_ERROR( + cuGraphExecMemcpyNodeSetParams(exec, node, ¶ms, context->context()), + "Failed to set memcpy d2d node params"); + + return ::tsl::OkStatus(); +} + namespace { struct BitPatternToString { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index d4564f9ea3bb20..0742b62c994d7b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -261,7 +261,6 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, uint64_t size) { TF_RETURN_IF_ERROR(CheckNotFinalized()); - // Adds a new memcpy node to the graph under construction. if (state_ == State::kCreate) { Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); @@ -270,6 +269,13 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, AsDevicePtr(*dst), AsDevicePtr(src), size); } + if (state_ == State::kUpdate) { + GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + return GpuDriver::GraphExecMemcpyD2DNodeSetParams( + parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), + AsDevicePtr(src), size); + } + return UnsupportedStateError(state_); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index be4c94a95f54f7..2dff75b93e1cef 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -540,6 +540,12 @@ class GpuDriver { GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size); + // Sets the parameters for a memcpy node in the given graphExec. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g26186d58858ab32ccc7425b53786cce5 + static tsl::Status GraphExecMemcpyD2DNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size); + // Creates a memset node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g89dc8fc3743392777c0daa2c4aca40d3 static tsl::Status GraphAddMemsetNode( From 739d18a1f1a8cdda2c36266610976679da7c70cc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Nov 2023 19:23:33 -0800 Subject: [PATCH 164/381] Converts the AutoShardingSolverRequest object into a proto (which provides built-in utilities for saving & loading). PiperOrigin-RevId: 586169032 --- .../xla/hlo/experimental/auto_sharding/BUILD | 12 + .../auto_sharding/auto_sharding.cc | 123 ++++--- .../auto_sharding/auto_sharding.proto | 61 ++++ .../auto_sharding/auto_sharding_solver.cc | 325 +++++++++--------- .../auto_sharding/auto_sharding_solver.h | 27 +- .../auto_sharding_solver_impl.cc | 1 + .../auto_sharding_solver_test.cc | 188 ++++++---- 7 files changed, 439 insertions(+), 298 deletions(-) create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index c27a2cc6984e03..b871a6012fbb5a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -2,6 +2,7 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//xla:xla.bzl", "auto_sharding_deps", "auto_sharding_solver_deps", "xla_cc_binary", "xla_cc_test") +load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") package( default_visibility = ["//visibility:public"], @@ -80,6 +81,7 @@ cc_library( srcs = ["auto_sharding_solver_impl.cc"], visibility = ["//visibility:public"], deps = [ + ":auto_sharding_proto_cc", ":auto_sharding_strategy", "@com_google_ortools//ortools/linear_solver", ], @@ -90,6 +92,7 @@ cc_library( srcs = ["auto_sharding_solver.cc"], visibility = ["//visibility:public"], deps = [ + ":auto_sharding_proto_cc", ":auto_sharding_strategy", "//xla:statusor", "//xla:util", @@ -116,6 +119,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":auto_sharding_proto_cc", "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", @@ -273,6 +277,12 @@ xla_cc_binary( ], ) +tf_proto_library( + name = "auto_sharding_proto", + srcs = ["auto_sharding.proto"], + visibility = ["//visibility:public"], +) + build_test( name = "auto_sharding_runner_build_test", targets = [ @@ -291,6 +301,7 @@ xla_cc_test( deps = [ ":auto_sharding", ":auto_sharding_option", + ":auto_sharding_proto_cc", ":auto_sharding_util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", @@ -313,6 +324,7 @@ xla_cc_test( "no_oss", ], deps = [ + ":auto_sharding_proto_cc", ":auto_sharding_solver", ":auto_sharding_strategy", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 6750ec3e374ba9..b9fca89d5770c7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -87,6 +87,12 @@ limitations under the License. namespace xla { namespace spmd { + +namespace { +constexpr double kOverbudgetCoeff = 1e6; +constexpr double kSaltiplier = 0.001; // Modifies each obj. term by at most .1% +} // namespace + // Compute the resharding cost vector from multiple possible strategies // to a desired sharding spec. std::vector ReshardingCostVector( @@ -2507,24 +2513,31 @@ AutoShardingSolverResult CallSolver( sharding_propagation_solution) { // Serialize edges and edge costs to 1d numpy arrays AutoShardingSolverRequest request; - request.num_nodes = strategy_groups.size(); - request.memory_budget = option.memory_budget_per_device; - request.s_len = cost_graph.node_lens_; - request.s_follow = cost_graph.follow_idx_; - request.s_hint = s_hint; - request.solver_timeout_in_seconds = solver_timeout_in_seconds; - request.crash_at_infinity_costs_check = !option.try_multiple_mesh_shapes; - request.compute_iis = compute_iis; - for (const auto& iter : cost_graph.edge_costs_) { - request.e.push_back(iter.first); - std::vector rij; - const Matrix& edge_cost = iter.second; + request.set_num_nodes(strategy_groups.size()); + request.set_memory_budget(option.memory_budget_per_device); + request.mutable_s_len()->Add(cost_graph.node_lens_.begin(), + cost_graph.node_lens_.end()); + request.mutable_s_follow()->Add(cost_graph.follow_idx_.begin(), + cost_graph.follow_idx_.end()); + request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); + request.mutable_solver_timeout()->set_solver_timeout_in_seconds( + solver_timeout_in_seconds); + request.mutable_overbudget_coeff()->set_coeff(kOverbudgetCoeff); + request.set_crash_at_infinity_costs_check(!option.try_multiple_mesh_shapes); + request.set_compute_iis(compute_iis); + request.set_saltiplier(kSaltiplier); + for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) { + AutoShardingSolverRequest_Pair raw_edge; + raw_edge.set_first(edge.first); + raw_edge.set_second(edge.second); + *request.add_edges() = raw_edge; + AutoShardingSolverRequest_Costs rij; for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) { for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) { - rij.push_back(edge_cost(i, j)); + rij.add_costs(edge_cost(i, j)); } } - request.r.push_back(std::move(rij)); + request.mutable_resharding_costs()->Add(std::move(rij)); } const HloInstructionSequence& sequence = @@ -2533,13 +2546,13 @@ AutoShardingSolverResult CallSolver( // Serialize node costs int num_nodes_without_default = 0; - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const StrategyGroup* strategy_group = strategy_groups[node_idx]; auto instruction_name = instructions.at(strategy_group->instruction_id)->name(); - request.instruction_names.push_back( + request.add_instruction_names( absl::StrCat(instruction_name, " (id: ", node_idx, ")")); - std::vector ci, di, mi, pi; + AutoShardingSolverRequest_Costs ci, di, mi, pi; std::optional default_strategy; auto iter = sharding_propagation_solution.find(instruction_name); if (iter != sharding_propagation_solution.end()) { @@ -2555,23 +2568,23 @@ AutoShardingSolverResult CallSolver( for (NodeStrategyIdx j = 0; j < strategy_group->strategies.size(); ++j) { const ShardingStrategy& strategy = strategy_group->strategies[j]; const HloSharding& sharding = strategy.output_sharding; - ci.push_back(strategy.compute_cost); - di.push_back(strategy.communication_cost + + ci.add_costs(strategy.compute_cost); + di.add_costs(strategy.communication_cost + cost_graph.extra_node_costs_[node_idx][j]); - mi.push_back(strategy.memory_cost); - pi.push_back(default_strategy && sharding == *default_strategy ? 0 : 1); + mi.add_costs(strategy.memory_cost); + pi.add_costs(default_strategy && sharding == *default_strategy ? 0 : 1); } if (option.use_sharding_propagation_for_default_shardings && - *std::min_element(pi.begin(), pi.end()) > 0) { + *std::min_element(pi.costs().begin(), pi.costs().end()) > 0) { LOG(WARNING) << "No default strategy for {node_idx " << node_idx << ", instruction ID " << strategy_group->instruction_id << ", instruction name " << instruction_name << "}"; ++num_nodes_without_default; } - request.c.push_back(ci); - request.d.push_back(di); - request.m.push_back(mi); - request.p.push_back(pi); + request.mutable_computation_costs()->Add(std::move(ci)); + request.mutable_communication_costs()->Add(std::move(di)); + request.mutable_memory_costs()->Add(std::move(mi)); + request.mutable_departure_costs()->Add(std::move(pi)); } LOG(INFO) << "Total nodes without default: " << num_nodes_without_default; @@ -2600,62 +2613,70 @@ AutoShardingSolverResult CallSolver( std::vector row_indices; std::vector col_indices; - if (request.s_follow[idx_a] >= 0) { + if (request.s_follow(idx_a) >= 0) { row_indices = cost_graph.reindexing_vector_.at(idx_a); - idx_a = request.s_follow[idx_a]; + idx_a = request.s_follow(idx_a); } else { - row_indices.assign(request.s_len[idx_a], 0); + row_indices.assign(request.s_len(idx_a), 0); std::iota(row_indices.begin(), row_indices.end(), 0); } - if (request.s_follow[idx_b] >= 0) { + if (request.s_follow(idx_b) >= 0) { col_indices = cost_graph.reindexing_vector_.at(idx_b); - idx_b = request.s_follow[idx_b]; + idx_b = request.s_follow(idx_b); } else { - col_indices.assign(request.s_len[idx_b], 0); + col_indices.assign(request.s_len(idx_b), 0); std::iota(col_indices.begin(), col_indices.end(), 0); } - CHECK_EQ(request.s_len[idx_a], row_indices.size()); - CHECK_EQ(request.s_len[idx_b], col_indices.size()); + CHECK_EQ(request.s_len(idx_a), row_indices.size()); + CHECK_EQ(request.s_len(idx_b), col_indices.size()); - std::vector vij; + AutoShardingSolverRequest_Costs vij; for (NodeStrategyIdx i : row_indices) { for (NodeStrategyIdx j : col_indices) { - vij.push_back(raw_cost(i, j)); + vij.add_costs(raw_cost(i, j)); } } - bool convertable = (row_indices.size() == col_indices.size()); - for (NodeStrategyIdx i = 0; i < row_indices.size() && convertable; ++i) { - if (vij[i * col_indices.size() + i] != 0.0) convertable = false; + bool convertible = (row_indices.size() == col_indices.size()); + for (NodeStrategyIdx i = 0; i < row_indices.size() && convertible; ++i) { + if (vij.costs(i * col_indices.size() + i) != 0.0) convertible = false; } - if (convertable && option.allow_alias_to_follower_conversion) { + if (convertible && option.allow_alias_to_follower_conversion) { new_followers.push_back({idx_a, idx_b}); } else { - request.a.push_back({idx_a, idx_b}); - request.v.push_back(vij); + AutoShardingSolverRequest_Pair alias; + alias.set_first(idx_a); + alias.set_second(idx_b); + *request.add_aliases() = alias; + request.mutable_value_costs()->Add(std::move(vij)); } } // Process any new followers that had originally been modeled as aliases. - std::vector& s_follow = request.s_follow; + auto s_follow = request.mutable_s_follow(); for (auto [follower, followee] : new_followers) { // New followers may have introduced chains, so find the root nodes. - while (s_follow[follower] >= 0) follower = s_follow[follower]; - while (s_follow[followee] >= 0) followee = s_follow[followee]; - if (follower != followee) s_follow[follower] = followee; + while (s_follow->at(follower) >= 0) follower = s_follow->at(follower); + while (s_follow->at(followee) >= 0) followee = s_follow->at(followee); + if (follower != followee) s_follow->Set(follower, followee); } // Flatten the follower indices to remove any transitive arcs. - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (s_follow[node_idx] < 0) continue; - while (s_follow[s_follow[node_idx]] >= 0) { - s_follow[node_idx] = s_follow[s_follow[node_idx]]; + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (s_follow->at(node_idx) < 0) continue; + while (s_follow->at(s_follow->at(node_idx)) >= 0) { + s_follow->Set(node_idx, s_follow->at(s_follow->at(node_idx))); } } // Serialize liveness_set - request.live = liveness_node_set; + for (const auto& liveness_node_subset : liveness_node_set) { + AutoShardingSolverRequest_Nodes nodes; + nodes.mutable_nodes()->Add(liveness_node_subset.begin(), + liveness_node_subset.end()); + request.mutable_live()->Add(std::move(nodes)); + } PopulateTemporalValues(cost_graph, request); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto new file mode 100644 index 00000000000000..5683b4a02f957b --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto @@ -0,0 +1,61 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +message AutoShardingSolverRequest { + message Pair { + int64 first = 1; + int64 second = 2; + } + message Costs { + repeated double costs = 1; + } + message Nodes { + repeated int64 nodes = 1; + } + message SolverTimeout { + int64 solver_timeout_in_seconds = 1; + } + message Coeff { + double coeff = 1; + } + + int64 num_nodes = 1; + int64 memory_budget = 2; + repeated int64 s_len = 3; + repeated int64 s_follow = 4; + repeated int64 s_hint = 5; + repeated Pair edges = 6; + repeated Nodes live = 7; + repeated Costs computation_costs = 8; + repeated Costs communication_costs = 9; + repeated Costs memory_costs = 10; + repeated Costs departure_costs = 11; + repeated Costs resharding_costs = 12; + repeated Costs duration_costs = 13; + repeated Pair aliases = 14; + repeated Costs value_costs = 15; + repeated string instruction_names = 16; + optional SolverTimeout solver_timeout = 17; + optional Coeff overbudget_coeff = 18; + optional Coeff makespan_coeff = 19; + optional Coeff max_departures = 20; + bool crash_at_infinity_costs_check = 21; + bool compute_iis = 22; + double saltiplier = 23; +} diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 0c71a705118e5a..6b1790f8b78070 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -20,12 +20,13 @@ limitations under the License. #include #include #include -#include #include #include #include #include +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" + #ifdef PLATFORM_GOOGLE #include "file/base/options.h" #endif @@ -65,18 +66,16 @@ bool AutoShardingSolverResult::operator==( void PrintLargestInstructions( const std::vector& chosen_strategy, - const std::vector>& memory_cost, - const std::vector>& liveness, - const std::vector& instruction_names) { + const AutoShardingSolverRequest& request) { // This memory consumption computation is different from // that in PrintAutoShardingSolution() because how L and m are created to be // different from liveness_set and strategy.memory_cost. std::vector> time_memory_usage; - for (LivenessIdx time_idx = 0; time_idx < liveness.size(); ++time_idx) { + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { double mem = 0.0; - for (NodeIdx node_idx : liveness[time_idx]) { - mem += memory_cost[node_idx][chosen_strategy[node_idx]]; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + mem += request.memory_costs(node_idx).costs(chosen_strategy[node_idx]); } time_memory_usage.push_back(std::make_pair(time_idx, mem)); } @@ -96,9 +95,11 @@ void PrintLargestInstructions( k = std::min(k, time_memory_usage.size()); std::vector> instruction_mem; absl::flat_hash_set instruction_set; - for (LivenessIdx time_idx = 0; time_idx < k; time_idx++) { - for (NodeIdx node_idx : liveness[time_memory_usage.at(time_idx).first]) { - double mem = memory_cost[node_idx][chosen_strategy[node_idx]]; + for (auto usage_idx = 0; usage_idx < k; usage_idx++) { + LivenessIdx time_idx = time_memory_usage.at(usage_idx).first; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + double mem = + request.memory_costs(node_idx).costs(chosen_strategy[node_idx]); if (mem > 100 * 1024 * 1024 && instruction_set.find(node_idx) == instruction_set.end()) { instruction_mem.push_back(std::make_pair(node_idx, mem)); @@ -113,7 +114,7 @@ void PrintLargestInstructions( VLOG(1) << "Top " << top_tensors << " largest tensors:"; for (size_t i = 0; i < top_tensors; i++) { VLOG(1) << "instruction name: " - << instruction_names.at(instruction_mem.at(i).first) + << request.instruction_names(instruction_mem.at(i).first) << " memory usage: " << instruction_mem.at(i).second / (1024 * 1024 * 1024) << "GB"; } @@ -139,10 +140,10 @@ AutoShardingSolverResult SolveAndExtractSolution( double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { double minimum_memory_budget_required_estimate = 0.0; - for (LivenessIdx time_idx = 0; time_idx < request.live.size(); ++time_idx) { + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { double minimum_memory_budget_required_estimate_local = 0.0; - for (NodeIdx node_idx : request.live[time_idx]) { - const std::vector& m = request.m[node_idx]; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + const auto& m = request.memory_costs(node_idx).costs(); const double fixed_memory_cost = *std::min_element(m.begin(), m.end()); minimum_memory_budget_required_estimate_local += fixed_memory_cost; } @@ -200,7 +201,7 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { AutoShardingSolverResult CallORToolsSolver( const AutoShardingSolverRequest& request) { - size_t num_edges = request.e.size(); + size_t num_edges = request.edges_size(); int32_t num_workers = 32; // SAT or SCIP @@ -221,64 +222,67 @@ AutoShardingSolverResult CallORToolsSolver( } #endif // Create variables - std::vector> s(request.num_nodes); + std::vector> s(request.num_nodes()); std::vector> e(num_edges); MPVariable* overbudget_var = nullptr; MPVariable* makespan_var = nullptr; size_t var_vector_cnt = 0; - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (request.s_follow[node_idx] < 0) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (request.s_follow(node_idx) < 0) { var_vector_cnt += 1; // Creates variables for instructions that do not follow others. - solver->MakeBoolVarArray(request.s_len[node_idx], + solver->MakeBoolVarArray(request.s_len(node_idx), absl::StrCat("s[", node_idx, "]"), &s[node_idx]); } } - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (request.s_follow[node_idx] >= 0) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (request.s_follow(node_idx) >= 0) { // Copies the variable of followed instruction to the following // instruction. - s[node_idx] = s[request.s_follow[node_idx]]; + s[node_idx] = s[request.s_follow(node_idx)]; } } std::vector e_follow(num_edges, -1); absl::flat_hash_map, EdgeIdx> edge_map; for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - const std::pair& edge = request.e[edge_idx]; - std::pair followed_edge = edge; - if (int f = request.s_follow[edge.first]; f >= 0) followed_edge.first = f; - if (int f = request.s_follow[edge.second]; f >= 0) followed_edge.second = f; + const auto& raw_edge = request.edges(edge_idx); + const std::pair edge(raw_edge.first(), raw_edge.second()); + auto followed_edge = edge; + if (int f = request.s_follow(edge.first); f >= 0) followed_edge.first = f; + if (int f = request.s_follow(edge.second); f >= 0) followed_edge.second = f; if (const auto& it = edge_map.find(followed_edge); it != edge_map.end()) { e[edge_idx] = e[it->second]; // Copy variable of followed edge e_follow[edge_idx] = it->second; continue; } solver->MakeBoolVarArray( - request.s_len[edge.first] * request.s_len[edge.second], + request.s_len(edge.first) * request.s_len(edge.second), absl::StrCat("e[", edge.first, ",", edge.second, "]"), &e[edge_idx]); edge_map.insert({followed_edge, edge_idx}); } - if (request.memory_budget > 0 && request.overbudget_coeff) { + if (request.memory_budget() > 0 && request.has_overbudget_coeff()) { overbudget_var = solver->MakeNumVar(0.0, MPSolver::infinity(), "overbudget"); } - if (request.makespan_coeff) { + if (request.has_makespan_coeff()) { makespan_var = CreateMakespanVar(request, e, *solver); } // Objective // Node costs - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double accumulated_coefficient = solver->Objective().GetCoefficient(s[node_idx][j]); - double coefficient = request.c[node_idx][j] + request.d[node_idx][j]; - AddSalt(absl::StrCat(node_idx, "S", j), request.saltiplier, &coefficient); + double coefficient = request.computation_costs(node_idx).costs(j) + + request.communication_costs(node_idx).costs(j); + AddSalt(absl::StrCat(node_idx, "S", j), request.saltiplier(), + &coefficient); solver->MutableObjective()->SetCoefficient( s[node_idx][j], accumulated_coefficient + coefficient); } @@ -288,8 +292,9 @@ AutoShardingSolverResult CallORToolsSolver( for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { double accumulated_coefficient = solver->Objective().GetCoefficient(e[edge_idx][j]); - double coefficient = request.r[edge_idx][j]; - AddSalt(absl::StrCat(edge_idx, "E", j), request.saltiplier, &coefficient); + double coefficient = request.resharding_costs(edge_idx).costs(j); + AddSalt(absl::StrCat(edge_idx, "E", j), request.saltiplier(), + &coefficient); solver->MutableObjective()->SetCoefficient( e[edge_idx][j], accumulated_coefficient + coefficient); } @@ -299,8 +304,8 @@ AutoShardingSolverResult CallORToolsSolver( // 0. Do not choose solutions with infinity costs, as it will make the // objective value so large that other solution choices do not matter anymore. // Remove these constraints once b/238210866 is done. - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (s[node_idx].empty() || request.s_follow[node_idx] >= 0) continue; + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (s[node_idx].empty() || request.s_follow(node_idx) >= 0) continue; bool all_infinity = true; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { if (solver->Objective().GetCoefficient(s[node_idx][j]) >= kInfinityCost) { @@ -331,10 +336,10 @@ AutoShardingSolverResult CallORToolsSolver( } } if (all_infinity) { - auto err_msg = - absl::StrCat("All of e[", request.e[edge_idx].first, "][", - request.e[edge_idx].second, "][*] have infinity costs"); - if (request.crash_at_infinity_costs_check) { + auto err_msg = absl::StrCat("All of e[", request.edges(edge_idx).first(), + "][", request.edges(edge_idx).second(), + "][*] have infinity costs"); + if (request.crash_at_infinity_costs_check()) { LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; @@ -345,8 +350,8 @@ AutoShardingSolverResult CallORToolsSolver( // a. specified via "BoolVarArray" // b. - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (request.s_follow[node_idx] >= 0) continue; + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (request.s_follow(node_idx) >= 0) continue; MPConstraint* constraint = solver->MakeRowConstraint( 1.0, 1.0, absl::StrCat("sum(s[", node_idx, "][j] for j = [0 .. ", @@ -356,48 +361,47 @@ AutoShardingSolverResult CallORToolsSolver( } } // c. - if (request.memory_budget > 0) { + if (request.memory_budget() > 0) { const double minimum_memory_budget_required_estimate = MinimumMemoryBudgetRequired(request); const double minimum_memory_overbudget = std::max( - 0.0, minimum_memory_budget_required_estimate - request.memory_budget); - for (LivenessIdx time_idx = 0; time_idx < request.live.size(); ++time_idx) { - const std::string str = - absl::StrCat("[", absl::StrJoin(request.live[time_idx], ", "), "]"); - double upper_bound = request.memory_budget; + 0.0, minimum_memory_budget_required_estimate - request.memory_budget()); + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { + double upper_bound = request.memory_budget(); if (overbudget_var) upper_bound += minimum_memory_overbudget; - MPConstraint* constraint = solver->MakeRowConstraint( - -MPSolver::infinity(), upper_bound, - absl::StrCat("mem[", time_idx, "] = ", str)); + MPConstraint* constraint = + solver->MakeRowConstraint(-MPSolver::infinity(), upper_bound, + absl::StrCat("mem[", time_idx, "]")); if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); - for (NodeIdx node_idx : request.live[time_idx]) { + for (NodeIdx node_idx : request.live(time_idx).nodes()) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { const double accumulated_coefficient = constraint->GetCoefficient(s[node_idx][j]); - constraint->SetCoefficient( - s[node_idx][j], accumulated_coefficient + request.m[node_idx][j]); + const double memory_cost = request.memory_costs(node_idx).costs(j); + constraint->SetCoefficient(s[node_idx][j], + accumulated_coefficient + memory_cost); } } } if (overbudget_var) { - solver->MutableObjective()->SetCoefficient(overbudget_var, - *request.overbudget_coeff); - solver->MutableObjective()->SetOffset(*request.overbudget_coeff * + solver->MutableObjective()->SetCoefficient( + overbudget_var, request.overbudget_coeff().coeff()); + solver->MutableObjective()->SetOffset(request.overbudget_coeff().coeff() * minimum_memory_overbudget); } LOG(INFO) << "Minimum memory budget estimate: " << minimum_memory_budget_required_estimate; - LOG(INFO) << "Using memory budget: " << request.memory_budget; + LOG(INFO) << "Using memory budget: " << request.memory_budget(); } // d. specified via "BoolVarArray" // e. for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { if (e_follow[edge_idx] >= 0) continue; - const std::pair& edge = request.e[edge_idx]; + const auto& edge = request.edges(edge_idx); MPConstraint* constraint = solver->MakeRowConstraint( 1.0, 1.0, - absl::StrCat("sum(e[", edge.first, "][", edge.second, "][*]) = 1")); + absl::StrCat("sum(e[", edge.first(), "][", edge.second(), "][*]) = 1")); for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { constraint->SetCoefficient(e[edge_idx][j], 1.0); } @@ -405,14 +409,14 @@ AutoShardingSolverResult CallORToolsSolver( // f. for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { if (e_follow[edge_idx] >= 0) continue; - const std::pair& edge = request.e[edge_idx]; - for (NodeStrategyIdx p = 0; p < s[edge.first].size(); ++p) { + const auto& edge = request.edges(edge_idx); + for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { MPConstraint* constraint = solver->MakeRowConstraint( -MPSolver::infinity(), 0, absl::StrCat("f for i = ", edge_idx, ", p = ", p)); - constraint->SetCoefficient(s[edge.first][p], -1.0); - for (NodeStrategyIdx q = 0; q < s[edge.second].size(); ++q) { - constraint->SetCoefficient(e[edge_idx][p * s[edge.second].size() + q], + constraint->SetCoefficient(s[edge.first()][p], -1.0); + for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { + constraint->SetCoefficient(e[edge_idx][p * s[edge.second()].size() + q], 1.0); } } @@ -420,55 +424,57 @@ AutoShardingSolverResult CallORToolsSolver( // g. for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { if (e_follow[edge_idx] >= 0) continue; - const std::pair& edge = request.e[edge_idx]; - for (NodeStrategyIdx q = 0; q < s[edge.second].size(); ++q) { + const auto& edge = request.edges(edge_idx); + for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { MPConstraint* constraint = solver->MakeRowConstraint( -MPSolver::infinity(), 0, absl::StrCat("g for i = ", edge_idx, ", q = ", q)); - constraint->SetCoefficient(s[edge.second][q], -1.0); - for (NodeStrategyIdx p = 0; p < s[edge.first].size(); ++p) { - constraint->SetCoefficient(e[edge_idx][p * s[edge.second].size() + q], + constraint->SetCoefficient(s[edge.second()][q], -1.0); + for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { + constraint->SetCoefficient(e[edge_idx][p * s[edge.second()].size() + q], 1.0); } } } // h. - for (AliasIdx alias_idx = 0; alias_idx < request.a.size(); ++alias_idx) { - const std::pair& alias = request.a[alias_idx]; - for (NodeStrategyIdx p = 0; p < s[alias.first].size(); ++p) { - for (NodeStrategyIdx q = 0; q < s[alias.second].size(); ++q) { + for (auto alias_idx = 0; alias_idx < request.aliases_size(); ++alias_idx) { + const auto& alias = request.aliases(alias_idx); + const auto& value_costs = request.value_costs(alias_idx).costs(); + for (NodeStrategyIdx p = 0; p < s[alias.first()].size(); ++p) { + for (NodeStrategyIdx q = 0; q < s[alias.second()].size(); ++q) { // if lhs == 1 - if (request.v[alias_idx][p * s[alias.second].size() + q] > 0.5) { + if (value_costs[p * s[alias.second()].size() + q] > 0.5) { MPConstraint* constraint = solver->MakeRowConstraint( -MPSolver::infinity(), 1, - absl::StrCat("s[", alias.first, "][", p, "] + s[", alias.second, - "][", q, "] <= 1")); - constraint->SetCoefficient(s[alias.first][p], 1.0); - constraint->SetCoefficient(s[alias.second][q], 1.0); + absl::StrCat("s[", alias.first(), "][", p, "] + s[", + alias.second(), "][", q, "] <= 1")); + constraint->SetCoefficient(s[alias.first()][p], 1.0); + constraint->SetCoefficient(s[alias.second()][q], 1.0); } } } } - if (request.max_departures) { + if (request.has_max_departures()) { MPConstraint* constraint = solver->MakeRowConstraint( - 0, *request.max_departures, - absl::StrCat("departures <= ", *request.max_departures)); - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + 0, request.max_departures().coeff(), + absl::StrCat("departures <= ", request.max_departures().coeff())); + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double accumulated_coefficient = constraint->GetCoefficient(s[node_idx][j]); - constraint->SetCoefficient( - s[node_idx][j], accumulated_coefficient + request.p[node_idx][j]); + double departure_cost = request.departure_costs(node_idx).costs(j); + constraint->SetCoefficient(s[node_idx][j], + accumulated_coefficient + departure_cost); } } } - if (!request.s_hint.empty()) { + if (!request.s_hint().empty()) { std::vector> hint; - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - if (request.s_follow[node_idx] >= 0) continue; + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + if (request.s_follow(node_idx) >= 0) continue; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - double hint_val = (request.s_hint[node_idx] == j) ? 1.0 : 0.0; + double hint_val = (request.s_hint(node_idx) == j) ? 1.0 : 0.0; hint.push_back({s[node_idx][j], hint_val}); } } @@ -490,8 +496,9 @@ AutoShardingSolverResult CallORToolsSolver( } } #endif - if (request.solver_timeout_in_seconds) { - solver->SetTimeLimit(absl::Seconds(*request.solver_timeout_in_seconds)); + if (request.has_solver_timeout()) { + solver->SetTimeLimit( + absl::Seconds(request.solver_timeout().solver_timeout_in_seconds())); } VLOG(0) << "Starting solver " << solver->ProblemType() << "\n" << "Solver parameter string: " << solver_parameter_str << "\n" @@ -500,8 +507,8 @@ AutoShardingSolverResult CallORToolsSolver( << "Time limit: " << solver->time_limit() << "\n" << "Number variables for ILP: " << solver->NumVariables() << "\n" << "Total vector of variables: " << var_vector_cnt << "\n" - << "Total instructions: " << request.num_nodes << "\n" - << "Memory budget: " << request.memory_budget / (1024 * 1024 * 1024) + << "Total instructions: " << request.num_nodes() << "\n" + << "Memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) << "GB\n" << "Number of ILP constraints: " << solver->NumConstraints(); return SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, @@ -518,7 +525,7 @@ AutoShardingSolverResult SolveAndExtractSolution( if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; #ifdef PLATFORM_GOOGLE - if (request.compute_iis) { + if (request.compute_iis()) { operations_research::MPModelRequest model_request; solver.ExportModelToProto(model_request.mutable_model()); if (solver.ProblemType() == @@ -586,16 +593,17 @@ AutoShardingSolverResult SolveAndExtractSolution( } // Return value - size_t num_edges = request.e.size(); + size_t num_edges = request.edges_size(); double unsalted_objective = 0.0; - std::vector chosen_strategy(request.num_nodes, -1); + std::vector chosen_strategy(request.num_nodes(), -1); std::vector e_val(num_edges, -1); - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { // if lhs == 1 if (s[node_idx][j]->solution_value() > 0.5) { chosen_strategy[node_idx] = j; - unsalted_objective += request.c[node_idx][j] + request.d[node_idx][j]; + unsalted_objective += request.computation_costs(node_idx).costs(j) + + request.communication_costs(node_idx).costs(j); break; } } @@ -605,31 +613,30 @@ AutoShardingSolverResult SolveAndExtractSolution( // if lhs == 1 if (e[edge_idx][j]->solution_value() > 0.5) { e_val[edge_idx] = j; - unsalted_objective += request.r[edge_idx][j]; + unsalted_objective += request.resharding_costs(edge_idx).costs(j); break; } } } if (overbudget_var) { unsalted_objective += - *request.overbudget_coeff * overbudget_var->solution_value(); + request.overbudget_coeff().coeff() * overbudget_var->solution_value(); unsalted_objective += solver.Objective().offset(); } if (makespan_var) { unsalted_objective += - *request.makespan_coeff * makespan_var->solution_value(); + request.makespan_coeff().coeff() * makespan_var->solution_value(); } LOG(INFO) << "Unsalted objective value: " << unsalted_objective; - LOG(INFO) << "N = " << request.num_nodes; - if (request.memory_budget < 0) { + LOG(INFO) << "N = " << request.num_nodes(); + if (request.memory_budget() < 0) { LOG(INFO) << "memory budget: -1"; } else { LOG(INFO) << "memory budget: " - << request.memory_budget / (1024 * 1024 * 1024) << " GB"; + << request.memory_budget() / (1024 * 1024 * 1024) << " GB"; } - PrintLargestInstructions(chosen_strategy, request.m, request.live, - request.instruction_names); + PrintLargestInstructions(chosen_strategy, request); return AutoShardingSolverResult( std::make_tuple(std::move(chosen_strategy), std::move(e_val), unsalted_objective), @@ -658,85 +665,90 @@ bool AutoShardingEvaluation::operator==( AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const AutoShardingSolverResult& result) { - const std::vector>& c = request.c; - const std::vector>& d = request.d; - const std::vector>& r = request.r; + const auto& c = request.computation_costs(); + const auto& d = request.communication_costs(); + const auto& r = request.resharding_costs(); + const auto& v = request.value_costs(); + const auto& p = request.departure_costs(); const std::vector& s_val = std::get<0>(*result.status); const std::vector& e_val = std::get<1>(*result.status); AutoShardingEvaluation evaluation; // Compute violations. - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - NodeIdx s_follow = request.s_follow[node_idx]; + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + NodeIdx s_follow = request.s_follow(node_idx); if (s_follow >= 0 && s_val[node_idx] != s_val[s_follow]) { evaluation.violation_codes.insert(kFollowerViolationCode); } } - for (AliasIdx alias_idx = 0; alias_idx < request.a.size(); ++alias_idx) { - const std::pair& alias = request.a[alias_idx]; - NodeStrategyIdx p = s_val[alias.first], q = s_val[alias.second]; - if (request.v[alias_idx][p * request.s_len[alias.second] + q] > 0.5) { + for (auto alias_idx = 0; alias_idx < request.aliases_size(); ++alias_idx) { + const auto& alias = request.aliases(alias_idx); + NodeStrategyIdx p = s_val[alias.first()], q = s_val[alias.second()]; + if (v.at(alias_idx).costs(p * request.s_len(alias.second()) + q) > 0.5) { evaluation.violation_codes.insert(kAliasViolationCode); } } - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { NodeStrategyIdx strat_idx = s_val[node_idx]; - if (c[node_idx][strat_idx] + d[node_idx][strat_idx] >= kInfinityCost) { + const double node_cost = + c.at(node_idx).costs(strat_idx) + d.at(node_idx).costs(strat_idx); + if (node_cost >= kInfinityCost) { evaluation.violation_codes.insert(kInfiniteCostViolationCode); } } - for (EdgeIdx edge_idx = 0; edge_idx < request.e.size(); ++edge_idx) { - if (request.r[edge_idx][e_val[edge_idx]] >= kInfinityCost) { + for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { + if (r.at(edge_idx).costs(e_val[edge_idx]) >= kInfinityCost) { evaluation.violation_codes.insert(kInfiniteCostViolationCode); } } - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - evaluation.total_departures += request.p[node_idx][s_val[node_idx]]; - if (request.max_departures && - evaluation.total_departures > *request.max_departures) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + evaluation.total_departures += p.at(node_idx).costs(s_val[node_idx]); + if (request.has_max_departures() && + evaluation.total_departures > request.max_departures().coeff()) { evaluation.violation_codes.insert(kMaxDeparturesViolationCode); } } - if (request.memory_budget > 0) { + if (request.memory_budget() > 0) { double total_overbudget = 0.0; double lower_bound_overbudget = 0.0; - for (LivenessIdx time_idx = 0; time_idx < request.live.size(); ++time_idx) { + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { double total_memory_cost = 0.0; double lower_bound_memory_cost = 0.0; - for (NodeIdx node_idx : request.live[time_idx]) { - const std::vector& m = request.m[node_idx]; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + const auto& m = request.memory_costs(node_idx).costs(); total_memory_cost += m[s_val[node_idx]]; lower_bound_memory_cost += *std::min_element(m.begin(), m.end()); } - if (request.overbudget_coeff) { - total_overbudget = std::max(total_overbudget, - total_memory_cost - request.memory_budget); + if (request.has_overbudget_coeff()) { + total_overbudget = std::max( + total_overbudget, total_memory_cost - request.memory_budget()); lower_bound_overbudget = std::max(lower_bound_overbudget, - lower_bound_memory_cost - request.memory_budget); - } else if (total_memory_cost > request.memory_budget) { + lower_bound_memory_cost - request.memory_budget()); + } else if (total_memory_cost > request.memory_budget()) { evaluation.violation_codes.insert(kMemoryViolationCode); } } - if (request.overbudget_coeff) { + if (request.has_overbudget_coeff()) { evaluation.total.overbudget_cost = - *request.overbudget_coeff * total_overbudget; + request.overbudget_coeff().coeff() * total_overbudget; evaluation.lower_bound.overbudget_cost = - *request.overbudget_coeff * lower_bound_overbudget; + request.overbudget_coeff().coeff() * lower_bound_overbudget; } } // Compute metrics & lower bounds. - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - evaluation.total.communication_cost += d[node_idx][s_val[node_idx]]; - evaluation.total.computation_cost += c[node_idx][s_val[node_idx]]; - evaluation.lower_bound.communication_cost += - *std::min_element(d[node_idx].begin(), d[node_idx].end()); - evaluation.lower_bound.computation_cost += - *std::min_element(c[node_idx].begin(), c[node_idx].end()); - } - for (EdgeIdx edge_idx = 0; edge_idx < request.e.size(); ++edge_idx) { - evaluation.total.resharding_cost += r[edge_idx][e_val[edge_idx]]; - evaluation.lower_bound.resharding_cost += - *std::min_element(r[edge_idx].begin(), r[edge_idx].end()); + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + evaluation.total.communication_cost += + d.at(node_idx).costs(s_val[node_idx]); + evaluation.total.computation_cost += c.at(node_idx).costs(s_val[node_idx]); + evaluation.lower_bound.communication_cost += *std::min_element( + d.at(node_idx).costs().begin(), d.at(node_idx).costs().end()); + evaluation.lower_bound.computation_cost += *std::min_element( + c.at(node_idx).costs().begin(), c.at(node_idx).costs().end()); + } + for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { + evaluation.total.resharding_cost += r.at(edge_idx).costs(e_val[edge_idx]); + evaluation.lower_bound.resharding_cost += *std::min_element( + r.at(edge_idx).costs().begin(), r.at(edge_idx).costs().end()); } evaluation.total_makespan = EvaluateMakespan(request, result, evaluation); return evaluation; @@ -746,23 +758,25 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, const AutoShardingSolverResult& result, const AutoShardingSolverResult& subopt) { std::vector rationales; - const std::vector& names = request.instruction_names; + const auto& names = request.instruction_names(); const std::vector& s_result = std::get<0>(*result.status); const std::vector& s_subopt = std::get<0>(*subopt.status); - for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const NodeStrategyIdx j = s_result[node_idx], k = s_subopt[node_idx]; if (j != k) { rationales.push_back(absl::StrCat( "strategy changes for ", names[node_idx], " (", j, " -> ", k, ")")); } - const double dj = request.d[node_idx][j], dk = request.d[node_idx][k]; + const double dj = request.communication_costs(node_idx).costs(j); + const double dk = request.communication_costs(node_idx).costs(k); if (dj < dk) { rationales.push_back(absl::StrCat("communication cost increases for ", names[node_idx], " (", dj, " -> ", dk, ")")); } - const double cj = request.c[node_idx][j], ck = request.c[node_idx][k]; + const double cj = request.computation_costs(node_idx).costs(j); + const double ck = request.computation_costs(node_idx).costs(k); if (cj < ck) { rationales.push_back(absl::StrCat("computation cost increases for ", names[node_idx], " (", cj, " -> ", ck, @@ -772,13 +786,14 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, const std::vector& e_result = std::get<1>(*result.status); const std::vector& e_subopt = std::get<1>(*subopt.status); - for (EdgeIdx edge_idx = 0; edge_idx < request.e.size(); ++edge_idx) { - const std::pair& edge = request.e[edge_idx]; + for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { + const auto& edge = request.edges(edge_idx); const EdgeStrategyIdx j = e_result[edge_idx], k = e_subopt[edge_idx]; - const double rj = request.r[edge_idx][j], rk = request.r[edge_idx][k]; + const double rj = request.resharding_costs(edge_idx).costs(j); + const double rk = request.resharding_costs(edge_idx).costs(k); if (rj < rk) { const std::string edge_name = - absl::StrCat(names[edge.first], " and ", names[edge.second]); + absl::StrCat(names[edge.first()], " and ", names[edge.second()]); rationales.push_back(absl::StrCat("resharding cost increases for ", edge_name, " (", rj, " -> ", rk, ")")); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 28477c7b5e7971..9bfa64a1149901 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/statusor.h" #include "ortools/linear_solver/linear_solver.h" @@ -34,32 +35,6 @@ using MPVariable = operations_research::MPVariable; namespace xla { namespace spmd { -struct AutoShardingSolverRequest { - int64_t num_nodes = 0; - int64_t memory_budget = -1; - std::vector s_len; - std::vector s_follow; - std::vector s_hint; - std::vector> e; - std::vector> live; - std::vector> c; - std::vector> d; - std::vector> m; - std::vector> p; - std::vector> r; - std::vector> t; - std::vector> a; - std::vector> v; - std::vector instruction_names; - std::optional solver_timeout_in_seconds; - std::optional overbudget_coeff = 1e6; - std::optional makespan_coeff; - std::optional max_departures; - bool crash_at_infinity_costs_check = false; - bool compute_iis = true; - double saltiplier = 0.001; // Modifies each objective term by at most 0.1% -}; - struct AutoShardingSolverResult { public: AutoShardingSolverResult( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 0b1ddcabcc7acf..f155be5c12d8bd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include "ortools/linear_solver/linear_solver.h" diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 41e1c450e93bdc..1c54cd7aae11e8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -12,6 +12,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" +#include #include #include #include @@ -19,67 +20,112 @@ limitations under the License. #include #include +#include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" namespace xla { namespace spmd { namespace { +using CostMatrix = std::vector>; +using NodeMatrix = std::vector>; + +void AddCosts(proto2::RepeatedPtrField* costs, + const CostMatrix& cost_matrix) { + for (const auto& cost_row : cost_matrix) { + AutoShardingSolverRequest_Costs cost; + cost.mutable_costs()->Add(cost_row.begin(), cost_row.end()); + costs->Add(std::move(cost)); + } +} + +void AddNodes(proto2::RepeatedPtrField* nodes, + const NodeMatrix& node_matrix) { + for (const auto& node_row : node_matrix) { + AutoShardingSolverRequest_Nodes node; + node.mutable_nodes()->Add(node_row.begin(), node_row.end()); + nodes->Add(std::move(node)); + } +} + // clang-format off AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { - AutoShardingSolverRequest request; // The problem below is partially inspired by 'DotLHSTwoNonContractingDims' - request.num_nodes = 5; - request.memory_budget = 1500000; - request.s_len = {4, 3, 4, 4, 3}; - request.s_follow = {-1, -1, -1, 2, -1}; - request.e = {{0, 2}, {1, 2}}; - request.live = {{1, 0}, - {1, 0}, - {1, 2, 0}, - {1, 2, 3, 0}, - {1, 3, 0}}; - request.c = {{10, 11, 12, 13}, - {20, 21, 22}, - {30, 31, 32, 33}, - {40, 41, 42, 43}, - {50, 51, 52, 53}}; - request.d = {{100, 110, 120, 130}, - {200, 210, 220}, - {300, 310, 320, 330}, - {400, 410, 420, 430}, - {500, 510, 520}}; - request.m = {{100000, 110000, 990000, 130000}, - {200000, 210000, 220000}, - {300000, 310000, 320000, 330000}, - {400000, 410000, 420000, 430000}, - {500000, 510000, 520000}}; - request.p = {{1.0, 0.0, 1.0, 1.0}, - {1.0, 0.0, 1.0}, - {1.0, 0.0, 1.0, 1.0}, - {1.0, 0.0, 1.0, 1.0}, - {1.0, 0.0, 1.0}}; - request.r = {{1000, 1100, 1200, 1300, - 2000, 2100, 2200, 2300, - 3000, 3100, 3200, 3300, - 4000, 4100, 4200, 4300}, - {5000, 5100, 5200, 5300, - 6000, 6100, 6200, 6300, - 7000, 7100, 7200, 7300}}; - request.t = {{73000, 72000, 71000, 70000, - 63000, 62000, 61000, 60000, - 53000, 52000, 51000, 50000, - 43000, 42000, 41000, 40000}, - {33000, 32000, 31000, 30000, - 23000, 22000, 21000, 20000, - 13000, 12000, 11000, 10000}}; - request.a = {{1, 4}}; - request.v = {{0, 1, 1, - 1, 0, 1, - 1, 1, 0}}; - request.instruction_names = {"A", "B", "C", "D", "E"}; - request.overbudget_coeff = std::nullopt; + const auto s_len = {4, 3, 4, 4, 3}; + const auto s_follow = {-1, -1, -1, 2, -1}; + AutoShardingSolverRequest_Pair edge1, edge2; + edge1.set_first(0); + edge1.set_second(2); + edge2.set_first(1); + edge2.set_second(2); + const auto edges = {edge1, edge2}; + const NodeMatrix live = {{1, 0}, + {1, 0}, + {1, 2, 0}, + {1, 2, 3, 0}, + {1, 3, 0}}; + const CostMatrix c = {{10, 11, 12, 13}, + {20, 21, 22}, + {30, 31, 32, 33}, + {40, 41, 42, 43}, + {50, 51, 52, 53}}; + const CostMatrix d = {{100, 110, 120, 130}, + {200, 210, 220}, + {300, 310, 320, 330}, + {400, 410, 420, 430}, + {500, 510, 520}}; + const CostMatrix m = {{100000, 110000, 990000, 130000}, + {200000, 210000, 220000}, + {300000, 310000, 320000, 330000}, + {400000, 410000, 420000, 430000}, + {500000, 510000, 520000}}; + const CostMatrix p = {{1.0, 0.0, 1.0, 1.0}, + {1.0, 0.0, 1.0}, + {1.0, 0.0, 1.0, 1.0}, + {1.0, 0.0, 1.0, 1.0}, + {1.0, 0.0, 1.0}}; + const CostMatrix r = {{1000, 1100, 1200, 1300, + 2000, 2100, 2200, 2300, + 3000, 3100, 3200, 3300, + 4000, 4100, 4200, 4300}, + {5000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 7300}}; + const CostMatrix t = {{73000, 72000, 71000, 70000, + 63000, 62000, 61000, 60000, + 53000, 52000, 51000, 50000, + 43000, 42000, 41000, 40000}, + {33000, 32000, 31000, 30000, + 23000, 22000, 21000, 20000, + 13000, 12000, 11000, 10000}}; + AutoShardingSolverRequest_Pair alias; + alias.set_first(1); + alias.set_second(4); + const auto aliases = {alias}; + const CostMatrix v = {{0, 1, 1, + 1, 0, 1, + 1, 1, 0}}; + const std::vector instruction_names = {"A", "B", "C", "D", "E"}; + + AutoShardingSolverRequest request; + request.set_num_nodes(5); + request.set_memory_budget(1500000); + request.mutable_s_len()->Add(s_len.begin(), s_len.end()); + request.mutable_s_follow()->Add(s_follow.begin(), s_follow.end()); + request.mutable_edges()->Add(edges.begin(), edges.end()); + AddNodes(request.mutable_live(), live); + AddCosts(request.mutable_computation_costs(), c); + AddCosts(request.mutable_communication_costs(), d); + AddCosts(request.mutable_memory_costs(), m); + AddCosts(request.mutable_departure_costs(), p); + AddCosts(request.mutable_resharding_costs(), r); + AddCosts(request.mutable_duration_costs(), t); + request.mutable_aliases()->Add(aliases.begin(), aliases.end()); + AddCosts(request.mutable_value_costs(), v); + request.mutable_instruction_names()->Add(instruction_names.begin(), + instruction_names.end()); + AddCosts(request.mutable_computation_costs(), c); return request; } @@ -99,8 +145,8 @@ TEST(CallORToolsSolverTest, SolvesOptimally) { TEST(CallORToolsSolverTest, SolvesOverbudget) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.memory_budget = 100000; - request.overbudget_coeff = 10.0; + request.set_memory_budget(100000); + request.mutable_overbudget_coeff()->set_coeff(10.0); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -115,7 +161,7 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) { TEST(CallORToolsSolverTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.max_departures = 3.0; + request.mutable_max_departures()->set_coeff(3.0); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -130,7 +176,9 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) { TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.c[0][0] = request.c[0][1] = request.c[0][2] = kInfinityCost; + request.mutable_computation_costs(0)->set_costs(0, kInfinityCost); + request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); + request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -145,7 +193,7 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.r[0][0] = kInfinityCost; + request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -160,10 +208,15 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { TEST(CallORToolsSolverTest, HandlesFollowedEdges) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.e.push_back({1, 3}); // Reduces to {1, 2} since node 3 follows node 2 - request.r.push_back({5000, 5100, 5200, 5300, - 6000, 6100, 6200, 6300, - 7000, 7100, 7200, 7300}); + AutoShardingSolverRequest_Pair edge; + edge.set_first(1); + edge.set_second(3); + // Reduces to {1, 2} since node 3 follows node 2 + *request.mutable_edges()->Add() = edge; + const CostMatrix r = {{5000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 7300}}; + AddCosts(request.mutable_resharding_costs(), r); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -178,7 +231,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { TEST(CallORToolsSolverTest, UsesHint) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. + const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. + request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); const AutoShardingSolverResult result = CallORToolsSolver(request); @@ -215,8 +269,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.memory_budget = 100000; - request.overbudget_coeff = 10.0; + request.set_memory_budget(100000); + request.mutable_overbudget_coeff()->set_coeff(10.0); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const std::vector e_val = {10, 6}; const double objective_value = 11138.0; @@ -310,7 +364,9 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.c[0][0] = request.c[0][1] = request.c[0][2] = kInfinityCost; + request.mutable_computation_costs(0)->set_costs(0, kInfinityCost); + request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); + request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; const std::vector e_val = {2, 6}; const double objective_value = 1e+20; @@ -334,7 +390,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.r[0][2] = kInfinityCost; + request.mutable_resharding_costs(0)->set_costs(2, kInfinityCost); const std::vector s_val = {0, 1, 2, 2, 1}; const std::vector e_val = {2 /* violates */, 6}; const double objective_value = 1e+20; @@ -358,7 +414,7 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - request.max_departures = 2.0; + request.mutable_max_departures()->set_coeff(2.0); const std::vector s_val = {3, 1, 2, 2, 1}; const std::vector e_val = {14, 6}; const double objective_value = 12149.0; From 0d829d7ea0e710716e7d3c41d6faa3616bd5dba6 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 28 Nov 2023 19:54:40 -0800 Subject: [PATCH 165/381] Integrate StableHLO at openxla/stablehlo@83f095e7 PiperOrigin-RevId: 586173556 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 80ab0e479b0ca8..a60b5db8e74b5d 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "e3276cd896751bfbebd7b18112f81547fbc2bc9c" - STABLEHLO_SHA256 = "948d0265a9ea4214ecfa854564793b08161d693e9cfbd73686f2df9e38034ada" + STABLEHLO_COMMIT = "83f095e7217c897f1eccac5652600ceb944cb0e0" + STABLEHLO_SHA256 = "00e442f7e9c8a52a1ac774ce997f8b5a99d12450c4dfe1594df816dcbad5126f" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 80ab0e479b0ca8..a60b5db8e74b5d 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "e3276cd896751bfbebd7b18112f81547fbc2bc9c" - STABLEHLO_SHA256 = "948d0265a9ea4214ecfa854564793b08161d693e9cfbd73686f2df9e38034ada" + STABLEHLO_COMMIT = "83f095e7217c897f1eccac5652600ceb944cb0e0" + STABLEHLO_SHA256 = "00e442f7e9c8a52a1ac774ce997f8b5a99d12450c4dfe1594df816dcbad5126f" # LINT.ThenChange(Google-internal path) tf_http_archive( From f0240adedd95a24380545fbf44d5ed5cb93c5b49 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 28 Nov 2023 19:56:07 -0800 Subject: [PATCH 166/381] Re-factor macOS CI build environment setup Our Mac builds require some specific build environment setup such as installing Bazelisk, upgrading Pyenv, installing Python, etc. Since these scripts are meant to be run by both internal CI builds and external users, we re-work some conditional logic that were previously only meant to run for internal CI builds. These will now instead use the `TFCI_*_ENABLE` variables. This makes the conditionals from being possibly confusing system checks in scripts to explicit settings in "envs" files and allows both internal CI builds and external users to decide if they want to enable or disable a particular macOS build environment setup task. PiperOrigin-RevId: 586173730 --- ci/official/envs/ci_default | 7 ++- ci/official/envs/continuous_macos_arm64_py310 | 1 + ci/official/envs/continuous_macos_arm64_py311 | 1 + ci/official/envs/continuous_macos_arm64_py39 | 1 + ci/official/envs/nightly_macos_arm64_py310 | 3 +- ci/official/envs/nightly_macos_arm64_py311 | 2 +- ci/official/envs/nightly_macos_arm64_py312 | 3 +- ci/official/envs/nightly_macos_arm64_py39 | 3 +- ci/official/utilities/setup_macos.sh | 61 +++++++++++++------ 9 files changed, 57 insertions(+), 25 deletions(-) diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index 890069df93c982..a8e212ff47f005 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -30,4 +30,9 @@ TFCI_WHL_AUDIT_PLAT= TFCI_WHL_BAZEL_TEST_ENABLE=1 TFCI_WHL_SIZE_LIMIT= TFCI_WHL_SIZE_LIMIT_ENABLE=1 -TFCI_PYENV_INSTALL_LOCAL_ENABLE= +TFCI_MACOS_UPGRADE_PYENV_ENABLE= +TFCI_MACOS_INSTALL_BAZELISK_ENABLE= +TFCI_MACOS_INSTALL_BAZELISK_URL= +TFCI_MACOS_PYENV_INSTALL_ENABLE= +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE= +TFCI_MACOS_BAZEL_TEST_DIR_PATH= diff --git a/ci/official/envs/continuous_macos_arm64_py310 b/ci/official/envs/continuous_macos_arm64_py310 index 42b304729bc113..f967c768fdd473 100644 --- a/ci/official/envs/continuous_macos_arm64_py310 +++ b/ci/official/envs/continuous_macos_arm64_py310 @@ -2,3 +2,4 @@ TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_ca TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.10 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/continuous_macos_arm64_py311 b/ci/official/envs/continuous_macos_arm64_py311 index 8a47d7ef15c863..c6d06085fe3f44 100644 --- a/ci/official/envs/continuous_macos_arm64_py311 +++ b/ci/official/envs/continuous_macos_arm64_py311 @@ -2,3 +2,4 @@ TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_ca TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.11 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/continuous_macos_arm64_py39 b/ci/official/envs/continuous_macos_arm64_py39 index da892742753048..18275771093672 100644 --- a/ci/official/envs/continuous_macos_arm64_py39 +++ b/ci/official/envs/continuous_macos_arm64_py39 @@ -2,3 +2,4 @@ TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_ca TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.9 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310 index 2c902aef38764a..daa47b4db03d56 100644 --- a/ci/official/envs/nightly_macos_arm64_py310 +++ b/ci/official/envs/nightly_macos_arm64_py310 @@ -7,4 +7,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M -TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311 index 61eb97792ce620..5379f792fbcaf9 100644 --- a/ci/official/envs/nightly_macos_arm64_py311 +++ b/ci/official/envs/nightly_macos_arm64_py311 @@ -7,4 +7,4 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M -TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312 index 41371ad4befbef..d843f25b2a83e4 100644 --- a/ci/official/envs/nightly_macos_arm64_py312 +++ b/ci/official/envs/nightly_macos_arm64_py312 @@ -7,4 +7,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M -TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39 index 8f3c172065b974..1e259de5ca7fd3 100644 --- a/ci/official/envs/nightly_macos_arm64_py39 +++ b/ci/official/envs/nightly_macos_arm64_py39 @@ -9,4 +9,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M -TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index 61378bc4e75204..e91100b5e3d2ee 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -34,29 +34,50 @@ else exit 1 fi -if [[ -n "${KOKORO_JOB_NAME}" ]]; then - # Mac builds need ~150 GB of disk space to be able to run all the tests. By - # default, Kokoro runs the Bazel commands in a partition that does not have - # enough free space so we need to set TEST_TMPDIR explicitly. - mkdir -p /Volumes/BuildData/bazel_output - export TEST_TMPDIR=/Volumes/BuildData/bazel_output +# "TFCI_MACOS_BAZEL_TEST_DIR_PATH" specifies the directory that Bazel should use +# when running tests. Each test will be executed in a separate subdirectory +# inside this directory. TF Mac builds need ~150 GB of disk space to be able to +# run all the tests. Since TFCI Mac VMs execute Bazel test commands in a +# partition with insufficient storage, we specify the +# 'TFCI_MACOS_BAZEL_TEST_DIR_PATH' environment variable to point to a partition +# with ample storage. When this variable is empty (i.e by default), Bazel will +# use the output base directory to run tests. +if [[ "${TFCI_MACOS_BAZEL_TEST_DIR_ENABLE}"==1 ]]; then + mkdir -p "${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" + export TEST_TMPDIR="${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" +fi + +# "TFCI_MACOS_INSTALL_BAZELISK_ENABLE" is used to decide if we need to install +# Bazelisk manually. We enable this for macOS x86 builds as those VMs do not +# have Bazelisk pre-installed. "TFCI_MACOS_INSTALL_BAZELISK_URL" contains the +# link to the Bazelisk binary which needs to be downloaded. +if [[ "${TFCI_MACOS_INSTALL_BAZELISK_ENABLE}" == 1 ]]; then + sudo wget --no-verbose -O "/usr/local/bin/bazel" "${TFCI_MACOS_INSTALL_BAZELISK_URL}" + chmod +x "/usr/local/bin/bazel" +fi + +# "TFCI_MACOS_UPGRADE_PYENV_ENABLE" is used to decide if we need to upgrade the +# Pyenv version. We enable this for macOS x86 builds as the default Pyenv on +# those VMs does not support installing Python 3.10 and above which we need +# for running smoke tests in nightly/release wheel builds. +if [[ "${TFCI_MACOS_UPGRADE_PYENV_ENABLE}" == 1 ]]; then + brew upgrade pyenv +fi - # Before uploading the nightly and release wheels, we install them in a - # virtual environment and run some smoke tests on it. The Kokoro Mac VMs - # only have Python 3.11 installed so we need to install the other Python - # versions manually. - if [[ -n "${TFCI_BUILD_PIP_PACKAGE_ARGS}" ]] && [[ "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" != 3.11 ]]; then - pyenv install "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" - pyenv local "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" - fi -elif [[ "${TFCI_WHL_BAZEL_TEST_ENABLE}" == 1 ]]; then - echo '==TFCI==: Note: Mac builds need ~150 GB of disk space to be able to' - echo 'run all the tests. Please make sure your system has enough disk space' - echo 'You can control where Bazel stores test artifacts by setting the' - echo '`TEST_TMPDIR` environment variable.' +# "TFCI_MACOS_PYENV_INSTALL_ENABLE" controls whether to use Pyenv to install +# the Python version set in "TFCI_PYTHON_VERSION" and use it as default. +# We enable this in the nightly and release builds because before uploading the +# wheels, we install them in a virtual environment and run some smoke tests on +# it. TFCI Mac VMs only have one Python version installed so we need to install +# the other versions manually. +if [[ "${TFCI_MACOS_PYENV_INSTALL_ENABLE}" == 1 ]]; then + pyenv install "$TFCI_PYTHON_VERSION" + pyenv local "$TFCI_PYTHON_VERSION" + # Do a sanity check to make sure that we using the correct Python version + python --version fi -if [[ "${TFCI_PYTHON_VERSION}" == "3.12" ]]; then +if [[ "$TFCI_PYTHON_VERSION" == "3.12" ]]; then # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. # Having CMake allows building them. # Once the wheels are added, this should be removed - b/308399490. From e1dbfeba8acb1df8f42dfa6f76262f5cb23e1fa1 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 28 Nov 2023 20:20:03 -0800 Subject: [PATCH 167/381] [stream_executor] NFC: Guard new features with CUDA_VERSION check PiperOrigin-RevId: 586180704 --- third_party/xla/xla/stream_executor/cuda/cuda_driver.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 469043ba02b80b..a85188c37627c2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -902,12 +902,19 @@ static CUmemLocationType ToCudaLocationType( return CU_MEM_LOCATION_TYPE_INVALID; case GpuDriver::MemLocationType::kDevice: return CU_MEM_LOCATION_TYPE_DEVICE; +#if CUDA_VERSION >= 12000 case GpuDriver::MemLocationType::kHost: return CU_MEM_LOCATION_TYPE_HOST; case GpuDriver::MemLocationType::kHostNuma: return CU_MEM_LOCATION_TYPE_HOST_NUMA; case GpuDriver::MemLocationType::kHostNumaCurrent: return CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT; +#else + case GpuDriver::MemLocationType::kHost: + case GpuDriver::MemLocationType::kHostNuma: + case GpuDriver::MemLocationType::kHostNumaCurrent: + return CU_MEM_LOCATION_TYPE_INVALID; +#endif // CUDA_VERSION >= 12000 } } @@ -942,7 +949,9 @@ static CUmemAllocationType ToCudaAllocationType( mem_pool_props.allocType = ToCudaAllocationType(allocation_type); mem_pool_props.handleTypes = CU_MEM_HANDLE_TYPE_NONE; mem_pool_props.location = mem_location; +#if CUDA_VERSION >= 12000 mem_pool_props.maxSize = max_pool_size; +#endif // CUDA_VERSION >= 12000 params.accessDescCount = 1; params.bytesize = size; From f9a4e14fc4fbc253309729f8cc0159de23a72e10 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Tue, 28 Nov 2023 20:47:42 -0800 Subject: [PATCH 168/381] [XLA] Add support for while loop simplifier to remove duplicated "dynamic-update-slice" stored values. PiperOrigin-RevId: 586186698 --- .../xla/xla/service/while_loop_simplifier.cc | 81 ++++++++++++++++--- .../xla/service/while_loop_simplifier_test.cc | 59 ++++++++++++++ 2 files changed, 127 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/while_loop_simplifier.cc b/third_party/xla/xla/service/while_loop_simplifier.cc index d4f3e919f80ae1..bf0e5118c3d55f 100644 --- a/third_party/xla/xla/service/while_loop_simplifier.cc +++ b/third_party/xla/xla/service/while_loop_simplifier.cc @@ -125,10 +125,11 @@ void CopyFrontendAttributes(HloInstruction* old_while_op, // This is a utility function that removes the given tuple indices from the // while loop init, body, and condition. The final shape returned is still the -// same as before. +// same as before. If set index_for_replaced will replace any use of the removed +// indices in the final shape with a copy of the removed index. static StatusOr RemoveDeadTupleIndices( - HloInstruction* while_op, - absl::flat_hash_set& used_tuple_indices) { + HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices, + int64_t index_for_replaced = -1) { // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), used_tuple_indices.end()); @@ -274,8 +275,11 @@ static StatusOr RemoveDeadTupleIndices( const int64_t tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); for (int64_t old_idx = 0; old_idx < tuple_size; ++old_idx) { auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); - if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { - int64_t gte_idx = new_tuple_idx_it->second; + if (new_tuple_idx_it != old_to_new_tuple_idx.end() || + index_for_replaced != -1) { + int64_t gte_idx = new_tuple_idx_it != old_to_new_tuple_idx.end() + ? new_tuple_idx_it->second + : index_for_replaced; new_tuple_elems.push_back( computation->AddInstruction(HloInstruction::CreateGetTupleElement( new_while_op->shape().tuple_shapes(gte_idx), new_while_op, @@ -546,7 +550,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // duplicates by replacing them with tuple_index, followed by a call to // RemoveDeadTupleIndices. static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( - HloInstruction* while_op, const int64_t tuple_index, + HloInstruction* while_op, const int64_t tuple_index, bool replace_with_init, absl::flat_hash_set& duplicates) { HloComputation* while_cond = while_op->while_condition(); HloComputation* while_body = while_op->while_body(); @@ -586,14 +590,23 @@ static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( used_tuple_indices.insert(index); } } - // Remove the duplicate tuple elements. - TF_ASSIGN_OR_RETURN(while_op, - RemoveDeadTupleIndices(while_op, used_tuple_indices)); + TF_ASSIGN_OR_RETURN( + while_op, RemoveDeadTupleIndices(while_op, used_tuple_indices, + replace_with_init ? -1 : tuple_index)); return while_op; } +// Returns if this instruction looks like an insertion inside a variable of a +// while loop. +static bool IsDynamicUpdateSliceWhileInsertion( + const HloInstruction* instr, const HloComputation* while_body) { + return instr->opcode() == HloOpcode::kDynamicUpdateSlice && + instr->operand(0)->opcode() == HloOpcode::kGetTupleElement && + instr->operand(0)->operand(0) == while_body->parameter_instruction(0); +} + // If the while loop init passes the same values to several tuple indices, and // if the body keeps on passing them through, we can remove the duplicates. static StatusOr TryRemoveRepeatedWhileTupleIndices( @@ -638,6 +651,7 @@ static StatusOr TryRemoveRepeatedWhileTupleIndices( absl::flat_hash_set duplicates; auto* pivot_init_elem = while_init->operand(index_to_investigate); auto* pivot_body_elem = while_body_root->operand(index_to_investigate); + bool replace_with_init = true; if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement && pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) { if (pivot_body_elem->tuple_index() != index_to_investigate) { @@ -647,6 +661,16 @@ static StatusOr TryRemoveRepeatedWhileTupleIndices( index_to_investigate++; continue; } + } else if (IsDynamicUpdateSliceWhileInsertion(pivot_body_elem, + while_body)) { + if (pivot_body_elem->operand(0)->tuple_index() != index_to_investigate) { + VLOG(2) + << "Mismatch between pivot_body_elem->operand(0)->tuple_index() " + << pivot_body_elem->operand(0)->tuple_index() + << " index_to_investigate " << index_to_investigate; + index_to_investigate++; + continue; + } } else { index_to_investigate++; continue; @@ -657,13 +681,44 @@ static StatusOr TryRemoveRepeatedWhileTupleIndices( i < while_shape.tuple_shapes_size(); ++i) { auto* init_elem = while_init->operand(i); auto* body_elem = while_body_root->operand(i); - if (body_elem->opcode() == HloOpcode::kGetTupleElement && + if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement && + body_elem->opcode() == HloOpcode::kGetTupleElement && body_elem->operand(0) == while_body->parameter_instruction(0)) { if (body_elem->tuple_index() != i) { VLOG(2) << "Mismatch between body_elem->tuple_index() " << body_elem->tuple_index() << " i " << i; continue; } + } else if (IsDynamicUpdateSliceWhileInsertion(pivot_body_elem, + while_body) && + IsDynamicUpdateSliceWhileInsertion(body_elem, while_body)) { + if (pivot_body_elem->operand_count() != body_elem->operand_count()) { + VLOG(2) << "Mismatch in operand count of dynamic-update-slice " + << pivot_body_elem->operand_count() << " vs " + << body_elem->operand_count(); + continue; + } + if (body_elem->operand(0)->tuple_index() != i) { + VLOG(2) << "Mismatch between body_elem->operand(0)->tuple_index() " + << body_elem->tuple_index() << " i " << i; + continue; + } + if (pivot_body_elem->operand(0) == body_elem->operand(0)) { + VLOG(2) << "Inserting in the same input index"; + continue; + } + bool mismatch = false; + for (int64_t i = 1; i < body_elem->operand_count(); ++i) { + if (body_elem->operand(i) != pivot_body_elem->operand(i)) { + VLOG(2) << "Mismatch in insertion indices or values"; + mismatch = true; + break; + } + } + if (mismatch) { + continue; + } + replace_with_init = false; } else { continue; } @@ -681,9 +736,9 @@ static StatusOr TryRemoveRepeatedWhileTupleIndices( if (!duplicates.empty()) { VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init " << pivot_init_elem->ToString(); - TF_ASSIGN_OR_RETURN(while_op, - TryRemoveRepeatedWhileTupleIndicesHelper( - while_op, index_to_investigate, duplicates)); + TF_ASSIGN_OR_RETURN(while_op, TryRemoveRepeatedWhileTupleIndicesHelper( + while_op, index_to_investigate, + replace_with_init, duplicates)); changed = true; VLOG(2) << "Changed while_op " << while_op->ToString() << " while_op operand count " << while_op->operand_count(); diff --git a/third_party/xla/xla/service/while_loop_simplifier_test.cc b/third_party/xla/xla/service/while_loop_simplifier_test.cc index b3c157aaaee6bf..94d87e754815b4 100644 --- a/third_party/xla/xla/service/while_loop_simplifier_test.cc +++ b/third_party/xla/xla/service/while_loop_simplifier_test.cc @@ -1124,5 +1124,64 @@ TEST_F(WhileLoopSimplifierTest, NotRemoveCompare) { .value()); } +TEST_F(WhileLoopSimplifierTest, RemoveDynUpdSlice) { + const std::string hlo_string = R"( +HloModule jit_scan + +%region_0.6 (arg_tuple.7: (s32[], f32[], f32[3], f32[3])) -> (s32[], f32[], f32[3], f32[3]) { + %arg_tuple.7 = (s32[], f32[], f32[3]{0}, f32[3]{0}) parameter(0) + %get-tuple-element.8 = s32[] get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %arg_tuple.7), index=0 + %constant.12 = s32[] constant(1) + %add.28 = s32[] add(s32[] %get-tuple-element.8, s32[] %constant.12) + %get-tuple-element.9 = f32[] get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %arg_tuple.7), index=1 + %sine.15 = f32[] sine(f32[] %get-tuple-element.9) + %get-tuple-element.10 = f32[3]{0} get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %arg_tuple.7), index=2 + %cosine.16 = f32[] cosine(f32[] %get-tuple-element.9) + %reshape.18 = f32[1]{0} reshape(f32[] %cosine.16) + %constant.14 = s32[] constant(0) + %compare.19 = pred[] compare(s32[] %get-tuple-element.8, s32[] %constant.14), direction=LT + %constant.13 = s32[] constant(3) + %add.20 = s32[] add(s32[] %get-tuple-element.8, s32[] %constant.13) + %select.21 = s32[] select(pred[] %compare.19, s32[] %add.20, s32[] %get-tuple-element.8) + %dynamic-update-slice.22 = f32[3]{0} dynamic-update-slice(f32[3]{0} %get-tuple-element.10, f32[1]{0} %reshape.18, s32[] %select.21) + %get-tuple-element.11 = f32[3]{0} get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %arg_tuple.7), index=3 + %dynamic-update-slice.27 = f32[3]{0} dynamic-update-slice(f32[3]{0} %get-tuple-element.11, f32[1]{0} %reshape.18, s32[] %select.21) + ROOT %tuple.29 = (s32[], f32[], f32[3]{0}, f32[3]{0}) tuple(s32[] %add.28, f32[] %sine.15, f32[3]{0} %dynamic-update-slice.22, f32[3]{0} %dynamic-update-slice.27) +} + +%region_1.30 (arg_tuple.31: (s32[], f32[], f32[3], f32[3])) -> pred[] { + %arg_tuple.31 = (s32[], f32[], f32[3]{0}, f32[3]{0}) parameter(0) + %get-tuple-element.32 = s32[] get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %arg_tuple.31), index=0 + %constant.36 = s32[] constant(3) + ROOT %compare.37 = pred[] compare(s32[] %get-tuple-element.32, s32[] %constant.36), direction=LT +} + +ENTRY %main.44 (Arg_0.1: f32[]) -> (f32[], f32[3], f32[3]) { + %constant.4 = s32[] constant(0) + %Arg_0.1 = f32[] parameter(0), sharding={replicated} + %constant.2 = f32[] constant(0) + %broadcast.3 = f32[3]{0} broadcast(f32[] %constant.2), dimensions={} + %tuple.5 = (s32[], f32[], f32[3]{0}, f32[3]{0}) tuple(s32[] %constant.4, f32[] %Arg_0.1, f32[3]{0} %broadcast.3, f32[3]{0} %broadcast.3) + %while.38 = (s32[], f32[], f32[3]{0}, f32[3]{0}) while((s32[], f32[], f32[3]{0}, f32[3]{0}) %tuple.5), condition=%region_1.30, body=%region_0.6 + %get-tuple-element.40 = f32[] get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %while.38), index=1 + %get-tuple-element.41 = f32[3]{0} get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %while.38), index=2 + %get-tuple-element.42 = f32[3]{0} get-tuple-element((s32[], f32[], f32[3]{0}, f32[3]{0}) %while.38), index=3 + ROOT %tuple.43 = (f32[], f32[3]{0}, f32[3]{0}) tuple(f32[] %get-tuple-element.40, f32[3]{0} %get-tuple-element.41, f32[3]{0} %get-tuple-element.42) +})"; + auto m = ParseAndReturnVerifiedModule(hlo_string).value(); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).value()); + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = ParseShape("(s32[], f32[], f32[3]{0})").value(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); +} + } // namespace } // namespace xla From 5b98c79b60d2b9f7359038d1b913720c189ff002 Mon Sep 17 00:00:00 2001 From: Jackson Stokes Date: Tue, 28 Nov 2023 21:18:43 -0800 Subject: [PATCH 169/381] [XLA:GPU] Add a pass merging producer fusions in to Triton Softmax fusions. This merges producer fusions into triton softmax fusions provided that the resulting fusion would be able to be tiled. The producer fusions can have multiple inputs, but for now must have a single output (the triton softmax fusion). The merger does not yet operate on consumer fusions or with any basic cost modeling, both of which will be added in a followup CL. PiperOrigin-RevId: 586193270 --- third_party/xla/xla/service/gpu/BUILD | 48 +++ .../xla/service/gpu/fusion_merger_triton.cc | 131 +++++++ .../xla/service/gpu/fusion_merger_triton.h | 55 +++ .../service/gpu/fusion_merger_triton_test.cc | 323 ++++++++++++++++++ .../xla/xla/service/gpu/gpu_compiler.cc | 5 + .../ir_emitter_triton_parametrized_test.cc | 84 +++++ 6 files changed, 646 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/fusion_merger_triton.cc create mode 100644 third_party/xla/xla/service/gpu/fusion_merger_triton.h create mode 100644 third_party/xla/xla/service/gpu/fusion_merger_triton_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 83ec740c3b5861..151fef12956b2a 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1474,6 +1474,53 @@ xla_cc_test( ], ) +cc_library( + name = "fusion_merger_triton", + srcs = ["fusion_merger_triton.cc"], + hdrs = ["fusion_merger_triton.h"], + visibility = ["//visibility:public"], + deps = [ + ":backend_configs_cc", + ":ir_emission_utils", + ":triton_fusion_analysis", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "fusion_merger_triton_test", + srcs = ["fusion_merger_triton_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm70", + ]}, + backends = [ + "gpu", + ], + deps = [ + ":fusion_merger_triton", + "//xla:autotune_results_proto_cc", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + ], +) + cc_library( name = "softmax_rewriter_triton", srcs = ["softmax_rewriter_triton.cc"], @@ -3103,6 +3150,7 @@ cc_library( "@local_tsl//tsl/platform:numbers", ]) + xla_export_hlo_deps() + [ ":command_buffer_scheduling", + ":fusion_merger_triton", ":fusion_pipeline", ":ir_emitter_context", ":ir_emitter_unnested", diff --git a/third_party/xla/xla/service/gpu/fusion_merger_triton.cc b/third_party/xla/xla/service/gpu/fusion_merger_triton.cc new file mode 100644 index 00000000000000..ad9d76d5cd6c87 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_merger_triton.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_merger_triton.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +// Taking in a producer HloFusionInstruction, tries to merge into consumer +// triton softmax fusion. +// The following is assumed: +// * The producer is an HloFusionInstruction +// * The consumer is a triton softmax fusion +// +// Returns true if the producer is merged into the consumer and replaced +// in the original computation. Returns false otherwise. +StatusOr TryMergeFusionProducerIntoTritonSoftmaxConsumer( + HloFusionInstruction* producer) { + HloComputation* computation = producer->parent(); + HloInstruction* original_softmax_instruction = producer->users().front(); + CHECK_EQ(original_softmax_instruction->opcode(), HloOpcode::kFusion); + + std::unique_ptr candidate = + original_softmax_instruction->Clone(); + HloInstruction* candidate_fusion = + static_cast(candidate.get()); + + // Try to merge the producer into candidate fusion + candidate_fusion->MergeFusionInstruction(producer); + + HloComputation* fused_computation = + candidate_fusion->called_computations().front(); + + TF_ASSIGN_OR_RETURN(const auto analysis, + TritonFusionAnalysis::Execute(*fused_computation)); + + computation->AddInstruction(std::move(candidate)); + + if (original_softmax_instruction->IsRoot()) { + computation->set_root_instruction(candidate_fusion); + } + + TF_CHECK_OK( + original_softmax_instruction->ReplaceAllUsesWith(candidate_fusion)); + TF_RETURN_IF_ERROR( + computation->RemoveInstruction(original_softmax_instruction)); + + CHECK_EQ(0, producer->user_count()) << producer->ToString(); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + return true; +} + +} // anonymous namespace + +StatusOr FusionMergerTriton::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + int fused_comps = 0; + for (HloComputation* comp : + module->MakeNonfusionComputations(execution_threads)) { + if (comp->IsCustomCallComputation()) { + continue; + } + + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kCustom && + instr->backend_config().ok() && + instr->backend_config()->kind() == + kTritonSoftmaxFusionKind) { + // TODO(b/313026024): Add support for multiple users + if (instr->operand(0)->opcode() != HloOpcode::kFusion || + instr->operand(0)->user_count() != 1) { + continue; + } + + HloFusionInstruction* producer = + Cast(instr->mutable_operand(0)); + + VLOG(6) << "Matched triton_softmax kernel, Fusing producer " + << producer->ToShortString() << " into " + << instr->ToShortString(); + + absl::StatusOr result = + TryMergeFusionProducerIntoTritonSoftmaxConsumer(producer); + + if (!result.ok()) { + VLOG(6) << "Did not fuse producer into " << instr->ToShortString(); + } + + if (result.ok() && *result) ++fused_comps; + } + } + } + return fused_comps > 0; +} +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusion_merger_triton.h b/third_party/xla/xla/service/gpu/fusion_merger_triton.h new file mode 100644 index 00000000000000..56fb5e4667bbb7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_merger_triton.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ +#define XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla { +namespace gpu { + +// An HLO pass that attempts to merge producer fusions into triton softmax +// fusions. +// +// Producer kernels are only merged if the resulting fusion can be correctly +// tiled. If the result can be tiled, all operations from the auxiliary +// producer fusion will be merged into the triton softmax computation, and this +// computation will replace both the auxiliary and original triton softmax +// fusion. +// +// Auxiliary fusions are not merged into consumer triton fusions if: +// * The auxiliary fusion has multiple users +// * The resulting merged fusion is not tilable +class FusionMergerTriton : public HloModulePass { + public: + explicit FusionMergerTriton() = default; + absl::string_view name() const override { return "fusion-merger-triton"; } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ diff --git a/third_party/xla/xla/service/gpu/fusion_merger_triton_test.cc b/third_party/xla/xla/service/gpu/fusion_merger_triton_test.cc new file mode 100644 index 00000000000000..86de08d942864f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_merger_triton_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_merger_triton.h" + +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "xla/autotune_results.pb.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status_matchers.h" + +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +namespace xla { +namespace gpu { +namespace { + +namespace m = ::xla::match; +using FusionMergerTritonTest = HloTestBase; + +TEST_F(FusionMergerTritonTest, + CanMergeTritonFusionWithSingleParameterProducer) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_THAT(fusion_merger.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(verifier().Run(module.get()), IsOk()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter()))); +} + +TEST_F( + FusionMergerTritonTest, + CanMergeProducerFusionIntoTritonSoftmaxConsumerWhenTheConsumerIsNotRoot) { // NOLINT(whitespace/line_length) + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation + triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT broadcast = f32[10,125,127]{2,1,0} broadcast(triton_softmax), dimensions={1,2} +})"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_THAT(fusion_merger.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(verifier().Run(module.get()), IsOk()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Fusion(m::Parameter())))); +} + +TEST_F(FusionMergerTritonTest, + CanMergeTritonFusionWithMultipleParameterProducer) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + parameter_1 = f32[125,127]{1,0} parameter(1) + broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, broadcast) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + param_1 = f32[125,127]{1,0} parameter(1) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_TRUE(fusion_merger.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(FusionMergerTritonTest, CanMergeTritonFusionWithTransposeProducer) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + parameter_1 = f32[127,125]{1,0} parameter(1) + transpose = f32[125,127]{1,0} transpose(parameter_1), dimensions={1,0} + broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(transpose, broadcast) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + param_1 = f32[127,125]{1,0} parameter(1) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_TRUE(fusion_merger.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(FusionMergerTritonTest, + DoesNotMergeTritonFusionWithProducerContainingUntileableOp) { + // Right now, concatenate is not tileable. + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + parameter_0 = f32[125,63]{1,0} parameter(0) + parameter_1 = f32[125,64]{1,0} parameter(1) + ROOT concatenate = f32[125,127]{1,0} concatenate(parameter_0, parameter_1), dimensions={1} +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125,63]{1,0} parameter(0) + param_1 = f32[125,64]{1,0} parameter(1) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_FALSE(fusion_merger.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Fusion(m::Parameter(), m::Parameter())))); +} + +TEST_F(FusionMergerTritonTest, CanMergeTritonFusionWithElementwiseProducer) { + const std::string kHloText = R"( +HloModule layernorm + +add_f32 { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add_6 = f32[] add(Arg_0, Arg_1) +} + +auxiliary_fusion { + parameter_0 = f32[125,127]{1,0} parameter(0) + parameter_1 = f32[125,127]{1,0} parameter(1) + ROOT multiply_1 = f32[125,127]{1,0} multiply(parameter_0, parameter_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + constant_0 = f32[] constant(0) + reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=add_f32 + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply_result = f32[125,127]{1,0} multiply(parameter_0, broadcast) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[125,127]{1,0} parameter(1) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=auxiliary_fusion + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} + +)"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_TRUE(fusion_merger.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(FusionMergerTritonTest, + DoesNotMergeSoftmaxWithParamBroadcastedAlongBatchAndReduceDimensions) { + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +auxiliary_computation { + param_0 = f32[10,125,127]{2,1,0} parameter(0) + param_1 = f32[10]{0} parameter(1) + broadcast_0 = f32[10,125,127]{2,1,0} broadcast(param_1), dimensions={0} + ROOT multiply_0 = f32[10,125,127]{2,1,0} multiply(param_0, broadcast_0) +} + +triton_softmax_computation { + param_0 = f32[10,125,127]{2,1,0} parameter(0) + multiply = f32[10,125,127]{2,1,0} multiply(param_0, param_0) + constant = f32[] constant(0) + reduce = f32[10,125]{1,0} reduce(multiply, constant), dimensions={2}, to_apply=add + broadcast = f32[10,125,127]{2,1,0} broadcast(reduce), dimensions={0,1} + ROOT multiply_out = f32[10,125,127]{2,1,0} multiply(param_0, broadcast) +} + +ENTRY main { + param_0 = f32[10,125,127]{2,1,0} parameter(0) + param_1 = f32[10]{0} parameter(1) + auxiliary_fusion = f32[10,125,127]{2,1,0} fusion(param_0, param_1), kind=kCustom, calls=auxiliary_computation + ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + FusionMergerTriton fusion_merger; + EXPECT_FALSE(fusion_merger.Run(module.get()).value()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Fusion()))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index ceb72769ee7b79..4a5a8f39929025 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -127,6 +127,7 @@ limitations under the License. #include "xla/service/gpu/copy_fusion.h" #include "xla/service/gpu/custom_fusion_rewriter.h" #include "xla/service/gpu/dot_dimension_sorter.h" +#include "xla/service/gpu/fusion_merger_triton.h" #include "xla/service/gpu/fusion_pipeline.h" #include "xla/service/gpu/fusion_wrapper.h" #include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" @@ -948,6 +949,10 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, .Run(hlo_module) .status()); + if (debug_options.xla_gpu_enable_triton_softmax_fusion()) { + TF_RETURN_IF_ERROR(FusionMergerTriton().Run(hlo_module).status()); + } + if (debug_options.xla_gpu_collect_cost_model_stats()) { GpuHloCostAnalysis::Options cost_analysis_options{ ShapeSizeBytesFunction(), diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 848c774aa94382..12b3f5205f9e59 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -2363,6 +2363,90 @@ ENTRY main { INSTANTIATE_TEST_SUITE_P(TritonSoftmaxTestSuite, TritonSoftmaxTest, ::testing::Values(F32, F16, BF16)); +TEST_F(TritonSoftmaxTest, CanFuseAndEmitTritonSoftmaxWithTwoParameters) { + const std::string hlo_text = R"( +HloModule layernorm + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[125,127]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} +)"; + + // Param order is arbitrary. We test that only param_1 is in the fused root + // instruction below. + const std::string hlo_ref = R"( +; CHECK: ENTRY +; CHECK-DAG: %[[param_0:.*]] = f32[125,127]{1,0} parameter(0) +; CHECK-DAG: %[[param_1:.*]] = f32[127]{0} parameter(1) +; CHECK: ROOT +; CHECK-SAME: f32[125,127]{1,0} fusion +; CHECK-SAME: %[[param_1]] +; CHECK-SAME: kind=kCustom +; CHECK-SAME: triton_softmax +)"; + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 2e-6; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + +TEST_F(TritonSoftmaxTest, CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { + const std::string hlo_text = R"( +HloModule layernorm + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[10,125,127]{2,1,0} parameter(1) + constant = f32[] constant(0) + reduce_0 = f32[125,127]{1,0} reduce(param_1, constant), dimensions={0}, to_apply=add + multiply_0 = f32[125,127]{1,0} multiply(param_0, reduce_0) + constant_0 = f32[] constant(0) + reduce_1 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_1), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} +)"; + + // We expect to not fuse everything into the triton softmax, because of the + // reduce over the non-row dimension. + const std::string hlo_ref = R"( +; CHECK: ENTRY +; CHECK-DAG: %[[P0:.*]] = f32[125,127]{1,0} parameter(0) +; CHECK-DAG: %[[P1:.*]] = f32[10,125,127]{2,1,0} parameter(1) +; CHECK: %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]]) +; CHECK: kind=kLoop +; CHECK: ROOT +; CHECK-SAME: f32[125,127]{1,0} fusion(%[[FUSION]]) +; CHECK-SAME: kind=kCustom +; CHECK-SAME: triton_softmax +)"; + MatchOptimizedHlo(hlo_text, hlo_ref); + + float tolerance = 2e-6; + EXPECT_TRUE(RunAndCompare(hlo_text, + ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); +} + } // namespace } // namespace gpu } // namespace xla From 941c1a127c2f2158e50f3be47a05e2024c85eb23 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 29 Nov 2023 00:55:52 -0800 Subject: [PATCH 170/381] [XLA] Fix the MHLO TopK `assemblyFormat` to correctly support the optional `largest` attribute. PiperOrigin-RevId: 586242470 --- third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 2 +- .../xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index dea1b6a1ee74da..e566a7463a64b4 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2998,7 +2998,7 @@ def MHLO_TopKOp : MHLO_Op<"topk", [RecursiveMemoryEffects, InferTensorType]> { MHLO_Tensor:$indices); let assemblyFormat = [{ - `(`$operand `,` `k` `=` $k `,` `largest` `=` $largest `)` attr-dict `:` + `(`$operand `,` `k` `=` $k (`,` `largest` `=` $largest^)? `)` attr-dict `:` type($operand) `->` `(`type($values)`,` type($indices)`)` }]; let hasCustomHLOConverter = 1; diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index f3393b478d1dae..52e18f22335c85 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6552,6 +6552,20 @@ func.func @top_k_unranked(%arg0 : tensor<*xf32>) { // ----- +func.func @top_k_1d_false(%arg0 : tensor<16xf32>) { + %0:2 = mhlo.topk(%arg0, k=8, largest=false) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) + return +} + +// ----- + +func.func @top_k_1d_default(%arg0 : tensor<16xf32>) { + %0:2 = mhlo.topk(%arg0, k=8) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) + return +} + +// ----- + func.func @topk_rank_at_least_one(%arg0 : tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{operand's rank must be at least 1}} From 8a283df5824322fa33e17b87a0133a5899145b62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 01:01:59 -0800 Subject: [PATCH 171/381] Update GraphDef version to 1695. PiperOrigin-RevId: 586243999 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d766021c9e5a68..c7699beb5be775 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1694 // Updated: 2023/11/28 +#define TF_GRAPH_DEF_VERSION 1695 // Updated: 2023/11/29 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From dedcf06cfb646762c51bce72971bac95149cc45a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 01:01:59 -0800 Subject: [PATCH 172/381] compat: Update forward compatibility horizon to 2023-11-29 PiperOrigin-RevId: 586244002 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 221e88fa1b8b0a..288f3ffe741104 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 28) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 29) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 7a7a4c10bc6b86126733e85157b8068a6b6fe0da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Wed, 29 Nov 2023 03:11:16 -0800 Subject: [PATCH 173/381] [XLA:GPU][NFC] Beautify code related to Triton fusions This is a little beautification for the earlier "Split GemmRewriterTriton into 4 parts" effort. PiperOrigin-RevId: 586278961 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 38 +++++++++-------- .../xla/service/gpu/triton_fusion_analysis.cc | 41 ++++++++++++------- .../xla/xla/service/gpu/triton_support.h | 1 + .../service/gpu/triton_tiling_propagation.cc | 2 + .../service/gpu/triton_tiling_propagation.h | 1 + 5 files changed, 51 insertions(+), 32 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index fa20343a435827..f618411e7f49de 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -60,6 +60,12 @@ namespace gpu { namespace { +using triton_fusion::DimOrdersAndReqs; +using triton_fusion::DimOrdersAndReqsOrError; +using triton_fusion::FusionContext; +using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible; +using triton_fusion::TransformDirection; + using OldToNewHloMap = absl::flat_hash_map; @@ -170,14 +176,13 @@ void TryToFuseWithInputsRecursively(HloInstruction& root, continue; } num_requeued = 0; - const triton_fusion::DimOrdersAndReqsOrError result = + const DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *hlo, triton_fusion::TransformDirection::kOutputToInput, + *hlo, TransformDirection::kOutputToInput, /*src_operand_index=*/std::nullopt, context.dim_orders().at(hlo), gpu_version, context.hero_properties()); - if (!std::holds_alternative(result) || - !context.CombineDimOrdersAndReqs( - std::get(result))) { + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs(std::get(result))) { continue; } if (hlo->opcode() != HloOpcode::kParameter) { @@ -236,12 +241,12 @@ StatusOr FuseDot(HloInstruction& dot, // differently shaped tiles but may go through same HLO graph nodes. // Direct dot inputs have well defined dimension orders. - auto fuse_inputs = [&](int operand_number, OldToNewHloMap& old_to_new_map) - -> StatusOr { + auto fuse_inputs = + [&](int operand_number, + OldToNewHloMap& old_to_new_map) -> StatusOr { const int operand_count_before = fusion_inputs.size(); // Direct dot inputs have well defined dimension orders. - auto context = - triton_fusion::FusionContext::FromDotOperand(dot, operand_number); + auto context = FusionContext::FromDotOperand(dot, operand_number); TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), gpu_version, context, old_to_new_map, fusion_inputs, builder); @@ -255,7 +260,7 @@ StatusOr FuseDot(HloInstruction& dot, // Original instruction -> fused one. Separate for each scope. OldToNewHloMap lhs_old_to_new_map; - TF_ASSIGN_OR_RETURN(const triton_fusion::FusionContext lhs_context, + TF_ASSIGN_OR_RETURN(const FusionContext lhs_context, fuse_inputs(0, lhs_old_to_new_map)); OldToNewHloMap rhs_old_to_new_map; @@ -272,7 +277,7 @@ StatusOr FuseDot(HloInstruction& dot, // Fusion at dot's output. // These describe _outputs_ of corresponding HLOs. - auto context = triton_fusion::FusionContext::FromDotOutput( + auto context = FusionContext::FromDotOutput( dot, /*split_k=*/1, lhs_context.splittable_dimension_major_part_size()); HloInstruction* fusion_output = ˙ bool output_changed = true; @@ -285,15 +290,14 @@ StatusOr FuseDot(HloInstruction& dot, if (!IsDistributiveOverAddition(*user)) { break; } - triton_fusion::DimOrdersAndReqsOrError result = - triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *user, triton_fusion::TransformDirection::kInputToOutput, + DimOrdersAndReqsOrError result = + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + *user, TransformDirection::kInputToOutput, user->operand_index(fusion_output), context.dim_orders().at(fusion_output), gpu_version, context.hero_properties()); - if (!std::holds_alternative(result) || - !context.CombineDimOrdersAndReqs( - std::get(result))) { + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs(std::get(result))) { break; } for (HloInstruction* operand : user->operands()) { diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index 8af0e9529d49cf..1d9188667a9bbf 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -42,6 +42,17 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +using triton_fusion::DimOrdersAndReqs; +using triton_fusion::DimOrdersAndReqsOrError; +using triton_fusion::FusionContext; +using triton_fusion::GetPropagatedDimOrdersAndRequirements; +using triton_fusion::kNoSplitRequirement; +using triton_fusion::TransformDirection; + +} // namespace + namespace triton_fusion { /*static*/ FusionContext FusionContext::FromDotOperand( @@ -103,6 +114,7 @@ namespace triton_fusion { } namespace { + // Tells how many new parameters does a fusion gain by fusing the operation as // an input. int64_t NumAddedParameters(const HloInstruction& hlo) { @@ -114,6 +126,7 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { // All other instructions add all own inputs and remove own single output. return hlo.operand_count() - 1; } + } // namespace bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { @@ -127,7 +140,7 @@ bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { } RequirementsOrError requirements_or_error = - triton_fusion::CombineRequirements(requirements_, update.requirements); + CombineRequirements(requirements_, update.requirements); if (std::holds_alternative(requirements_or_error)) { return false; } @@ -197,7 +210,7 @@ StatusOr TritonFusionAnalysis::Execute( Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( const HloInstruction& root) { - auto context = triton_fusion::FusionContext::FromSoftmaxRoot(root); + auto context = FusionContext::FromSoftmaxRoot(root); // Softmax fusion uses one tiled scope. TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); @@ -208,11 +221,10 @@ Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, const int split_k) { - int64_t lhs_nc_split_major_part_size = triton_fusion::kNoSplitRequirement; + int64_t lhs_nc_split_major_part_size = kNoSplitRequirement; for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); - auto context = triton_fusion::FusionContext::FromDotOperand( - dot, operand_number, split_k); + auto context = FusionContext::FromDotOperand(dot, operand_number, split_k); TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); if (scope == Scope::LHS) { @@ -221,8 +233,8 @@ Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, } } - auto context = triton_fusion::FusionContext::FromDotOutput( - dot, split_k, lhs_nc_split_major_part_size); + auto context = + FusionContext::FromDotOutput(dot, split_k, lhs_nc_split_major_part_size); const HloInstruction* output = ˙ // Currently supported is one fusion output and one path from dot to it. // Propagate dimension order from dot to root. @@ -230,15 +242,12 @@ Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, TF_RET_CHECK(output->user_count() == 1); const HloInstruction* input = output; output = output->users()[0]; - triton_fusion::DimOrdersAndReqsOrError result = - triton_fusion::GetPropagatedDimOrdersAndRequirements( - *output, context.dim_orders().at(input), - triton_fusion::TransformDirection::kInputToOutput, - context.hero_properties()); + DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( + *output, context.dim_orders().at(input), + TransformDirection::kInputToOutput, context.hero_properties()); + TF_RET_CHECK(std::holds_alternative(result)); TF_RET_CHECK( - std::holds_alternative(result)); - TF_RET_CHECK(context.CombineDimOrdersAndReqs( - std::get(result))); + context.CombineDimOrdersAndReqs(std::get(result))); } TF_RET_CHECK( iter_specs_[Scope::OUTPUT] @@ -267,6 +276,7 @@ const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec( } namespace { + std::string IterationSpecByInstructionMapToString( const TritonFusionAnalysis::IterationSpecByInstructionMap& m) { return absl::StrCat("IterSpec{", @@ -288,6 +298,7 @@ std::string ScopeToString(TritonFusionAnalysis::Scope s) { return "OUTPUT"; } } + } // namespace std::string TritonFusionAnalysis::ToString() const { diff --git a/third_party/xla/xla/service/gpu/triton_support.h b/third_party/xla/xla/service/gpu/triton_support.h index e95f196fdc6a5f..f3dcd23d954d34 100644 --- a/third_party/xla/xla/service/gpu/triton_support.h +++ b/third_party/xla/xla/service/gpu/triton_support.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" + namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index f64042649ec032..ee581ad85e6a23 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -215,6 +215,7 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { } namespace { + // Logical index of a dimension in `shape` labeled with `label` in the // `dim_order` describing the shape. std::optional LogicalIndexOfLabeledDimension( @@ -261,6 +262,7 @@ RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a, // SoftmaxRequirements is an empty class for now. return a; } + } // namespace RequirementsOrError CombineRequirements(Requirements a, diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h index 1887b962a90e6c..69485b909684cb 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" + namespace xla { namespace gpu { From 1403549cc68cb8d257b400fe831c48693ba7bc7c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 29 Nov 2023 04:11:28 -0800 Subject: [PATCH 174/381] Fix a compile error caused by in_process_collectives As in the comment, the C++ standard forbids putting static_assert(false) in a branch of a constexpr if (at least until C++23 or so). This uses the workaround proposed in the C++ reference to fix the issue. PiperOrigin-RevId: 586291757 --- .../xla/xla/service/cpu/in_process_collectives.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index 0de37e2f6017e4..fc13c6c3870377 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -85,6 +85,13 @@ T GetInitialValue(ReductionKind reduction_kind) { } } +// We cannot use static_assert(false), because the C++ standard (prior to +// CWG2518) does not allow the statement discarded by a constexpr if to +// be ill-formed for every possible specialization. +// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if +template +constexpr bool always_false_v = false; + template void Reduce(absl::Span acc, absl::Span const> inputs) { // TODO(penporn): make sure this gets vectorized. @@ -113,7 +120,7 @@ void Reduce(absl::Span acc, absl::Span const> inputs) { } } } else { - static_assert(false, "Unsupported reduction kind"); + static_assert(always_false_v, "Unsupported reduction kind"); } } From 71f59bb3b7a8781ea48a6ca1e5722e35bb2c9027 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 29 Nov 2023 04:17:59 -0800 Subject: [PATCH 175/381] [TileAnalysis] Add indexing computation for reshape. PiperOrigin-RevId: 586293353 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../xla/service/gpu/model/tile_analysis.cc | 172 ++++++++++++++++++ .../service/gpu/model/tile_analysis_test.cc | 164 +++++++++++++++++ 3 files changed, 337 insertions(+) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index d8a78bfc3515d6..497d601a7a0db1 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -224,6 +224,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:matmul_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index e36317645ce7fd..65b2cce028afab 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -15,7 +15,9 @@ limitations under the License. #include "xla/service/gpu/model/tile_analysis.h" +#include #include +#include #include #include #include @@ -23,6 +25,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -347,6 +350,172 @@ StatusOr ComputeReduceOpIndexing( return HloInstructionIndexing{std::move(operand_indexing_maps)}; } +// Computes strides for a shape. +std::vector ComputeStrides(absl::Span dims) { + size_t rank = dims.size(); + std::vector strides(rank, 1); + for (int i = rank - 2; i >= 0; --i) { + strides[i] = dims[i + 1] * strides[i + 1]; + } + return strides; +} + +// Computes 1D index given a shape and N-d indexing expressions. +AffineExpr LinearizeShape(absl::Span dims, + absl::Span dimension_exprs, + MLIRContext* mlir_context) { + AffineExpr linear_index = getAffineConstantExpr(0, mlir_context); + + auto strides = ComputeStrides(dims); + for (auto [stride, dimension_expr] : llvm::zip(strides, dimension_exprs)) { + linear_index = getAffineBinaryOpExpr( + AffineExprKind::Add, linear_index, + getAffineBinaryOpExpr(AffineExprKind::Mul, + getAffineConstantExpr(stride, mlir_context), + dimension_expr)); + } + return linear_index; +} + +// Computes N-d indexing expressions given a linear index and a shape. +std::vector DelinearizeIndex(absl::Span dims, + AffineExpr linear_index, + MLIRContext* mlir_context) { + std::vector multi_index; + multi_index.reserve(dims.size()); + + AffineExpr remainder = linear_index; + for (int64_t stride : ComputeStrides(dims)) { + AffineExpr stride_expr = getAffineConstantExpr(stride, mlir_context); + multi_index.push_back(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, + remainder, stride_expr)); + remainder = + getAffineBinaryOpExpr(AffineExprKind::Mod, remainder, stride_expr); + } + return multi_index; +} + +// Computes indexing for "minimal" reshapes, i.e. reshapes that cannot be +// represented by a series of composed reshapes, i.e. when there are no +// subshapes in input and output that have the same number of elements. +// For example, [8, 4] -> [8, 2, 2] is not a minimal reshape, it has matching +// subshapes [8] -> [8] and [4] -> [2, 2]. +// +// There are only 4 types of "minimal" reshapes considers only 4 cases: +// 1. Dimension is not changed, e.g. [8] -> [8] +// 2. Dimension is expanded, e.g. [8] -> [4, 2] +// 3. Dimension is collapsed, e.g. [4, 2] -> [8] +// 4. Dimension is collapsed and expanded, e.g. [8, 16] -> [4, 32] +// +// The function computes indexing maps for these 4 cases, i.e. considers given +// input/output shapes and checks if the shapes are the same, expanded or +// collapsed. Otherwise, performs linearization/delinearization. +void ComputeMinimalReshapeIndexing( + absl::Span input_dims, absl::Span output_dims, + absl::Span output_dims_exprs, + std::vector* exprs, MLIRContext* mlir_context) { + // The shape does not change. + if (input_dims.size() == 1 && output_dims.size() == 1) { + absl::c_copy(output_dims_exprs, std::back_inserter(*exprs)); + return; + } + // Expand shape. + if (input_dims.size() == 1) { + exprs->push_back( + LinearizeShape(output_dims, output_dims_exprs, mlir_context)); + return; + } + // Collapse shape. + if (output_dims.size() == 1) { + auto multi_index = + DelinearizeIndex(input_dims, output_dims_exprs.front(), mlir_context); + absl::c_copy(multi_index, std::back_inserter(*exprs)); + return; + } + // Generic case. + AffineExpr linear_index = + LinearizeShape(output_dims, output_dims_exprs, mlir_context); + auto multi_index = DelinearizeIndex(input_dims, linear_index, mlir_context); + absl::c_copy(multi_index, std::back_inserter(*exprs)); +} + +// Scans input and output shapes from left to right in an attempt to find +// subshapes with the same number of elements and then computes indexing map for +// every pair of subshapes. +// +// Example: +// p0 = f32[4, 8, 12] parameter(0) +// reshape = f32[32, 3, 4] reshape(p0) +// +// This reshape can be represented as a composition of two reshapes. +// The first reshape collapses dimensions first two input dimensions [4, 8] onto +// the output dimension [32]. +// The second reshape expands the input dimension [12] into two output +// dimensions [3, 4]. +// This is an optimization that allows us to construct simpler affine maps, +// otherwise we would need to linearize/delinearize even some of the simpler +// cases. +std::vector ComputeComposedReshapeIndexing( + absl::Span input_dims, absl::Span output_dims, + MLIRContext* mlir_context) { + std::vector exprs; + + size_t input_rank = input_dims.size(); + size_t output_rank = output_dims.size(); + std::vector output_dims_exprs; + + // Find subshapes with the same element count and compute indexing for them. + int64_t input_num_elements = 1; + int64_t output_num_elements = 1; + std::vector input_subshape, output_subshape; + size_t input_dim_id = 0, output_dim_id = 0; + while (input_dim_id < input_rank || output_dim_id < output_rank || + !input_subshape.empty()) { + if (input_dim_id < input_rank && + (input_subshape.empty() || input_num_elements < output_num_elements || + input_dims[input_dim_id] == 1)) { + input_num_elements *= input_dims[input_dim_id]; + input_subshape.push_back(input_dims[input_dim_id]); + ++input_dim_id; + continue; + } + if (output_dim_id < output_rank && + (output_subshape.empty() || output_num_elements < input_num_elements || + output_dims[output_dim_id] == 1)) { + output_num_elements *= output_dims[output_dim_id]; + output_subshape.push_back(output_dims[output_dim_id]); + output_dims_exprs.push_back( + getAffineDimExpr(output_dim_id, mlir_context)); + ++output_dim_id; + continue; + } + ComputeMinimalReshapeIndexing(input_subshape, output_subshape, + output_dims_exprs, &exprs, mlir_context); + input_num_elements = 1; + output_num_elements = 1; + input_subshape.clear(); + output_subshape.clear(); + output_dims_exprs.clear(); + } + return exprs; +} + +StatusOr ComputeReshapeOpIndexing( + const HloInstruction* reshape, MLIRContext* mlir_context) { + auto input_dims = reshape->operand(0)->shape().dimensions(); + auto output_dims = reshape->shape().dimensions(); + + std::vector exprs = + ComputeComposedReshapeIndexing(input_dims, output_dims, mlir_context); + + IndexingMap indexing_map{ + .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + .input_dims_sizes = {}}; + return HloInstructionIndexing{{HloOperandIndexing{ + .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; +} + StatusOr ComputeReverseOpIndexing( const HloReverseInstruction* reverse, MLIRContext* mlir_context) { absl::flat_hash_set reverse_dims(reverse->dimensions().begin(), @@ -483,6 +652,9 @@ StatusOr ComputeInstructionIndexing( if (auto reduce = DynCast(instr)) { return ComputeReduceOpIndexing(reduce, output_id, mlir_context); } + if (auto reshape = DynCast(instr)) { + return ComputeReshapeOpIndexing(reshape, mlir_context); + } if (auto reverse = DynCast(instr)) { return ComputeReverseOpIndexing(reverse, mlir_context); } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 7010fc24703dbb..5a10285386e38b 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -296,6 +296,72 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { std::vector{16, 128}))))); } +TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[128] parameter(0) + expand = f32[8, 16] reshape(p0) + ROOT collapse = f32[128] reshape(expand) + } + ENTRY e { + p0 = f32[128] parameter(0) + ROOT fusion = f32[128] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0) -> (d0)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[8, 16] parameter(0) + collapse = f32[128] reshape(p0) + ROOT expand = f32[8, 16] reshape(collapse) + } + ENTRY e { + p0 = f32[8, 16] parameter(0) + ROOT fusion = f32[8, 16] fusion(p0), kind=kLoop, calls=f + } + )")); + // TODO(b/313840171): Simplify the composed affine expression. + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + f { + p0 = f32[10, 10, 10] parameter(0) + reshape1 = f32[50, 20] reshape(p0) + ROOT reshape2 = f32[10, 10, 10] reshape(reshape1) + } + ENTRY e { + p0 = f32[10, 10, 10] parameter(0) + ROOT fusion = f32[10, 10, 10] fusion(p0), kind=kLoop, calls=f + } + )")); + // TODO(b/313840171): Simplify the composed affine expression. + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " + "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, d2 mod 10)", + std::vector{}))))); +} + TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( @@ -320,6 +386,104 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { std::vector{}))))); } +TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4,8] parameter(0) + ROOT reshape = f32[32] reshape(p0) + } + )")); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0) -> (d0 floordiv 8, d0 mod 8)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[32] parameter(0) + ROOT reshape = f32[4, 8] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0 * 8 + d1)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, ReshapeOpExpandAndCollapseShape) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4, 8, 12] parameter(0) + ROOT reshape = f32[32, 3, 4] reshape(p0) + } + )")); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[16, 8] parameter(0) + ROOT reshape = f32[4, 4, 8] reshape(p0) + } + )")); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0 * 4 + d1, d2)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4,8] parameter(0) + ROOT reshape = f32[2, 4, 4] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " + "(d0 * 16 + d1 * 4 + d2) mod 8)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[2, 4, 4] parameter(0) + ROOT reshape = f32[4, 8] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1) -> ((d0 * 8 + d1) floordiv 16, " + "((d0 * 8 + d1) mod 16) floordiv 4, d1 mod 4)", + std::vector{}))))); +} + TEST_F(TileAnalysisTest, ReduceOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( From 34c440cc2e011f8a75f14b4cc1cb04f990a8348f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 29 Nov 2023 06:36:06 -0800 Subject: [PATCH 176/381] Fix invalid conditional check in macOS CI script If no space is provided around the operator `==`, Bash evaluates the condition `[[ $FOO==1 ]]` as `[[ -n $FOO==1]]`. That is, instead of checking if `FOO` is equal to 1, we end up checking if the string "$FOO==1" is empty. PiperOrigin-RevId: 586323341 --- ci/official/utilities/setup_macos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index e91100b5e3d2ee..9dcc28406907d6 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -42,7 +42,7 @@ fi # 'TFCI_MACOS_BAZEL_TEST_DIR_PATH' environment variable to point to a partition # with ample storage. When this variable is empty (i.e by default), Bazel will # use the output base directory to run tests. -if [[ "${TFCI_MACOS_BAZEL_TEST_DIR_ENABLE}"==1 ]]; then +if [[ "${TFCI_MACOS_BAZEL_TEST_DIR_ENABLE}" == 1 ]]; then mkdir -p "${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" export TEST_TMPDIR="${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" fi From 74866075411bd9444246e16a79429b852e4db31c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 29 Nov 2023 06:54:26 -0800 Subject: [PATCH 177/381] Basic simplifier for indexing maps. Doesn't yet handle everything that can and should be handled, see TODOs in tests. PiperOrigin-RevId: 586327437 --- .../xla/service/gpu/model/tile_analysis.cc | 264 ++++++++++++++++++ .../xla/xla/service/gpu/model/tile_analysis.h | 8 + .../service/gpu/model/tile_analysis_test.cc | 30 +- 3 files changed, 291 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 65b2cce028afab..3f1f4eefd4df76 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -53,9 +53,12 @@ namespace gpu { namespace { using llvm::SmallVector; +using mlir::AffineBinaryOpExpr; +using mlir::AffineDimExpr; using mlir::AffineExpr; using mlir::AffineExprKind; using mlir::AffineMap; +using mlir::AffineSymbolExpr; using mlir::getAffineBinaryOpExpr; using mlir::getAffineConstantExpr; using mlir::getAffineDimExpr; @@ -590,8 +593,269 @@ std::string ToStringImpl(const T& value) { return ss.str(); } +struct IndexingMapSimplifier { + struct Bounds { + int64_t lower; + int64_t upper; + }; + + Bounds BoundsInclusive(AffineExpr expr) { + auto bound = bounds.find(expr); + if (bound != bounds.end()) return bound->second; + + switch (expr.getKind()) { + case AffineExprKind::Constant: { + int64_t value = mlir::cast(expr).getValue(); + CHECK_GE(value, 0); + return bounds[expr] = {value, value}; + } + case AffineExprKind::DimId: { + int64_t size = + dimension_sizes[mlir::cast(expr).getPosition()]; + return bounds[expr] = {0, size - 1}; + } + case AffineExprKind::SymbolId: { + int64_t size = + symbol_sizes[mlir::cast(expr).getPosition()]; + return bounds[expr] = {0, size - 1}; + } + default: + auto binary_op = mlir::dyn_cast(expr); + CHECK(binary_op); + auto lhs = BoundsInclusive(binary_op.getLHS()); + auto rhs = BoundsInclusive(binary_op.getRHS()); + + auto& result = bounds[expr]; + switch (expr.getKind()) { + case AffineExprKind::Add: + return result = {lhs.lower + rhs.lower, lhs.upper + rhs.upper}; + case AffineExprKind::Mul: + return result = {lhs.lower * rhs.lower, lhs.upper * rhs.upper}; + case AffineExprKind::Mod: { + CHECK_EQ(rhs.lower, rhs.upper) << "RHS of mod must be a constant"; + int64_t m = rhs.lower; + if (lhs.upper < m) { + return result = lhs; + } + return result = {0, m - 1}; + } + case AffineExprKind::FloorDiv: { + CHECK_EQ(rhs.lower, rhs.upper) + << "RHS of floor_div must be a constant"; + int64_t d = rhs.lower; + return result = {lhs.lower / d, lhs.upper / d}; + } + default: + // We don't use ceildiv, so we don't support it. + LOG(FATAL) << "Unsupported expression"; + } + } + } + + // Simplifier for mod. + // - Rewrites (a * 100 + ...) % 100 to (...) % 100 + // - Rewrites a % b to a if a is known to be less than b. + AffineExpr RewriteMod(AffineBinaryOpExpr mod) { + auto lhs_simplified = SimplifyOnce(mod.getLHS()); + + auto lhs = BoundsInclusive(lhs_simplified); + auto rhs = BoundsInclusive(mod.getRHS()); + + // a % b where b is always larger than a? + if (lhs.upper < rhs.lower) return lhs_simplified; + + // The logic below assumes we have a constant RHS. + if (rhs.lower != rhs.upper) return mod; + int64_t m = rhs.lower; + + auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (expr.getKind() != AffineExprKind::Mul) { + return true; + } + + auto mul_rhs = + BoundsInclusive(mlir::cast(expr).getRHS()); + bool remove = mul_rhs.lower == mul_rhs.upper && (mul_rhs.lower % m) == 0; + return !remove; // We keep it if we don't remove it! + }); + + // If we weren't able to remove or simplify anything, return the original + // expression. + if (new_lhs == mod.getLHS()) { + return mod; + } + // If we removed everything, return 0. + if (!new_lhs) { + return getAffineConstantExpr(0, mlir_context); + } + // Otherwise, return new_sum % m. + return getAffineBinaryOpExpr(AffineExprKind::Mod, new_lhs, mod.getRHS()); + } + + // Simplifier for floordiv. + // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 + // - Rewrites a / 100 to 0 when a is known to be less than 100. + AffineExpr RewriteFloorDiv(AffineBinaryOpExpr div) { + auto lhs_simplified = SimplifyOnce(div.getLHS()); + auto lhs = BoundsInclusive(lhs_simplified); + auto rhs = BoundsInclusive(div.getRHS()); + + if (lhs.upper < rhs.lower) { + return getAffineConstantExpr(0, mlir_context); + } + + // The logic below assumes we have a constant RHS. + if (rhs.lower != rhs.upper) return div; + int64_t d = rhs.lower; + + AffineExpr extracted = getAffineConstantExpr(0, mlir_context); + auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhsMultiplier(expr)) { + // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep + // one x, but we currently have no reason to do that. + if (*multiplier % d != 0) return true; + int64_t factor = *multiplier / d; + extracted = getAffineBinaryOpExpr( + AffineExprKind::Add, extracted, + getAffineBinaryOpExpr(AffineExprKind::Mul, + cast(expr).getLHS(), + getAffineConstantExpr(factor, mlir_context))); + // Remove from dividend. + return false; + } + + // Not a constant multiplier, keep in dividend. + return true; + }); + + // If we removed everything, skip the div. + if (!new_dividend) return extracted; + // If we removed nothing, return the original division. + if (extracted == getAffineConstantExpr(0, mlir_context) && + new_dividend == div.getLHS()) { + return div; + } + + return getAffineBinaryOpExpr( + AffineExprKind::Add, extracted, + getAffineBinaryOpExpr(AffineExprKind::FloorDiv, new_dividend, + div.getRHS())); + } + + std::optional GetConstantRhsMultiplier(AffineExpr expr) { + if (expr.getKind() != AffineExprKind::Mul) return std::nullopt; + auto bound = BoundsInclusive(mlir::cast(expr).getRHS()); + if (bound.lower != bound.upper) return std::nullopt; + return bound.lower; + } + + AffineExpr RewriteSumIf(AffineExpr expr, + const std::function& pred) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + auto lhs = RewriteSumIf(add.getLHS(), pred); + auto rhs = RewriteSumIf(add.getRHS(), pred); + if (lhs == add.getLHS() && rhs == add.getRHS()) { + return add; + } + if (lhs && rhs) { + return getAffineBinaryOpExpr(AffineExprKind::Add, lhs, rhs); + } + return lhs ? lhs : (rhs ? rhs : nullptr); + } + return pred(expr) ? expr : nullptr; + } + + // Attempts to simplify the expression, but doesn't attempt to simplify the + // result further. + AffineExpr SimplifyOnce(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Mul: + case AffineExprKind::Add: { + auto binop = mlir::cast(expr); + auto lhs = SimplifyOnce(binop.getLHS()); + auto rhs = SimplifyOnce(binop.getRHS()); + if (lhs == binop.getLHS() && rhs == binop.getRHS()) { + return expr; + } + return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); + } + case AffineExprKind::Mod: + return RewriteMod(cast(expr)); + case AffineExprKind::FloorDiv: + return RewriteFloorDiv(cast(expr)); + default: + return expr; + } + } + + // Simplifies the expression as much as possible. + AffineExpr Simplify(AffineExpr expr) { + while (true) { + auto simplified = SimplifyOnce(expr); + if (simplified == expr) return expr; + expr = simplified; + } + } + + MLIRContext* mlir_context; + absl::Span dimension_sizes; + absl::Span symbol_sizes; + llvm::DenseMap bounds{}; +}; + } // namespace +bool IndexingMap::Simplify(absl::Span dimension_sizes) { + IndexingMapSimplifier simplifier{affine_map.getContext(), dimension_sizes, + input_dims_sizes}; + std::vector results; + bool any_changed = false; + for (auto expr : affine_map.getResults()) { + auto simplified = simplifier.Simplify(expr); + any_changed |= simplified != expr; + results.push_back(simplified); + } + + if (!any_changed) { + return false; + } + + affine_map = + AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(), + results, affine_map.getContext()); + return true; +} + +bool HloOperandIndexing::Simplify(absl::Span dimension_sizes) { + std::vector to_remove; + std::vector to_add; + for (auto map : indexing_maps) { + to_remove.push_back(map); + if (map.Simplify(dimension_sizes)) { + to_add.push_back(map); + } else { + to_remove.pop_back(); + } + } + for (auto& map : to_remove) { + indexing_maps.erase(map); + } + for (auto& map : to_add) { + indexing_maps.insert(map); + } + return !to_remove.empty(); +} + +bool HloInstructionIndexing::Simplify( + absl::Span dimension_sizes) { + bool any_simplified = false; + for (auto& operand_indexing : operand_indexing_maps) { + any_simplified |= operand_indexing.Simplify(dimension_sizes); + } + return any_simplified; +} + std::string ToString(const AffineMap& affine_map) { std::string s; llvm::raw_string_ostream ss(s); diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.h b/third_party/xla/xla/service/gpu/model/tile_analysis.h index 601272b1f68360..a8bed459116973 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.h @@ -65,6 +65,8 @@ namespace gpu { // could not be expressed via dimensions of the output. struct IndexingMap { std::string ToString() const; + // Returns true if the map was simplified. + bool Simplify(absl::Span dimension_sizes); mlir::AffineMap affine_map; std::vector input_dims_sizes; @@ -84,6 +86,9 @@ H AbslHashValue(H h, const IndexingMap& indexing_map) { struct HloOperandIndexing { std::string ToString() const; + // Returns true if the indexing was simplified. + bool Simplify(absl::Span dimension_sizes); + absl::flat_hash_set indexing_maps; int64_t operand_id; }; @@ -95,6 +100,9 @@ std::ostream& operator<<(std::ostream& out, struct HloInstructionIndexing { std::string ToString() const; + // Returns true if the indexing was simplified. + bool Simplify(absl::Span dimension_sizes); + std::vector operand_indexing_maps; }; std::ostream& operator<<(std::ostream& out, diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 5a10285386e38b..776be601656c22 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -330,12 +330,11 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { ROOT fusion = f32[8, 16] fusion(p0), kind=kLoop, calls=f } )")); - // TODO(b/313840171): Simplify the composed affine expression. + EXPECT_TRUE(input_indexing.Simplify({8, 16})); EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)", - std::vector{}))))); + 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { @@ -352,14 +351,12 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { ROOT fusion = f32[10, 10, 10] fusion(p0), kind=kLoop, calls=f } )")); - // TODO(b/313840171): Simplify the composed affine expression. + EXPECT_TRUE(input_indexing.Simplify({10, 10, 10})); EXPECT_THAT( input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " - "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, d2 mod 10)", - std::vector{}))))); + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { @@ -395,6 +392,7 @@ TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { ROOT reshape = f32[32] reshape(p0) } )")); + EXPECT_FALSE(input_indexing.Simplify({32})); EXPECT_THAT( input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( @@ -411,6 +409,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { ROOT reshape = f32[4, 8] reshape(p0) } )")); + EXPECT_FALSE(input_indexing.Simplify({4, 8})); EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0 * 8 + d1)", @@ -426,6 +425,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandAndCollapseShape) { ROOT reshape = f32[32, 3, 4] reshape(p0) } )")); + EXPECT_FALSE(input_indexing.Simplify({32, 3, 4})); EXPECT_THAT( input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( @@ -443,6 +443,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { ROOT reshape = f32[4, 4, 8] reshape(p0) } )")); + EXPECT_FALSE(input_indexing.Simplify({4, 4, 8})); EXPECT_THAT( input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( @@ -459,11 +460,13 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { ROOT reshape = f32[2, 4, 4] reshape(p0) } )")); + EXPECT_TRUE(input_indexing.Simplify({2, 4, 4})); + // TODO(b/313840171): Simplify `(d1 * 4 + d2) floordiv 8` to `d1 floordiv 2`. EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " - "(d0 * 16 + d1 * 4 + d2) mod 8)", + "(d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, " + "(d1 * 4 + d2) mod 8)", std::vector{}))))); } @@ -476,6 +479,10 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { ROOT reshape = f32[4, 8] reshape(p0) } )")); + EXPECT_FALSE(input_indexing.Simplify({4, 8})); + // TODO(b/313840171): Simplify `(d0 * 8 + d1) floordiv 16` to `d0 floordiv 2`. + // TODO(b/313840171): Simplify `((d0 * 8 + d1) mod 16) floordiv 4` to + // `((d0 * 8 + d1) floordiv 4) mod 4` to `(d0 * 2 + d1 floordiv 4) mod 4`. EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( @@ -565,6 +572,7 @@ TEST_F(TileAnalysisTest, ReverseOp) { ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} } )")); + // TODO(b/313840171): Support simplifying this. EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( From 56016f2bfc0c3e85bcf072a0344b2da81f692953 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 29 Nov 2023 07:34:14 -0800 Subject: [PATCH 178/381] [TileAnalysis] Add indexing computation for HloBitcast. PiperOrigin-RevId: 586336936 --- .../xla/service/gpu/model/tile_analysis.cc | 93 ++++++++++++++----- .../service/gpu/model/tile_analysis_test.cc | 72 +++++++++++++- 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 3f1f4eefd4df76..771ec1494e48e6 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -40,6 +41,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -352,10 +354,9 @@ StatusOr ComputeReduceOpIndexing( } return HloInstructionIndexing{std::move(operand_indexing_maps)}; } - // Computes strides for a shape. std::vector ComputeStrides(absl::Span dims) { - size_t rank = dims.size(); + int rank = static_cast(dims.size()); std::vector strides(rank, 1); for (int i = rank - 2; i >= 0; --i) { strides[i] = dims[i + 1] * strides[i + 1]; @@ -458,9 +459,9 @@ void ComputeMinimalReshapeIndexing( // This is an optimization that allows us to construct simpler affine maps, // otherwise we would need to linearize/delinearize even some of the simpler // cases. -std::vector ComputeComposedReshapeIndexing( - absl::Span input_dims, absl::Span output_dims, - MLIRContext* mlir_context) { +IndexingMap ComputeReshapeIndexingMap(absl::Span input_dims, + absl::Span output_dims, + MLIRContext* mlir_context) { std::vector exprs; size_t input_rank = input_dims.size(); @@ -500,23 +501,20 @@ std::vector ComputeComposedReshapeIndexing( output_subshape.clear(); output_dims_exprs.clear(); } - return exprs; + return IndexingMap{ + .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + .input_dims_sizes = {}}; } StatusOr ComputeReshapeOpIndexing( - const HloInstruction* reshape, MLIRContext* mlir_context) { + const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { auto input_dims = reshape->operand(0)->shape().dimensions(); auto output_dims = reshape->shape().dimensions(); - - std::vector exprs = - ComputeComposedReshapeIndexing(input_dims, output_dims, mlir_context); - - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .input_dims_sizes = {}}; + IndexingMap reshape_indexing_map = + ComputeReshapeIndexingMap(input_dims, output_dims, mlir_context); return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; + .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; } StatusOr ComputeReverseOpIndexing( @@ -572,17 +570,63 @@ StatusOr ComputeSliceOpIndexing( .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; } +IndexingMap ComputeTransposeIndexingMap(absl::Span permutation, + MLIRContext* mlir_context) { + auto forward_permutation = AffineMap::getPermutationMap( + std::vector(permutation.begin(), permutation.end()), + mlir_context); + return IndexingMap{ + .affine_map = mlir::inversePermutation(forward_permutation), + .input_dims_sizes = {}}; +} + StatusOr ComputeTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - std::vector permutation(transpose->dimensions().begin(), - transpose->dimensions().end()); - IndexingMap permutation_map{ - .affine_map = mlir::inversePermutation( - AffineMap::getPermutationMap(permutation, mlir_context)), - .input_dims_sizes = {}}; + IndexingMap transpose_indexing_map = + ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); + return HloInstructionIndexing{{HloOperandIndexing{ + .indexing_maps = {std::move(transpose_indexing_map)}, .operand_id = 0}}}; +} +StatusOr ComputeBitcastOpIndexing( + const HloInstruction* bitcast, MLIRContext* mlir_context) { + const Shape& input_shape = bitcast->operand(0)->shape(); + const Shape& output_shape = bitcast->shape(); + ShapeUtil::BitcastDecomposition decomposed_bitcast = + ShapeUtil::DecomposeBitcast(input_shape, output_shape); + + if (std::holds_alternative( + decomposed_bitcast)) { + auto permutation = ShapeUtil::DeduceTransposeDimensionsForBitcast( + input_shape, output_shape); + CHECK(permutation.has_value()) + << "Failed to deduce permutation for a bitcast."; + IndexingMap transpose_indexing_map = + ComputeTransposeIndexingMap(*permutation, mlir_context); + return HloInstructionIndexing{{HloOperandIndexing{ + .indexing_maps = {std::move(transpose_indexing_map)}, + .operand_id = 0}}}; + } + if (std::holds_alternative( + decomposed_bitcast)) { + IndexingMap reshape_indexing_map = ComputeReshapeIndexingMap( + input_shape.dimensions(), output_shape.dimensions(), mlir_context); + return HloInstructionIndexing{{HloOperandIndexing{ + .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; + } + // `trt` stands for transpose-reshape-transpose decomposition of bitcast. + auto trt = std::get(decomposed_bitcast); + IndexingMap transpose_map_1 = + ComputeTransposeIndexingMap(trt.transpose1_dims, mlir_context); + IndexingMap reshape_map = + ComputeReshapeIndexingMap(trt.transpose1_shape.dimensions(), + trt.reshape_shape.dimensions(), mlir_context); + IndexingMap transpose_map_2 = + ComputeTransposeIndexingMap(trt.transpose2_dims, mlir_context); + IndexingMap composed_map = ComposeIndexingMaps( + ComposeIndexingMaps(transpose_map_1, reshape_map), transpose_map_2); return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(permutation_map)}, .operand_id = 0}}}; + .indexing_maps = {std::move(composed_map)}, .operand_id = 0}}}; } template @@ -907,6 +951,9 @@ StatusOr ComputeInstructionIndexing( if (auto bcast = DynCast(instr)) { return ComputeBroadcastOpIndexing(bcast, mlir_context); } + if (instr->opcode() == HloOpcode::kBitcast) { + return ComputeBitcastOpIndexing(instr, mlir_context); + } if (auto dot = DynCast(instr)) { return ComputeDotOpIndexing(dot, mlir_context); } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 776be601656c22..7dadcc852974c0 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -88,6 +88,54 @@ TEST_F(TileAnalysisTest, ElementwiseOp) { std::vector{}))))); } +TEST_F(TileAnalysisTest, BitcastIsReshape) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[4, 32] parameter(0) + ROOT bitcast = f32[4, 8, 4] bitcast(p0) + } + )")); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1 * 4 + d2)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, BitcastIsTranspose) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, BitcastIsTransposeReshapeTranspose) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[16, 17, 3] parameter(0) + ROOT bitcast = f32[51, 16] {0, 1} bitcast(p0) + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1) -> (d1, d0 floordiv 3, d0 mod 3)", + std::vector{}))))); +} + TEST_F(TileAnalysisTest, BroadcastOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( @@ -602,15 +650,31 @@ TEST_F(TileAnalysisTest, TransposeOp) { GetIndexingMapsForEntryComputation(R"( HloModule m ENTRY e { - p0 = f16[1, 8, 1536, 512] parameter(0) - ROOT transpose = f16[1, 8, 512, 1536]{2, 3, 1, 0} - transpose(p0), dimensions={0, 1, 3, 2} + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT transpose = f32[3, 6, 128, 12288] + transpose(p0), dimensions={0, 2, 3, 1} + } + )")); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", + std::vector{}))))); +} + +TEST_F(TileAnalysisTest, TransposeOp4D) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + ENTRY e { + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d1, d3, d2)", + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", std::vector{}))))); } From fe25113b710605f8ad7197203a12ea8793a11d71 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 29 Nov 2023 07:42:47 -0800 Subject: [PATCH 179/381] [XLA:GPU] Append ".0" suffix to all instructions names. This will make it much easier to map instructions before and after fusion for debug. PiperOrigin-RevId: 586338747 --- .../xla/xla/service/gpu/priority_fusion.cc | 16 ++++++++++++++++ .../xla/xla/service/gpu/priority_fusion_test.cc | 3 +-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 8b07d6c329df82..0c0c4c69a58775 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -451,6 +451,22 @@ StatusOr GpuPriorityFusion::Run( fusion_process_dump_ = std::make_unique(); } + // Appends ".0" suffix to all instructions. + // + // Every time an instruction is duplicated, the last integer suffix is + // incremented. + // Before: broadcast.123 -> broadcast.124 + // After: broadcast.123.0 -> broadcast.123.1 + // + // With this modification it will be easier to match intructions before and + // after fusion passes, because they will have the same unique prefix. Names + // are not used in the pipeline, but it makes debugging much easier. + for (auto* computation : GetFusionComputations(module, execution_threads)) { + for (auto* instruction : computation->instructions()) { + instruction->SetAndSanitizeName(absl::StrCat(instruction->name(), ".0")); + } + } + auto result = InstructionFusion::Run(module, execution_threads); if (dump_enabled) { diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 16cbcdf9e0cd65..8c122df7f1c00a 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -618,8 +618,7 @@ TEST_F(PriorityFusionTest, EpilogueFusion) { })"; RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( -CHECK: ROOT %fusion = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation - )"); +CHECK: ROOT {{.*}} = f32[8,4,128]{2,1,0} fusion(%p{{.*}}), kind=kInput, calls=%fused_computation)"); } TEST_F(PriorityFusionTest, EpilogueFusionFails) { From 13cd409dc76b868c8020e9a93909f66e086a1e97 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 29 Nov 2023 08:00:00 -0800 Subject: [PATCH 180/381] [xla:gpu] NFC: Migrate custom CUTLASS kernel to GemmUniversal templates GemmUniversal is recommended way of constructing CUTLASS kernels and it has support for GEMM kernels with mixed dtypes. PiperOrigin-RevId: 586342662 --- third_party/xla/xla/service/gpu/kernels/BUILD | 16 ++ .../gpu/kernels/cutlass_gemm_kernel.cu.cc | 128 ++++-------- .../gpu/kernels/cutlass_gemm_universal.cu.h | 185 ++++++++++++++++++ .../xla/stream_executor/cuda/cuda_driver.cc | 3 +- 4 files changed, 240 insertions(+), 92 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 267d51c520e97c..d08e19e189af96 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -121,6 +121,9 @@ cc_library( # visibility = ["//visibility:private"], # deps = [ # ":custom_kernel", +# ":cutlass_gemm_universal", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", # "//third_party/gpus/cutlass", # "//xla:statusor", # "//xla:xla_data_proto_cc", @@ -128,6 +131,19 @@ cc_library( # ], # ) # +# cuda_library( +# name = "cutlass_gemm_universal", +# hdrs = ["cutlass_gemm_universal.cu.h"], +# visibility = ["//visibility:private"], +# deps = [ +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", +# "//third_party/gpus/cutlass", +# "//xla:statusor", +# "//xla/stream_executor", +# ], +# ) +# # xla_test( # name = "cutlass_gemm_test", # srcs = if_cuda_is_configured(["cutlass_gemm_test.cc"]), diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc index 5a897b830d3c05..a84686a4617b3e 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc @@ -15,108 +15,54 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" -#include "third_party/gpus/cutlass/include/cutlass/gemm/device/gemm.h" -#include "xla/stream_executor/kernel.h" +#include +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_universal.cu.h" +#include "xla/statusor.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/xla_data.pb.h" namespace xla::gpu::kernel { -// The most basic CUTLASS f32 gemm kernel. -using CutlassGemm = - cutlass::gemm::device::Gemm; +using F32xF32toF32 = + cutlass::gemm::device::GemmUniversal; -StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, - int32_t n, int32_t k) { - if (dtype != PrimitiveType::F32) - return absl::InvalidArgumentError( - "Currently cutlass gemm kernel supports only F32 data type"); +//===----------------------------------------------------------------------===// +// Adaptor from a CUTLASS GemmUniversal to a CustomKernel. +//===----------------------------------------------------------------------===// - // Underlying CUDA kernel implementing gemm operation. - using GemmKernel = typename CutlassGemm::GemmKernel; +template +StatusOr LoadCutlassGemmUniversal(int32_t m, int32_t n, + int32_t k) { + using Kernel = typename Gemm::GemmKernel; cutlass::gemm::GemmCoord problem_size = {m, n, k}; - using ThreadblockShape = typename CutlassGemm::ThreadblockShape; - cutlass::gemm::GemmCoord tile_size = { - ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}; - - typename CutlassGemm::ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord tiled_shape = - threadblock_swizzle.get_tiled_shape(problem_size, tile_size, - /*split_k_slices=*/1); - - // Compute kernel launch grid size and shared memory requirement. - dim3 grid = threadblock_swizzle.get_grid_shape(tiled_shape); - se::BlockDim block_dims(grid.x, grid.y, grid.z); - se::ThreadDim thread_dims(GemmKernel::kThreadCount, 1, 1); - size_t shared_memory_bytes = sizeof(typename GemmKernel::SharedStorage); - - // Packs device memory arguments into CUTLASS kernel parameters struct. - using PackedArgs = StatusOr>; - auto pack = [problem_size, - tiled_shape](const se::KernelArgs &args) -> PackedArgs { - auto *mem_args = Cast(&args); - - // Converts DeviceMemoryBase to an opaque `void *` device pointer. - // - // TODO(ezhulenev): Add more checks for the number and types of device - // memory arguments. Right now we unsafely cast and extract buffers. - auto device_ptr = [&](size_t index) { - const void *opaque = mem_args->device_memory_ptr(index); - return static_cast(const_cast(opaque)); - }; - - // Strides for a row major layout. - int32_t lda = problem_size.k(); - int32_t ldb = problem_size.n(); - int32_t ldc = problem_size.n(); - - // Check if GemmKernel can implement the given problem size. - cutlass::Status can_implement = GemmKernel::can_implement( - problem_size, // problem size - {device_ptr(0), lda}, // Tensor-ref for source matrix A - {device_ptr(1), ldb}, // Tensor-ref for source matrix B - {device_ptr(2), ldc}, // Tensor-ref for source matrix C - {device_ptr(2), ldc} // Tensor-ref for destination matrix D - ); - - if (can_implement != cutlass::Status::kSuccess) { - return absl::InternalError( - "CUTLASS GemmKernel can not implement gemm for a given problem size"); - } - - // Sanity check that we do not accidentally get a giant parameters struct. - static_assert(sizeof(GemmKernel::Params) < 512, - "GemmKernel::Params struct size is unexpectedly large"); - - float alpha = 1.0, beta = 0.0; - GemmKernel::Params params{ - problem_size, - tiled_shape, - {device_ptr(0), lda}, // Tensor-ref for source matrix A - {device_ptr(1), ldb}, // Tensor-ref for source matrix B - {device_ptr(2), ldc}, // Tensor-ref for source matrix C - {device_ptr(2), ldc}, // Tensor-ref for destination matrix D - {alpha, beta}, // Scalars used in the Epilogue - /*workspace=*/nullptr, - /*gather_A_indices=*/nullptr, - /*gather_B_indices=*/nullptr, - /*gather_D_indices=*/nullptr}; - - return se::PackKernelArgs(args.number_of_shared_bytes(), - params); - }; - - // TODO(ezhulenev): We should generate a more descriptive names for custom + // TODO(ezhulenev): We should generate more descriptive names for custom // kernels, i.e. include tile and dimensions sizes, dtypes, etc. - se::MultiKernelLoaderSpec kernel_spec(/*arity=*/1, std::move(pack)); - kernel_spec.AddInProcessSymbol( - reinterpret_cast(cutlass::Kernel), "cutlass_gemm"); + se::MultiKernelLoaderSpec spec( + /*arity=*/1, gemm_universal::ArgsPacking(problem_size)); + spec.AddInProcessSymbol(reinterpret_cast(cutlass::Kernel2), + "cutlass_universal_gemm"); + + return CustomKernel("cutlass_gemm:f32<-f32xf32", std::move(spec), + gemm_universal::BlockDim(problem_size), + gemm_universal::ThreadDim(), + sizeof(typename Kernel::SharedStorage)); +} + +StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, + int32_t n, int32_t k) { + if (dtype != PrimitiveType::F32) + return absl::InvalidArgumentError( + "Currently cutlass gemm kernel supports only F32 data type"); - return CustomKernel("cutlass_gemm:f32<-f32xf32", std::move(kernel_spec), - block_dims, thread_dims, shared_memory_bytes); + return LoadCutlassGemmUniversal(m, n, k); } } // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h new file mode 100644 index 00000000000000..d5758dbac810a9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h @@ -0,0 +1,185 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "third_party/gpus/cutlass/include/cutlass/cutlass.h" +#include "third_party/gpus/cutlass/include/cutlass/gemm/device/gemm_universal.h" +#include "third_party/gpus/cutlass/include/cutlass/gemm/gemm_enumerated_types.h" +#include "third_party/gpus/cutlass/include/cutlass/gemm_coord.h" +#include "third_party/gpus/cutlass/include/cutlass/layout/matrix.h" +#include "xla/statusor.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::gpu::kernel::gemm_universal { + +// This is a template library that implements an adaptor from a CUTLASS +// GemmUniversal kernel to StreamExecutor primitives for kernel arguments +// packing and kernel launching. +// +// In all templates defined below `typename Gemm` should be a +// an instance of `cutlass::gemm::device::GemmUniversal` template. + +namespace se = ::stream_executor; + +//===----------------------------------------------------------------------===// +// Gemm launch dimension computation. +//===----------------------------------------------------------------------===// + +template +se::ThreadDim ThreadDim() { + using Kernel = typename Gemm::GemmKernel; + return se::ThreadDim(Kernel::kThreadCount, 1, 1); +} + +template +se::BlockDim BlockDim(const cutlass::gemm::GemmCoord &problem_size) { + using ThreadblockSwizzle = typename Gemm::ThreadblockSwizzle; + using ThreadblockShape = typename Gemm::ThreadblockShape; + + cutlass::gemm::GemmCoord tile_size = { + ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}; + + cutlass::gemm::GemmCoord grid_tiled_shape = + ThreadblockSwizzle::get_tiled_shape(problem_size, tile_size, + /*split_k_slices=*/1); + + auto grid = ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); + + return se::BlockDim(grid.x, grid.y, grid.z); +} + +//===----------------------------------------------------------------------===// +// Gemm strides computation. +//===----------------------------------------------------------------------===// + +template +int64_t LdA(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutA = typename Gemm::LayoutA; + + if constexpr (std::is_same_v) { + return problem_size.k(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +template +int64_t LdB(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutB = typename Gemm::LayoutB; + + if constexpr (std::is_same_v) { + return problem_size.n(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +template +int64_t LdC(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutC = typename Gemm::LayoutA; + + if constexpr (std::is_same_v) { + return problem_size.n(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +//===----------------------------------------------------------------------===// +// Packing kernel arguments to CUTLASS kernel parameters struct. +//===----------------------------------------------------------------------===// + +using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; + +template +auto *DevicePtr(const se::KernelArgsDeviceMemoryArray *args) { + const void *opaque = args->device_memory_ptr(index); + + if constexpr (index == 0) { + return static_cast(const_cast(opaque)); + } else if constexpr (index == 1) { + return static_cast(const_cast(opaque)); + } else if constexpr (index == 2) { + return static_cast(const_cast(opaque)); + } else { + static_assert(sizeof(Gemm) == 0, "illegal Gemm argument index"); + } +} + +template +KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size) { + using Arguments = typename Gemm::Arguments; + using Kernel = typename Gemm::GemmKernel; + using Params = typename Kernel::Params; + + // Sanity check that we do not accidentally get a giant parameters struct. + static_assert(sizeof(Params) < 512, + "Params struct size is unexpectedly large"); + + using PackedArgs = StatusOr>; + + return [=](const se::KernelArgs &args) -> PackedArgs { + auto *mem_args = Cast(&args); + + cutlass::Status can_implement = Kernel::can_implement(problem_size); + if (can_implement != cutlass::Status::kSuccess) { + return absl::InternalError(absl::StrCat( + "CUTLASS kernel can not implement gemm for a given problem size", + ": m=", problem_size.m(), ", n=", problem_size.n(), + ", k=", problem_size.k())); + } + + auto lda = LdA(problem_size); + auto ldb = LdB(problem_size); + auto ldc = LdC(problem_size); + + auto ptr_a = DevicePtr(mem_args); + auto ptr_b = DevicePtr(mem_args); + auto ptr_c = DevicePtr(mem_args); + + auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + float alpha = 1.0, beta = 0.0; + + // CUTLASS operation arguments. + Arguments arguments(mode, problem_size, + 1, // batch + {alpha, beta}, // epilogue + ptr_a, ptr_b, ptr_c, ptr_c, // pointers + 0, 0, 0, 0, // batch strides + lda, ldb, ldc, ldc // strides + ); + + // TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate + // correct kernel occupancy using GpuRuntime. + Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10); + + return se::PackKernelArgs(args.number_of_shared_bytes(), params); + }; +} + +} // namespace xla::gpu::kernel::gemm_universal + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index a85188c37627c2..19fb144b9fc11e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -1159,7 +1159,8 @@ struct BitPatternToValue { VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y - << " bdz: " << block_dim_z; + << " bdz: " << block_dim_z + << "; shared_mem_bytes: " << shared_mem_bytes; if (shared_mem_bytes != 0) { RETURN_IF_CUDA_RES_ERROR( cuFuncSetAttribute(function, From 2c822d6a7d7fbc2443acbb2386f700c7c1defdfd Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 29 Nov 2023 08:52:27 -0800 Subject: [PATCH 181/381] [XLA:GPU][NFC] Refactor tiling propagation. - Use `sliced count` instead of `slice limit` (will allow expressing slicing and concatenations same way). - Rename `size` to `count` for uniformity. PiperOrigin-RevId: 586356499 --- .../service/gpu/gemm_rewriter_triton_test.cc | 17 ++--- .../service/gpu/triton_tiling_propagation.cc | 64 +++++++++---------- .../service/gpu/triton_tiling_propagation.h | 32 +++++----- 3 files changed, 57 insertions(+), 56 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 537fea463d09c0..8913782398cca7 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -792,12 +792,12 @@ ENTRY e { EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, computation->parameter_instruction(0), 0), ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6, - /*slice_start=*/2, /*slice_limit=*/5, + /*slice_start=*/2, /*sliced_count=*/3, /*subfragments=*/ElementsAre(3)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, computation->parameter_instruction(0), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, - /*slice_start=*/16, /*slice_limit=*/23, + /*slice_start=*/16, /*sliced_count=*/7, /*subfragments=*/ElementsAre(7)))); } @@ -829,34 +829,35 @@ e { EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(0), 0), ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153, - /*slice_start=*/0, /*slice_limit=*/153, + /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(0), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536, - /*slice_start=*/0, /*slice_limit=*/1536, + /*slice_start=*/0, /*sliced_count=*/1536, /*subfragments=*/ElementsAre(1536)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(1), 0), ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153, - /*slice_start=*/0, /*slice_limit=*/153, + /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(1), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, - /*slice_start=*/0, /*slice_limit=*/128, + /*slice_start=*/0, /*sliced_count=*/128, /*subfragments=*/ElementsAre(128)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(2), 0), ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153, - /*slice_start=*/0, /*slice_limit=*/153, + /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(2), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, - /*slice_start=*/0, /*slice_limit=*/256, + /*slice_start=*/0, + /*sliced_count=*/256, /*subfragments=*/ElementsAre(256)))); } diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index ee581ad85e6a23..ec45285cb517f8 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -79,7 +79,7 @@ std::string TensorIterationSpec::IterationSpecFragment::ToString() const { bool TensorIterationSpec::IterationSpecFragment::operator!=( const IterationSpecFragment& other) const { return stride != other.stride || count != other.count || - slice_start != other.slice_start || slice_limit != other.slice_limit; + slice_start != other.slice_start || sliced_count != other.sliced_count; } std::string TensorIterationSpec::ToString() const { @@ -143,8 +143,8 @@ using FragmentOrders = DimensionOrder::FragmentOrders; } std::string DimensionOrder::Fragment::ToString() const { - return absl::StrCat(dst_dim_number_, ":", size_, ":", slice_start_, "-", - slice_limit_); + return absl::StrCat(dst_dim_number_, ":", count_, ":", slice_start_, "-", + sliced_count_); } std::string DimensionOrder::ToString() const { @@ -184,29 +184,29 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { dim_spec.back().subfragments.pop_back(); } // Contiguous dimension, split only logically. Merge it back. - if (fragment.full_size() > 1) { + if (fragment.full_count() > 1) { CHECK(!dim_spec.empty()); CHECK(!dim_spec.back().is_sliced()) << "Only the major-most fragment can have an offset."; dim_spec.back().slice_start = fragment.slice_start() * dim_spec.back().count; - dim_spec.back().slice_limit = - fragment.slice_limit() * dim_spec.back().count; - dim_spec.back().count *= fragment.full_size(); - dim_spec.back().subfragments.push_back(fragment.sliced_size()); + dim_spec.back().sliced_count = + fragment.sliced_count() * dim_spec.back().count; + dim_spec.back().count *= fragment.full_count(); + dim_spec.back().subfragments.push_back(fragment.sliced_count()); } } else { remove_last_fragment_if_degenerate(last_dim); // Add part of the dimension. - dim_spec.push_back( - TensorIterationSpec::IterationSpecFragment{accumulated_stride, - fragment.full_size(), - fragment.slice_start(), - fragment.slice_limit(), - {fragment.sliced_size()}}); + dim_spec.push_back(TensorIterationSpec::IterationSpecFragment{ + accumulated_stride, + fragment.full_count(), + fragment.slice_start(), + fragment.sliced_count(), + {fragment.sliced_count()}}); } - accumulated_stride *= fragment.full_size(); + accumulated_stride *= fragment.full_count(); last_dim = fragment.dst_dim_number(); } remove_last_fragment_if_degenerate(last_dim); @@ -225,7 +225,7 @@ std::optional LogicalIndexOfLabeledDimension( const int64_t dim_size = shape.dimensions()[dim]; int64_t fragments_size = 1; while (fragments_size < dim_size) { - fragments_size *= fragment_it->full_size(); + fragments_size *= fragment_it->full_count(); if (fragment_it->dst_dim_number() == label) { return dim; } @@ -304,12 +304,12 @@ RequirementsOrError GetRequirementsIfSupportedOrder( if (fragment_it == dim_fragments.cend()) { break; } - int64_t grouped_size = tensor_dim_fragments[*fragment_it].full_size(); + int64_t grouped_size = tensor_dim_fragments[*fragment_it].full_count(); // Gather contiguous fragments: they have consecutive indices. while ((fragment_it + 1) != dim_fragments.cend() && *(fragment_it + 1) == *fragment_it + 1) { ++fragment_it; - grouped_size *= tensor_dim_fragments[*fragment_it].full_size(); + grouped_size *= tensor_dim_fragments[*fragment_it].full_count(); } // Ignore 1-sized groups of fragments. if (grouped_size == 1) { @@ -481,11 +481,11 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // Find a continuous group of fragments corresponding to this dimension in // the source and assign the corresponding size in fragments of the // destination ignoring the source ones. - dst_remaining_size = src_dim->full_size(); + dst_remaining_size = src_dim->full_count(); while (src_dim + 1 != src_fragments_order.cend() && (src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) { ++src_dim; - dst_remaining_size *= src_dim->full_size(); + dst_remaining_size *= src_dim->full_count(); } while (dst_remaining_size > 1) { CHECK(dst_dim_it != dst_dim_end); @@ -496,8 +496,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( } continue; } - if (dst_remaining_size >= src_dim->full_size()) { - if (dst_remaining_size % src_dim->full_size()) { + if (dst_remaining_size >= src_dim->full_count()) { + if (dst_remaining_size % src_dim->full_count()) { return "Unsupported bitcast"; } // Source dimension fragment completely fits into the destination one: @@ -505,12 +505,12 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( add_new_fragment(*src_dim); // Update the size of the remaining part of the destination that is // carried over to next source dimensions. - dst_remaining_size /= src_dim->full_size(); + dst_remaining_size /= src_dim->full_count(); } else { // Source is larger than destination. // Assign further destination dimensions. // Size of the not yet assigned part of the source dimension. - int64_t src_remaining_size = src_dim->full_size(); + int64_t src_remaining_size = src_dim->full_count(); // Handle dimension splits. if (dst_remaining_size > 1) { // If there is a remaining fragment of a previous destination dimension @@ -617,7 +617,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( std::vector subdim_group; do { CHECK(src_fragment_it != src_fragments_order.end()); - subdim_size_accumulator *= src_fragment_it->full_size(); + subdim_size_accumulator *= src_fragment_it->full_count(); subdim_group.push_back(&*src_fragment_it); ++src_fragment_it; } while (subdim_size_accumulator < dim_size); @@ -690,7 +690,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { return FusionDecision("Unsupported concatenation."); } - dst_logical[i][0]->set_size(dst->shape().dimensions(i)); + dst_logical[i][0]->set_count(dst->shape().dimensions(i)); dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); } } @@ -728,7 +728,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( new_fragments.emplace_back( fragments[0]->dst_dim_number(), - fragments[0]->full_size() * fragments[1]->full_size() - padding); + fragments[0]->full_count() * fragments[1]->full_count() - + padding); dst_logical[i] = {&new_fragments.back()}; } } @@ -743,12 +744,11 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( return FusionDecision("Slicing of fragmented dimension."); } auto fragment = dst_logical[dim].front(); - fragment->set_size(dst->shape().dimensions(dim)); + fragment->set_count(dst->shape().dimensions(dim)); // Slicing of an already sliced dimension means adding offsets. fragment->set_slice( fragment->slice_start() + slice->slice_starts(dim), - fragment->slice_start() + slice->slice_starts(dim) + - fragment->sliced_size()); + fragment->sliced_count()); } } } else { @@ -777,7 +777,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( const auto it = src_to_dst.find(&src_fragments_order[fragment_number]); if (it == src_to_dst.cend()) { if (hlo.opcode() == HloOpcode::kBroadcast && - src_fragments_order[fragment_number].full_size() > 1 && + src_fragments_order[fragment_number].full_count() > 1 && dim_numbers_present_in_dst.contains(dim_index)) { return FusionDecision("Unsupported broadcast"); } @@ -882,7 +882,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, 0; })) { return FusionDecision( - "One or more operands of concatenation can not be perfectly tiled."); + "At least one operand of concatenation can not be perfectly tiled."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h index 69485b909684cb..08445f73b76080 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h @@ -36,17 +36,18 @@ namespace gpu { class TensorIterationSpec { public: - // Description of basic iteration: `count` elements separated by `stride`. + // Description of basic iteration: `count` elements separated by `stride` + // with initial offset of `slice_start` and only `sliced_count` elements used. struct IterationSpecFragment { int64_t stride; int64_t count; int64_t slice_start; - int64_t slice_limit; + int64_t sliced_count; // Logical subfragments when this iteration is composed // of several HLO dimensions. std::vector subfragments; - bool is_sliced() const { return count != slice_limit - slice_start; } + bool is_sliced() const { return count != sliced_count; } bool operator!=(const IterationSpecFragment& other) const; std::string ToString() const; }; @@ -103,35 +104,34 @@ class DimensionOrder { // Description of a continuous fragment of one dimension of a tensor. class Fragment { public: - explicit Fragment(int dst_dim_number, int64_t size) + explicit Fragment(int dst_dim_number, int64_t count) : dst_dim_number_(dst_dim_number), - size_(size), + count_(count), slice_start_(0), - slice_limit_(size) {} + sliced_count_(count) {} std::string ToString() const; // Label carrying the dimension number of an defining operation. int dst_dim_number() const { return dst_dim_number_; } // Total number of elements in the fragment ignoring slicing. - int64_t full_size() const { return size_; } + int64_t full_count() const { return count_; } // First used element. int64_t slice_start() const { return slice_start_; } - // Last used element. - int64_t slice_limit() const { return slice_limit_; } - int64_t sliced_size() const { return slice_limit_ - slice_start_; } - bool is_sliced() const { return full_size() != sliced_size(); } - void set_slice(int64_t start, int64_t limit) { + // Number of used elements. + int64_t sliced_count() const { return sliced_count_; } + bool is_sliced() const { return count_ != sliced_count_; } + void set_slice(int64_t start, int64_t count) { slice_start_ = start; - slice_limit_ = limit; + sliced_count_ = count; } - void set_size(int64_t size) { size_ = size; } + void set_count(int64_t count) { count_ = count; } private: const int dst_dim_number_; - int64_t size_; + int64_t count_; int64_t slice_start_; - int64_t slice_limit_; + int64_t sliced_count_; }; using Fragments = std::vector; using FragmentOrders = absl::flat_hash_map>; From cc7d75aa018ee5d7a8667ee8aba916df426064a4 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Wed, 29 Nov 2023 09:29:42 -0800 Subject: [PATCH 182/381] [XLA] Lower verbosity in MSA. PiperOrigin-RevId: 586366293 --- .../best_fit_repacker.cc | 17 ++++++++-------- .../memory_space_assignment.cc | 20 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc index aa129e88bb127f..7d2aa7af631961 100644 --- a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc +++ b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc @@ -246,7 +246,7 @@ class BestFitRepacker CHECK_EQ(allocation_blocks_.size(), full_buffer_interval_map_.size()); CHECK_EQ(allocation_blocks_.size(), sliced_buffer_interval_map_.size()); - VLOG(1) << [&]() -> std::string { + VLOG(2) << [&]() -> std::string { int sliced_blocks = 0; int colocation_sets = 0; int colocation_sets_with_multiple_sliced_blocks = 0; @@ -323,7 +323,7 @@ class BestFitRepacker // - chunks is sorted in slice time order void CommitChunks(const AllocationBlock* allocation_block, const std::vector& chunks) { - VLOG(2) << "Committing repack chunks for " << allocation_block->ToString(); + VLOG(3) << "Committing repack chunks for " << allocation_block->ToString(); int64_t new_offset = -1; std::optional repacked_slice_data = std::nullopt; @@ -345,7 +345,7 @@ class BestFitRepacker const Chunk& chunk = chunks[i]; int64_t start_time = sorted_inclusive_start_times[i]; result_.heap_size = result_.UpdatedHeapSize(chunk); - VLOG(2) << "Adding sliced chunk " << chunk.ToString() << " at [" + VLOG(3) << "Adding sliced chunk " << chunk.ToString() << " at [" << start_time << ", " << allocation_block->end_time << "]"; interval_tree_.Add(start_time, allocation_block->end_time, chunk); new_offset = (new_offset == -1 ? chunk.offset @@ -361,7 +361,7 @@ class BestFitRepacker CHECK_EQ(chunks.size(), 1); new_offset = chunks.front().offset; result_.heap_size = result_.UpdatedHeapSize(chunks.front()); - VLOG(2) << "Adding unsliced chunk " << chunks.front().ToString() + VLOG(3) << "Adding unsliced chunk " << chunks.front().ToString() << " at [" << allocation_block->inclusive_start_time << ", " << allocation_block->end_time << ")"; interval_tree_.Add(allocation_block->inclusive_start_time, @@ -555,8 +555,7 @@ class BestFitRepacker Finish(); bool success = result_.heap_size <= max_size_; if (!success) { - LOG(INFO) << "Repacking unsuccessful with heap size " - << result_.heap_size; + VLOG(1) << "Repacking unsuccessful with heap size " << result_.heap_size; return false; } @@ -576,13 +575,13 @@ class BestFitRepacker DebuggingValidate(); } - if (VLOG_IS_ON(1)) { + if (VLOG_IS_ON(2)) { for (AllocationBlock* block : allocation_blocks_) { - VLOG(1) << "AllocationBlock after repacking: " << block->ToString(); + VLOG(2) << "AllocationBlock after repacking: " << block->ToString(); } } - LOG(INFO) << "Repacking successful with heap size " << result_.heap_size; + VLOG(1) << "Repacking successful with heap size " << result_.heap_size; return true; } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index 66ea87342f95f1..85196bb5225823 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -6024,10 +6024,10 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( "CopyAllocations or SlicedCopyAllocations."; } if (prefetch_time != *request.preferred_prefetch_time) { - LOG(WARNING) << "Scheduled prefetch time (" << prefetch_time - << ") doesn't match the preferred prefetch time (" - << *request.preferred_prefetch_time - << "): " << request.use->hlo_use.ToString(); + VLOG(1) << "Scheduled prefetch time (" << prefetch_time + << ") doesn't match the preferred prefetch time (" + << *request.preferred_prefetch_time + << "): " << request.use->hlo_use.ToString(); } } return Result::kSuccess; @@ -6035,10 +6035,10 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( // Warn if there was a preferred prefetch time but we couldn't actually // prefetch. if (request.preferred_prefetch_time) { - LOG(WARNING) << "The request has a preferred prefetch time (" - << *request.preferred_prefetch_time - << ") which could not be satisfied: " - << request.use->hlo_use.ToString(); + VLOG(1) << "The request has a preferred prefetch time (" + << *request.preferred_prefetch_time + << ") which could not be satisfied: " + << request.use->hlo_use.ToString(); } result_mark(prefetch_result, allocation_result); } @@ -6334,8 +6334,8 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( return Result::kSuccess; } if (request.prefer_no_copy_alternate_mem_allocation) { - LOG(WARNING) << "Preferred no-copy allocation, but this was not possible: " - << request.use->hlo_use.ToString(); + VLOG(1) << "Preferred no-copy allocation, but this was not possible: " + << request.use->hlo_use.ToString(); } return Result::kFailOutOfMemory; } From e19035fd91fc9f0ab00575e32f3762d96f5d480b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 09:40:39 -0800 Subject: [PATCH 183/381] Record free_gpu_system_memory using addressable_devices() to avoid potential OOM issues across multiple processes. PiperOrigin-RevId: 586369191 --- third_party/xla/xla/pjrt/BUILD | 3 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 17 +++++++++++++- third_party/xla/xla/pjrt/metrics.cc | 23 ++++--------------- third_party/xla/xla/pjrt/metrics.h | 4 +++- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 59dc48817a1783..dd1e74b9528ce7 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -380,10 +380,7 @@ cc_library( hdrs = ["metrics.h"], visibility = ["//visibility:public"], deps = [ - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_init", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@local_tsl//tsl/lib/monitoring:counter", "@local_tsl//tsl/lib/monitoring:gauge", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index ad450b5ee94990..7070a25f5c67f9 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -537,7 +537,22 @@ StreamExecutorGpuClient::Compile(const XlaComputation& computation, auto executable = PjRtStreamExecutorClient::Compile(computation, options); #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) - metrics::RecordFreeGpuSystemMemory(); + for (const PjRtDevice* device : addressable_devices()) { + LocalDeviceState* local_device_state = + tensorflow::down_cast(device) + ->local_device_state(); + int64_t free_memory, total_memory; + if (local_device_state != nullptr) { + se::StreamExecutor* executor = local_device_state->executor(); + int device_ordinal = executor->device_ordinal(); + if (executor->DeviceMemoryUsage(&free_memory, &total_memory)) { + metrics::RecordFreeGpuSystemMemory(device_ordinal, free_memory); + } else { + LOG(ERROR) << "Failed to query available memory for GPU " + << device_ordinal; + } + } + } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM return executable; } diff --git a/third_party/xla/xla/pjrt/metrics.cc b/third_party/xla/xla/pjrt/metrics.cc index 06aaa18121c115..1a17cc6590e08d 100644 --- a/third_party/xla/xla/pjrt/metrics.cc +++ b/third_party/xla/xla/pjrt/metrics.cc @@ -17,11 +17,7 @@ limitations under the License. #include -#include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" #include "tsl/lib/monitoring/counter.h" #include "tsl/lib/monitoring/gauge.h" @@ -73,21 +69,10 @@ void RecordPjrtCompilerCompileModuleStatus(bool is_compiling) { pjrt_compiler_is_compiling_module->GetCell()->Set(is_compiling); } -void RecordFreeGpuSystemMemory() { - tensorflow::se::Platform* gpu_manager = tensorflow::se::GPUMachineManager(); - int visible_device_count = gpu_manager->VisibleDeviceCount(); - if (gpu_manager == nullptr || visible_device_count <= 0) return; - - for (int i = 0; i < visible_device_count; ++i) { - tensorflow::se::StreamExecutor* se = - gpu_manager->ExecutorForDevice(i).value(); - int64_t free_memory = 0, total_memory = 0; - if (se->DeviceMemoryUsage(&free_memory, &total_memory)) { - free_gpu_system_memory->GetCell(absl::StrCat(i))->Set(free_memory); - } else { - LOG(ERROR) << "Failed to query available memory for GPU " << i; - } - } +void RecordFreeGpuSystemMemory(const int device_ordinal, + const int64_t free_memory) { + free_gpu_system_memory->GetCell(absl::StrCat(device_ordinal)) + ->Set(free_memory); } int64_t GetFreeGpuSystemMemory(int gpu_id) { diff --git a/third_party/xla/xla/pjrt/metrics.h b/third_party/xla/xla/pjrt/metrics.h index 05c618e0e50ceb..870e473aac6a89 100644 --- a/third_party/xla/xla/pjrt/metrics.h +++ b/third_party/xla/xla/pjrt/metrics.h @@ -38,8 +38,10 @@ void RecordPjrtCompilerCompileComputationStatus(bool is_compiling); void RecordPjrtCompilerCompileModuleStatus(bool is_compiling); -void RecordFreeGpuSystemMemory(); +// TODO(xiangll): Refactor to a more appropriate location. +void RecordFreeGpuSystemMemory(int device_ordinal, int64_t free_memory); +// TODO(xiangll): Refactor to a more appropriate location. int64_t GetFreeGpuSystemMemory(int gpu_id); } // namespace metrics From 755370ccb4c06e51f2cc3fac6bc6dcbdf65f4b9f Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Wed, 29 Nov 2023 10:11:17 -0800 Subject: [PATCH 184/381] Limit "single source file per target" presubmit to check only python targets. Add path exclusions for certain directories. PiperOrigin-RevId: 586378455 --- tensorflow/build_cleaner_spec.textproto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/build_cleaner_spec.textproto b/tensorflow/build_cleaner_spec.textproto index 4c593a0ec4fd4c..bea7e8ac36462a 100644 --- a/tensorflow/build_cleaner_spec.textproto +++ b/tensorflow/build_cleaner_spec.textproto @@ -1,12 +1,12 @@ # proto-file: devtools/build_cleaner/proto/actions.proto # proto-message: ActionSpecs -# Rules (except for the allowlist) should not have more than one source file. +# Python rules should not have more than one source file. action_spec { action: CHECK_FILE_COUNT file_count_params { rule_selector { - rule_kind_regex: "^(?!filegroup|genrule|_policy_filegroup).*$" + rule_kind_regex: "^.*py(type)?(_strict)?_(binary|library|test).*$" generator_function_regex: "^(?!boq_header)$" } max_source_count: 1 From 890ed36c68f8df2a3823c3dcfc00c696eabd3fcb Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Wed, 29 Nov 2023 10:13:12 -0800 Subject: [PATCH 185/381] Integrate LLVM at llvm/llvm-project@3287ae8f6520 Updates LLVM usage to match [3287ae8f6520](https://github.com/llvm/llvm-project/commit/3287ae8f6520) PiperOrigin-RevId: 586379104 --- third_party/llvm/generated.patch | 48 +++++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 12 ++--- .../translate/mhlo_to_hlo/type_to_shape.cc | 2 +- 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..51af6376f76b2e 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,49 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/compiler-rt/lib/msan/msan_interceptors.cpp b/compiler-rt/lib/msan/msan_interceptors.cpp +--- a/compiler-rt/lib/msan/msan_interceptors.cpp ++++ b/compiler-rt/lib/msan/msan_interceptors.cpp +@@ -244,20 +244,23 @@ + #endif + + #if !SANITIZER_FREEBSD && !SANITIZER_NETBSD +-// This function actually returns a struct by value, but we can't unpoison a +-// temporary! The following is equivalent on all supported platforms but +-// aarch64 (which uses a different register for sret value). We have a test +-// to confirm that. +-INTERCEPTOR(void, mallinfo, __sanitizer_struct_mallinfo *sret) { +-#ifdef __aarch64__ +- uptr r8; +- asm volatile("mov %0,x8" : "=r" (r8)); +- sret = reinterpret_cast<__sanitizer_struct_mallinfo*>(r8); +-#endif +- REAL(memset)(sret, 0, sizeof(*sret)); ++ ++template ++static NOINLINE void clear_mallinfo(T *sret) { ++ ENSURE_MSAN_INITED(); ++ internal_memset(sret, 0, sizeof(*sret)); + __msan_unpoison(sret, sizeof(*sret)); + } +-#define MSAN_MAYBE_INTERCEPT_MALLINFO INTERCEPT_FUNCTION(mallinfo) ++ ++// Interceptor relies on NRVO and assumes that sret will be pre-allocated in ++// caller frame. ++INTERCEPTOR(__sanitizer_struct_mallinfo, mallinfo) { ++ __sanitizer_struct_mallinfo sret; ++ clear_mallinfo(&sret); ++ return sret; ++} ++ ++# define MSAN_MAYBE_INTERCEPT_MALLINFO INTERCEPT_FUNCTION(mallinfo) + #else + #define MSAN_MAYBE_INTERCEPT_MALLINFO + #endif +diff -ruN --strip-trailing-cr a/compiler-rt/test/msan/Linux/mallinfo.cpp b/compiler-rt/test/msan/Linux/mallinfo.cpp +--- a/compiler-rt/test/msan/Linux/mallinfo.cpp ++++ b/compiler-rt/test/msan/Linux/mallinfo.cpp +@@ -1,5 +1,4 @@ + // RUN: %clangxx_msan -O0 -g %s -o %t && %run %t +-// UNSUPPORTED: aarch64-target-arch + + #include + #include diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index ad08dfe0605b8e..e5364bd9ec0ab3 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372" - LLVM_SHA256 = "9d9ae8ae30f6262ca0823493893398ea2ab6fbd49027e338e06ac7c25bb8caf4" + LLVM_COMMIT = "3287ae8f6520ef81570377c1fb4c7147782a13ef" + LLVM_SHA256 = "87c55be01fb53ab0f2ce03bca419a5b08393247c28ecc8b1facd78bd3e7614da" tf_http_archive( name = name, diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index 098a682c82045c..b5f44584f1abc5 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -91,7 +91,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, if (is_bounded_dynamic) return Unimplemented( "MHLO doesn't support bounded dynamic shapes for sparse tensors"); - llvm::SmallVector dlts; + llvm::SmallVector lts; for (size_t i = 0, e = layout.dim_level_types().size(); i < e; ++i) { auto dlt = layout.dim_level_types()[i]; bool ordered = @@ -100,19 +100,19 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, i < layout.dim_unique().size() ? layout.dim_unique()[i] : true; switch (dlt) { case DimLevelType::DIM_DENSE: - dlts.push_back(*mlir::sparse_tensor::buildLevelType( + lts.push_back(*mlir::sparse_tensor::buildLevelType( mlir::sparse_tensor::LevelFormat::Dense, ordered, unique)); break; case DimLevelType::DIM_COMPRESSED: - dlts.push_back(*mlir::sparse_tensor::buildLevelType( + lts.push_back(*mlir::sparse_tensor::buildLevelType( mlir::sparse_tensor::LevelFormat::Compressed, ordered, unique)); break; case DimLevelType::DIM_SINGLETON: - dlts.push_back(*mlir::sparse_tensor::buildLevelType( + lts.push_back(*mlir::sparse_tensor::buildLevelType( mlir::sparse_tensor::LevelFormat::Singleton, ordered, unique)); break; case DimLevelType::DIM_LOOSE_COMPRESSED: - dlts.push_back(*mlir::sparse_tensor::buildLevelType( + lts.push_back(*mlir::sparse_tensor::buildLevelType( mlir::sparse_tensor::LevelFormat::LooseCompressed, ordered, unique)); break; @@ -127,7 +127,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, builder.getContext()); // TODO(atondwal): support sizes other than 32 when XLA does encoding = SparseTensorEncodingAttr::get( - builder.getContext(), dlts, id_map, mlir::AffineMap(), 32, 32); + builder.getContext(), lts, id_map, mlir::AffineMap(), 32, 32); } } return TypeT::get(shape, element_type_or.value(), encoding); diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc index d5709b91b80378..e10d8b7cff00f8 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -82,7 +82,7 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { } std::optional> ConvertDimLevelType( - mlir::sparse_tensor::DimLevelType lt) { + mlir::sparse_tensor::LevelType lt) { auto f = mlir::sparse_tensor::getLevelFormat(lt); if (!f) return std::nullopt; From 23b87eb7c8cc0bfc417ae09d14ad575f445fa3a5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 10:13:49 -0800 Subject: [PATCH 186/381] [XLA:CPU] Add a direct implementation of AllGather, rather than lowering AllGather to AllReduce. PiperOrigin-RevId: 586379273 --- third_party/xla/xla/service/cpu/BUILD | 3 +- .../xla/service/cpu/collectives_interface.h | 9 ++- .../xla/xla/service/cpu/cpu_compiler.cc | 2 - .../xla/service/cpu/cpu_layout_assignment.cc | 10 ++++ .../xla/xla/service/cpu/cpu_runtime.cc | 39 ++++++++++++ third_party/xla/xla/service/cpu/cpu_runtime.h | 7 +++ .../xla/service/cpu/in_process_collectives.cc | 59 +++++++++++++++++++ .../xla/service/cpu/in_process_collectives.h | 4 ++ third_party/xla/xla/service/cpu/ir_emitter.cc | 46 +++++++++++++++ third_party/xla/xla/service/cpu/ir_emitter.h | 1 + .../xla/xla/service/cpu/simple_orc_jit.cc | 1 + 11 files changed, 176 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index ef51b94dcd80ad..44dbf2f4db410d 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -250,7 +250,6 @@ cc_library( "//xla/runtime:executable", "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", - "//xla/service:all_gather_decomposer", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", "//xla/service:batch_dot_simplification", @@ -1351,7 +1350,9 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features", + "//xla:shape_util", "//xla:util", + "//xla/hlo/ir:hlo", "//xla/service:computation_layout", "//xla/service:layout_assignment", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index bd518db3a780bc..4191df1d831fa2 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -59,9 +59,14 @@ class CollectivesCommunicator { // The all-to-all chunks are passed separately and do not have to be // contiguous in memory. virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffer, - absl::Span output_buffer, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) = 0; + + // Performs an all-gather. + virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index ae2a2cd5672ae6..e2723c77c18a7b 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -115,7 +115,6 @@ limitations under the License. #include "xla/runtime/executable.h" #include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" -#include "xla/service/all_gather_decomposer.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" #include "xla/service/batch_dot_simplification.h" @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 8b124ddaa60397..9b2b8331e2d3fd 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -18,9 +18,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/map_util.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" +#include "xla/shape_util.h" #include "tsl/platform/errors.h" namespace xla { @@ -126,6 +129,13 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kAllGather) { + // XLA:CPU can only support all-gathers where the gather dimension is the + // most major dimension in the layout. + auto ag = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()), + ag)); } else { for (int64_t operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index a4090d67fc8d9b..81bd76652a3ac6 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -142,6 +142,7 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -345,6 +346,34 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void AllGatherImpl(const ExecutableRunOptions* run_options, + int32_t channel_id_present, int64_t op_id, + const void* replica_groups_str, + int32_t replica_groups_str_size, int64_t buffer_size, + void* source_buffer, void* destination_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + /*use_global_device_ids=*/std::nullopt, op_id); + + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size, + source_buffer, destination_buffer, + DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -503,6 +532,16 @@ void __xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions* run_options, destination_buffers); } +void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, + int32_t channel_id_present, int64_t op_id, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int64_t buffer_size, void* source_buffer, + void* destination_buffer) { + return xla::cpu::runtime::AllGatherImpl( + run_options, channel_id_present, op_id, replica_groups_str, + replica_groups_str_size, buffer_size, source_buffer, destination_buffer); +} void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 361a116e7c300f..9429242d5f1b86 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -84,6 +84,7 @@ extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; +extern const char* const kAllGatherSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -195,6 +196,12 @@ extern void __xla_cpu_runtime_AllToAll( int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size, void** source_buffers, void** destination_buffers); +extern void __xla_cpu_runtime_AllGather( + const xla::ExecutableRunOptions* run_options, int32_t channel_id_present, + int64_t op_id, const void* replica_groups_str, + int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, + void* destination_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index fc13c6c3870377..78eee1aa74f731 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -340,6 +340,44 @@ class CpuAllToAllRendezvous } }; +struct AllGatherParticipantData : ParticipantData { + AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + const void* source_buffer; + void* destination_buffer; + size_t chunk_size; + + std::string ToString() const override { + return absl::StrFormat( + "AllGatherParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_size); + } +}; + +class CpuAllGatherRendezvous + : public Rendezvous { + public: + explicit CpuAllGatherRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const AllGatherParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + char* out = static_cast(p.destination_buffer); + for (int i = 0; i < world_size; ++i, out += p.chunk_size) { + std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); + } + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -349,6 +387,8 @@ struct InProcessCollectivesState { collective_permute_rendezvous_map; RefcountingHashMap all_to_all_rendezvous_map; + RefcountingHashMap + all_gather_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -429,6 +469,25 @@ absl::Status InProcessCollectivesCommunicator::AllToAll( .status(); } +absl::Status InProcessCollectivesCommunicator::AllGather( + const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + AllGatherParticipantData participant(key, rank_); + participant.chunk_size = chunk_bytes; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllGatherRendezvous::SubmitParticipant( + [&] { + return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index fb25fd3528d606..aaedc474fa39b2 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -55,6 +55,10 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { absl::Span output_buffers, absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 18523dec844113..f5d7a4c2c40fab 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1344,6 +1344,52 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { return OkStatus(); } +Status IrEmitter::HandleAllGather(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); + + std::string replica_groups = + ReplicaGroupsToString(instruction->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + + const HloInstruction* op = instruction->operand(0); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice in_slice, + assignment_.GetUniqueSlice(op, {})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(instruction, {})); + const Shape& operand_shape = op->shape(); + CHECK(op->shape().IsArray()) + << "Operand to all-gather must be arrays: " << instruction->ToString(); + llvm::Value* output_buffer = EmitBufferPointer(out_slice, operand_shape); + llvm::Value* input_buffer = GetEmittedValueFor(op); + int64_t buffer_size = in_slice.size(); + + EmitCallToFunc( + runtime::kAllGatherSymbolName, + { + /*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32( + static_cast(instruction->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(instruction->channel_id().has_value() + ? *instruction->channel_id() + : instruction->GetModule()->unique_id()), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + /*buffer_size=*/b_.getInt64(buffer_size), + /*source_buffer=*/input_buffer, + /*destination_buffer=*/output_buffer, + }, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); + return OkStatus(); +} + Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 3a194d054cb5fd..a0e31671ab4375 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -134,6 +134,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* instruction) override; Status HandleAllToAll(HloInstruction* instruction) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant) override; diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 8895b4f6451d5a..2e27a7c810869d 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -485,6 +485,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); + REGISTER_CPU_RUNTIME_SYMBOL(AllGather); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); From 4d6482958fd624b500600674275697fb052656d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 10:17:31 -0800 Subject: [PATCH 187/381] User sharding annotations can sometimes be invalid wrt to the shape of the instruction. When this happens, and a consumer instruction generates sharding startegies for itself based on the strategies for such an operand with an invalid sharding annotation, we don't want to drop that strategy if only one exists. PiperOrigin-RevId: 586380530 --- .../auto_sharding/auto_sharding.cc | 7 +++-- .../auto_sharding/auto_sharding_test.cc | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b9fca89d5770c7..b77b1a134b8b80 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1457,7 +1457,8 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, for (int64_t i = 0; i < shape.rank(); ++i) { if (tile_assignment.dim(i) > 1 && tile_assignment.dim(i) > shape.dimensions(i)) { - VLOG(1) << "Removing invalid strategy: " << strategy.ToString(); + VLOG(1) << "May remove invalid strategy if valid ones exist: " + << strategy.ToString(); is_strategy_valid = false; break; } @@ -1466,7 +1467,9 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, new_vector.push_back(strategy); } } - strategy_group->strategies = std::move(new_vector); + if (!new_vector.empty()) { + strategy_group->strategies = std::move(new_vector); + } } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 7b772c51bd0c34..88700280b6b3cd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1561,6 +1561,32 @@ ENTRY %entry { EXPECT_TRUE(changed); } +TEST_F(AutoShardingTest, ReshapeWithInvalidUserSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param.0 = bf16[24,16,16]{2,1,0} parameter(0), sharding={devices=[32,1,1]<=[32]} + %reshape = bf16[1,24,16,16]{3,2,1,0} reshape(%param.0) + %copy = bf16[1,24,16,16]{3,2,1,0} copy(%reshape) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {32, 1}; + option.device_mesh_ids.resize(32); + std::iota(option.device_mesh_ids.begin(), option.device_mesh_ids.end(), 0); + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + HloInstruction* reshape = FindInstruction(module.get(), "reshape"); + EXPECT_THAT(reshape, op::Sharding("{devices=[1,32,1,1]<=[32]}")); +} + TEST_F(AutoShardingTest, Broadcast) { const char* const hlo_string = R"( HloModule module From cc27487b56781c794395704b0284a7c00c0f177d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 10:22:39 -0800 Subject: [PATCH 188/381] Some cleanup: Inline a function used once, remove some dead code, and compactify some code. PiperOrigin-RevId: 586382339 --- .../auto_sharding/auto_sharding.cc | 59 +++++-------------- 1 file changed, 16 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b77b1a134b8b80..de3b237510c236 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1217,18 +1217,6 @@ StatusOr> CreateAllStrategiesGroup( return strategy_group; } -StatusOr> CreateParameterStrategyGroup( - const HloInstruction* ins, const Shape& shape, size_t instruction_id, - StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, const AutoShardingOption& option, - double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const CallGraph& call_graph, bool only_allow_divisible) { - return CreateAllStrategiesGroup( - ins, shape, instruction_id, strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, call_graph, - only_allow_divisible, option.allow_replicated_parameters); -} - // The sharding is replicated or the total number of tiles is over or equal to // the total number of devices. If returns true, this sharding is likely // provided by users. @@ -1290,26 +1278,20 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. strategy_group->following = nullptr; - std::vector strategy_indices; + std::vector new_strategies; for (size_t i = 0; i < strategy_group->strategies.size(); i++) { if (strategy_group->strategies[i].output_sharding == existing_sharding) { - strategy_indices.push_back(i); + VLOG(1) << "Keeping strategy index: " << i; + ShardingStrategy found_strategy = strategy_group->strategies[i]; + new_strategies.push_back(found_strategy); } } - if (!strategy_indices.empty()) { - VLOG(1) << "Keeping strategy indices: " - << spmd::ToString(strategy_indices); + if (!new_strategies.empty()) { // Stores other strategies in the map, removes them in the vector and // only keeps the one we found. pretrimmed_strategy_map[strategy_group->node_idx] = strategy_group->strategies; - std::vector new_strategies; - for (int32_t found_strategy_index : strategy_indices) { - ShardingStrategy found_strategy = - strategy_group->strategies[found_strategy_index]; - new_strategies.push_back(found_strategy); - } strategy_group->strategies.clear(); strategy_group->strategies = new_strategies; } else { @@ -1317,9 +1299,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( std::string name = ToStringSimple(existing_sharding); std::vector> resharding_costs; std::vector> input_shardings; - if (strategy_group->in_nodes.empty()) { - resharding_costs = {}; - } else { + if (!strategy_group->in_nodes.empty()) { HloInstruction* ins = instructions.at(strategy_group->instruction_id); for (size_t i = 0; i < strategy_group->in_nodes.size(); i++) { HloInstruction* operand = @@ -1331,15 +1311,13 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( input_shardings.push_back(input_sharding_or.value()); } - StrategyGroup* operand_strategy_group; - Shape operand_shape; + StrategyGroup* operand_strategy_group = + strategy_map.at(operand).get(); + Shape operand_shape = operand->shape(); if (ins->opcode() == HloOpcode::kGetTupleElement) { operand_strategy_group = - strategy_map.at(operand)->childs[ins->tuple_index()].get(); - operand_shape = operand->shape().tuple_shapes(ins->tuple_index()); - } else { - operand_strategy_group = strategy_map.at(operand).get(); - operand_shape = operand->shape(); + operand_strategy_group->childs[ins->tuple_index()].get(); + operand_shape = operand_shape.tuple_shapes(ins->tuple_index()); } resharding_costs.push_back( ReshardingCostVector(operand_strategy_group, operand_shape, @@ -1770,10 +1748,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kRngBitGenerator: case HloOpcode::kRng: { strategy_group = - CreateParameterStrategyGroup( - ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible) + CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, + strategy_groups, cluster_env, strategy_map, + option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + option.allow_replicated_parameters) .value(); break; } @@ -1905,11 +1884,6 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - const HloInstruction* operand = ins->operand(0); - - const StrategyGroup* operand_strategies = - strategy_map.at(operand).get(); - CHECK(!operand_strategies->is_tuple); if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, cluster_env, strategy_map, strategy_group, @@ -2364,7 +2338,6 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); } - break; } case HloOpcode::kConditional: From 7d4df8994e12d12b673ee0a17df101c755437579 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 29 Nov 2023 10:52:56 -0800 Subject: [PATCH 189/381] Reenable `BitcastConvert` tests in JAX CI now that the tests pass again PiperOrigin-RevId: 586392427 --- third_party/xla/.kokoro/jax/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index 5affd26e0f74bd..305c1806c106ea 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -87,7 +87,7 @@ build_and_test_on_rbe_gpu() { --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow|LaxTest.testBitcastConvertType" \ + --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \ --test_tag_filters=-multiaccelerator \ -- //tests:gpu_tests //tests:backend_independent_tests } From 7f89ef60e10161db3061f2def13f9cb6fc5ccae0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:01:15 -0800 Subject: [PATCH 190/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/1bdc3ca2c91c108235a08d603a629970533649f9. PiperOrigin-RevId: 586395075 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index e501613187629a..40f4624639e1f4 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "58f2ec4dc891dc0bc0815a8c2d1caf196bfc13d5" - TFRT_SHA256 = "7525f0bb63fc3c0cf2df7ce1b09949510b42ffa669cc38ed83f6948a078d6633" + TFRT_COMMIT = "1bdc3ca2c91c108235a08d603a629970533649f9" + TFRT_SHA256 = "dd9c04c8907c217cefd8336908fcbc2adff00a3126fbc0f6af4182e22ce87395" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index e501613187629a..40f4624639e1f4 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "58f2ec4dc891dc0bc0815a8c2d1caf196bfc13d5" - TFRT_SHA256 = "7525f0bb63fc3c0cf2df7ce1b09949510b42ffa669cc38ed83f6948a078d6633" + TFRT_COMMIT = "1bdc3ca2c91c108235a08d603a629970533649f9" + TFRT_SHA256 = "dd9c04c8907c217cefd8336908fcbc2adff00a3126fbc0f6af4182e22ce87395" tf_http_archive( name = "tf_runtime", From 2d6d24d62b6eb1e44dae182984844342343a74f7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:03:28 -0800 Subject: [PATCH 191/381] Adds the ability to dump solver request protos. PiperOrigin-RevId: 586395795 --- .../experimental/auto_sharding/auto_sharding.cc | 10 ++++++---- .../experimental/auto_sharding/auto_sharding.h | 2 +- .../auto_sharding/auto_sharding.proto | 1 + .../auto_sharding/auto_sharding_impl.cc | 4 ++-- .../auto_sharding/auto_sharding_solver.cc | 16 +++++++++++++++- .../auto_sharding/auto_sharding_wrapper.h | 2 +- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index de3b237510c236..b2947dd85c750f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2479,7 +2479,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // NOLINTEND AutoShardingSolverResult CallSolver( - const HloLiveRange& hlo_live_range, + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, @@ -2489,6 +2489,7 @@ AutoShardingSolverResult CallSolver( sharding_propagation_solution) { // Serialize edges and edge costs to 1d numpy arrays AutoShardingSolverRequest request; + request.set_module_name(hlo_module.name()); request.set_num_nodes(strategy_groups.size()); request.set_memory_budget(option.memory_budget_per_device); request.mutable_s_len()->Add(cost_graph.node_lens_.begin(), @@ -4415,9 +4416,10 @@ StatusOr AutoShardingImplementation::RunAutoSharding( std::vector e_val; double objective = -1.0; if (!option_.load_solution_vector) { - auto solver_result = Solve( - *hlo_live_range, liveness_node_set, strategy_map, strategy_groups, - cost_graph, alias_set, option_, sharding_propagation_solution); + auto solver_result = + Solve(*module, *hlo_live_range, liveness_node_set, strategy_map, + strategy_groups, cost_graph, alias_set, option_, + sharding_propagation_solution); if (solver_result.skip_auto_sharding) { return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; } else if (!solver_result.status.ok()) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index fa1a19c8d34f3f..5af6ad3d35cc8e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -212,7 +212,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, // The high-level "recipe" for solving an Auto Sharding problem. AutoShardingSolverResult Solve( - const HloLiveRange& hlo_live_range, + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto index 5683b4a02f957b..095a3fce35e3e5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.proto @@ -58,4 +58,5 @@ message AutoShardingSolverRequest { bool crash_at_infinity_costs_check = 21; bool compute_iis = 22; double saltiplier = 23; + string module_name = 24; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 546470f6ad7f25..21281269eb5350 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -31,13 +31,13 @@ namespace xla { namespace spmd { AutoShardingSolverResult Solve( - const HloLiveRange& hlo_live_range, + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver(hlo_live_range, liveness_node_set, strategy_map, + return CallSolver(hlo_module, hlo_live_range, liveness_node_set, strategy_map, strategy_groups, cost_graph, alias_set, /*s_hint*/ {}, /*compute_iis*/ true, option.solver_timeout_in_seconds, option, sharding_propagation_solution); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 6b1790f8b78070..1847af9e87aed9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -495,6 +495,19 @@ AutoShardingSolverResult CallORToolsSolver( LOG(ERROR) << write_status.message(); } } + // Exports the solver request proto for debugging. + bool dump_solver_request = false; + if (dump_solver_request) { + uint64_t solver_request_fprint = + tsl::Fingerprint64(request.SerializeAsString()); + auto write_status = file::SetBinaryProto( + // Modify this file path if needed. + absl::StrCat("/tmp/solver_request_", solver_request_fprint, ".proto"), + request, file::Defaults()); + if (!write_status.ok()) { + LOG(ERROR) << write_status.message(); + } + } #endif if (request.has_solver_timeout()) { solver->SetTimeLimit( @@ -510,7 +523,8 @@ AutoShardingSolverResult CallORToolsSolver( << "Total instructions: " << request.num_nodes() << "\n" << "Memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) << "GB\n" - << "Number of ILP constraints: " << solver->NumConstraints(); + << "Number of ILP constraints: " << solver->NumConstraints() << "\n" + << "Module name: " << request.module_name(); return SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, *solver); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 0308d636d6835f..008a4184e00b96 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -36,7 +36,7 @@ namespace spmd { // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. AutoShardingSolverResult CallSolver( - const HloLiveRange& hlo_live_range, + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, From 78ca91bbefbeaf2d2d36e9f241b79df56bb4923e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:36:06 -0800 Subject: [PATCH 192/381] Skip two tests in tensorflow/python/module:module_test, on Python 3.12. For details, see the thread below, particularly starting with the linked comment: https://github.com/GrahamDumpleton/wrapt/issues/231#issuecomment-1570187961 PiperOrigin-RevId: 586406063 --- tensorflow/python/module/module_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py index 64972f3f850768..bcfa84b14d1507 100644 --- a/tensorflow/python/module/module_test.py +++ b/tensorflow/python/module/module_test.py @@ -17,6 +17,8 @@ import abc import collections import itertools +import sys +import unittest from absl.testing import parameterized @@ -514,6 +516,8 @@ class DangerousModule(module.Module): self.assertLen(mod.variables, 1) self.assertEqual(mod.variables[0], mod.normal_variable) + @unittest.skipIf(sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="b/313658911: _TupleWrapper __dict__ attribute error") def test_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) @@ -531,6 +535,8 @@ def test_with_path(self): ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},) + @unittest.skipIf(sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="b/313658911: _TupleWrapper __dict__ attribute error") def test_cycles_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) From 254200866239d101b22b0347478de037f19ce350 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 29 Nov 2023 11:38:49 -0800 Subject: [PATCH 193/381] Fastpath for setting disable-jit. PiperOrigin-RevId: 586406799 --- third_party/xla/xla/python/jax_jit.cc | 10 ++++++++++ third_party/xla/xla/python/xla_client.py | 2 +- third_party/xla/xla/python/xla_extension/jax_jit.pyi | 3 +++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index 754e31c1620cea..cd9407fce89ccf 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -334,6 +334,16 @@ void BuildJaxjitSubmodule(py::module& m) { "thread_local_state", [&]() { return &ThreadLocalJitState(); }, py::return_value_policy::reference); + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + py::return_value_policy::reference); + jitlib.def("jit_is_disabled", &GetDisableJit); jitlib.def("get_enable_x64", &GetEnableX64); jitlib.def("set_thread_local_state_initialization_callback", diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 4a476977f18118..919eb677336970 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 217 +_version = 218 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/third_party/xla/xla/python/xla_extension/jax_jit.pyi b/third_party/xla/xla/python/xla_extension/jax_jit.pyi index e495b5fe8db9a2..9bf5e30d6c8907 100644 --- a/third_party/xla/xla/python/xla_extension/jax_jit.pyi +++ b/third_party/xla/xla/python/xla_extension/jax_jit.pyi @@ -39,6 +39,9 @@ def get_enable_x64() -> bool: ... def set_thread_local_state_initialization_callback( function: Callable[[], None]): ... +def swap_thread_local_state_disable_jit( + value: Optional[bool]) -> Optional[bool]: ... + class ArgSignature: dtype: np.dtype shape: Tuple[int, ...] From fb54a8650ee59939ca7e12def6e546b77b2d58a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:49:00 -0800 Subject: [PATCH 194/381] When iterative solving is turned on, shardings annotations from a previous iteration associated with a partial mesh conflict with the current larger mesh. This CL handles this case when calling sharding propagation to infer operand shardings given a the sharding for an instruction. PiperOrigin-RevId: 586409784 --- .../auto_sharding/auto_sharding_util.cc | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 88305332ff2f6a..d5965d057576a0 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -75,17 +75,29 @@ std::optional GetInputSharding(const HloInstruction* ins, int64_t num_devices) { auto ins_clone = ins->Clone(); ins_clone->set_sharding(output_sharding); - auto operand_clone = operand->Clone(); - if (operand_clone->has_sharding() && - !operand_clone->sharding() - .Validate(operand_clone->shape(), num_devices) - .ok()) { - operand_clone->clear_sharding(); - } - auto s = ins_clone->ReplaceOperandWith(op_index, operand_clone.get()); - CHECK_OK(s); - return ShardingPropagation::GetShardingFromUser(*operand_clone, *ins_clone, - 10, true, call_graph); + + std::vector> operands; + for (size_t i = 0; i < ins->operand_count(); ++i) { + const HloInstruction* operand = ins->operand(i); + if (i != op_index && + (!operand->has_sharding() || + operand->sharding().Validate(operand->shape(), num_devices).ok())) { + continue; + } + std::unique_ptr operand_clone = operand->Clone(); + if (operand_clone->has_sharding() && + !operand_clone->sharding() + .Validate(operand_clone->shape(), num_devices) + .ok()) { + operand_clone->clear_sharding(); + } + CHECK_OK(ins_clone->ReplaceOperandWith(i, operand_clone.get())); + operands.push_back(std::move(operand_clone)); + } + + auto result = ShardingPropagation::GetShardingFromUser( + *ins_clone->operand(op_index), *ins_clone, 10, true, call_graph); + return result; } // Return whether the instruction is an activation from another pipeline stage. @@ -1135,7 +1147,7 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( absl::c_iota(axes, 0); bool found = false; do { - auto transposed_mesh = Transpose(mesh, axes); + Array transposed_mesh = Transpose(mesh, axes); if (std::equal(transposed_mesh.begin(), transposed_mesh.end(), spec.tile_assignment().array().begin())) { found = true; From 8ec04137079caae4f6296603ff4e711e1c061237 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:51:22 -0800 Subject: [PATCH 195/381] Move large functions in auto_sharding_dot_handler.cc out of class definitions PiperOrigin-RevId: 586410406 --- .../auto_sharding_dot_handler.cc | 1451 +++++++++-------- 1 file changed, 767 insertions(+), 684 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index d574a68758142c..bde495a71447dc 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -46,6 +47,7 @@ limitations under the License. namespace xla { namespace spmd { +namespace { using DimMap = StableHashMap; using MeshDims = absl::Span; @@ -79,26 +81,7 @@ class HandlerBase { void AppendNewStrategy(const std::string& name, const HloSharding& output_spec, absl::Span input_specs, - double compute_cost, double communication_cost) { - std::vector> resharding_costs; - - for (int i = 0; i < ins_->operand_count(); ++i) { - const HloInstruction* operand = ins_->operand(i); - resharding_costs.push_back( - ReshardingCostVector(strategy_map_.at(operand).get(), - operand->shape(), input_specs[i], cluster_env_)); - } - - strategy_group_->strategies.push_back(ShardingStrategy({ - name, - output_spec, - compute_cost, - communication_cost, - GetBytes(ins_->shape()) / output_spec.NumTiles(), - resharding_costs, - {input_specs.begin(), input_specs.end()}, - })); - } + double compute_cost, double communication_cost); bool CheckDims(const HloInstruction* ins, const DimMap& dim_map) const { for (const auto& [tensor_dim, mesh_dim] : dim_map) { @@ -124,81 +107,17 @@ class HandlerBase { } // Given lhs and rhs dim maps, infers a sharding for the output by relying on - // the sharding_propagation pass. Given that this is a relatively new change - // (as of 11/2023), we also take an optional expected output dim map as an - // argument, to verify that sharding propagation in fact infers the sharding - // we expect (and to crash if it doesn't). - // TODO(b/309638633) As we build more confidence in this, we should remove - // this expected_output_dim_map argument and fully rely on sharding - // propagation. + // the sharding_propagation pass. void MaybeAppend( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const std::optional& expected_output_dim_map, const Array& device_mesh, double compute_cost = 0, const std::optional>& - communication_cost_fn = std::nullopt) { - HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh); - HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh); - if (std::optional output_spec = - GetShardingFromUser(lhs_spec, rhs_spec); - output_spec.has_value()) { - if (expected_output_dim_map.has_value()) { - HloSharding expected_output_spec = - CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); - // TODO(b/308687597) Once the bug is resolved, we ideally either want - // have a CHECK statement verifying that the sharding inferred by - // sharding propagation is in fact what we expect, or we trust sharding - // propagation's results without the check. b/308687597 currently - // prevents us from doing so. AutoShardingTest.LargeSize in - // //third_party/tensorflow/compiler/xla/hlo/experimental/auto_sharding:auto_sharding_test - // currently fails due to the issue. - if (ins_->opcode() == HloOpcode::kDot && - *output_spec != expected_output_spec) { - output_spec = expected_output_spec; - LOG(ERROR) - << "The sharding inferred by sharding propagation in this case " - "does not match the expected sharding for the dot " - "instruction. This may be related to b/308687597. Given this " - "mismatch, we continue with the expected sharding"; - } - } - double communication_cost = 0; - if (communication_cost_fn.has_value()) { - communication_cost = communication_cost_fn.value()(*output_spec); - } - AppendNewStrategy(name, *output_spec, {lhs_spec, rhs_spec}, compute_cost, - communication_cost); - } else { - LOG(FATAL) << "Sharding propagation could not infer output sharding"; - } - } + communication_cost_fn = std::nullopt); std::optional GetShardingFromUser(const HloSharding& lhs_spec, - const HloSharding& rhs_spec) { - std::unique_ptr ins_clone = ins_->Clone(); - std::unique_ptr lhs_clone = lhs_->Clone(); - std::unique_ptr rhs_clone = rhs_->Clone(); - ins_clone->clear_sharding(); - lhs_clone->set_sharding(lhs_spec); - rhs_clone->set_sharding(rhs_spec); - CHECK_OK(ins_clone->ReplaceOperandWith(0, lhs_clone.get())); - CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); - if (ins_->opcode() == HloOpcode::kConvolution) { - xla::InferConvolutionShardingFromOperands( - ins_clone.get(), call_graph_, 10, - /* may_combine_partial_sharding */ true, /* is_spmd */ true); - } else { - xla::InferDotShardingFromOperands( - ins_clone.get(), call_graph_, - dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); - } - if (!ins_clone->has_sharding()) { - return std::nullopt; - } - return ins_clone->sharding(); - } + const HloSharding& rhs_spec); // Enumerates combinations of the given mesh + tensor dimensions. void Enumerate(std::function split_func, @@ -240,26 +159,10 @@ class HandlerBase { class DotHandler : public HandlerBase { public: DotHandler(std::unique_ptr& strategy_group, - StrategyMap& strategy_map, const HloInstruction* ins, + StrategyMap& strategy_map, const HloDotInstruction* ins, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), - is_dot_(true), - space_base_dim_( - ins->dot_dimension_numbers().lhs_batch_dimensions_size()), - lhs_con_dims_( - ins->dot_dimension_numbers().lhs_contracting_dimensions()), - rhs_con_dims_( - ins->dot_dimension_numbers().rhs_contracting_dimensions()), - lhs_batch_dims_(ins->dot_dimension_numbers().lhs_batch_dimensions()), - rhs_batch_dims_(ins->dot_dimension_numbers().rhs_batch_dimensions()) { - std::tie(lhs_space_dims_, rhs_space_dims_) = GetSpaceDims( - lhs_->shape(), rhs_->shape(), ins->dot_dimension_numbers()); - CHECK_EQ(lhs_con_dims_.size(), rhs_con_dims_.size()); - CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size()); - } + const AutoShardingOption& option, const CallGraph& call_graph); DotHandler( std::unique_ptr& strategy_group, StrategyMap& strategy_map, @@ -267,674 +170,854 @@ class DotHandler : public HandlerBase { const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, - const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), - is_dot_(false), - space_base_dim_(-1) { - CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); - - for (auto dim_idx : conv_as_dot_dims.batch_dims) { - if (dim_idx.lhs >= 0) lhs_batch_dims_.Add(dim_idx.lhs); - if (dim_idx.rhs >= 0) rhs_batch_dims_.Add(dim_idx.rhs); - } + const CallGraph& call_graph); - for (auto dim_idx : conv_as_dot_dims.contracting_dims) { - if (dim_idx.lhs >= 0) lhs_con_dims_.Add(dim_idx.lhs); - if (dim_idx.rhs >= 0) rhs_con_dims_.Add(dim_idx.rhs); - } + void SplitLhsSpaceRhsSpace(); - for (auto dim_idx : conv_as_dot_dims.lhs_non_contracting_dims) { - if (dim_idx.lhs >= 0) lhs_space_dims_.Add(dim_idx.lhs); - } + void SplitLhsSpaceOnly(); - for (auto dim_idx : conv_as_dot_dims.rhs_non_contracting_dims) { - if (dim_idx.rhs >= 0) rhs_space_dims_.Add(dim_idx.rhs); - } + void SplitRhsSpaceOnly(); + + void SplitLhsSpaceBothContract(); + + void SplitRhsSpaceBothContract(); + + void SplitOneBatchDim(); + + void SplitTwoBatchDims(); + + void SplitBatchDimLhsSpace(); + + void SplitBatchDimRhsSpace(); + + void SplitBatchDimBothContract(); + + void SplitBothContractTwoDims(); + + void RecomputeSplitBothContract(); + + void Add1DDataParallel(); + + void Add1DBatchSplit(); + + Status RegisterStrategies(); + + // Dimension information + bool is_dot_; + int64_t space_base_dim_; + tsl::protobuf::RepeatedField lhs_space_dims_, rhs_space_dims_; + tsl::protobuf::RepeatedField lhs_con_dims_; + tsl::protobuf::RepeatedField rhs_con_dims_; + tsl::protobuf::RepeatedField lhs_batch_dims_; + tsl::protobuf::RepeatedField rhs_batch_dims_; +}; + +class ConvHandler : public HandlerBase { + public: + ConvHandler(std::unique_ptr& strategy_group, + StrategyMap& strategy_map, const HloInstruction* ins, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingOption& option, const CallGraph& call_graph); + + void SplitLhsBatchRhsOutchannel(); + + void SplitLhsBatchBothInchannel(); + + void SplitRhsOutchannelBothInchannel(); + + void Add1DDataParallel(); + + void SplitDepthwise(bool forward); + + Status RegisterStrategies(); + + // Dimension information + const ConvolutionDimensionNumbers& conv_dnums_; + int64_t lhs_batch_dim_, lhs_in_channel_dim_; + int64_t rhs_in_channel_dim_, rhs_out_channel_dim_; + int64_t out_batch_dim_, out_out_channel_dim_; +}; + +/************** HandlerBase function definitions **************/ + +void HandlerBase::AppendNewStrategy(const std::string& name, + const HloSharding& output_spec, + absl::Span input_specs, + double compute_cost, + double communication_cost) { + std::vector> resharding_costs; + + for (int i = 0; i < ins_->operand_count(); ++i) { + const HloInstruction* operand = ins_->operand(i); + resharding_costs.push_back( + ReshardingCostVector(strategy_map_.at(operand).get(), operand->shape(), + input_specs[i], cluster_env_)); } - void SplitLhsSpaceRhsSpace() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("SS = SR x RS @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); + strategy_group_->strategies.push_back(ShardingStrategy({ + name, + output_spec, + compute_cost, + communication_cost, + GetBytes(ins_->shape()) / output_spec.NumTiles(), + resharding_costs, + {input_specs.begin(), input_specs.end()}, + })); +} - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = - DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; +// Given lhs and rhs dim maps, infers a sharding for the output by relying on +// the sharding_propagation pass. Given that this is a relatively new change +// (as of 11/2023), we also take an optional expected output dim map as an +// argument, to verify that sharding propagation in fact infers the sharding +// we expect (and to crash if it doesn't). +// TODO(b/309638633) As we build more confidence in this, we should remove +// this expected_output_dim_map argument and fully rely on sharding +// propagation. +void HandlerBase::MaybeAppend( + const std::string& name, const DimMap& lhs_dim_map, + const DimMap& rhs_dim_map, + const std::optional& expected_output_dim_map, + const Array& device_mesh, double compute_cost, + const std::optional>& + communication_cost_fn) { + HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh); + HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh); + if (std::optional output_spec = + GetShardingFromUser(lhs_spec, rhs_spec); + output_spec.has_value()) { + if (expected_output_dim_map.has_value()) { + HloSharding expected_output_spec = + CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); + // TODO(b/308687597) Once the bug is resolved, we ideally either want + // have a CHECK statement verifying that the sharding inferred by + // sharding propagation is in fact what we expect, or we trust sharding + // propagation's results without the check. b/308687597 currently + // prevents us from doing so. AutoShardingTest.LargeSize in + // //third_party/tensorflow/compiler/xla/hlo/experimental/auto_sharding:auto_sharding_test + // currently fails due to the issue. + if (ins_->opcode() == HloOpcode::kDot && + *output_spec != expected_output_spec) { + output_spec = expected_output_spec; + LOG(ERROR) + << "The sharding inferred by sharding propagation in this case " + "does not match the expected sharding for the dot " + "instruction. This may be related to b/308687597. Given this " + "mismatch, we continue with the expected sharding"; } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - Enumerate(func, lhs_space_dims_.size(), rhs_space_dims_.size()); + } + double communication_cost = 0; + if (communication_cost_fn.has_value()) { + communication_cost = communication_cost_fn.value()(*output_spec); + } + AppendNewStrategy(name, *output_spec, {lhs_spec, rhs_spec}, compute_cost, + communication_cost); + } else { + LOG(FATAL) << "Sharding propagation could not infer output sharding"; } +} - void SplitLhsSpaceOnly() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, - {lhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + e.j, e.mesh_dims[1]}}; - } - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_); - }; - EnumerateHalf(func, lhs_space_dims_.size(), lhs_space_dims_.size()); +std::optional HandlerBase::GetShardingFromUser( + const HloSharding& lhs_spec, const HloSharding& rhs_spec) { + std::unique_ptr ins_clone = ins_->Clone(); + std::unique_ptr lhs_clone = lhs_->Clone(); + std::unique_ptr rhs_clone = rhs_->Clone(); + ins_clone->clear_sharding(); + lhs_clone->set_sharding(lhs_spec); + rhs_clone->set_sharding(rhs_spec); + CHECK_OK(ins_clone->ReplaceOperandWith(0, lhs_clone.get())); + CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); + if (ins_->opcode() == HloOpcode::kConvolution) { + xla::InferConvolutionShardingFromOperands( + ins_clone.get(), call_graph_, 10, + /* may_combine_partial_sharding */ true, /* is_spmd */ true); + } else { + xla::InferDotShardingFromOperands( + ins_clone.get(), call_graph_, + dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), + /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); + } + if (!ins_clone->has_sharding()) { + return std::nullopt; } + return ins_clone->sharding(); +} - void SplitRhsSpaceOnly() { - auto func = [this](const Enumeration& e) { - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[0]}, - {rhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = - DimMap{{space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[0]}, - {space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; - } - MaybeAppend(name, {}, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func, rhs_space_dims_.size(), rhs_space_dims_.size()); +/************** DotHandler function definitions **************/ + +DotHandler::DotHandler(std::unique_ptr& strategy_group, + StrategyMap& strategy_map, const HloDotInstruction* ins, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingOption& option, + const CallGraph& call_graph) + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, + option, call_graph), + is_dot_(true), + space_base_dim_(ins->dot_dimension_numbers().lhs_batch_dimensions_size()), + lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), + rhs_con_dims_(ins->dot_dimension_numbers().rhs_contracting_dimensions()), + lhs_batch_dims_(ins->dot_dimension_numbers().lhs_batch_dimensions()), + rhs_batch_dims_(ins->dot_dimension_numbers().rhs_batch_dimensions()) { + std::tie(lhs_space_dims_, rhs_space_dims_) = + GetSpaceDims(lhs_->shape(), rhs_->shape(), ins->dot_dimension_numbers()); + CHECK_EQ(lhs_con_dims_.size(), rhs_con_dims_.size()); + CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size()); +} + +DotHandler::DotHandler( + std::unique_ptr& strategy_group, StrategyMap& strategy_map, + const HloConvolutionInstruction* ins, + const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const CallGraph& call_graph) + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, + option, call_graph), + is_dot_(false), + space_base_dim_(-1) { + CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); + + for (auto dim_idx : conv_as_dot_dims.batch_dims) { + if (dim_idx.lhs >= 0) lhs_batch_dims_.Add(dim_idx.lhs); + if (dim_idx.rhs >= 0) rhs_batch_dims_.Add(dim_idx.rhs); } - void SplitLhsSpaceBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = - absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, - {lhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.j], e.mesh_dims[1]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; - } + for (auto dim_idx : conv_as_dot_dims.contracting_dims) { + if (dim_idx.lhs >= 0) lhs_con_dims_.Add(dim_idx.lhs); + if (dim_idx.rhs >= 0) rhs_con_dims_.Add(dim_idx.rhs); + } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, lhs_space_dims_.size(), lhs_con_dims_.size()); + for (auto dim_idx : conv_as_dot_dims.lhs_non_contracting_dims) { + if (dim_idx.lhs >= 0) lhs_space_dims_.Add(dim_idx.lhs); } - void SplitRhsSpaceBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; - std::string name = - absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, - {rhs_con_dims_[e.j], e.mesh_dims[0]}}; - const DimMap lhs_dim_map = {{lhs_con_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = - DimMap{{space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; - } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, rhs_space_dims_.size(), lhs_con_dims_.size()); + for (auto dim_idx : conv_as_dot_dims.rhs_non_contracting_dims) { + if (dim_idx.rhs >= 0) rhs_space_dims_.Add(dim_idx.rhs); } +} + +void DotHandler::SplitLhsSpaceRhsSpace() { + auto func = [this](const Enumeration& e) { + const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_space_dims_[e.j], e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); + + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{ + {space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + Enumerate(func, lhs_space_dims_.size(), rhs_space_dims_.size()); +} - void SplitOneBatchDim() { - if (absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) != 1) { +void DotHandler::SplitLhsSpaceOnly() { + auto func = [this](const Enumeration& e) { + const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, + {lhs_space_dims_[e.j], e.mesh_dims[1]}}; + std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}", + absl::StrJoin(e.mesh_dims, ",")); + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + e.j, e.mesh_dims[1]}}; + } + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_); + }; + EnumerateHalf(func, lhs_space_dims_.size(), lhs_space_dims_.size()); +} + +void DotHandler::SplitRhsSpaceOnly() { + auto func = [this](const Enumeration& e) { + const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[0]}, + {rhs_space_dims_[e.j], e.mesh_dims[1]}}; + std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}", + absl::StrJoin(e.mesh_dims, ",")); + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{ + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + } + MaybeAppend(name, {}, rhs_dim_map, out_dim_map, device_mesh_); + }; + EnumerateHalf(func, rhs_space_dims_.size(), rhs_space_dims_.size()); +} + +void DotHandler::SplitLhsSpaceBothContract() { + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) return; + std::string name = + absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", + absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); + const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, + {lhs_con_dims_[e.j], e.mesh_dims[1]}}; + const DimMap rhs_dim_map = {{rhs_con_dims_[e.j], e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; } - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_batch_dims_[e.i], e.j}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.i], e.j}}; - std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", e.i, e.j); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{e.i, e.j}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); }; - Enumerate(func, lhs_batch_dims_.size(), device_mesh_.num_dimensions()); - } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + Enumerate(func, lhs_space_dims_.size(), lhs_con_dims_.size()); +} - void SplitTwoBatchDims() { - if (lhs_batch_dims_.size() != 2) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - const DimMap lhs_dim_map = {{lhs_batch_dims_[0], e.mesh_dims[0]}, - {lhs_batch_dims_[1], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[0], e.mesh_dims[0]}, - {rhs_batch_dims_[1], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("Sb = Sb x Sb @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); +void DotHandler::SplitRhsSpaceBothContract() { + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; + std::string name = + absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", + absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); + const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, + {rhs_con_dims_[e.j], e.mesh_dims[0]}}; + const DimMap lhs_dim_map = {{lhs_con_dims_[e.j], e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{ + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + } + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); }; - EnumerateHalf(func, lhs_batch_dims_.size(), lhs_batch_dims_.size()); + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + Enumerate(func, rhs_space_dims_.size(), lhs_con_dims_.size()); +} + +void DotHandler::SplitOneBatchDim() { + if (absl::c_count_if(device_mesh_.dimensions(), + [](int64_t size) { return size > 1; }) != 1) { + return; } + auto func = [this](const Enumeration& e) { + const DimMap lhs_dim_map = {{lhs_batch_dims_[e.i], e.j}}; + const DimMap rhs_dim_map = {{rhs_batch_dims_[e.i], e.j}}; + std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", e.i, e.j); + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.i, e.j}}; + } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + Enumerate(func, lhs_batch_dims_.size(), device_mesh_.num_dimensions()); +} - void SplitBatchDimLhsSpace() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat("SbSi = SbSi x SbR @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[1]}, - {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{e.j, e.mesh_dims[0]}, - {space_base_dim_ + e.i, e.mesh_dims[1]}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); +void DotHandler::SplitTwoBatchDims() { + if (lhs_batch_dims_.size() != 2) return; + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + const DimMap lhs_dim_map = {{lhs_batch_dims_[0], e.mesh_dims[0]}, + {lhs_batch_dims_[1], e.mesh_dims[1]}}; + const DimMap rhs_dim_map = {{rhs_batch_dims_[0], e.mesh_dims[0]}, + {rhs_batch_dims_[1], e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("Sb = Sb x Sb @ {%s}", absl::StrJoin(e.mesh_dims, ",")); + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; + } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + EnumerateHalf(func, lhs_batch_dims_.size(), lhs_batch_dims_.size()); +} + +void DotHandler::SplitBatchDimLhsSpace() { + if (lhs_batch_dims_.empty()) return; + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + std::string name = absl::StrFormat("SbSi = SbSi x SbR @ {%s}", + absl::StrJoin(e.mesh_dims, ",")); + const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[1]}, + {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.j, e.mesh_dims[0]}, + {space_base_dim_ + e.i, e.mesh_dims[1]}}; + } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + Enumerate(func, lhs_space_dims_.size(), lhs_batch_dims_.size()); +} + +void DotHandler::SplitBatchDimRhsSpace() { + if (lhs_batch_dims_.empty()) return; + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + std::string name = absl::StrFormat("SbSj = SbR x SbSj @ {%s}", + absl::StrJoin(e.mesh_dims, ",")); + const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, + {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; + const DimMap lhs_dim_map = {{lhs_batch_dims_[e.j], e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{ + {e.j, e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + Enumerate(func, rhs_space_dims_.size(), lhs_batch_dims_.size()); +} + +void DotHandler::SplitBatchDimBothContract() { + if (lhs_batch_dims_.empty()) return; + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + std::string name = + absl::StrFormat("SbR = SbSk x SbSk @ {%s} (allreduce @ %d}", + absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); + const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, + {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + } + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); }; - Enumerate(func, lhs_space_dims_.size(), lhs_batch_dims_.size()); - } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + Enumerate(func, lhs_con_dims_.size(), lhs_batch_dims_.size()); +} - void SplitBatchDimRhsSpace() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat("SbSj = SbR x SbSj @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, - {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap lhs_dim_map = {{lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = - DimMap{{e.j, e.mesh_dims[0]}, - {space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); +void DotHandler::SplitBothContractTwoDims() { + if (lhs_con_dims_.size() < 2 || rhs_con_dims_.size() < 2) return; + auto func = [this](const Enumeration& e) { + // Applies when there are more than one contracting dimension. + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + std::string name = absl::StrFormat("RR = SS x SS @ {%s} (allreduce @ {%s}}", + absl::StrJoin(e.mesh_dims, ","), + absl::StrJoin(e.mesh_dims, ", ")); + const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}, + {lhs_con_dims_[e.j], e.mesh_dims[1]}}; + const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}, + {rhs_con_dims_[e.j], e.mesh_dims[1]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{}; + } + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0], + e.mesh_dims[1]); }; - Enumerate(func, rhs_space_dims_.size(), lhs_batch_dims_.size()); - } + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + EnumerateHalf(func, lhs_con_dims_.size(), lhs_con_dims_.size()); +} - void SplitBatchDimBothContract() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = - absl::StrFormat("SbR = SbSk x SbSk @ {%s} (allreduce @ %d}", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, - {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; +void DotHandler::RecomputeSplitBothContract() { + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", + e.mesh_dims[0], e.mesh_dims[0]); + const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}}; + std::optional out_dim_map = std::nullopt; + if (is_dot_) { + out_dim_map = DimMap{}; + } + double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, + compute_cost, communication_cost_fn); + }; + Enumerate(func, lhs_con_dims_.size(), 1); +} + +void DotHandler::Add1DDataParallel() { + if (device_mesh_.dim(0) > 1 && + absl::c_count_if(device_mesh_.dimensions(), + [](int64_t size) { return size > 1; }) > 1) { + int mesh_dim = 0; + int64_t num_devices = device_mesh_1d_.dim(mesh_dim); + + // Si = Si x R @ 0 + for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { + const DimMap lhs_dim_map = {{lhs_space_dims_[i], mesh_dim}}; + if (lhs_->shape().dimensions(lhs_space_dims_[i]) < num_devices) { + continue; + } + if (option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]), + num_devices)) { + continue; + } + std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); std::optional out_dim_map = std::nullopt; if (is_dot_) { - out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, lhs_con_dims_.size(), lhs_batch_dims_.size()); - } + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); + } - void SplitBothContractTwoDims() { - if (lhs_con_dims_.size() < 2 || rhs_con_dims_.size() < 2) return; - auto func = [this](const Enumeration& e) { - // Applies when there are more than one contracting dimension. - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat( - "RR = SS x SS @ {%s} (allreduce @ {%s}}", - absl::StrJoin(e.mesh_dims, ","), absl::StrJoin(e.mesh_dims, ", ")); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}, - {lhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}, - {rhs_con_dims_[e.j], e.mesh_dims[1]}}; + // R = Sk x Sk @ (allreduce @ 0) + for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { + const DimMap lhs_dim_map = {{lhs_con_dims_[i], mesh_dim}}; + const DimMap rhs_dim_map = {{rhs_con_dims_[i], mesh_dim}}; + if (lhs_->shape().dimensions(lhs_con_dims_[i]) < num_devices) { + continue; + } + if (option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]), + num_devices)) { + continue; + } + std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", + mesh_dim, mesh_dim); std::optional out_dim_map = std::nullopt; if (is_dot_) { out_dim_map = DimMap{}; } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + auto communication_cost_fn = [this, + mesh_dim](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0], - e.mesh_dims[1]); + return cluster_env_.AllReduceCost(memory_cost, mesh_dim); }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func, lhs_con_dims_.size(), lhs_con_dims_.size()); + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_, + 0, communication_cost_fn); + } } +} - void RecomputeSplitBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", - e.mesh_dims[0], e.mesh_dims[0]); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}}; +void DotHandler::Add1DBatchSplit() { + if (device_mesh_.dim(0) > 1 && + absl::c_count_if(device_mesh_.dimensions(), + [](int64_t size) { return size > 1; }) > 1) { + int mesh_dim = 0; + for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { + const DimMap lhs_dim_map = {{lhs_batch_dims_[i], mesh_dim}}; + const DimMap rhs_dim_map = {{rhs_batch_dims_[i], mesh_dim}}; + std::string name = + absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); std::optional out_dim_map = std::nullopt; if (is_dot_) { - out_dim_map = DimMap{}; + out_dim_map = DimMap{{i, mesh_dim}}; } - double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, - compute_cost, communication_cost_fn); - }; - Enumerate(func, lhs_con_dims_.size(), 1); + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_); + } } +} - void Add1DDataParallel() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - int64_t num_devices = device_mesh_1d_.dim(mesh_dim); - - // Si = Si x R @ 0 - for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_space_dims_[i], mesh_dim}}; - if (lhs_->shape().dimensions(lhs_space_dims_[i]) < num_devices) { - continue; - } - if (option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]), - num_devices)) { - continue; - } - std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; - } - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); - } +Status DotHandler::RegisterStrategies() { + // SS = SR x RS + // Split lhs space dim and rhs space dim. + SplitLhsSpaceRhsSpace(); - // R = Sk x Sk @ (allreduce @ 0) - for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_con_dims_[i], mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[i], mesh_dim}}; - if (lhs_->shape().dimensions(lhs_con_dims_[i]) < num_devices) { - continue; - } - if (option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]), - num_devices)) { - continue; - } - std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", - mesh_dim, mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{}; - } - auto communication_cost_fn = [this, mesh_dim]( - const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, mesh_dim); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, - device_mesh_1d_, 0, communication_cost_fn); - } - } + // SSR = SSR x RR + // Split lhs space dims only if it has more than 1 space dims. + if (lhs_space_dims_.size() > 1) { + SplitLhsSpaceOnly(); } - - void Add1DBatchSplit() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_batch_dims_[i], mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[i], mesh_dim}}; - std::string name = - absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{i, mesh_dim}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, - device_mesh_1d_); - } - } + // RSS = RR x RSS + // Split rhs space dims only if it has more than 1 space dims. + if (rhs_space_dims_.size() > 1) { + SplitRhsSpaceOnly(); } - Status RegisterStrategies() { - // SS = SR x RS - // Split lhs space dim and rhs space dim. - SplitLhsSpaceRhsSpace(); + // SR = SS x SR + // Split lhs space dim and both contracting dims. + SplitLhsSpaceBothContract(); - // SSR = SSR x RR - // Split lhs space dims only if it has more than 1 space dims. - if (lhs_space_dims_.size() > 1) { - SplitLhsSpaceOnly(); - } - // RSS = RR x RSS - // Split rhs space dims only if it has more than 1 space dims. - if (rhs_space_dims_.size() > 1) { - SplitRhsSpaceOnly(); - } + // RS = RS x SS + // Split rhs space dim and both contracting dims. + SplitRhsSpaceBothContract(); - // SR = SS x SR - // Split lhs space dim and both contracting dims. - SplitLhsSpaceBothContract(); + // RR = SS x SS + // Split two contracting dims on lhs and rhs. + SplitBothContractTwoDims(); - // RS = RS x SS - // Split rhs space dim and both contracting dims. - SplitRhsSpaceBothContract(); + // RR = RS x SR + // This is a special case where we allow splitting only one dim in the + // multi-dimensional mesh case. This allows some recomputation + // (e.g., the dense layer in the LM_head of BERT). + RecomputeSplitBothContract(); - // RR = SS x SS - // Split two contracting dims on lhs and rhs. - SplitBothContractTwoDims(); + // Add 1d data parallel in multi-dimensional mesh + if (option_.allow_mixed_mesh_shape) { + Add1DDataParallel(); + } - // RR = RS x SR - // This is a special case where we allow spliting only one dim in the - // multi-dimensional mesh case. This allows some recomputation - // (e.g., the dense layer in the LM_head of BERT). - RecomputeSplitBothContract(); + if (option_.batch_matmul_always_split_batch && !lhs_batch_dims_.empty() && + cluster_env_.non_zero_mesh_dims_.size() > 1) { + // If there is a batch dim and the device mesh is multi-dimensional, + // always split on batch dim. Clear all old strategies. + strategy_group_->strategies.clear(); + } - // Add 1d data parallel in multi-dimensional mesh - if (option_.allow_mixed_mesh_shape) { - Add1DDataParallel(); - } + // Sb = Sb x Sb + // Split one batch dim. Only used for 1d mesh + SplitOneBatchDim(); - if (option_.batch_matmul_always_split_batch && !lhs_batch_dims_.empty() && - cluster_env_.non_zero_mesh_dims_.size() > 1) { - // If there is a batch dim and the device mesh is multi-dimensional, - // always split on batch dim. Clear all old strategies. - strategy_group_->strategies.clear(); - } + // SbSi = SbSi x SbR + // Split batch dim and lhs space dim + SplitBatchDimLhsSpace(); - // Sb = Sb x Sb - // Split one batch dim. Only used for 1d mesh - SplitOneBatchDim(); - - // SbSi = SbSi x SbR - // Split batch dim and lhs space dim - SplitBatchDimLhsSpace(); - - // SbSj = SbR x SbSj - // Split batch dim and rhs space dim - SplitBatchDimRhsSpace(); - - // SbSj = SbR x SbSj - // Split batch dim and contracting dim - SplitBatchDimBothContract(); - - if (option_.batch_matmul_always_split_batch && - lhs_batch_dims_.size() == 2 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - // If there are two batch dims, always split on these two dims. - // Clear all old strategies. - strategy_group_->strategies.clear(); - } + // SbSj = SbR x SbSj + // Split batch dim and rhs space dim + SplitBatchDimRhsSpace(); - // Sb = Sb x Sb - // Split batch dims. - SplitTwoBatchDims(); + // SbSj = SbR x SbSj + // Split batch dim and contracting dim + SplitBatchDimBothContract(); - if (option_.allow_mixed_mesh_shape) { - Add1DBatchSplit(); - } + if (option_.batch_matmul_always_split_batch && lhs_batch_dims_.size() == 2 && + absl::c_count_if(device_mesh_.dimensions(), + [](int64_t size) { return size > 1; }) > 1) { + // If there are two batch dims, always split on these two dims. + // Clear all old strategies. + strategy_group_->strategies.clear(); + } - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option_.force_batch_dim_to_mesh_dim >= 0 && - batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, - cluster_env_, batch_map_, option_)); - } + // Sb = Sb x Sb + // Split batch dims. + SplitTwoBatchDims(); - return OkStatus(); + if (option_.allow_mixed_mesh_shape) { + Add1DBatchSplit(); } - // Dimension information - bool is_dot_; - int64_t space_base_dim_; - tsl::protobuf::RepeatedField lhs_space_dims_, rhs_space_dims_; - tsl::protobuf::RepeatedField lhs_con_dims_; - tsl::protobuf::RepeatedField rhs_con_dims_; - tsl::protobuf::RepeatedField lhs_batch_dims_; - tsl::protobuf::RepeatedField rhs_batch_dims_; -}; - -// Register strategies for dot instructions. -Status HandleDot(std::unique_ptr& strategy_group, - StrategyGroups& strategy_groups, StrategyMap& strategy_map, - const HloInstruction* ins, size_t instruction_id, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option, - const CallGraph& call_graph) { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - strategy_groups); + // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies + // and only keep the data parallel strategies. + if (option_.force_batch_dim_to_mesh_dim >= 0 && + batch_map_.contains(GetBatchDimMapKey(ins_))) { + TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, + cluster_env_, batch_map_, option_)); + } - DotHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph); - TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return OkStatus(); } -class ConvHandler : public HandlerBase { - public: - ConvHandler(std::unique_ptr& strategy_group, - StrategyMap& strategy_map, const HloInstruction* ins, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), - conv_dnums_(ins->convolution_dimension_numbers()) { - lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); - lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); - rhs_in_channel_dim_ = conv_dnums_.kernel_input_feature_dimension(); - rhs_out_channel_dim_ = conv_dnums_.kernel_output_feature_dimension(); - out_batch_dim_ = conv_dnums_.output_batch_dimension(); - out_out_channel_dim_ = conv_dnums_.output_feature_dimension(); - } +/************** ConvHandler function definitions **************/ + +ConvHandler::ConvHandler(std::unique_ptr& strategy_group, + StrategyMap& strategy_map, const HloInstruction* ins, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingOption& option, + const CallGraph& call_graph) + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, + option, call_graph), + conv_dnums_(ins->convolution_dimension_numbers()) { + lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); + lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); + rhs_in_channel_dim_ = conv_dnums_.kernel_input_feature_dimension(); + rhs_out_channel_dim_ = conv_dnums_.kernel_output_feature_dimension(); + out_batch_dim_ = conv_dnums_.output_batch_dimension(); + out_out_channel_dim_ = conv_dnums_.output_feature_dimension(); +} - void SplitLhsBatchRhsOutchannel() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = absl::StrFormat("SS = SR x RS @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, - {out_out_channel_dim_, e.mesh_dims[1]}}; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func); +Status ConvHandler::RegisterStrategies() { + // For 1D sharding + if ((ins_->feature_group_count() == + lhs_->shape().dimensions(lhs_in_channel_dim_) && + ins_->feature_group_count() == + rhs_->shape().dimensions(rhs_out_channel_dim_))) { + // for depthwise conv + // SS = SS x S + // Split batch dim and channel dim + SplitDepthwise(true); + } else if ((ins_->batch_group_count() == + lhs_->shape().dimensions(lhs_batch_dim_) && + ins_->batch_group_count() == + rhs_->shape().dimensions(rhs_out_channel_dim_))) { + // for depthwise conv filter_backward + // SS = SS x S + // Split batch dim and channel dim + SplitDepthwise(false); } - void SplitLhsBatchBothInchannel() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}, - {lhs_in_channel_dim_, e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}}; - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func); - } + // SS = SR x RS + // Split lhs batch dim and rhs out_channel dim. + SplitLhsBatchRhsOutchannel(); - void SplitRhsOutchannelBothInchannel() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; - const DimMap lhs_dim_map = {{lhs_in_channel_dim_, e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[0]}, - {rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); - const DimMap out_dim_map = {{out_out_channel_dim_, e.mesh_dims[1]}}; - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func); - } + // SR = SS x SR + // Split lhs batch dim and both in_channel dims. + SplitLhsBatchBothInchannel(); - void Add1DDataParallel() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - int64_t num_devices = device_mesh_1d_.dim(mesh_dim); - - // Si = Si x R @ 0 - if (lhs_->shape().dimensions(lhs_batch_dim_) % num_devices == 0) { - const DimMap lhs_dim_map = {{lhs_batch_dim_, mesh_dim}}; - std::string name = absl::StrFormat("Si = Si x R @ 0"); - const DimMap out_dim_map = {{out_batch_dim_, mesh_dim}}; - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); - } + // RS = RS x SS + // Split rhs out_channel dim and both in_channel dims. + SplitRhsOutchannelBothInchannel(); - // R = Sk x Sk @ (allreduce @ 0) - if (lhs_->shape().dimensions(lhs_in_channel_dim_) % num_devices == 0 && - rhs_->shape().dimensions(rhs_in_channel_dim_) % num_devices == 0) { - const DimMap lhs_dim_map = {{lhs_in_channel_dim_, mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, mesh_dim}}; - std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", - mesh_dim, mesh_dim); - const DimMap out_dim_map = {}; - auto communication_cost_fn = [this](const HloSharding& output_spec) { - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - return cluster_env_.AllReduceCost(memory_cost, 0) + - cluster_env_.AllReduceCost(memory_cost, 1); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, - device_mesh_1d_, 0, communication_cost_fn); - } - } + // Add 1d data parallel in multi-dimensional mesh + if (option_.allow_mixed_mesh_shape) { + Add1DDataParallel(); } - void SplitDepthwise(bool forward) { - auto func = [this, forward](const Enumeration& e) { - const DimMap lhs_dim_map = { - {lhs_batch_dim_, e.mesh_dims[forward ? 0 : 1]}, - {lhs_in_channel_dim_, e.mesh_dims[forward ? 1 : 0]}}; - const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = absl::StrFormat("SS = SS x RS @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, - {out_out_channel_dim_, e.mesh_dims[1]}}; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func); + // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies + // and only keep the data parallel strategies. + if (option_.force_batch_dim_to_mesh_dim >= 0 && + batch_map_.contains(GetBatchDimMapKey(ins_))) { + TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, + cluster_env_, batch_map_, option_)); } - Status RegisterStrategies() { - // For 1D sharding - if ((ins_->feature_group_count() == - lhs_->shape().dimensions(lhs_in_channel_dim_) && - ins_->feature_group_count() == - rhs_->shape().dimensions(rhs_out_channel_dim_))) { - // for depthwise conv - // SS = SS x S - // Split batch dim and channel dim - SplitDepthwise(true); - } else if ((ins_->batch_group_count() == - lhs_->shape().dimensions(lhs_batch_dim_) && - ins_->batch_group_count() == - rhs_->shape().dimensions(rhs_out_channel_dim_))) { - // for depthwise conv filter_backward - // SS = SS x S - // Split batch dim and channel dim - SplitDepthwise(false); - } + return OkStatus(); +} - // SS = SR x RS - // Split lhs batch dim and rhs out_channel dim. - SplitLhsBatchRhsOutchannel(); +void ConvHandler::SplitLhsBatchRhsOutchannel() { + auto func = [this](const Enumeration& e) { + const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, + {out_out_channel_dim_, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + EnumerateHalf(func); +} - // SR = SS x SR - // Split lhs batch dim and both in_channel dims. - SplitLhsBatchBothInchannel(); +void ConvHandler::SplitLhsBatchBothInchannel() { + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || + device_mesh_.dim(e.mesh_dims[1]) <= 1) + return; + const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}, + {lhs_in_channel_dim_, e.mesh_dims[1]}}; + const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", + absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + EnumerateHalf(func); +} - // RS = RS x SS - // Split rhs out_channel dim and both in_channel dims. - SplitRhsOutchannelBothInchannel(); +void ConvHandler::SplitRhsOutchannelBothInchannel() { + auto func = [this](const Enumeration& e) { + if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; + const DimMap lhs_dim_map = {{lhs_in_channel_dim_, e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[0]}, + {rhs_out_channel_dim_, e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", + absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); + const DimMap out_dim_map = {{out_out_channel_dim_, e.mesh_dims[1]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); + }; + EnumerateHalf(func); +} - // Add 1d data parallel in multi-dimensional mesh - if (option_.allow_mixed_mesh_shape) { - Add1DDataParallel(); +void ConvHandler::Add1DDataParallel() { + if (device_mesh_.dim(0) > 1 && + absl::c_count_if(device_mesh_.dimensions(), + [](int64_t size) { return size > 1; }) > 1) { + int mesh_dim = 0; + int64_t num_devices = device_mesh_1d_.dim(mesh_dim); + + // Si = Si x R @ 0 + if (lhs_->shape().dimensions(lhs_batch_dim_) % num_devices == 0) { + const DimMap lhs_dim_map = {{lhs_batch_dim_, mesh_dim}}; + std::string name = absl::StrFormat("Si = Si x R @ 0"); + const DimMap out_dim_map = {{out_batch_dim_, mesh_dim}}; + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); } - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option_.force_batch_dim_to_mesh_dim >= 0 && - batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, - cluster_env_, batch_map_, option_)); + // R = Sk x Sk @ (allreduce @ 0) + if (lhs_->shape().dimensions(lhs_in_channel_dim_) % num_devices == 0 && + rhs_->shape().dimensions(rhs_in_channel_dim_) % num_devices == 0) { + const DimMap lhs_dim_map = {{lhs_in_channel_dim_, mesh_dim}}; + const DimMap rhs_dim_map = {{rhs_in_channel_dim_, mesh_dim}}; + std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", + mesh_dim, mesh_dim); + const DimMap out_dim_map = {}; + auto communication_cost_fn = [this](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, 0) + + cluster_env_.AllReduceCost(memory_cost, 1); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_, + 0, communication_cost_fn); } - - return OkStatus(); } +} - // Dimension information - const ConvolutionDimensionNumbers& conv_dnums_; - int64_t lhs_batch_dim_, lhs_in_channel_dim_; - int64_t rhs_in_channel_dim_, rhs_out_channel_dim_; - int64_t out_batch_dim_, out_out_channel_dim_; -}; +void ConvHandler::SplitDepthwise(bool forward) { + auto func = [this, forward](const Enumeration& e) { + const DimMap lhs_dim_map = { + {lhs_batch_dim_, e.mesh_dims[forward ? 0 : 1]}, + {lhs_in_channel_dim_, e.mesh_dims[forward ? 1 : 0]}}; + const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; + std::string name = + absl::StrFormat("SS = SS x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, + {out_out_channel_dim_, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); + }; + EnumerateHalf(func); +} + +} // namespace // Register strategies for dot instructions. +Status HandleDot(std::unique_ptr& strategy_group, + StrategyGroups& strategy_groups, StrategyMap& strategy_map, + const HloInstruction* ins, size_t instruction_id, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingOption& option, + const CallGraph& call_graph) { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + strategy_groups); + + DotHandler handler(strategy_group, strategy_map, Cast(ins), + cluster_env, batch_map, option, call_graph); + TF_RETURN_IF_ERROR(handler.RegisterStrategies()); + return OkStatus(); +} + +// Register strategies for convolution instructions. Status HandleConv(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, From 812596bc441ffe0a347c8f9e990e0ea1d939b866 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:57:41 -0800 Subject: [PATCH 196/381] [XLA] No functional change: Refactoring cost analysis for fusions. PiperOrigin-RevId: 586412178 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/hlo_cost_analysis.cc | 115 ++++++++++++------ .../xla/xla/service/hlo_cost_analysis.h | 39 +++++- 3 files changed, 114 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 18750c8c4f9a89..aef5a3f3e428dd 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4101,6 +4101,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:util", diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index 8207046fa6388a..9fe36cd2eb16ef 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/window_util.h" @@ -222,7 +223,11 @@ Status HloCostAnalysis::FusionCalculateUtilizations( // instruction. for (const HloInstruction* instr : fusion->fused_instructions_computation()->instructions()) { - hlo_properties_[instr][kUtilizationKey] = 1.f; + if (ShouldFilterFusionInstruction(fusion, instr)) { + hlo_properties_[instr][kUtilizationKey] = 0.f; + } else { + hlo_properties_[instr][kUtilizationKey] = 1.f; + } } return OkStatus(); } @@ -1009,28 +1014,11 @@ Status HloCostAnalysis::HandleRngGetAndUpdateState( return OkStatus(); } -Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - VLOG(8) << "Processing fusion " << fusion->ToString(); - - if (fusion->IsCustomFusion()) { - for (const HloInstruction* hlo : - fusion->fused_instructions_computation()->instructions()) { - if (hlo->opcode() == HloOpcode::kGather) { - return HandleGather(hlo); - } - if (hlo->opcode() == HloOpcode::kScatter) { - return HandleScatter(hlo); - } - } - } - TF_ASSIGN_OR_RETURN( - current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); - +Status HloCostAnalysis::FusionProcessOutputBytesAccessed( + const HloInstruction* fusion) { // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to // prevent intermediate data from touching slow memory. - current_properties_[kBytesAccessedKey] = 0; ShapeUtil::ForEachSubshape( fusion->shape(), [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) { @@ -1039,7 +1027,17 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { } const HloInstruction* root = fusion->fused_expression_root(); - if (shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple) { + + auto further_examine_index = + shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple; + if (further_examine_index && + ShouldFilterFusionOutputIndex(fusion, shape_index)) { + current_properties_.set_output_bytes_accessed(shape_index, 0); + hlo_properties_[root->operand(shape_index[0])] + [GetOperandUtilizationKey(0)] = 0; + return; + } + if (further_examine_index) { root = root->operand(shape_index[0]); } @@ -1072,6 +1070,9 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { } for (int i = 0; i < shape.tuple_shapes_size(); ++i) { const Shape& subshape = shape.tuple_shapes(i); + if (!subshape.IsTuple() && ShouldFilterFusionOutputIndex(fusion, {i})) { + continue; + } ShapeIndex subshape_index(shape_index); subshape_index.push_back(i); bytes_accessed += @@ -1082,27 +1083,20 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_[GetOutputBytesAccessedKey()] = 0; propagate_output_size_to_parent(fusion->shape(), {}); } + return OkStatus(); +} - TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); - - // Count memory access to all large constants. - for (const HloInstruction* instr : - fusion->fused_instructions_computation()->instructions()) { - if (instr->opcode() == HloOpcode::kConstant && - ShapeUtil::ElementsIn(instr->shape()) > - immediate_constant_max_elements()) { - float utilization = hlo_properties_[instr][kUtilizationKey]; - if (!options_.count_multiple_input_accesses) { - utilization = fmin(utilization, 1.0); - } - current_properties_[kBytesAccessedKey] += - GetShapeSize(instr->shape()) * utilization; - } - } - +Status HloCostAnalysis::FusionProcessOperandBytesRead( + const HloInstruction* fusion) { for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) { const HloInstruction* operand = fusion->fused_parameter(i); int64_t operand_size = 0; + if (ShouldFilterFusionInput(fusion, i)) { + current_properties_.set_operand_bytes_accessed(i, operand_size); + current_properties_.set_operand_utilization( + i, hlo_properties_[operand][kUtilizationKey]); + continue; + } if (!operand->shape().IsTuple()) { operand_size = FusionParameterReadBytes(operand); } else { @@ -1131,6 +1125,51 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_.set_operand_utilization( i, hlo_properties_[operand][kUtilizationKey]); } + return OkStatus(); +} + +Status HloCostAnalysis::FusionCountConstantsMemoryAccess( + const HloInstruction* fusion) { + // Count memory access to all large constants. + for (const HloInstruction* instr : + fusion->fused_instructions_computation()->instructions()) { + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::ElementsIn(instr->shape()) > + immediate_constant_max_elements()) { + float utilization = hlo_properties_[instr][kUtilizationKey]; + if (!options_.count_multiple_input_accesses) { + utilization = fmin(utilization, 1.0); + } + current_properties_[kBytesAccessedKey] += + GetShapeSize(instr->shape()) * utilization; + } + } + return OkStatus(); +} + +Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { + VLOG(8) << "Processing fusion " << fusion->ToString(); + + if (fusion->IsCustomFusion()) { + for (const HloInstruction* hlo : + fusion->fused_instructions_computation()->instructions()) { + if (hlo->opcode() == HloOpcode::kGather) { + return HandleGather(hlo); + } + if (hlo->opcode() == HloOpcode::kScatter) { + return HandleScatter(hlo); + } + } + } + TF_ASSIGN_OR_RETURN( + current_properties_, + ProcessSubcomputation(fusion->fused_instructions_computation())); + + current_properties_[kBytesAccessedKey] = 0; + TF_RETURN_IF_ERROR(FusionProcessOutputBytesAccessed(fusion)); + TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); + TF_RETURN_IF_ERROR(FusionCountConstantsMemoryAccess(fusion)); + TF_RETURN_IF_ERROR(FusionProcessOperandBytesRead(fusion)); return OkStatus(); } diff --git a/third_party/xla/xla/service/hlo_cost_analysis.h b/third_party/xla/xla/service/hlo_cost_analysis.h index 8305b0fadd215c..705eb2437860fb 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.h +++ b/third_party/xla/xla/service/hlo_cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_COST_ANALYSIS_H_ #define XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#include #include #include #include @@ -247,7 +248,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // props[kFlopsKey] gets optimized to `return flops_` just fine. // Getters/setters for more complex properties like operand utilization, - // where we have a fastpath for e.g. operand 0/1 + shape_index {}. + // where we have a fastpath, e.g., operand 0/1 + shape_index {}. float operand_utilization(int64_t operand, const ShapeIndex& shape_index = {}) { if (operand == 0 && shape_index.empty()) { @@ -571,6 +572,37 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const DotDimensionNumbers& dnums); protected: + // Computes the bytes accessed based on the outputs produced by the fusion + // instruction. + virtual Status FusionProcessOutputBytesAccessed(const HloInstruction* fusion); + + // Computes the bytes accessed (read) based on the inputs consumed by the + // fusion instruction. + virtual Status FusionProcessOperandBytesRead(const HloInstruction* fusion); + + // Computes memory access to all larger constants in the fusion instruction. + virtual Status FusionCountConstantsMemoryAccess(const HloInstruction* fusion); + + // Allows exclusion of certain types of inputs from bytes accessed during + // FusionProcessOperandBytesRead. + virtual bool ShouldFilterFusionInput(const HloInstruction* fusion, + int64_t input_index) { + return false; + } + + // Allows exclusion of certain instructions from FusionCalculateUtilizations. + virtual bool ShouldFilterFusionInstruction( + const HloInstruction* fusion, const HloInstruction* instruction) { + return false; + } + + // Allows exclusion of certain types of output from bytes written during + // FusionProcessOutputBytesAccessed. + virtual bool ShouldFilterFusionOutputIndex(const HloInstruction* fusion, + const ShapeIndex& output_index) { + return false; + } + typedef absl::flat_hash_map HloToProperties; @@ -588,7 +620,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // given hlo. The cost of visited sub HLO instructions is saved to // hlo_properties_, which will be used by functions such as // flop_count(hlo_instruction) to return cost of a particular HLO instruction. - StatusOr ProcessSubcomputation(HloComputation* computation); + virtual StatusOr ProcessSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -615,7 +648,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // bottleneck. bool current_should_compute_bottleneck_time_; - // The properties of the currently visited instruction. A HandleFoo method can + // The properties of the currently visited instruction. A HandleFoo method // modify these to change the default values computed in Preprocess. Properties current_properties_; From 7aeb9c8df5b2db6edcde1550e92b196db143b749 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 11:59:06 -0800 Subject: [PATCH 197/381] Changed Version of Bazel to version 6.4.0 PiperOrigin-RevId: 586412544 --- .bazelversion | 2 +- tensorflow/core/kernels/mlir_generated/build_defs.bzl | 1 + tensorflow/tools/ci_build/release/common.sh | 2 +- third_party/xla/.bazelversion | 2 +- third_party/xla/third_party/tsl/.bazelversion | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.bazelversion b/.bazelversion index b536fbc5061305..204ac7c926e437 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -6.1.0 +6.4.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index fe071b2375ec0e..a3535c4ac93080 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -217,6 +217,7 @@ _gen_kernel_bin_rule = rule( outputs = {"kernel": "%{name}_kernel.o"}, toolchains = use_cpp_toolchain(), implementation = _gen_kernel_bin_impl, + provides = [CcInfo], ) # Returns the shape string (e.g. "4x4" or "16Bx2") as comma-separated integers. diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index e5a490ac177e59..33cb30c1df1381 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -17,7 +17,7 @@ # Keeps Bazel versions of the build scripts. # LINT.IfChange -LATEST_BAZEL_VERSION=6.1.0 +LATEST_BAZEL_VERSION=6.4.0 # LINT.ThenChange( # //tensorflow/opensource_only/.bazelversion, # //tensorflow/tools/ci_build/install/install_bazel.sh, diff --git a/third_party/xla/.bazelversion b/third_party/xla/.bazelversion index b536fbc5061305..204ac7c926e437 100644 --- a/third_party/xla/.bazelversion +++ b/third_party/xla/.bazelversion @@ -1,2 +1,2 @@ -6.1.0 +6.4.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/.bazelversion b/third_party/xla/third_party/tsl/.bazelversion index b536fbc5061305..204ac7c926e437 100644 --- a/third_party/xla/third_party/tsl/.bazelversion +++ b/third_party/xla/third_party/tsl/.bazelversion @@ -1,2 +1,2 @@ -6.1.0 +6.4.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file From 0749e6a0a851976c9bd51520625f3d7a61a9c232 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 12:40:57 -0800 Subject: [PATCH 198/381] [XLA:CPU] Add a direct implementation of ReduceScatter, instead of lowering ReduceScatter to AllReduce+DynamicSlice. PiperOrigin-RevId: 586424242 --- third_party/xla/xla/service/cpu/BUILD | 3 +- .../xla/service/cpu/collectives_interface.h | 6 + .../xla/xla/service/cpu/cpu_compiler.cc | 2 - .../xla/service/cpu/cpu_layout_assignment.cc | 7 + .../xla/xla/service/cpu/cpu_runtime.cc | 77 ++++- third_party/xla/xla/service/cpu/cpu_runtime.h | 8 + .../xla/service/cpu/in_process_collectives.cc | 280 +++++++++++++----- .../xla/service/cpu/in_process_collectives.h | 6 + third_party/xla/xla/service/cpu/ir_emitter.cc | 98 ++++-- .../xla/xla/service/cpu/simple_orc_jit.cc | 1 + 10 files changed, 381 insertions(+), 107 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 44dbf2f4db410d..a385177618b2db 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -305,7 +305,6 @@ cc_library( "//xla/service:optimization_barrier_expander", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", - "//xla/service:reduce_scatter_decomposer", "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:result_caster", @@ -902,6 +901,7 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -911,6 +911,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index 4191df1d831fa2..e2c8190c81b986 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -67,6 +67,12 @@ class CollectivesCommunicator { virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, void* output_buffer, absl::Duration timeout) = 0; + + // Performs a reduce-scatter + virtual absl::Status ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index e2723c77c18a7b..598281a271924b 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -191,7 +191,6 @@ limitations under the License. #include "xla/service/optimization_barrier_expander.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" -#include "xla/service/reduce_scatter_decomposer.h" #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/result_caster.h" @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 9b2b8331e2d3fd..71d7073dd53046 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -129,6 +129,13 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kReduceScatter) { + // XLA:CPU can only support reduce-scatter where the scatter dimension + // is the most major dimension in the layout. + auto ars = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), + ars)); } else if (instruction->opcode() == HloOpcode::kAllGather) { // XLA:CPU can only support all-gathers where the gather dimension is the // most major dimension in the layout. diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 81bd76652a3ac6..185ea948d2b196 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -29,6 +29,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -46,6 +49,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -143,6 +147,8 @@ extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; +extern const char* const kReduceScatterSymbolName = + "__xla_cpu_runtime_ReduceScatter"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -315,6 +321,19 @@ CollectivesInterface* GetInProcessCollectivesImpl() { absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } +absl::StatusOr RankInGlobalDevices( + absl::Span devices, GlobalDeviceId device) { + auto it = absl::c_find(devices, device); + if (it == devices.end()) { + return InvalidArgument( + "Device %d not present in global devices %s.", device.value(), + absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })); + } + return std::distance(devices.begin(), it); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t channel_id_present, int64_t op_id, @@ -331,9 +350,7 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -361,9 +378,7 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -374,6 +389,35 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void ReduceScatterImpl(const ExecutableRunOptions* run_options, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, + int32_t reduction_kind, int32_t element_type, + int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + /*use_global_device_ids=*/std::nullopt, op_id); + + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->ReduceScatter( + rendezvous_key, static_cast(reduction_kind), + static_cast(element_type), chunk_elems, input_buffer, + output_buffer, DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -399,9 +443,7 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -450,9 +492,7 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -542,6 +582,19 @@ void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, run_options, channel_id_present, op_id, replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, destination_buffer); } + +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, + int32_t element_type, int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + return xla::cpu::runtime::ReduceScatterImpl( + run_options, replica_groups_str, replica_groups_str_size, + channel_id_present, op_id, reduction_kind, element_type, chunk_elems, + input_buffer, output_buffer); +} + void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 9429242d5f1b86..dd00571cb2e8dc 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -85,6 +85,7 @@ extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; extern const char* const kAllGatherSymbolName; +extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -202,6 +203,13 @@ extern void __xla_cpu_runtime_AllGather( int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, void* destination_buffer); +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, + int32_t element_type, int64_t chunk_elems, void* input_buffer, + void* output_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index 78eee1aa74f731..ed30082be82e8e 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { @@ -93,7 +94,7 @@ template constexpr bool always_false_v = false; template -void Reduce(absl::Span acc, absl::Span const> inputs) { +void ReduceHelper(absl::Span acc, absl::Span inputs) { // TODO(penporn): make sure this gets vectorized. if constexpr (reduction_kind == ReductionKind::SUM) { for (size_t j = 0; j < inputs.size(); ++j) { @@ -124,6 +125,49 @@ void Reduce(absl::Span acc, absl::Span const> inputs) { } } +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + class CpuAllReduceRendezvous : public Rendezvous { public: @@ -146,110 +190,86 @@ class CpuAllReduceRendezvous return nullptr; } + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + switch (me.primitive_type) { case S8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case PRED: case U8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C128: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; default: return absl::UnimplementedError("Unexpected datatype"); } - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; + // All-gather the reduced chunks. for (const auto& p : participants_) { if (p->local_rank != me.local_rank) { - std::memcpy( - reinterpret_cast(p->destination_data) + chunk_offset, - reinterpret_cast(me.destination_data) + chunk_offset, - chunk_bytes); + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); } } return nullptr; } - - template - absl::Status DoAllReduce(const AllReduceParticipantData& me, - int64_t start_elem, int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(me.reduction_kind); - T* acc = reinterpret_cast(me.destination_data); - for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { - acc[i] = initial_value; - } - - absl::Span out_chunk = absl::MakeSpan( - reinterpret_cast(me.destination_data) + start_elem, num_elems); - std::vector> inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(absl::Span( - reinterpret_cast(p->source_data) + start_elem, num_elems)); - } - switch (me.reduction_kind) { - case ReductionKind::SUM: - Reduce(out_chunk, inputs); - break; - case ReductionKind::PRODUCT: - Reduce(out_chunk, inputs); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); - } }; struct CollectivePermuteParticipantData : ParticipantData { @@ -378,6 +398,109 @@ class CpuAllGatherRendezvous } }; +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + switch (me.element_type) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + default: + return absl::UnimplementedError("Unexpected datatype"); + } + + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -389,6 +512,8 @@ struct InProcessCollectivesState { all_to_all_rendezvous_map; RefcountingHashMap all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -488,6 +613,27 @@ absl::Status InProcessCollectivesCommunicator::AllGather( .status(); } +absl::Status InProcessCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = element_type; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = chunk_elems; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index aaedc474fa39b2..f80baf38c4ebdc 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -59,6 +59,12 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { const void* input_buffer, void* output_buffer, absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index f5d7a4c2c40fab..46ae3978aaa2e6 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1169,35 +1169,36 @@ Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { return OkStatus(); } +// Data types supported by ReduceScatter and AllReduce. +static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { + // TODO(cheshire): Fix duplication wrt. cpu_runtime + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); - bool is_datatype_supported = [&] { - // TODO(cheshire): Fix duplication wrt. cpu_runtime - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } - }(); - - if (!is_datatype_supported) { + if (!DataTypeIsSupportedByReduceScatter(datatype)) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } @@ -1285,7 +1286,54 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { - return Unimplemented("ReduceScatter is not implemented on CPU."); + CHECK_EQ(rs->operand_count(), 1); + PrimitiveType datatype = rs->operand(0)->shape().element_type(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); + + if (!DataTypeIsSupportedByReduceScatter(datatype)) { + return Unimplemented("ReduceScatter for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + if (!MatchReductionComputation(rs->to_apply()).has_value()) { + return Unimplemented("ReduceScatter for computation '%s' is not supported", + rs->to_apply()->ToString()); + } + + std::string replica_groups = ReplicaGroupsToString(rs->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + Shape shape = rs->operand(0)->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(rs->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(rs, {})); + llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); + llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + + EmitCallToFunc( + runtime::kReduceScatterSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + + /*channel_id_present=*/ + b_.getInt32(static_cast(rs->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(rs->channel_id().has_value() ? *rs->channel_id() + : rs->GetModule()->unique_id()), + /*reduction_kind=*/ + b_.getInt32( + static_cast(*MatchReductionComputation(rs->to_apply()))), + /*element_type=*/ + b_.getInt32(static_cast(datatype)), + /*shape=*/b_.getInt64(ShapeUtil::ElementsIn(rs->shape())), + /*input_buffer=*/input_buffer, + /*output_buffer=*/output_buffer}, + b_.getVoidTy()); + + return OkStatus(); } Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 2e27a7c810869d..0cc07a27246f77 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -486,6 +486,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(AllGather); + REGISTER_CPU_RUNTIME_SYMBOL(ReduceScatter); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); From bf163a7d4536a86fa4099d0215685bb3290bc04b Mon Sep 17 00:00:00 2001 From: shawnwang18 <35983922+shawnwang18@users.noreply.github.com> Date: Wed, 29 Nov 2023 12:59:59 -0800 Subject: [PATCH 199/381] PR #7370: [XLA:GPU] fix command_buffer_thunk_test failure Imported from GitHub PR https://github.com/openxla/xla/pull/7370 Copybara import of the project: -- 2764c45902d16c7259a46d5ce926e2283ea171ce by Shawn Wang : fix command_buffer_thunk_test failure Merging this change closes #7370 PiperOrigin-RevId: 586429505 --- third_party/xla/xla/service/gpu/buffer_allocations.cc | 9 ++++++--- third_party/xla/xla/stream_executor/cuda/cuda_driver.cc | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.cc b/third_party/xla/xla/service/gpu/buffer_allocations.cc index a1b8a7214f62b9..38c832d40654d6 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/buffer_allocations.cc @@ -87,14 +87,17 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); if (base.is_external_allocation_marker()) { - auto cmd_buffer_base = command_buffer->GetAllocationAddress( - buffer_slice.allocation()->index()); + auto cmd_buffer_base = + command_buffer->GetAllocationAddress(buffer_slice.index()); CHECK(cmd_buffer_base.ok()) << "Get allocation address from command_buffer failed"; CHECK(!cmd_buffer_base.value().is_null()) << "Allocation is not yet allocated by command buffer for slice: " << buffer_slice.ToString(); - return cmd_buffer_base.value(); + return se::DeviceMemoryBase( + static_cast(cmd_buffer_base.value().opaque()) + + buffer_slice.offset(), + buffer_slice.size()); } return se::DeviceMemoryBase( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 19fb144b9fc11e..e60e17d641ce6f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -952,6 +952,8 @@ static CUmemAllocationType ToCudaAllocationType( #if CUDA_VERSION >= 12000 mem_pool_props.maxSize = max_pool_size; #endif // CUDA_VERSION >= 12000 + // cuda graph requires reserved space initialized to 0 + memset(mem_pool_props.reserved, 0, sizeof(mem_pool_props.reserved)); params.accessDescCount = 1; params.bytesize = size; From d9e9f33d19146ffa6a30eaa9fa3a64ad242e088e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 29 Nov 2023 13:10:25 -0800 Subject: [PATCH 200/381] [xla:ffi] Added error reporting to existing decoders MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_BufferArgX1 6.05ns ± 1% 6.72ns ± 3% +10.95% (p=0.000 n=18+17) BM_BufferArgX4 11.5ns ± 6% 13.4ns ± 8% +15.73% (p=0.000 n=18+20) BM_TupleOfI32Attrs 36.8ns ± 1% 42.4ns ± 3% +15.11% (p=0.000 n=19+17) name old time/op new time/op delta BM_BufferArgX1 6.05ns ± 1% 6.72ns ± 3% +10.95% (p=0.000 n=18+17) BM_BufferArgX4 11.5ns ± 6% 13.4ns ± 8% +15.71% (p=0.000 n=18+20) BM_TupleOfI32Attrs 36.8ns ± 1% 42.4ns ± 3% +15.11% (p=0.000 n=19+17) PiperOrigin-RevId: 586432525 --- third_party/xla/xla/ffi/BUILD | 2 - third_party/xla/xla/ffi/api/BUILD | 6 +- third_party/xla/xla/ffi/api/api.h | 179 +++++++++++++++++++----- third_party/xla/xla/ffi/api/ffi.h | 33 ++++- third_party/xla/xla/ffi/api/ffi_test.cc | 52 ++++++- 5 files changed, 219 insertions(+), 53 deletions(-) diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 9b1b2cc67b442e..702552a61ada23 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -89,12 +89,10 @@ xla_cc_test( name = "ffi_test", srcs = ["ffi_test.cc"], deps = [ - ":api", ":call_frame", ":ffi", ":ffi_api", "//xla:xla_data_proto_cc", - "//xla/ffi/api:c_api", "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 6965ae7705fa5a..fa35ce81f57128 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -72,17 +72,15 @@ xla_cc_test( name = "ffi_test", srcs = ["ffi_test.cc"], deps = [ - ":api", ":ffi", "//xla:xla_data_proto_cc", "//xla/ffi:call_frame", "//xla/ffi:ffi_api", - "//xla/ffi/api:c_api", - "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 9ebc01805d4b74..b37a170f57638d 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -25,12 +25,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include // This is a header-only base C++ library that defines templates for decoding @@ -55,6 +57,22 @@ limitations under the License. #include "xla/ffi/api/c_api.h" +#if __has_attribute(always_inline) +#define XLA_ATTRIBUTE_ALWAYS_INLINE inline __attribute__((always_inline)) +#elif defined(_MSC_VER) +#define XLA_ATTRIBUTE_ALWAYS_INLINE __forceinline +#else +#define XLA_ATTRIBUTE_ALWAYS_INLINE inline +#endif + +#if __has_attribute(noinline) +#define XLA_ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#elif defined(_MSC_VER) +#define XLA_ATTRIBUTE_NEVER_INLINE __declspec(noinline) +#else +#define XLA_ATTRIBUTE_NEVER_INLINE +#endif + namespace xla::ffi { // Forward declare template defined below. @@ -354,7 +372,11 @@ class DiagnosticEngine; class InFlightDiagnostic { public: explicit InFlightDiagnostic(DiagnosticEngine* engine, std::string s) - : engine_(engine), stream_(std::move(s)) {} + : engine_(engine) { + stream_ << s; + } + InFlightDiagnostic(const InFlightDiagnostic&) = delete; + InFlightDiagnostic& operator=(const InFlightDiagnostic&) = delete; ~InFlightDiagnostic(); @@ -364,14 +386,12 @@ class InFlightDiagnostic { return *this; } - operator std::nullopt_t() const { // NOLINT + template + operator std::optional() const { // NOLINT return std::nullopt; } private: - InFlightDiagnostic& operator=(const InFlightDiagnostic&) = delete; - InFlightDiagnostic& operator=(InFlightDiagnostic&&) = delete; - DiagnosticEngine* engine_; std::stringstream stream_; }; @@ -386,14 +406,14 @@ class DiagnosticEngine { return InFlightDiagnostic(this, std::move(message)); } - std::string Result() const { return s_; } + std::string Result() const { return acc_; } private: friend class InFlightDiagnostic; - void append(std::string s) { s_.append(std::move(s)); } + void append(std::string s) { acc_.append(std::move(s)); } - std::string s_; + std::string acc_; }; inline InFlightDiagnostic::~InFlightDiagnostic() { @@ -422,6 +442,7 @@ struct DecodingContext { template struct Decode { + XLA_ATTRIBUTE_ALWAYS_INLINE static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, DiagnosticEngine& diagnostic) { int64_t idx = offsets.args++; @@ -470,6 +491,52 @@ struct internal::Decode> { } }; +//===----------------------------------------------------------------------===// +// Expected +//===----------------------------------------------------------------------===// + +// Forward declare. +template +class Unexpected; + +// TODO(slebedev): Replace with `std::expected` when C++23 is available. +template +class Expected { + public: + Expected(T value) : data_(std::move(value)) {} // NOLINT + Expected(Unexpected u); // NOLINT + + operator bool() const { // NOLINT + return has_value(); + } + T operator*() const { return value(); } + T* operator->() const { return &value(); } + + bool has_value() const { return std::holds_alternative(data_); } + T value() const { return std::get(data_); } + E error() const { return std::get(data_); } + + private: + std::variant data_; +}; + +template +class Unexpected { + public: + explicit Unexpected(E error) : error_(std::move(error)) {} + + private: + template + friend class Expected; + + E error_; +}; + +Unexpected(const char*) -> Unexpected; + +template +Expected::Expected(Unexpected u) : data_(std::move(u.error_)) {} + //===----------------------------------------------------------------------===// // Type-safe wrapper for accessing a variable number of arguments. //===----------------------------------------------------------------------===// @@ -485,14 +552,19 @@ class RemainingArgs { bool empty() const { return size() == 0; } template - std::optional get(size_t index) const { + Expected get(size_t index) const { size_t idx = offset_ + index; - if (idx >= args_->num_args) return std::nullopt; + if (idx >= args_->num_args) { + return Unexpected("Index out of range."); + } - // TODO(slebedev): Expose the collected diagnostic to the caller. DiagnosticEngine diagnostic; - return ArgDecoding::Decode(args_->types[idx], args_->args[idx], - diagnostic); + auto value_opt = + ArgDecoding::Decode(args_->types[idx], args_->args[idx], diagnostic); + if (!value_opt.has_value()) { + return Unexpected(diagnostic.Result()); + } + return *value_opt; } private: @@ -524,15 +596,25 @@ class Dictionary { } template - std::optional get(std::string_view name) const { + Expected get(std::string_view name) const { + DiagnosticEngine diagnostic; + auto value_opt = get(name, diagnostic); + if (!value_opt.has_value()) { + return Unexpected(diagnostic.Result()); + } + return *value_opt; + } + + template + std::optional get(std::string_view name, + DiagnosticEngine& diagnostic) const { size_t idx = Find(name); - if (idx >= attrs_->num_attrs) return std::nullopt; + if (idx >= attrs_->num_attrs) { + return diagnostic.Emit("Unexpected attribute: ") << name; + } XLA_FFI_AttrType attr_type = attrs_->types[idx]; void* attr = attrs_->attrs[idx]; - - // TODO(slebedev): Expose the collected diagnostic to the caller. - DiagnosticEngine diagnostic; return AttrDecoding::Decode(attr_type, attr, diagnostic); } @@ -710,8 +792,8 @@ class Handler : public Ffi { private: template - XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame, - std::index_sequence) const { + XLA_ATTRIBUTE_ALWAYS_INLINE XLA_FFI_Error* Call( + const XLA_FFI_CallFrame* call_frame, std::index_sequence) const { // A helper structure to allow each decoder find the correct offset. internal::DecodingOffsets offsets; @@ -791,13 +873,29 @@ class Handler : public Ffi { // Builtin attributes decoding //===----------------------------------------------------------------------===// +inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_AttrType type) { + switch (type) { + case XLA_FFI_AttrType_I32: + return os << "int32"; + case XLA_FFI_AttrType_I64: + return os << "int64"; + case XLA_FFI_AttrType_F32: + return os << "float"; + case XLA_FFI_AttrType_STRING: + return os << "string"; + case XLA_FFI_AttrType_DICTIONARY: + return os << "dictionary"; + } +} + #define XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(T, TYPE) \ template <> \ struct AttrDecoding { \ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine&) { \ + DiagnosticEngine& diagnostic) { \ if (type != TYPE) { \ - return std::nullopt; \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << TYPE << " but got " << type; \ } \ \ return *reinterpret_cast(attr); \ @@ -813,9 +911,11 @@ XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_AttrType_F32); template <> struct AttrDecoding { static std::optional Decode(XLA_FFI_AttrType type, - void* attr, DiagnosticEngine&) { + void* attr, + DiagnosticEngine& diagnostic) { if (type != XLA_FFI_AttrType_STRING) { - return std::nullopt; + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_STRING << " but got " << type; } auto* span = reinterpret_cast(attr); @@ -826,9 +926,10 @@ struct AttrDecoding { template <> struct AttrDecoding { static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine&) { + DiagnosticEngine& diagnostic) { if (type != XLA_FFI_AttrType_DICTIONARY) { - return std::nullopt; + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_DICTIONARY << " but got " << type; } auto* attrs = reinterpret_cast(attr); @@ -856,16 +957,21 @@ template struct DecodeDictionaryAttr { static constexpr size_t kSize = sizeof...(Ts); + XLA_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode(const XLA_FFI_Attrs* attrs, - std::array names) { - return Decode(attrs, names, std::make_index_sequence{}); + std::array names, + DiagnosticEngine& diagnostic) { + return Decode(attrs, names, std::make_index_sequence{}, diagnostic); } template - static std::optional Decode(const XLA_FFI_Attrs* attrs, - std::array names, - std::index_sequence) { - if (kSize != attrs->num_attrs) return std::nullopt; + XLA_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode( + const XLA_FFI_Attrs* attrs, std::array names, + std::index_sequence, DiagnosticEngine& diagnostic) { + if (kSize != attrs->num_attrs) { + return diagnostic.Emit("Wrong number of attributes: expected ") + << kSize << " attributes but got " << attrs->num_attrs; + } // TODO(ezhulenev): We rely on dictionary to lookup struct members by name // at run time, however it can become really expensive. We should @@ -877,7 +983,8 @@ struct DecodeDictionaryAttr { // constructor. Add benchmarks first to know what to improve! Dictionary dict(attrs); - std::tuple...> members = {dict.get(names[Is])...}; + std::tuple...> members = { + dict.get(names[Is], diagnostic)...}; bool all_decoded = (std::get(members).has_value() && ...); if (!all_decoded) return std::nullopt; @@ -910,15 +1017,17 @@ auto DictionaryDecoder(Members... m) { template <> \ struct AttrDecoding { \ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine&) { \ + DiagnosticEngine& diagnostic) { \ if (type != XLA_FFI_AttrType_DICTIONARY) { \ + diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_DICTIONARY << " but got " << type; \ return std::nullopt; \ } \ \ auto decoder = internal::DictionaryDecoder(__VA_ARGS__); \ return decltype(decoder)::Decode( \ reinterpret_cast(attr), \ - internal::StructMemberNames(__VA_ARGS__)); \ + internal::StructMemberNames(__VA_ARGS__), diagnostic); \ } \ } diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 9982335fe12b47..113209123415bd 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -53,6 +54,14 @@ enum class DataType : uint8_t { BF16 = XLA_FFI_DataType_BF16, }; +inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { + static constexpr const char* kDataTypeNames[] = { + "PRED", "S8", "S16", "S32", "S64", "U8", "U16", + "U32", "U64", "F16", "F32", "F64", "BF16", + }; + return os << kDataTypeNames[static_cast(dtype)]; +} + //===----------------------------------------------------------------------===// // Span is non-owning view into contiguous values of type `T`. //===----------------------------------------------------------------------===// @@ -146,17 +155,31 @@ struct BufferBase { // Arguments decoding //===----------------------------------------------------------------------===// +inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) { + switch (type) { + case XLA_FFI_ArgType_BUFFER: + return os << "buffer"; + } +} + template struct ArgDecoding> { + XLA_ATTRIBUTE_ALWAYS_INLINE static std::optional> Decode(XLA_FFI_ArgType type, - void* arg, DiagnosticEngine&) { - if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; + void* arg, + DiagnosticEngine& diagnostic) { + if (type != XLA_FFI_ArgType_BUFFER) { + return diagnostic.Emit("Wrong argument type: expected ") + << XLA_FFI_ArgType_BUFFER << " but got " << type; + } auto* buf = reinterpret_cast(arg); - // TODO(slebedev): Emit a user-friendly error instead. - if (static_cast(buf->dtype) != dtype) return std::nullopt; + if (auto actual_dtype = static_cast(buf->dtype); + actual_dtype != dtype) { + return diagnostic.Emit("Wrong buffer dtype: expected ") + << dtype << " but got " << actual_dtype; + } auto* data = static_cast::Type*>(buf->data); - return BufferBase{data, Span(buf->dims, buf->rank)}; } }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 7b4715bdfc4d41..212d7ea855c24f 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -25,11 +25,20 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" namespace xla::ffi { +namespace { + +using ::testing::HasSubstr; +using ::tsl::testing::StatusIs; + +} // namespace + TEST(FfiTest, DataTypeEnumValue) { // Verify that xla::PrimitiveType and xla::ffi::DataType use the same // integer value for encoding data types. @@ -62,18 +71,47 @@ TEST(FfiTest, BufferArgument) { builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto fn = [&](BufferBase buffer) { - EXPECT_EQ(buffer.data, storage.data()); - EXPECT_EQ(buffer.dimensions.size(), 2); - return Error::Success(); - }; - - auto handler = Ffi::Bind().Arg>().To(fn); + auto handler = + Ffi::Bind().Arg>().To([&](auto buffer) { + EXPECT_EQ(buffer.data, storage.data()); + EXPECT_EQ(buffer.dimensions.size(), 2); + return Error::Success(); + }); auto status = Call(*handler, call_frame); TF_ASSERT_OK(status); } +TEST(FfiTest, MissingBufferArgument) { + CallFrameBuilder builder; + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg>().To( + [](auto) { return Error::Success(); }); + auto status = Call(*handler, call_frame); + + EXPECT_THAT(status, StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("Wrong number of arguments"))); +} + +TEST(FfiTest, WrongTypeBufferArgument) { + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(std::int32_t)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg>().To( + [](auto) { return Error::Success(); }); + auto status = Call(*handler, call_frame); + + EXPECT_THAT( + status, + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("Wrong buffer dtype: expected F64 but got S64"))); +} + //===----------------------------------------------------------------------===// // Performance benchmarks are below. //===----------------------------------------------------------------------===// From e2b224e8d0a501b40db01bea3db19cef818fae1c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 13:32:44 -0800 Subject: [PATCH 201/381] [XLA:CPU] Add a direct implementation of ReduceScatter, instead of lowering ReduceScatter to AllReduce+DynamicSlice. PiperOrigin-RevId: 586438695 --- third_party/xla/xla/service/cpu/BUILD | 3 +- .../xla/service/cpu/collectives_interface.h | 6 - .../xla/xla/service/cpu/cpu_compiler.cc | 2 + .../xla/service/cpu/cpu_layout_assignment.cc | 7 - .../xla/xla/service/cpu/cpu_runtime.cc | 77 +---- third_party/xla/xla/service/cpu/cpu_runtime.h | 8 - .../xla/service/cpu/in_process_collectives.cc | 280 +++++------------- .../xla/service/cpu/in_process_collectives.h | 6 - third_party/xla/xla/service/cpu/ir_emitter.cc | 98 ++---- .../xla/xla/service/cpu/simple_orc_jit.cc | 1 - 10 files changed, 107 insertions(+), 381 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index a385177618b2db..44dbf2f4db410d 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -305,6 +305,7 @@ cc_library( "//xla/service:optimization_barrier_expander", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", + "//xla/service:reduce_scatter_decomposer", "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:result_caster", @@ -901,7 +902,6 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:types", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -911,7 +911,6 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index e2c8190c81b986..4191df1d831fa2 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -67,12 +67,6 @@ class CollectivesCommunicator { virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, void* output_buffer, absl::Duration timeout) = 0; - - // Performs a reduce-scatter - virtual absl::Status ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 598281a271924b..e2723c77c18a7b 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -191,6 +191,7 @@ limitations under the License. #include "xla/service/optimization_barrier_expander.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" +#include "xla/service/reduce_scatter_decomposer.h" #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/result_caster.h" @@ -684,6 +685,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 71d7073dd53046..9b2b8331e2d3fd 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -129,13 +129,6 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (instruction->opcode() == HloOpcode::kReduceScatter) { - // XLA:CPU can only support reduce-scatter where the scatter dimension - // is the most major dimension in the layout. - auto ars = Cast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout( - ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), - ars)); } else if (instruction->opcode() == HloOpcode::kAllGather) { // XLA:CPU can only support all-gathers where the gather dimension is the // most major dimension in the layout. diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 185ea948d2b196..81bd76652a3ac6 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -29,9 +29,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -49,7 +46,6 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -147,8 +143,6 @@ extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; -extern const char* const kReduceScatterSymbolName = - "__xla_cpu_runtime_ReduceScatter"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -321,19 +315,6 @@ CollectivesInterface* GetInProcessCollectivesImpl() { absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } -absl::StatusOr RankInGlobalDevices( - absl::Span devices, GlobalDeviceId device) { - auto it = absl::c_find(devices, device); - if (it == devices.end()) { - return InvalidArgument( - "Device %d not present in global devices %s.", device.value(), - absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) { - absl::StrAppend(out, id.value()); - })); - } - return std::distance(devices.begin(), it); -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t channel_id_present, int64_t op_id, @@ -350,7 +331,9 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -378,7 +361,9 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -389,35 +374,6 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY -void ReduceScatterImpl(const ExecutableRunOptions* run_options, - const void* replica_groups_str, - int32_t replica_groups_str_size, - int32_t channel_id_present, int64_t op_id, - int32_t reduction_kind, int32_t element_type, - int64_t chunk_elems, void* input_buffer, - void* output_buffer) { - GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( - static_cast(replica_groups_str), replica_groups_str_size); - std::vector group = - ParseReplicaGroupsOnly(replica_groups_serialized).value(); - RendezvousKey rendezvous_key = - GetRendezvousKey(run_options, device, group, channel_id_present, - /*use_global_device_ids=*/std::nullopt, op_id); - - int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - - CollectivesInterface* collectives = GetInProcessCollectivesImpl(); - - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); - TF_CHECK_OK(communicator->ReduceScatter( - rendezvous_key, static_cast(reduction_kind), - static_cast(element_type), chunk_elems, input_buffer, - output_buffer, DefaultCollectiveTimeout())); -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -443,7 +399,9 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -492,7 +450,9 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -582,19 +542,6 @@ void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, run_options, channel_id_present, op_id, replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, destination_buffer); } - -void __xla_cpu_runtime_ReduceScatter( - const xla::ExecutableRunOptions* run_options, - const void* replica_groups_str, int32_t replica_groups_str_size, - int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, - int32_t element_type, int64_t chunk_elems, void* input_buffer, - void* output_buffer) { - return xla::cpu::runtime::ReduceScatterImpl( - run_options, replica_groups_str, replica_groups_str_size, - channel_id_present, op_id, reduction_kind, element_type, chunk_elems, - input_buffer, output_buffer); -} - void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index dd00571cb2e8dc..9429242d5f1b86 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -85,7 +85,6 @@ extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; extern const char* const kAllGatherSymbolName; -extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -203,13 +202,6 @@ extern void __xla_cpu_runtime_AllGather( int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, void* destination_buffer); -void __xla_cpu_runtime_ReduceScatter( - const xla::ExecutableRunOptions* run_options, - const void* replica_groups_str, int32_t replica_groups_str_size, - int32_t channel_id_present, int64_t op_id, int32_t reduction_kind, - int32_t element_type, int64_t chunk_elems, void* input_buffer, - void* output_buffer); - // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index ed30082be82e8e..78eee1aa74f731 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -39,7 +39,6 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/util.h" -#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { @@ -94,7 +93,7 @@ template constexpr bool always_false_v = false; template -void ReduceHelper(absl::Span acc, absl::Span inputs) { +void Reduce(absl::Span acc, absl::Span const> inputs) { // TODO(penporn): make sure this gets vectorized. if constexpr (reduction_kind == ReductionKind::SUM) { for (size_t j = 0; j < inputs.size(); ++j) { @@ -125,49 +124,6 @@ void ReduceHelper(absl::Span acc, absl::Span inputs) { } } -template -absl::Status ReduceScatter(ReductionKind reduction_kind, - absl::Span inputs, void* output, - int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(reduction_kind); - - absl::Span out_chunk = - absl::MakeSpan(reinterpret_cast(output), num_elems); - for (int64_t i = 0; i < num_elems; ++i) { - out_chunk[i] = initial_value; - } - - absl::Span input_chunks( - reinterpret_cast(inputs.data()), inputs.size()); - switch (reduction_kind) { - case ReductionKind::SUM: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::PRODUCT: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); -} - class CpuAllReduceRendezvous : public Rendezvous { public: @@ -190,86 +146,110 @@ class CpuAllReduceRendezvous return nullptr; } - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; - void* reduce_output = - reinterpret_cast(me.destination_data) + chunk_offset; - - std::vector inputs; - inputs.reserve(world_size); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_data) + - chunk_offset); - } - switch (me.primitive_type) { case S8: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case PRED: case U8: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case S16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case U16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case S32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case U32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case S64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case U64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case F16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case F32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case F64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case C64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; case C128: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); + TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); break; default: return absl::UnimplementedError("Unexpected datatype"); } - // All-gather the reduced chunks. + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; for (const auto& p : participants_) { if (p->local_rank != me.local_rank) { - std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, - reduce_output, chunk_bytes); + std::memcpy( + reinterpret_cast(p->destination_data) + chunk_offset, + reinterpret_cast(me.destination_data) + chunk_offset, + chunk_bytes); } } return nullptr; } + + template + absl::Status DoAllReduce(const AllReduceParticipantData& me, + int64_t start_elem, int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(me.reduction_kind); + T* acc = reinterpret_cast(me.destination_data); + for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { + acc[i] = initial_value; + } + + absl::Span out_chunk = absl::MakeSpan( + reinterpret_cast(me.destination_data) + start_elem, num_elems); + std::vector> inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(absl::Span( + reinterpret_cast(p->source_data) + start_elem, num_elems)); + } + switch (me.reduction_kind) { + case ReductionKind::SUM: + Reduce(out_chunk, inputs); + break; + case ReductionKind::PRODUCT: + Reduce(out_chunk, inputs); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + Reduce(out_chunk, inputs); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + Reduce(out_chunk, inputs); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); + } }; struct CollectivePermuteParticipantData : ParticipantData { @@ -398,109 +378,6 @@ class CpuAllGatherRendezvous } }; -struct ReduceScatterParticipantData : ParticipantData { - ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - ReductionKind reduction_kind; - PrimitiveType element_type; - const void* source_buffer; - void* destination_buffer; - size_t chunk_elems; - - std::string ToString() const override { - return absl::StrFormat( - "ReduceScatterParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_elems=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_elems); - } -}; - -class CpuReduceScatterRendezvous - : public Rendezvous { - public: - explicit CpuReduceScatterRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const ReduceScatterParticipantData& me) override { - auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); - int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; - - std::vector inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_buffer) + - chunk_offset); - } - - switch (me.element_type) { - case S8: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case C64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case C128: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - default: - return absl::UnimplementedError("Unexpected datatype"); - } - - return nullptr; - } -}; - } // namespace struct InProcessCollectivesState { @@ -512,8 +389,6 @@ struct InProcessCollectivesState { all_to_all_rendezvous_map; RefcountingHashMap all_gather_rendezvous_map; - RefcountingHashMap - reduce_scatter_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -613,27 +488,6 @@ absl::Status InProcessCollectivesCommunicator::AllGather( .status(); } -absl::Status InProcessCollectivesCommunicator::ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - ReduceScatterParticipantData participant(key, rank_); - participant.element_type = element_type; - participant.reduction_kind = reduction_kind; - participant.chunk_elems = chunk_elems; - participant.source_buffer = input_buffer; - participant.destination_buffer = output_buffer; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuReduceScatterRendezvous::SubmitParticipant( - [&] { - return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index f80baf38c4ebdc..aaedc474fa39b2 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -59,12 +59,6 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { const void* input_buffer, void* output_buffer, absl::Duration timeout) override; - absl::Status ReduceScatter(const RendezvousKey& key, - ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 46ae3978aaa2e6..f5d7a4c2c40fab 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1169,36 +1169,35 @@ Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { return OkStatus(); } -// Data types supported by ReduceScatter and AllReduce. -static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { - // TODO(cheshire): Fix duplication wrt. cpu_runtime - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } -} - Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); - if (!DataTypeIsSupportedByReduceScatter(datatype)) { + bool is_datatype_supported = [&] { + // TODO(cheshire): Fix duplication wrt. cpu_runtime + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } + }(); + + if (!is_datatype_supported) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } @@ -1286,54 +1285,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { - CHECK_EQ(rs->operand_count(), 1); - PrimitiveType datatype = rs->operand(0)->shape().element_type(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); - - if (!DataTypeIsSupportedByReduceScatter(datatype)) { - return Unimplemented("ReduceScatter for datatype '%s' is not supported", - primitive_util::LowercasePrimitiveTypeName(datatype)); - } - - if (!MatchReductionComputation(rs->to_apply()).has_value()) { - return Unimplemented("ReduceScatter for computation '%s' is not supported", - rs->to_apply()->ToString()); - } - - std::string replica_groups = ReplicaGroupsToString(rs->replica_groups()); - int32_t replica_groups_size = replica_groups.size(); - llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); - - Shape shape = rs->operand(0)->shape(); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - assignment_.GetUniqueSlice(rs->operand(0), {})); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - assignment_.GetUniqueSlice(rs, {})); - llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); - llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); - - EmitCallToFunc( - runtime::kReduceScatterSymbolName, - {/*run_options=*/GetExecutableRunOptionsArgument(), - /*replica_groups_str=*/replica_groups_v, - /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), - - /*channel_id_present=*/ - b_.getInt32(static_cast(rs->channel_id().has_value())), - /*op_id=*/ - b_.getInt64(rs->channel_id().has_value() ? *rs->channel_id() - : rs->GetModule()->unique_id()), - /*reduction_kind=*/ - b_.getInt32( - static_cast(*MatchReductionComputation(rs->to_apply()))), - /*element_type=*/ - b_.getInt32(static_cast(datatype)), - /*shape=*/b_.getInt64(ShapeUtil::ElementsIn(rs->shape())), - /*input_buffer=*/input_buffer, - /*output_buffer=*/output_buffer}, - b_.getVoidTy()); - - return OkStatus(); + return Unimplemented("ReduceScatter is not implemented on CPU."); } Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 0cc07a27246f77..2e27a7c810869d 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -486,7 +486,6 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(AllGather); - REGISTER_CPU_RUNTIME_SYMBOL(ReduceScatter); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); From b0cc761464989edc819fce9ec8bbc26399ff89c5 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 29 Nov 2023 14:03:26 -0800 Subject: [PATCH 202/381] Remove obselete TODO PiperOrigin-RevId: 586447290 --- third_party/xla/.kokoro/jax/build.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index 305c1806c106ea..d533623a868f06 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -74,7 +74,6 @@ build_and_test_on_rbe_gpu() { # Runs non-multiaccelerator tests with one GPU apiece. # It appears --run_under needs an absolute path. - # TODO(ddunleavy): reenable `LaxTest.testBitcastConvertType` bazel \ test \ --verbose_failures=true \ From 3a5c401e70a7e9957e4c55cedf50fd15d70e0b67 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 14:12:05 -0800 Subject: [PATCH 203/381] [XLA:CPU] Add a direct implementation of AllGather, rather than lowering AllGather to AllReduce. PiperOrigin-RevId: 586449725 --- third_party/xla/xla/service/cpu/BUILD | 3 +- .../xla/service/cpu/collectives_interface.h | 9 +-- .../xla/xla/service/cpu/cpu_compiler.cc | 2 + .../xla/service/cpu/cpu_layout_assignment.cc | 10 ---- .../xla/xla/service/cpu/cpu_runtime.cc | 39 ------------ third_party/xla/xla/service/cpu/cpu_runtime.h | 7 --- .../xla/service/cpu/in_process_collectives.cc | 59 ------------------- .../xla/service/cpu/in_process_collectives.h | 4 -- third_party/xla/xla/service/cpu/ir_emitter.cc | 46 --------------- third_party/xla/xla/service/cpu/ir_emitter.h | 1 - .../xla/xla/service/cpu/simple_orc_jit.cc | 1 - 11 files changed, 5 insertions(+), 176 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 44dbf2f4db410d..ef51b94dcd80ad 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -250,6 +250,7 @@ cc_library( "//xla/runtime:executable", "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", + "//xla/service:all_gather_decomposer", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", "//xla/service:batch_dot_simplification", @@ -1350,9 +1351,7 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features", - "//xla:shape_util", "//xla:util", - "//xla/hlo/ir:hlo", "//xla/service:computation_layout", "//xla/service:layout_assignment", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index 4191df1d831fa2..bd518db3a780bc 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -59,14 +59,9 @@ class CollectivesCommunicator { // The all-to-all chunks are passed separately and do not have to be // contiguous in memory. virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, + absl::Span input_buffer, + absl::Span output_buffer, absl::Duration timeout) = 0; - - // Performs an all-gather. - virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index e2723c77c18a7b..ae2a2cd5672ae6 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -115,6 +115,7 @@ limitations under the License. #include "xla/runtime/executable.h" #include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" +#include "xla/service/all_gather_decomposer.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" #include "xla/service/batch_dot_simplification.h" @@ -684,6 +685,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 9b2b8331e2d3fd..8b124ddaa60397 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -18,12 +18,9 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/map_util.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" -#include "xla/shape_util.h" #include "tsl/platform/errors.h" namespace xla { @@ -129,13 +126,6 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (instruction->opcode() == HloOpcode::kAllGather) { - // XLA:CPU can only support all-gathers where the gather dimension is the - // most major dimension in the layout. - auto ag = Cast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout( - ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()), - ag)); } else { for (int64_t operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 81bd76652a3ac6..a4090d67fc8d9b 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -142,7 +142,6 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; -extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -346,34 +345,6 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } -ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY -void AllGatherImpl(const ExecutableRunOptions* run_options, - int32_t channel_id_present, int64_t op_id, - const void* replica_groups_str, - int32_t replica_groups_str_size, int64_t buffer_size, - void* source_buffer, void* destination_buffer) { - GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( - static_cast(replica_groups_str), replica_groups_str_size); - std::vector group = - ParseReplicaGroupsOnly(replica_groups_serialized).value(); - RendezvousKey rendezvous_key = - GetRendezvousKey(run_options, device, group, channel_id_present, - /*use_global_device_ids=*/std::nullopt, op_id); - - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); - - CollectivesInterface* collectives = GetInProcessCollectivesImpl(); - - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); - TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size, - source_buffer, destination_buffer, - DefaultCollectiveTimeout())); -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -532,16 +503,6 @@ void __xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions* run_options, destination_buffers); } -void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, - int32_t channel_id_present, int64_t op_id, - const void* replica_groups_str, - int32_t replica_groups_str_size, - int64_t buffer_size, void* source_buffer, - void* destination_buffer) { - return xla::cpu::runtime::AllGatherImpl( - run_options, channel_id_present, op_id, replica_groups_str, - replica_groups_str_size, buffer_size, source_buffer, destination_buffer); -} void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 9429242d5f1b86..361a116e7c300f 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -84,7 +84,6 @@ extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; -extern const char* const kAllGatherSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -196,12 +195,6 @@ extern void __xla_cpu_runtime_AllToAll( int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size, void** source_buffers, void** destination_buffers); -extern void __xla_cpu_runtime_AllGather( - const xla::ExecutableRunOptions* run_options, int32_t channel_id_present, - int64_t op_id, const void* replica_groups_str, - int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, - void* destination_buffer); - // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index 78eee1aa74f731..fc13c6c3870377 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -340,44 +340,6 @@ class CpuAllToAllRendezvous } }; -struct AllGatherParticipantData : ParticipantData { - AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - const void* source_buffer; - void* destination_buffer; - size_t chunk_size; - - std::string ToString() const override { - return absl::StrFormat( - "AllGatherParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_size); - } -}; - -class CpuAllGatherRendezvous - : public Rendezvous { - public: - explicit CpuAllGatherRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllGatherParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - char* out = static_cast(p.destination_buffer); - for (int i = 0; i < world_size; ++i, out += p.chunk_size) { - std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); - } - return nullptr; - } -}; - } // namespace struct InProcessCollectivesState { @@ -387,8 +349,6 @@ struct InProcessCollectivesState { collective_permute_rendezvous_map; RefcountingHashMap all_to_all_rendezvous_map; - RefcountingHashMap - all_gather_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -469,25 +429,6 @@ absl::Status InProcessCollectivesCommunicator::AllToAll( .status(); } -absl::Status InProcessCollectivesCommunicator::AllGather( - const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - AllGatherParticipantData participant(key, rank_); - participant.chunk_size = chunk_bytes; - participant.source_buffer = input_buffer; - participant.destination_buffer = output_buffer; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllGatherRendezvous::SubmitParticipant( - [&] { - return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index aaedc474fa39b2..fb25fd3528d606 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -55,10 +55,6 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { absl::Span output_buffers, absl::Duration timeout) override; - absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index f5d7a4c2c40fab..18523dec844113 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1344,52 +1344,6 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { return OkStatus(); } -Status IrEmitter::HandleAllGather(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); - - std::string replica_groups = - ReplicaGroupsToString(instruction->replica_groups()); - int32_t replica_groups_size = replica_groups.size(); - llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); - - std::vector input_buffer_ptrs; - std::vector output_buffer_ptrs; - - const HloInstruction* op = instruction->operand(0); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice in_slice, - assignment_.GetUniqueSlice(op, {})); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, - assignment_.GetUniqueSlice(instruction, {})); - const Shape& operand_shape = op->shape(); - CHECK(op->shape().IsArray()) - << "Operand to all-gather must be arrays: " << instruction->ToString(); - llvm::Value* output_buffer = EmitBufferPointer(out_slice, operand_shape); - llvm::Value* input_buffer = GetEmittedValueFor(op); - int64_t buffer_size = in_slice.size(); - - EmitCallToFunc( - runtime::kAllGatherSymbolName, - { - /*run_options=*/GetExecutableRunOptionsArgument(), - /*channel_id_present=*/ - b_.getInt32( - static_cast(instruction->channel_id().has_value())), - /*op_id=*/ - b_.getInt64(instruction->channel_id().has_value() - ? *instruction->channel_id() - : instruction->GetModule()->unique_id()), - /*replica_groups_str=*/replica_groups_v, - /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), - /*buffer_size=*/b_.getInt64(buffer_size), - /*source_buffer=*/input_buffer, - /*destination_buffer=*/output_buffer, - }, - b_.getVoidTy()); - - llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); - return OkStatus(); -} - Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index a0e31671ab4375..3a194d054cb5fd 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -134,7 +134,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* hlo) override; - Status HandleAllGather(HloInstruction* instruction) override; Status HandleAllToAll(HloInstruction* instruction) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant) override; diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 2e27a7c810869d..8895b4f6451d5a 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -485,7 +485,6 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); - REGISTER_CPU_RUNTIME_SYMBOL(AllGather); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); From 73f5d2fab5c0de118c08bc70f076dfaf5f8d3b98 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 29 Nov 2023 14:20:12 -0800 Subject: [PATCH 204/381] Remove tensorflow namespace from `tsl/platform/status_matchers.h` PiperOrigin-RevId: 586452165 --- .../third_party/tsl/tsl/platform/status_matchers.h | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h index bddf2529771f1e..ee2144dca8a698 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h @@ -98,19 +98,12 @@ limitations under the License. // Status status = OkStatus(); // EXPECT_THAT(status, IsOk()); -namespace tensorflow { -namespace error { -// TODO(ddunleavy) Move this to TSL. This stays here until error_codes proto -// is moved to TSL due to an ADL issue +namespace tsl { + inline void PrintTo(const tsl::error::Code code, std::ostream* os) { *os << Code_Name(code); } -} // namespace error -} // namespace tensorflow - -namespace tsl { - template void PrintTo(const StatusOr& status_or, std::ostream* os) { *os << ::testing::PrintToString(status_or.status()); From ba9f6615601dfbf816b2836329874a0397a289f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 14:21:56 -0800 Subject: [PATCH 205/381] [XLA:LatencyHidingScheduler] Prefer picking the instruction with less serial-resource conflicts. This CL adds the new rule kLessSerialResourceConflict, which encourages picking the instruction whose serial resource conflict is smaller than the other alternative. The conflict is computed as the sum of the number of conflicting resources in flight. The new rule is placed after kLessStall, which means if a conflicting instruction creates less stall than a non-conflicting instruction, it will still be picked. The new rule is useful when there are enough non-resource-conflicting zero-stall alternatives that we can overlap the async collectives with. PiperOrigin-RevId: 586452690 --- .../xla/service/latency_hiding_scheduler.cc | 35 +++++++++++++++++++ .../xla/service/latency_hiding_scheduler.h | 6 ++++ 2 files changed, 41 insertions(+) diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index f76d438aeed193..84f8888cf763e7 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -407,6 +407,14 @@ AsyncTracker::GetOccupiedShareableResourcesFromVector( return {}; } +// For now, only the target-defined resources have serial hazard type, so +// this async tracker does not know which resources are serial. +absl::InlinedVector +AsyncTracker::GetOccupiedSerialResourcesFromVector( + const ResourcesVector& resources) const { + return {}; +} + BufferInfoTracker::BufferInfoTracker( const HloModule* module, const HloAliasAnalysis* alias_analysis, const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) { @@ -772,6 +780,20 @@ class ReadySetLt { b_ready_interval < a_ready_interval, b, "kLessStall")) { return *value; } + if (sched_state_.config.resource_serializing) { + // Prioritize scheduling the instruction which has less serial-resource + // conflicts with the resources in flight. + const int64_t a_num_conflicting_resources = + GetNumConflictingSerialResources(a); + const int64_t b_num_conflicting_resources = + GetNumConflictingSerialResources(b); + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + a_num_conflicting_resources < b_num_conflicting_resources, a, + b_num_conflicting_resources < a_num_conflicting_resources, b, + "kLessSerialResourceConflict")) { + return *value; + } + } if (sched_state_.config.aggressive_scheduling_policies) { // If an instruction releasing a resource is not resource constrained and // has an async depth of 0, delay it as much as possible to avoid @@ -995,6 +1017,19 @@ class ReadySetLt { } return *cand.pressure_change; } + int64_t GetNumConflictingSerialResources( + DefaultSchedulerCore::ScheduleCandidate& cand) const { + auto resources = + sched_state_.async_tracker->GetOccupiedSerialResourcesFromVector( + cand.node->GetResources()); + int64_t num_conflicting_resources = 0; + for (int64_t resource : resources) { + if (!sched_state_.resources_in_flight.contains(resource)) continue; + num_conflicting_resources += + sched_state_.resources_in_flight.at(resource); + } + return num_conflicting_resources; + } }; } // namespace diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 32d229a31a12fa..ca6ccf54f4c498 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -108,6 +108,7 @@ struct SchedulerConfig { bool aggressive_scheduling_policies = false; bool enable_release_start_policy = false; bool resource_sharing = false; + bool resource_serializing = false; bool depth_based_memory_pressure_reduction = false; int64_t rerun = 0; }; @@ -237,6 +238,11 @@ class AsyncTracker { GetOccupiedShareableResourcesFromVector( const ResourcesVector& resources) const; + // Returns the list of the occupied serial resources filtered from the given + // resources vector. + virtual absl::InlinedVector GetOccupiedSerialResourcesFromVector( + const ResourcesVector& resources) const; + inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const { return get_canonical_async_op_(hlo); } From 9da5526ee0e967a829c730f9851f7d83871a1dde Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Wed, 29 Nov 2023 15:12:36 -0800 Subject: [PATCH 206/381] Remove math_ops.py's indirect dependency on resource_variable_ops.py by replacing the isinstance checks with checks in the C++ layer. PiperOrigin-RevId: 586466978 --- tensorflow/python/ops/BUILD | 2 +- tensorflow/python/ops/math_ops.py | 18 +++++++++--------- tensorflow/python/ops/resource_variable_ops.py | 2 -- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 7fabdd432da41f..4b1892b922edc9 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -2081,6 +2081,7 @@ py_strict_library( "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops/numpy_ops:np_dtypes", "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", @@ -2115,7 +2116,6 @@ py_strict_library( ":array_ops", ":array_ops_gen", ":handle_data_util", - ":math_ops", ":resource_variable_ops_gen", ":state_ops", ":state_ops_gen", diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 7645eaaae39311..07c60e959e77d4 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -95,6 +95,7 @@ # pylint: enable=wildcard-import from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch @@ -233,11 +234,6 @@ def linspace_nd(start, stop, num, name=None, axis=0): tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min)) -# This is set by resource_variable_ops.py. It is included in this way since -# there is a circular dependency between math_ops and resource_variable_ops -_resource_variable_type = None - - def _set_doc(doc): def _decorator(func): @@ -997,8 +993,9 @@ def cast(x, dtype, name=None): """ base_type = dtypes.as_dtype(dtype).base_dtype - if isinstance( - x, (tensor_lib.Tensor, _resource_variable_type)) and base_type == x.dtype: + if ( + isinstance(x, tensor_lib.Tensor) or _pywrap_utils.IsResourceVariable(x) + ) and base_type == x.dtype: return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): @@ -3755,9 +3752,12 @@ def matmul(a, f"`adjoint_b`={adjoint_b}.") if context.executing_eagerly(): - if not isinstance(a, (ops.EagerTensor, _resource_variable_type)): + if not ( + isinstance(a, ops.EagerTensor) or _pywrap_utils.IsResourceVariable(a) + ): a = ops.convert_to_tensor(a, name="a") - if not isinstance(b, (ops.EagerTensor, _resource_variable_type)): + if not isinstance(b, ops.EagerTensor) or _pywrap_utils.IsResourceVariable( + b): b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b") else: a = ops.convert_to_tensor(a, name="a") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 6e1a6b6280b10a..d5b776e5997a30 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -49,7 +49,6 @@ from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import handle_data_util -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables # go/tf-wildcard-import @@ -2333,7 +2332,6 @@ def __init__( # pylint: disable=super-init-not-called _pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) -math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): From e0693f70a4be76925111b407428435b976435f60 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 29 Nov 2023 15:13:18 -0800 Subject: [PATCH 207/381] [xla:ffi] Parameterized ffi::BaseBuffer with its rank MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_BufferArgX1 6.76ns ± 3% 6.72ns ± 1% ~ (p=0.518 n=19+16) BM_BufferArgX4 13.0ns ± 9% 12.9ns ± 8% ~ (p=0.092 n=18+20) BM_TupleOfI32Attrs 45.1ns ± 1% 43.3ns ± 1% -3.90% (p=0.000 n=16+17) name old time/op new time/op delta BM_BufferArgX1 6.76ns ± 3% 6.72ns ± 1% ~ (p=0.529 n=19+16) BM_BufferArgX4 13.0ns ± 9% 12.9ns ± 8% ~ (p=0.093 n=18+20) BM_TupleOfI32Attrs 45.1ns ± 1% 43.3ns ± 1% -3.90% (p=0.000 n=16+17) PiperOrigin-RevId: 586467144 --- third_party/xla/xla/ffi/api/ffi.h | 43 +++++++++------- third_party/xla/xla/ffi/api/ffi_test.cc | 65 +++++++++++++++---------- 2 files changed, 64 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 113209123415bd..a3e4b1fbbdbde0 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -129,23 +130,25 @@ struct PtrType { // clang-format off template <> struct PtrType { using Type = bool; }; -template <> struct PtrType { using Type = std::uint8_t; }; -template <> struct PtrType { using Type = std::uint16_t; }; -template <> struct PtrType { using Type = std::uint32_t; }; -template <> struct PtrType { using Type = std::uint64_t; }; -template <> struct PtrType { using Type = std::int8_t; }; -template <> struct PtrType { using Type = std::int16_t; }; -template <> struct PtrType { using Type = std::int32_t; }; -template <> struct PtrType { using Type = std::int64_t; }; -template <> struct PtrType { using Type = std::uint16_t; }; +template <> struct PtrType { using Type = uint8_t; }; +template <> struct PtrType { using Type = uint16_t; }; +template <> struct PtrType { using Type = uint32_t; }; +template <> struct PtrType { using Type = uint64_t; }; +template <> struct PtrType { using Type = int8_t; }; +template <> struct PtrType { using Type = int16_t; }; +template <> struct PtrType { using Type = int32_t; }; +template <> struct PtrType { using Type = int64_t; }; +template <> struct PtrType { using Type = uint16_t; }; template <> struct PtrType { using Type = float; }; template <> struct PtrType { using Type = double; }; -template <> struct PtrType { using Type = std::uint16_t; }; +template <> struct PtrType { using Type = uint16_t; }; // clang-format on +inline constexpr size_t kDynamicRank = std::numeric_limits::max(); + } // namespace internal -template +template struct BufferBase { typename internal::PtrType::Type* data; Span dimensions; @@ -162,12 +165,11 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) { } } -template -struct ArgDecoding> { +template +struct ArgDecoding> { XLA_ATTRIBUTE_ALWAYS_INLINE - static std::optional> Decode(XLA_FFI_ArgType type, - void* arg, - DiagnosticEngine& diagnostic) { + static std::optional> Decode( + XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) { if (type != XLA_FFI_ArgType_BUFFER) { return diagnostic.Emit("Wrong argument type: expected ") << XLA_FFI_ArgType_BUFFER << " but got " << type; @@ -180,7 +182,14 @@ struct ArgDecoding> { } auto* data = static_cast::Type*>(buf->data); - return BufferBase{data, Span(buf->dims, buf->rank)}; + if constexpr (rank != internal::kDynamicRank) { + if (buf->rank != rank) { + diagnostic.Emit("Wrong buffer rank: expected ") + << rank << " but got " << buf->rank; + return std::nullopt; + } + } + return BufferBase{data, Span(buf->dims, rank)}; } }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 212d7ea855c24f..5e081313d59bef 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -72,7 +72,7 @@ TEST(FfiTest, BufferArgument) { auto call_frame = builder.Build(); auto handler = - Ffi::Bind().Arg>().To([&](auto buffer) { + Ffi::Bind().Arg>().To([&](auto buffer) { EXPECT_EQ(buffer.data, storage.data()); EXPECT_EQ(buffer.dimensions.size(), 2); return Error::Success(); @@ -86,7 +86,7 @@ TEST(FfiTest, MissingBufferArgument) { CallFrameBuilder builder; auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -94,6 +94,23 @@ TEST(FfiTest, MissingBufferArgument) { HasSubstr("Wrong number of arguments"))); } +TEST(FfiTest, WrongRankBufferArgument) { + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(std::int32_t)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg>().To( + [](auto) { return Error::Success(); }); + auto status = Call(*handler, call_frame); + + EXPECT_THAT(status, + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("Wrong buffer rank: expected 1 but got 2"))); +} + TEST(FfiTest, WrongTypeBufferArgument) { std::vector storage(4, 0.0); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(std::int32_t)); @@ -102,7 +119,7 @@ TEST(FfiTest, WrongTypeBufferArgument) { builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -134,12 +151,11 @@ static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { void BM_BufferArgX1(benchmark::State& state) { auto call_frame = WithBufferArgs(1).Build(); - auto fn = [](BufferBase buffer) { - benchmark::DoNotOptimize(buffer); - return Error::Success(); - }; - - auto handler = Ffi::Bind().Arg>().To(fn); + auto handler = + Ffi::Bind().Arg>().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + return Error::Success(); + }); for (auto _ : state) { CHECK_OK(Call(*handler, call_frame)); } @@ -154,21 +170,18 @@ BENCHMARK(BM_BufferArgX1); void BM_BufferArgX4(benchmark::State& state) { auto call_frame = WithBufferArgs(4).Build(); - auto fn = [](BufferBase b0, BufferBase b1, - BufferBase b2, BufferBase b3) { - benchmark::DoNotOptimize(b0); - benchmark::DoNotOptimize(b1); - benchmark::DoNotOptimize(b2); - benchmark::DoNotOptimize(b3); - return Error::Success(); - }; - auto handler = Ffi::Bind() - .Arg>() - .Arg>() - .Arg>() - .Arg>() - .To(fn); + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .To([](auto b0, auto b1, auto b2, auto b3) { + benchmark::DoNotOptimize(b0); + benchmark::DoNotOptimize(b1); + benchmark::DoNotOptimize(b2); + benchmark::DoNotOptimize(b3); + return Error::Success(); + }); for (auto _ : state) { CHECK_OK(Call(*handler, call_frame)); @@ -205,12 +218,10 @@ void BM_TupleOfI32Attrs(benchmark::State& state) { builder.AddAttributes(attrs.Build()); auto call_frame = builder.Build(); - auto fn = [](TupleOfI32 tuple) { + auto handler = Ffi::Bind().Attrs().To([](auto tuple) { benchmark::DoNotOptimize(tuple); return Error::Success(); - }; - - auto handler = Ffi::Bind().Attrs().To(fn); + }); for (auto _ : state) { CHECK_OK(Call(*handler, call_frame)); From 44a6a97857f55e20a3654b6d7b812edef918cf91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 15:31:15 -0800 Subject: [PATCH 208/381] [XLA] Allow moving for HloSharding in xla::HloInstruction::set_sharding. PiperOrigin-RevId: 586471723 --- third_party/xla/xla/hlo/ir/hlo_instruction.h | 4 ++-- third_party/xla/xla/service/hlo_parser.cc | 2 +- third_party/xla/xla/service/sharding_propagation.cc | 13 +++++++------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 5145203e11da55..3de91ccd73a3a3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1663,8 +1663,8 @@ class HloInstruction { } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. - void set_sharding(const HloSharding& sharding) { - set_sharding(std::make_shared(sharding)); + void set_sharding(HloSharding sharding) { + set_sharding(std::make_shared(std::move(sharding))); } void set_sharding(std::shared_ptr sharding) { sharding_ = std::move(sharding); diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index b44c48994af247..3a31abe766ba88 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -1366,7 +1366,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, // normalizing tuple sharding. HloSharding hlo_sharding = HloSharding::FromProto(sharding.value()).value(); hlo_sharding = hlo_sharding.NormalizeTupleSharding(instruction->shape()); - instruction->set_sharding(hlo_sharding); + instruction->set_sharding(std::move(hlo_sharding)); } if (parameter_replication) { int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index b72ac39b0870a9..87415c7e2d1abe 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -1598,13 +1598,14 @@ StatusOr ProcessShardingInstruction( absl::c_sort(unspec_dims); unspecified_dims->emplace(instruction, std::move(unspec_dims)); } else if (!instruction->operand(0)->has_sharding()) { - instruction->mutable_operand(0)->set_sharding(sharding); + instruction->mutable_operand(0)->set_sharding(std::move(sharding)); } } else if (instruction->has_sharding()) { // Handle shard group in parameters/outputs. process_shard_group_instruction(instruction, instruction->sharding()); HloSharding sharding = instruction->sharding(); - instruction->set_sharding(sharding.ClearShardGroup()); + sharding.ClearShardGroup(); + instruction->set_sharding(std::move(sharding)); } } } @@ -1669,7 +1670,7 @@ int64_t ComputeNonRootUsers(const HloInstruction* instr) { operand_sharding = HloSharding::SingleTuple(operand->shape(), *sharding); } - operand->set_sharding(operand_sharding); + operand->set_sharding(std::move(operand_sharding)); } } return OkStatus(); @@ -2256,7 +2257,7 @@ bool ShardingPropagation::InferShardingFromOperands( HloSharding new_sharding = operand->sharding().GetSubSharding( operand->shape(), {instruction->tuple_index()}); if (new_sharding.IsManual()) { - instruction->set_sharding(new_sharding); + instruction->set_sharding(std::move(new_sharding)); return true; } return MaybeImproveInstructionSharding( @@ -2772,7 +2773,7 @@ bool ShardingPropagation::InferShardingFromUsers( ShardingPropagation::GetShardingFromUser( *instruction, *user, aggressiveness, is_spmd, call_graph); if (user_sharding && user_sharding->IsManual()) { - instruction->set_sharding(*user_sharding); + instruction->set_sharding(std::move(*user_sharding)); return true; } } @@ -3331,7 +3332,7 @@ StatusOr ShardingPropagation::Run( root_sharding.tuple_elements()[i] = saved_root_shardings[i]; } } - root_instruction->set_sharding(root_sharding); + root_instruction->set_sharding(std::move(root_sharding)); } auto params = module->entry_computation()->parameter_instructions(); if (allow_spmd_sharding_propagation_to_parameters_ && From 7b451a1295560459bcdeadb141109d06bb7ef806 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 15:31:29 -0800 Subject: [PATCH 209/381] Add EvalOrPattern to StablehloRefineShapes pass PiperOrigin-RevId: 586471778 --- third_party/stablehlo/temporary.patch | 72 ++++++++++++++++++- .../xla/third_party/stablehlo/temporary.patch | 72 ++++++++++++++++++- 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 11801eab6222a0..0014e5c9384b16 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -2509,7 +2509,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp -@@ -0,0 +1,1293 @@ +@@ -0,0 +1,1308 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -2818,6 +2818,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + } +}; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ +struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, @@ -3764,6 +3778,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); @@ -3803,4 +3818,59 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +@@ -340,6 +340,19 @@ + %1 = stablehlo.constant dense<2> : tensor + %2 = stablehlo.multiply %0, %1 : tensor + func.return %2 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @eval_or ++func.func @eval_or() -> tensor { ++ // CHECK-NOT: stablehlo.or ++ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor ++ // CHECK: return [[RESULT]] ++ %0 = stablehlo.constant dense : tensor ++ %1 = stablehlo.constant dense : tensor ++ %2 = stablehlo.or %0, %1 : tensor ++ func.return %2 : tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -304,6 +304,20 @@ + } + }; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ + struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, +@@ -1165,6 +1179,7 @@ + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index d14f4225f35aba..6ed8468c4819c5 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -2509,7 +2509,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp -@@ -0,0 +1,1293 @@ +@@ -0,0 +1,1308 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -2818,6 +2818,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + } +}; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ +struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, @@ -3764,6 +3778,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); @@ -3803,4 +3818,59 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +@@ -340,6 +340,19 @@ + %1 = stablehlo.constant dense<2> : tensor + %2 = stablehlo.multiply %0, %1 : tensor + func.return %2 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @eval_or ++func.func @eval_or() -> tensor { ++ // CHECK-NOT: stablehlo.or ++ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor ++ // CHECK: return [[RESULT]] ++ %0 = stablehlo.constant dense : tensor ++ %1 = stablehlo.constant dense : tensor ++ %2 = stablehlo.or %0, %1 : tensor ++ func.return %2 : tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -304,6 +304,20 @@ + } + }; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ + struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, +@@ -1165,6 +1179,7 @@ + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); From bfed4a531796cd2bf39913159a3c92a63f1dbefe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 15:42:35 -0800 Subject: [PATCH 210/381] [XLA] Allow tuple_shardings move in the HloSharding ctor. PiperOrigin-RevId: 586474576 --- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 5 +++-- third_party/xla/xla/hlo/ir/hlo_sharding.h | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 0f1cb985618b08..8e06a925d45a61 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -361,7 +361,7 @@ HloSharding HloSharding::Tuple(const Shape& tuple_shape, << "Flat list has " << flattened_list.size() << ", required " << RequiredLeaves(tuple_shape); } - return HloSharding(flattened_list); + return HloSharding(std::move(flattened_list)); } HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, @@ -811,7 +811,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, HloSharding::FromProto(tuple_sharding_proto)); tuple_shardings.push_back(sharding); } - return HloSharding(tuple_shardings).SetShardGroupFromProto(proto); + return std::move( + HloSharding(std::move(tuple_shardings)).SetShardGroupFromProto(proto)); } else if (proto.type() == OpSharding::REPLICATED) { return Replicate(metadata).SetShardGroupFromProto(proto); } else if (proto.type() == OpSharding::MANUAL) { diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.h b/third_party/xla/xla/hlo/ir/hlo_sharding.h index 430f95799a7bad..8b943b75b00c00 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.h @@ -568,8 +568,8 @@ class HloSharding { manual_(false), unknown_(false), replicate_on_last_tile_dim_(false) {} - explicit HloSharding(const std::vector& tuple_shardings) - : tuple_elements_(tuple_shardings), + explicit HloSharding(std::vector tuple_shardings) + : tuple_elements_(std::move(tuple_shardings)), replicated_(false), maximal_(false), tuple_(true), From 5759fac45ae353892c558928d85f0710fb3cfc2c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 29 Nov 2023 15:50:53 -0800 Subject: [PATCH 211/381] [xla:gpu] Move external allocation implementation details from StreamExecutor to XLA:GPU level This is a high level XLA implementation detail that should not leak deep into StreamExecutor PiperOrigin-RevId: 586476378 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/buffer_allocations.cc | 43 +++------- .../xla/xla/service/gpu/buffer_allocations.h | 41 ++++++--- .../xla/xla/service/gpu/runtime3/BUILD | 1 + .../gpu/runtime3/command_buffer_cmd.cc | 83 ++++++++++++++++--- .../gpu/runtime3/command_buffer_thunk_test.cc | 7 +- .../xla/xla/stream_executor/device_memory.h | 16 +--- 7 files changed, 119 insertions(+), 73 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 151fef12956b2a..b4d2f868fcc914 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -758,6 +758,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gpu_constants", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.cc b/third_party/xla/xla/service/gpu/buffer_allocations.cc index 38c832d40654d6..193586bc9a4807 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/buffer_allocations.cc @@ -15,17 +15,12 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" -#include -#include +#include +#include -#include "xla/map_util.h" -#include "xla/service/gpu/gpu_constants.h" -#include "xla/status_macros.h" -#include "xla/types.h" -#include "xla/util.h" -#include "tsl/lib/gtl/map_util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" namespace xla { namespace gpu { @@ -79,30 +74,14 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( buffer_slice.size()); } -se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( +StatusOr BufferAllocations::GetDeviceAddress( const BufferAllocation::Slice& buffer_slice, - const se::CommandBuffer* command_buffer) const { + const ExternalAllocations& external_allocations) const { + // Check if base memory address is an external allocation. se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); - CHECK_LE(buffer_slice.offset(), base.size()); - CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); - - if (base.is_external_allocation_marker()) { - auto cmd_buffer_base = - command_buffer->GetAllocationAddress(buffer_slice.index()); - CHECK(cmd_buffer_base.ok()) - << "Get allocation address from command_buffer failed"; - CHECK(!cmd_buffer_base.value().is_null()) - << "Allocation is not yet allocated by command buffer for slice: " - << buffer_slice.ToString(); - return se::DeviceMemoryBase( - static_cast(cmd_buffer_base.value().opaque()) + - buffer_slice.offset(), - buffer_slice.size()); - } - - return se::DeviceMemoryBase( - static_cast(base.opaque()) + buffer_slice.offset(), - buffer_slice.size()); + return reinterpret_cast(base.opaque()) == kExternalAllocationMarker + ? external_allocations.GetDeviceAddress(buffer_slice) + : GetDeviceAddress(buffer_slice); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.h b/third_party/xla/xla/service/gpu/buffer_allocations.h index e805a66053c0cb..37d2eb8c2d2f54 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.h +++ b/third_party/xla/xla/service/gpu/buffer_allocations.h @@ -16,17 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_ #define XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_ -#include +#include +#include #include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" +#include "xla/status.h" #include "xla/statusor.h" -#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" @@ -37,6 +38,28 @@ namespace gpu { // allocated device buffers. class BufferAllocations { public: + // This special address is used to indicate that the allocation is not + // allocated at construction time and instead will be lazily allocated and + // owned by the XLA executable itself (we use this special marker to handle + // buffer allocations allocated within command buffers, which for CUDA + // backends means that buffer allocation is done via memory allocation node). + // + // TODO(ezhulenev): Replace magic bit pattern with std::optional or + // std::variant to distinguish external allocations from a regular ones. + static constexpr uintptr_t kExternalAllocationMarker = 0xDEADBEEF; + + // A virtual base class for external allocations that provides a mapping + // from a buffer index to an externally-managed device memory. + class ExternalAllocations { + public: + virtual ~ExternalAllocations() = default; + + // Return a device address for a given buffer slice. Returns error if + // corresponding allocation is not yet allocated. + virtual StatusOr GetDeviceAddress( + BufferAllocation::Slice buffer_slice) const = 0; + }; + BufferAllocations(absl::Span buffers, int device_ordinal, se::DeviceMemoryAllocator* memory_allocator) @@ -69,14 +92,12 @@ class BufferAllocations { se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; - // For buffers that are lazily allocated through command buffer, this is - // indicated by specifying a special buffer address - // (LAZY_ALLOCATE_ADDRESS_MARKER), the real buffer address is tracked in - // CommandBuffer, this API will fetches the real address from CommandBuffer - // runtime. - se::DeviceMemoryBase GetDeviceAddress( + // Finds an allocation for a given buffer slice, and if it happens to be an + // external allocation resolves it using user-provided external allocations. + // Returns error if external allocations do not have an address for a slice. + StatusOr GetDeviceAddress( const BufferAllocation::Slice& buffer_slice, - const se::CommandBuffer* command_buffer) const; + const ExternalAllocations& external_allocations) const; // Tears down all buffers allocated by this object that are not in // `live_addresses`. diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index 45e0951d39e004..bb44a6442f7e7e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -38,6 +38,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 63841534daeed4..ebe03fb5b719c4 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -30,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/status.h" +#include "xla/statusor.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" @@ -41,6 +44,47 @@ limitations under the License. namespace xla::gpu { +//===----------------------------------------------------------------------===// +// External buffer allocations managed by a command buffer. +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): External allocations should be managed by the command buffer +// thunk, as the command buffer itself should not know anything about XLA buffer +// allocation index or buffer slice. + +namespace { +class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { + public: + explicit CommandBufferAllocations(se::CommandBuffer* command_buffer); + + StatusOr GetDeviceAddress( + BufferAllocation::Slice buffer_slice) const override; + + private: + se::CommandBuffer* command_buffer_; +}; +} // namespace + +CommandBufferAllocations::CommandBufferAllocations( + se::CommandBuffer* command_buffer) + : command_buffer_(command_buffer) {} + +StatusOr CommandBufferAllocations::GetDeviceAddress( + BufferAllocation::Slice buffer_slice) const { + TF_ASSIGN_OR_RETURN( + auto base, command_buffer_->GetAllocationAddress(buffer_slice.index())); + + if (base.is_null()) { + return absl::InternalError(absl::StrCat("Memory for a buffer #", + buffer_slice.index(), + " is not yet allocated")); + } + + return se::DeviceMemoryBase( + static_cast(base.opaque()) + buffer_slice.offset(), + buffer_slice.size()); +} + //===----------------------------------------------------------------------===// // CommandBufferCmdSequence //===----------------------------------------------------------------------===// @@ -128,10 +172,13 @@ Status LaunchCmd::Record(const RecordParams& params, "Kernel not loaded on a command buffer executor"); } + CommandBufferAllocations external_allocations(command_buffer); + absl::InlinedVector buffers; for (const BufferAllocation::Slice& arg : args_) { - se::DeviceMemoryBase buf = - params.buffer_allocations->GetDeviceAddress(arg, command_buffer); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase buf, + params.buffer_allocations->GetDeviceAddress(arg, external_allocations)); VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); buffers.push_back(buf); } @@ -165,10 +212,16 @@ Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, se::CommandBuffer* command_buffer) { VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << ", src=" << src_ << ", num_bytes=" << num_bytes_; - se::DeviceMemoryBase dst = - params.buffer_allocations->GetDeviceAddress(dst_, command_buffer); - se::DeviceMemoryBase src = - params.buffer_allocations->GetDeviceAddress(src_, command_buffer); + + CommandBufferAllocations external_allocations(command_buffer); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase dst, + params.buffer_allocations->GetDeviceAddress(dst_, external_allocations)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase src, + params.buffer_allocations->GetDeviceAddress(src_, external_allocations)); + return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); } @@ -255,12 +308,18 @@ Status GemmCmd::Record(const RecordParams& params, se::DeviceMemoryBase workspace(nullptr, 0); - se::DeviceMemoryBase lhs = - params.buffer_allocations->GetDeviceAddress(lhs_buffer_, command_buffer); - se::DeviceMemoryBase rhs = - params.buffer_allocations->GetDeviceAddress(rhs_buffer_, command_buffer); - se::DeviceMemoryBase out = params.buffer_allocations->GetDeviceAddress( - output_buffer_, command_buffer); + CommandBufferAllocations external_allocations(command_buffer); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs, + params.buffer_allocations->GetDeviceAddress( + lhs_buffer_, external_allocations)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs, + params.buffer_allocations->GetDeviceAddress( + rhs_buffer_, external_allocations)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase out, + params.buffer_allocations->GetDeviceAddress( + output_buffer_, external_allocations)); + TF_ASSIGN_OR_RETURN( auto nested_buffer, se::CommandBuffer::Trace(params.executor, [&](se::Stream* stream) { diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 66a9c17d92259e..f1569fc6f3d560 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -140,10 +140,9 @@ TEST(CommandBufferThunkTest, MemallocCmd) { // Prepare arguments: a=42, b=0 se::DeviceMemory a = executor->AllocateArray(length, 0); stream.ThenMemset32(&a, 42, byte_length); - - se::DeviceMemory b = - se::DeviceMemory::MakeExternalAllocationFromByteSize( - byte_length); + se::DeviceMemory b(se::DeviceMemoryBase( + reinterpret_cast(BufferAllocations::kExternalAllocationMarker), + byte_length)); se::DeviceMemory c = executor->AllocateArray(length, 0); BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator()); diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index bd8957ab14b648..fe3df0067f337d 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -26,6 +26,7 @@ limitations under the License. #include +#include #include #include "xla/stream_executor/platform/port.h" @@ -35,11 +36,6 @@ namespace stream_executor { class DeviceMemoryAllocator; class StreamExecutor; -// This special address is used to indicate that the allocation is not ready -// when constructing DeviceMemory object, and will be lazily allocated by -// an external allocator (e.g. command buffer for GPU backend). -inline constexpr uintptr_t kExternalAllocationMarker = 0xDEADBEEF; - // void*-analogous device memory allocation. For the typed variation, see // DeviceMemory. // @@ -62,10 +58,6 @@ class DeviceMemoryBase { // A `== nullptr` convenience method is also provided. bool is_null() const { return opaque_ == nullptr; } - bool is_external_allocation_marker() const { - return reinterpret_cast(opaque_) == kExternalAllocationMarker; - } - bool operator==(std::nullptr_t other) const { return is_null(); } bool operator!=(std::nullptr_t other) const { return !is_null(); } @@ -154,12 +146,6 @@ class DeviceMemory final : public DeviceMemoryBase { return DeviceMemory(opaque, bytes); } - static DeviceMemory MakeExternalAllocationFromByteSize( - uint64_t bytes) { - return DeviceMemory( - reinterpret_cast(kExternalAllocationMarker), bytes); - } - // Resets the DeviceMemory data, in MakeFromByteSize fashion. // This simply clobbers the prior values. void ResetFromByteSize(void *opaque, uint64_t bytes) { From 2b0099bf57979c487fe60cb76a954d531f91509f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 15:58:29 -0800 Subject: [PATCH 212/381] [XLA] Do not create temporary vector in HloSharding::GetSubSharding. PiperOrigin-RevId: 586478058 --- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 8e06a925d45a61..b72b9ab933ae85 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -1044,9 +1044,10 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, } if (sub_shape->IsTuple()) { auto begin_it = tuple_elements_.begin() + sharding_index; - std::vector sub_shardings( - begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); - return HloSharding::Tuple(*sub_shape, sub_shardings); + return HloSharding::Tuple( + *sub_shape, + absl::MakeConstSpan( + &*begin_it, &*(begin_it + ShapeUtil::GetLeafCount(*sub_shape)))); } else { return tuple_elements_[sharding_index]; } From cbd7da34404892f7555e7e0d05ea1324b0f8bc8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 16:08:20 -0800 Subject: [PATCH 213/381] [XLA] Remove an unnecessary HloSharding assignment. PiperOrigin-RevId: 586480698 --- third_party/xla/xla/service/sharding_propagation.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 87415c7e2d1abe..c9c15af999e02a 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -1938,8 +1938,8 @@ std::optional ShardingPropagation::GetShardingFromUser( case HloOpcode::kSort: { HloSharding user_sharding = user.sharding(); if (user_sharding.IsTuple()) { - return user_sharding = user_sharding.GetSubSharding( - user.shape(), {user.operand_index(&instruction)}); + return user_sharding.GetSubSharding(user.shape(), + {user.operand_index(&instruction)}); } return user_sharding; } From 474f4473c4621c773f2a9d098bc9cffa07173db7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 16:15:53 -0800 Subject: [PATCH 214/381] Use clang+NVCC compilers for all XLA GPU jobs. PiperOrigin-RevId: 586482573 --- third_party/xla/.kokoro/linux/build.sh | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/third_party/xla/.kokoro/linux/build.sh b/third_party/xla/.kokoro/linux/build.sh index 9444f5a12bfea1..635af61a6d3ed5 100644 --- a/third_party/xla/.kokoro/linux/build.sh +++ b/third_party/xla/.kokoro/linux/build.sh @@ -26,10 +26,6 @@ function is_linux_gpu_job() { [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*gpu.* ]] } -function is_use_nvcc() { - [[ "${USE_NVCC:-}" == "true" ]] -} - # Pull the container (in case it was updated since the instance started) and # store its SHA in the Sponge log. docker pull "$DOCKER_IMAGE" @@ -54,11 +50,7 @@ if is_linux_gpu_job ; then TAGS_FILTER="$TAGS_FILTER,gpu,requires-gpu-nvidia,-no_gpu" ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute" RC_FILE="/usertools/gpu.bazelrc" - if is_use_nvcc ; then - RBE_CONFIG="rbe_linux_cuda_nvcc" - else - RBE_CONFIG="rbe_linux_cuda" - fi + RBE_CONFIG="rbe_linux_cuda_nvcc" echo "***NOTE: nvidia-smi lists the highest CUDA version the driver supports, which may be different than the version of CUDA actually used!!***" nvidia-smi else From 413a3d47d339fee1bd4e0b0baa8ed6b6629f6ac1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 16:21:04 -0800 Subject: [PATCH 215/381] [XLA] Move HloSharding where possible. PiperOrigin-RevId: 586483855 --- third_party/xla/xla/hlo/ir/hlo_sharding.cc | 34 ++++++++++--------- .../xla/xla/service/sharding_propagation.cc | 12 +++---- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index b72b9ab933ae85..9497747e6ece24 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -809,23 +809,23 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { TF_ASSIGN_OR_RETURN(HloSharding sharding, HloSharding::FromProto(tuple_sharding_proto)); - tuple_shardings.push_back(sharding); + tuple_shardings.push_back(std::move(sharding)); } return std::move( HloSharding(std::move(tuple_shardings)).SetShardGroupFromProto(proto)); } else if (proto.type() == OpSharding::REPLICATED) { - return Replicate(metadata).SetShardGroupFromProto(proto); + return std::move(Replicate(metadata).SetShardGroupFromProto(proto)); } else if (proto.type() == OpSharding::MANUAL) { - return Manual(metadata).SetShardGroupFromProto(proto); + return std::move(Manual(metadata).SetShardGroupFromProto(proto)); } else if (proto.type() == OpSharding::UNKNOWN) { - return Unknown(metadata).SetShardGroupFromProto(proto); + return std::move(Unknown(metadata).SetShardGroupFromProto(proto)); } else if (proto.tile_assignment_devices().size() == 1) { - return HloSharding(proto.tile_assignment_devices(0), metadata) - .SetShardGroupFromProto(proto); + return std::move(HloSharding(proto.tile_assignment_devices(0), metadata) + .SetShardGroupFromProto(proto)); } else if (!proto.iota_reshape_dims().empty() && absl::c_all_of(proto.iota_reshape_dims(), [](int64_t d) { return d == 1; })) { - return HloSharding(0, metadata).SetShardGroupFromProto(proto); + return std::move(HloSharding(0, metadata).SetShardGroupFromProto(proto)); } TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL) @@ -882,15 +882,17 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, }; if (!subgroup_types.empty()) { TF_RET_CHECK(!proto.replicate_on_last_tile_dim()); - return Subgroup(create_tile_assignment(), subgroup_types, metadata) - .SetShardGroupFromProto(proto); - } - return proto.replicate_on_last_tile_dim() - ? PartialTile(create_tile_assignment(), metadata) - .SetShardGroupFromProto(proto) - : HloSharding(create_tile_assignment(), - /*replicate_on_last_tile_dim=*/false, metadata) - .SetShardGroupFromProto(proto); + return std::move( + Subgroup(create_tile_assignment(), subgroup_types, metadata) + .SetShardGroupFromProto(proto)); + } + if (proto.replicate_on_last_tile_dim()) { + return std::move(PartialTile(create_tile_assignment(), metadata) + .SetShardGroupFromProto(proto)); + } + return std::move(HloSharding(create_tile_assignment(), + /*replicate_on_last_tile_dim=*/false, metadata) + .SetShardGroupFromProto(proto)); } OpSharding HloSharding::ToProto() const { diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index c9c15af999e02a..e78b351a7eed79 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -183,9 +183,9 @@ bool MaybeImproveInstructionSharding(HloSharding sharding, HloInstruction* instruction, bool may_combine_partial_sharding, bool allow_aggressive_resharding = false) { - if (auto new_sharding = ReturnImprovedSharding(sharding, instruction, - may_combine_partial_sharding, - allow_aggressive_resharding)) { + if (auto new_sharding = ReturnImprovedSharding( + std::move(sharding), instruction, may_combine_partial_sharding, + allow_aggressive_resharding)) { instruction->set_sharding(std::move(*new_sharding)); return true; } @@ -200,8 +200,8 @@ bool MaybeImproveInstructionSubSharding( bool allow_aggressive_resharding = false) { if (instruction->shape().IsTuple()) { if (auto new_sub_sharding = ReturnImprovedSubSharding( - sharding, instruction, index, may_combine_partial_sharding, - allow_aggressive_resharding)) { + std::move(sharding), instruction, index, + may_combine_partial_sharding, allow_aggressive_resharding)) { HloSharding new_sharding = instruction->has_sharding() ? instruction->sharding() @@ -217,7 +217,7 @@ bool MaybeImproveInstructionSubSharding( } } CHECK(index.size() == 1 && index[0] == 0); - return MaybeImproveInstructionSharding(sharding, instruction, + return MaybeImproveInstructionSharding(std::move(sharding), instruction, may_combine_partial_sharding, allow_aggressive_resharding); } From a0f9e8bc13eab5e04a0c125581a9990aa4ac2be6 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Wed, 29 Nov 2023 16:31:48 -0800 Subject: [PATCH 216/381] Remove `alwayslink = True` for stablehlo bridge passes. PiperOrigin-RevId: 586486498 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 4773cc2c884d1e..0c3f41c4596a01 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -132,9 +132,6 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", ], - # Alwayslink is required for registering the MLIR passes. - # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. - alwayslink = True, ) td_library( @@ -255,8 +252,6 @@ cc_library( "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", ], - # Force link to ensure ConvertTFQuantOpsToMHLOPass is registered. - alwayslink = True, ) tf_cc_test( From 467205fb71fd61a7643cca924e0d34640bdc5195 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 16:38:33 -0800 Subject: [PATCH 217/381] [XLA] Change the prototype of ReturnImprovedShardingImpl to minmize copies. PiperOrigin-RevId: 586488071 --- .../xla/xla/service/sharding_propagation.cc | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index e78b351a7eed79..052c2eba7615ab 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -102,12 +102,11 @@ bool IsShardingStrictlyBetter(const HloSharding& lhs, const HloSharding& rhs) { // Implementation for returning a improved sharding from another sharding. std::optional ReturnImprovedShardingImpl( - HloSharding from, std::optional to_improved, + HloSharding from, const HloSharding* to_improved, const Shape& to_improved_shape, bool may_combine_partial_sharding, bool allow_aggressive_resharding = false) { // Always allow improve the sharding if it's straightly better. - if (to_improved.has_value() && - IsShardingStrictlyBetter(from, to_improved.value())) { + if (to_improved != nullptr && IsShardingStrictlyBetter(from, *to_improved)) { return from; } // We don't want to propagate tile maximal shardings. @@ -115,7 +114,7 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } // Any sharding is better then no sharding. - if (!to_improved.has_value()) { + if (to_improved == nullptr) { return from; } // We don't want to propagate manual shardings. @@ -123,18 +122,17 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } int64_t sharding_tiles = from.NumTiles(); - if (hlo_sharding_util::MergeSharding(to_improved.value(), &from, + if (hlo_sharding_util::MergeSharding(*to_improved, &from, may_combine_partial_sharding)) { // Override existing tiled sharding only when the new sharding is compatible // with the existing one. This avoids unexpected resharding when `sharding` // just has more tiles than existing sharding but they are not mergeable. if (!allow_aggressive_resharding && to_improved_shape.IsArray() && - !to_improved.value().IsTileMaximal() && - from.NumTiles() == sharding_tiles) { - if (!hlo_sharding_util::IsSubTilingOrEqualSharding( - to_improved_shape, from, to_improved.value())) { + !to_improved->IsTileMaximal() && from.NumTiles() == sharding_tiles) { + if (!hlo_sharding_util::IsSubTilingOrEqualSharding(to_improved_shape, + from, *to_improved)) { VLOG(10) << "Not merging because of different device distribution"; - VLOG(10) << "Instr sharding: " << to_improved.value().ToString(); + VLOG(10) << "Instr sharding: " << to_improved->ToString(); VLOG(10) << "New sharding " << from.ToString(); return std::nullopt; } @@ -150,10 +148,8 @@ std::optional ReturnImprovedSharding( bool may_combine_partial_sharding, bool allow_aggressive_resharding = false) { return ReturnImprovedShardingImpl( - sharding, - instruction->has_sharding() - ? std::optional(instruction->sharding()) - : std::nullopt, + std::move(sharding), + instruction->has_sharding() ? &instruction->sharding() : nullptr, instruction->shape(), may_combine_partial_sharding, allow_aggressive_resharding); } @@ -164,14 +160,20 @@ std::optional ReturnImprovedSubSharding( HloSharding sharding, HloInstruction* instruction, const ShapeIndex& index, bool may_combine_partial_sharding, bool allow_aggressive_resharding = false) { - return ReturnImprovedShardingImpl( - sharding, - instruction->has_sharding() - ? std::optional(instruction->sharding().GetSubSharding( - instruction->shape(), index)) - : std::nullopt, - ShapeUtil::GetSubshape(instruction->shape(), index), - may_combine_partial_sharding, allow_aggressive_resharding); + if (instruction->has_sharding()) { + const HloSharding to_improved = + instruction->sharding().GetSubSharding(instruction->shape(), index); + return ReturnImprovedShardingImpl( + std::move(sharding), &to_improved, + ShapeUtil::GetSubshape(instruction->shape(), index), + may_combine_partial_sharding, allow_aggressive_resharding); + + } else { + return ReturnImprovedShardingImpl( + std::move(sharding), nullptr, + ShapeUtil::GetSubshape(instruction->shape(), index), + may_combine_partial_sharding, allow_aggressive_resharding); + } } // Updates the sharding of the specified instruction with the specified sharding From c8ec05fb85a423fb1561f086b9072976b389a0d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 16:44:39 -0800 Subject: [PATCH 218/381] [XLA] Optimize HloValue::ComputeUses(). PiperOrigin-RevId: 586489430 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/hlo_value.cc | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index aef5a3f3e428dd..8ed5f82e417593 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4295,6 +4295,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/hlo_value.cc b/third_party/xla/xla/service/hlo_value.cc index 6f2d759067787c..9969d94ab16e9e 100644 --- a/third_party/xla/xla/service/hlo_value.cc +++ b/third_party/xla/xla/service/hlo_value.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -169,9 +170,12 @@ HloValue::Uses HloValue::ComputeUses() const { Uses uses; // Build vector of HloUses for the value. for (const HloPosition& position : positions_) { - for (HloInstruction* user : position.instruction->users()) { - for (int64_t i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) != position.instruction) { + for (HloInstruction* const user : position.instruction->users()) { + int i = -1; + for (const auto& operand : user->operands()) { + ++i; + + if (operand != position.instruction) { continue; } @@ -181,10 +185,12 @@ HloValue::Uses HloValue::ComputeUses() const { root_positions.contains(user)) { HloUse new_use{user, i, position.index}; +#ifndef NDEBUG // The new use must not already exist in uses. for (const HloUse& use : uses) { DCHECK_NE(use, new_use); } +#endif // NDEBUG uses.push_back(std::move(new_use)); } From c7e84a2cd09400d537a39c63e134fb7d4ea7665a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 29 Nov 2023 16:53:21 -0800 Subject: [PATCH 219/381] [xla:gpu] Own command buffer allocations at a Thunk level PiperOrigin-RevId: 586491456 --- .../xla/xla/service/gpu/runtime3/BUILD | 20 +++++ .../runtime3/command_buffer_allocations.cc | 67 +++++++++++++++ .../gpu/runtime3/command_buffer_allocations.h | 51 ++++++++++++ .../gpu/runtime3/command_buffer_cmd.cc | 83 +++++-------------- .../service/gpu/runtime3/command_buffer_cmd.h | 2 + .../gpu/runtime3/command_buffer_thunk.cc | 15 +--- .../gpu/runtime3/command_buffer_thunk.h | 34 ++++---- .../gpu/runtime3/command_buffer_thunk_test.cc | 2 +- .../xla/xla/stream_executor/command_buffer.cc | 10 +-- .../xla/xla/stream_executor/command_buffer.h | 13 +-- .../stream_executor/gpu/gpu_command_buffer.cc | 30 ++----- .../stream_executor/gpu/gpu_command_buffer.h | 8 +- .../stream_executor_internal.h | 7 +- 13 files changed, 198 insertions(+), 144 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc create mode 100644 third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index bb44a6442f7e7e..76a3e04dbc68f6 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -16,12 +16,31 @@ package_group( # Command Buffer Integration #===-------------------------------------------------------------------------------------------===// +cc_library( + name = "command_buffer_allocations", + srcs = ["command_buffer_allocations.cc"], + hdrs = ["command_buffer_allocations.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:status", + "//xla:statusor", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "command_buffer_cmd", srcs = ["command_buffer_cmd.cc"], hdrs = ["command_buffer_cmd.h"], visibility = ["//visibility:public"], deps = [ + ":command_buffer_allocations", "//xla:status", "//xla:statusor", "//xla:types", @@ -114,6 +133,7 @@ cc_library( hdrs = ["command_buffer_thunk.h"], visibility = ["//visibility:public"], deps = [ + ":command_buffer_allocations", ":command_buffer_cmd", "//xla:status", "//xla:statusor", diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc new file mode 100644 index 00000000000000..0aa46ffec10f29 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/buffer_assignment.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +StatusOr CommandBufferAllocations::GetDeviceAddress( + BufferAllocation::Slice buffer_slice) const { + auto base = allocs_.find(buffer_slice.index()); + if (base == allocs_.end()) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + buffer_slice.index(), + " was not allocated")); + } + + char* ptr = static_cast(const_cast(base->second.opaque())); + return se::DeviceMemoryBase(ptr + buffer_slice.offset(), buffer_slice.size()); +} + +Status CommandBufferAllocations::AddAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) { + VLOG(2) << "Add comand buffer allocation: index=" << index + << "; ptr=" << memory.opaque(); + + auto emplaced = allocs_.try_emplace(index, std::move(memory)); + if (emplaced.second == false) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + index, " was already allocated")); + } + return OkStatus(); +} + +Status CommandBufferAllocations::EraseAllocation( + BufferAllocation::Index index) { + VLOG(2) << "Erase comand buffer allocation: index=" << index; + + if (allocs_.erase(index) == 0) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + index, " was not allocated")); + } + return OkStatus(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h new file mode 100644 index 00000000000000..3435dc0d69434c --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ +#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +// Command buffer allocations tracks external buffer allocations done via the +// CommandBuffer API and owned by the XLA executable (via instantiated command +// buffers and memory allocation Gpu graph nodes). +class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { + public: + StatusOr GetDeviceAddress( + BufferAllocation::Slice buffer_slice) const override; + + // Adds an external allocation for a given buffer index. Returns error if + // allocation already exists. + Status AddAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory); + + // Erases an external allocation for a given buffer index. Returns error if + // allocation does not exists. + Status EraseAllocation(BufferAllocation::Index index); + + private: + absl::flat_hash_map allocs_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index ebe03fb5b719c4..854f0bd790e404 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -32,7 +31,6 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" @@ -44,47 +42,6 @@ limitations under the License. namespace xla::gpu { -//===----------------------------------------------------------------------===// -// External buffer allocations managed by a command buffer. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): External allocations should be managed by the command buffer -// thunk, as the command buffer itself should not know anything about XLA buffer -// allocation index or buffer slice. - -namespace { -class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { - public: - explicit CommandBufferAllocations(se::CommandBuffer* command_buffer); - - StatusOr GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const override; - - private: - se::CommandBuffer* command_buffer_; -}; -} // namespace - -CommandBufferAllocations::CommandBufferAllocations( - se::CommandBuffer* command_buffer) - : command_buffer_(command_buffer) {} - -StatusOr CommandBufferAllocations::GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const { - TF_ASSIGN_OR_RETURN( - auto base, command_buffer_->GetAllocationAddress(buffer_slice.index())); - - if (base.is_null()) { - return absl::InternalError(absl::StrCat("Memory for a buffer #", - buffer_slice.index(), - " is not yet allocated")); - } - - return se::DeviceMemoryBase( - static_cast(base.opaque()) + buffer_slice.offset(), - buffer_slice.size()); -} - //===----------------------------------------------------------------------===// // CommandBufferCmdSequence //===----------------------------------------------------------------------===// @@ -172,13 +129,11 @@ Status LaunchCmd::Record(const RecordParams& params, "Kernel not loaded on a command buffer executor"); } - CommandBufferAllocations external_allocations(command_buffer); - absl::InlinedVector buffers; for (const BufferAllocation::Slice& arg : args_) { - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase buf, - params.buffer_allocations->GetDeviceAddress(arg, external_allocations)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buf, + params.buffer_allocations->GetDeviceAddress( + arg, *params.command_buffer_allocations)); VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); buffers.push_back(buf); } @@ -213,14 +168,12 @@ Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << ", src=" << src_ << ", num_bytes=" << num_bytes_; - CommandBufferAllocations external_allocations(command_buffer); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase dst, - params.buffer_allocations->GetDeviceAddress(dst_, external_allocations)); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase src, - params.buffer_allocations->GetDeviceAddress(src_, external_allocations)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase dst, + params.buffer_allocations->GetDeviceAddress( + dst_, *params.command_buffer_allocations)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase src, + params.buffer_allocations->GetDeviceAddress( + src_, *params.command_buffer_allocations)); return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); } @@ -272,8 +225,14 @@ Status AllocateCmd::Record(const RecordParams& params, // Memory allocation address is returned on graph creation, and there is no // update operation VLOG(5) << "AllocationCmd: index=" << allocation_->index(); - return command_buffer->Allocate(se::CommandBuffer::AllocIndexSize{ - allocation_->index(), static_cast(allocation_->size())}); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, + command_buffer->Allocate(allocation_->size())); + + TF_RETURN_IF_ERROR(params.command_buffer_allocations->AddAllocation( + allocation_->index(), buffer)); + + return OkStatus(); } CommandBufferCmd::Slices AllocateCmd::slices() { return {}; } @@ -308,17 +267,15 @@ Status GemmCmd::Record(const RecordParams& params, se::DeviceMemoryBase workspace(nullptr, 0); - CommandBufferAllocations external_allocations(command_buffer); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs, params.buffer_allocations->GetDeviceAddress( - lhs_buffer_, external_allocations)); + lhs_buffer_, *params.command_buffer_allocations)); TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs, params.buffer_allocations->GetDeviceAddress( - rhs_buffer_, external_allocations)); + rhs_buffer_, *params.command_buffer_allocations)); TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase out, params.buffer_allocations->GetDeviceAddress( - output_buffer_, external_allocations)); + output_buffer_, *params.command_buffer_allocations)); TF_ASSIGN_OR_RETURN( auto nested_buffer, diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index ab894bb7952747..699c9a43283fb4 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" #include "xla/stream_executor/command_buffer.h" @@ -61,6 +62,7 @@ class CommandBufferCmd { struct RecordParams { se::StreamExecutor* executor; const BufferAllocations* buffer_allocations; + CommandBufferAllocations* command_buffer_allocations; }; // Prepares a command for recording on a given executor. We split it into a diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index b621a3f9c0f38d..3638935f2c0c9e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -76,11 +76,11 @@ Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, GetOrCreateCommandBuffer(executor)); - CommandBufferCmd::RecordParams record_params = {executor, - params.buffer_allocations}; - absl::MutexLock lock(&cmd_buffer->mutex); + CommandBufferCmd::RecordParams record_params = { + executor, params.buffer_allocations, &cmd_buffer->allocations}; + if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { TF_RETURN_IF_ERROR( commands_.Record(record_params, &cmd_buffer->command_buffer)); @@ -105,13 +105,4 @@ CommandBufferThunk::GetOrCreateCommandBuffer(se::StreamExecutor* executor) { return &emplaced.first->second; } -StatusOr CommandBufferThunk::GetLazyAllocationAddress( - const ExecuteParams& params, int64_t index) { - se::StreamExecutor* executor = params.stream->parent(); - TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, - GetOrCreateCommandBuffer(executor)); - absl::MutexLock lock(&cmd_buffer->mutex); - return cmd_buffer->command_buffer.GetAllocationAddress(index); -} - } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h index 47d7298545a1ff..e5e2a3ed6abda1 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/node_hash_map.h" #include "absl/synchronization/mutex.h" +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" @@ -56,16 +57,25 @@ class CommandBufferThunk : public Thunk { absl::Mutex mutex; se::CommandBuffer command_buffer ABSL_GUARDED_BY(mutex); - // Mapping from buffer allocation index to the device memory passed at that - // index to the last call of `commands_.Record(...)` for `command_buffer`. - // We can just use a vector instead of map because `BufferAllocation::Index` - // is a unique identifier assigned contiguously and thus can be used as - // array index. + // TODO(ezhulenev): We need to move command buffer allocations all the way + // up to the GpuExecutable as we can have Allocate and Free commands in + // different command buffers. Consider making it a part of + // BufferAllocations (as std::unique_ptr member). + + // Memory allocations performed by a `command_buffer`. + CommandBufferAllocations allocations ABSL_GUARDED_BY(mutex); + + // Mapping from buffer allocation index to the device memory passed at + // that index to the last call of `commands_.Record(...)` for + // `command_buffer`. We can just use a vector instead of map because + // `BufferAllocation::Index` is a unique identifier assigned + // contiguously and thus can be used as array index. // - // If no device memory addresses changed from a previous call to `Record`, - // we can skip command buffer update and simply submit it for execution on a - // stream. All other pieces of information (like thread and block sizes) - // captured by commands at construction time and do not change. + // If no device memory addresses changed from a previous call to + // `Record`, we can skip command buffer update and simply submit it for + // execution on a stream. All other pieces of information (like thread + // and block sizes) captured by commands at construction time and do not + // change. std::vector recorded_allocs ABSL_GUARDED_BY(mutex); }; @@ -73,12 +83,6 @@ class CommandBufferThunk : public Thunk { StatusOr GetOrCreateCommandBuffer( se::StreamExecutor* executor); - // Return the allocation address that was lazilly allocated inside command - // buffer. This API is required when the buffers are allocated inside command - // buffer but will be consumed by non-command buffer operations. - StatusOr GetLazyAllocationAddress( - const ExecuteParams& params, int64_t index); - // Command sequence that initializes command buffers on each executor. CommandBufferCmdSequence commands_; diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index f1569fc6f3d560..f976cdecdbf36b 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -110,7 +110,7 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer. // 4. MemCopyDEviceToDevice from "b" to "c" inside command buffer. // 5. Verify that region "c" has the same content as "a". -TEST(CommandBufferThunkTest, MemallocCmd) { +TEST(CommandBufferThunkTest, AllocateCmd) { se::StreamExecutor* executor = CudaExecutor(); se::Stream stream(executor); diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 734a3faeb6f946..551a667843d818 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -137,13 +138,8 @@ tsl::Status CommandBuffer::Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, return implementation_->Memset(dst, bit_pattern, num_elements); } -tsl::Status CommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { - return implementation_->Allocate(alloc); -} - -tsl::StatusOr CommandBuffer::GetAllocationAddress( - int64_t index) const { - return implementation_->GetAllocationAddress(index); +tsl::StatusOr CommandBuffer::Allocate(size_t bytes) { + return implementation_->Allocate(bytes); } tsl::Status CommandBuffer::If(StreamExecutor* executor, DeviceMemory pred, diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 54349127dcfeed..647d51283f8eab 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -55,10 +55,6 @@ class CommandBuffer { public: // Builder constructs nested command buffers owned by a parent command buffer. using Builder = std::function; - struct AllocIndexSize { - int64_t index; - uint64_t size; - }; ~CommandBuffer(); CommandBuffer(CommandBuffer&&); @@ -179,13 +175,8 @@ class CommandBuffer { //--------------------------------------------------------------------------// - // Adds a device memory allocation command to the command buffer, allocated - // address is tracked by command buffer runtime. - tsl::Status Allocate(AllocIndexSize alloc); - - // Get the device address for allocations previously allocated through - // Allocate command. - tsl::StatusOr GetAllocationAddress(int64_t index) const; + // Adds a device memory allocation command to the command buffer. + tsl::StatusOr Allocate(size_t bytes); // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 0742b62c994d7b..643ae4879df369 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -302,7 +302,7 @@ tsl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { +tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { TF_RETURN_IF_ERROR(CheckNotFinalized()); if (state_ == State::kCreate) { @@ -314,44 +314,30 @@ tsl::Status GpuCommandBuffer::Allocate(CommandBuffer::AllocIndexSize alloc) { node, graph_, absl::MakeSpan(deps), GpuDriver::MemAccessFlags::kReadWrite, GpuDriver::MemLocationType::kDevice, parent_->device_ordinal(), - GpuDriver::MemAllocationType::kPinned, alloc.size, &ptr)); + GpuDriver::MemAllocationType::kPinned, bytes, &ptr)); // For CUDA impl, VA range is reserved when adding memory allocation node. CHECK(ptr) << "CUDA graph memory allocation node returned nullptr"; VLOG(2) << "Setting device memory base with opaque pointer " << reinterpret_cast(ptr) << " device ordinal: " << parent_->device_ordinal(); - allocations_map_[alloc.index] = - DeviceMemoryBase{reinterpret_cast(ptr), alloc.size}; - return tsl::OkStatus(); + return DeviceMemoryBase(reinterpret_cast(ptr), bytes); } if (state_ == State::kUpdate) { // Memory allocation node implemented through CUDA graph does not allocate // new memory region on update, just return the memory region allocated // during the create step. - TF_ASSIGN_OR_RETURN( - AllocationResult params, - GpuDriver::GraphGetMemAllocNodeParams(nodes_[update_state_.node_idx])); - update_state_.node_idx++; - allocations_map_[alloc.index] = - DeviceMemoryBase{reinterpret_cast(params.first), params.second}; - return tsl::OkStatus(); + TF_ASSIGN_OR_RETURN(AllocationResult params, + GpuDriver::GraphGetMemAllocNodeParams( + nodes_[update_state_.node_idx++])); + return DeviceMemoryBase(reinterpret_cast(params.first), + params.second); } return UnsupportedStateError(state_); } -tsl::StatusOr GpuCommandBuffer::GetAllocationAddress( - int64_t index) const { - if (allocations_map_.contains(index)) { - return allocations_map_.at(index); - } else { - return absl::InternalError( - absl::StrCat("Allocation is not yet allocated: ", index)); - } -} - //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 3a5269f283b81d..5223e9d24e7cb5 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -63,10 +63,7 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { CommandBuffer::BitPattern bit_pattern, size_t num_elements) override; - tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) override; - - tsl::StatusOr GetAllocationAddress( - int64_t index) const override; + tsl::StatusOr Allocate(size_t bytes) override; tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) override; @@ -222,9 +219,6 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // error. tsl::Status CheckPrimary(); - // Keep tracks of allocations that is performed by allocation command. - absl::flat_hash_map allocations_map_; - // Returns OK status if the number of command buffers is equal to the expected // one, otherwise returns internal error. tsl::Status CheckNumCommandBuffers( diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 62289f7cedacf0..8c17dbdc155603 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -147,12 +147,7 @@ class CommandBufferInterface { size_t num_elements) = 0; // Adds a device memory allocation node to the command buffer. - virtual tsl::Status Allocate(CommandBuffer::AllocIndexSize alloc) = 0; - - // Get the device address for allocations performed through command buffer - // Allocate command. - virtual tsl::StatusOr GetAllocationAddress( - int64_t index) const = 0; + virtual tsl::StatusOr Allocate(size_t bytes) = 0; // For all conditional command APIs defined below, nested command buffers // constructed for conditional branches owned by *this and should never be From 5dd7256797aa29e9d92801e51547ee33fa4a9d7d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 17:09:31 -0800 Subject: [PATCH 220/381] Integrate LLVM at llvm/llvm-project@f688e0901213 Updates LLVM usage to match [f688e0901213](https://github.com/llvm/llvm-project/commit/f688e0901213) PiperOrigin-RevId: 586494799 --- third_party/llvm/generated.patch | 59 ++++++-------------------------- third_party/llvm/workspace.bzl | 4 +-- 2 files changed, 13 insertions(+), 50 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 51af6376f76b2e..ce1937af46e5d5 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,49 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/compiler-rt/lib/msan/msan_interceptors.cpp b/compiler-rt/lib/msan/msan_interceptors.cpp ---- a/compiler-rt/lib/msan/msan_interceptors.cpp -+++ b/compiler-rt/lib/msan/msan_interceptors.cpp -@@ -244,20 +244,23 @@ - #endif - - #if !SANITIZER_FREEBSD && !SANITIZER_NETBSD --// This function actually returns a struct by value, but we can't unpoison a --// temporary! The following is equivalent on all supported platforms but --// aarch64 (which uses a different register for sret value). We have a test --// to confirm that. --INTERCEPTOR(void, mallinfo, __sanitizer_struct_mallinfo *sret) { --#ifdef __aarch64__ -- uptr r8; -- asm volatile("mov %0,x8" : "=r" (r8)); -- sret = reinterpret_cast<__sanitizer_struct_mallinfo*>(r8); --#endif -- REAL(memset)(sret, 0, sizeof(*sret)); -+ -+template -+static NOINLINE void clear_mallinfo(T *sret) { -+ ENSURE_MSAN_INITED(); -+ internal_memset(sret, 0, sizeof(*sret)); - __msan_unpoison(sret, sizeof(*sret)); - } --#define MSAN_MAYBE_INTERCEPT_MALLINFO INTERCEPT_FUNCTION(mallinfo) -+ -+// Interceptor relies on NRVO and assumes that sret will be pre-allocated in -+// caller frame. -+INTERCEPTOR(__sanitizer_struct_mallinfo, mallinfo) { -+ __sanitizer_struct_mallinfo sret; -+ clear_mallinfo(&sret); -+ return sret; -+} -+ -+# define MSAN_MAYBE_INTERCEPT_MALLINFO INTERCEPT_FUNCTION(mallinfo) - #else - #define MSAN_MAYBE_INTERCEPT_MALLINFO - #endif -diff -ruN --strip-trailing-cr a/compiler-rt/test/msan/Linux/mallinfo.cpp b/compiler-rt/test/msan/Linux/mallinfo.cpp ---- a/compiler-rt/test/msan/Linux/mallinfo.cpp -+++ b/compiler-rt/test/msan/Linux/mallinfo.cpp -@@ -1,5 +1,4 @@ - // RUN: %clangxx_msan -O0 -g %s -o %t && %run %t --// UNSUPPORTED: aarch64-target-arch - - #include - #include +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +@@ -594,6 +594,7 @@ + name = "__support_bit", + hdrs = ["src/__support/bit.h"], + deps = [ ++ ":__support_cpp_type_traits", + ":__support_macros_attributes", + ], + ) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index e5364bd9ec0ab3..62f02cded785c5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "3287ae8f6520ef81570377c1fb4c7147782a13ef" - LLVM_SHA256 = "87c55be01fb53ab0f2ce03bca419a5b08393247c28ecc8b1facd78bd3e7614da" + LLVM_COMMIT = "f688e0901213726feb9b26cedc61919413cbf59c" + LLVM_SHA256 = "b8885c22a9b77f9c91a316b21d71414a7b48dae38513f170da1554002e85b030" tf_http_archive( name = name, From a492d5819d8ab8ef1393e6e90dbb38fdca6bd36d Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Wed, 29 Nov 2023 17:29:00 -0800 Subject: [PATCH 221/381] [stream_executor] Use new CUDA runtime API for TopK PiperOrigin-RevId: 586498567 --- third_party/xla/xla/service/gpu/runtime/BUILD | 12 ++ .../xla/xla/service/gpu/runtime/topk.cc | 6 +- .../xla/service/gpu/runtime/topk_kernel.cc | 96 +++++----- .../xla/xla/service/gpu/runtime/topk_kernel.h | 11 +- .../service/gpu/runtime/topk_kernel_test.cc | 165 ++++++++++-------- third_party/xla/xla/stream_executor/stream.h | 17 ++ 6 files changed, 170 insertions(+), 137 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index c154f5956af650..57548606ea9e31 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -358,8 +358,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gpu_kernel_helper", + ":support", "//xla:shape_util", + "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/runtime:memref_view", "//xla/stream_executor", # build_cleaner: keep "//xla/stream_executor:platform", "//xla/stream_executor/gpu:gpu_stream_header", @@ -367,6 +371,9 @@ cc_library( "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ ":topk_kernel_cuda", ]) + if_rocm_is_configured([ @@ -422,11 +429,16 @@ xla_cc_test( ":gpu_kernel_helper", ":topk_kernel", "//xla:xla_data_proto_cc", + "//xla/stream_executor", # build_cleaner: keep + "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/log:check", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", diff --git a/third_party/xla/xla/service/gpu/runtime/topk.cc b/third_party/xla/xla/service/gpu/runtime/topk.cc index 385dad8b865be1..d21d08e98d88f7 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk.cc @@ -47,9 +47,9 @@ static absl::Status TopkImpl(const ServiceExecutableRunOptions* run_options, size_t batch_size = has_batch ? data.sizes[0] : 1; size_t n = has_batch ? data.sizes[1] : data.sizes[0]; size_t k = has_batch ? top_elements.sizes[1] : top_elements.sizes[0]; - return RunTopk(se::gpu::AsGpuStreamValue(run_options->stream()), data.dtype, - data.data, n, top_elements.data, - static_cast(indices.data), k, batch_size); + return RunTopk(run_options->stream(), data.dtype, GetDeviceAddress(data), n, + GetDeviceAddress(top_elements), GetDeviceAddress(indices), k, + batch_size); } XLA_RUNTIME_DEFINE_CUSTOM_CALL( diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc b/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc index f241f5e4cbf084..7a65107ea9e6eb 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc @@ -20,21 +20,25 @@ limitations under the License. #include "xla/service/gpu/runtime/topk_kernel.h" #include +#include #include "absl/numeric/bits.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" +#include "Eigen/Core" // from @eigen_archive #include "xla/primitive_util.h" #include "xla/service/gpu/runtime/gpu_kernel_helper.h" #include "xla/service/gpu/runtime/topk_kernel_common.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { namespace { -using ::stream_executor::gpu::GpuStreamHandle; +using se::gpu::GpuStreamHandle; size_t NumThreads(size_t n, size_t k, size_t batch_size) { // Estimate number of threads per block that can run concurrently given the @@ -47,38 +51,6 @@ size_t NumThreads(size_t n, size_t k, size_t batch_size) { return std::min(threads_per_block, min_slice); } -// Helper type for converting the untyped arguments of RunTopk to TypedTopk -template -struct TopkArgs { - TopkArgs(GpuStreamHandle stream, PrimitiveType dtype, T* data, - size_t num_elements, T* top_elements, uint32_t* top_indices, - size_t k, size_t batch_size) - : stream(stream), - dtype(dtype), - data(data), - num_elements(num_elements), - top_elements(top_elements), - top_indices(top_indices), - k(k), - batch_size(batch_size) {} - - template - TopkArgs Convert() const { - return TopkArgs(stream, dtype, static_cast(data), num_elements, - static_cast(top_elements), top_indices, k, - batch_size); - } - - GpuStreamHandle stream; - PrimitiveType dtype; - T* data; - size_t num_elements; - T* top_elements; - uint32_t* top_indices; - size_t k; - size_t batch_size; -}; - template absl::StatusOr GetKernel(int n, int k) { if (k <= 1) return GetTopKKernelForK(n); @@ -90,44 +62,56 @@ absl::StatusOr GetKernel(int n, int k) { } template -absl::Status TypedTopK(TopkArgs args) { - int num_threads = NumThreads(args.num_elements, args.k, args.batch_size); +absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, + size_t num_elements, se::DeviceMemoryBase top_elements, + se::DeviceMemoryBase top_indices, size_t k, + size_t batch_size) { + constexpr size_t max_kv_size = sizeof(uint64_t); + // Allocate shmem assuming we have a full reduction. + int shmem_size = absl::bit_ceil(k) * max_kv_size * WAVEFRONT_SIZE; + int num_threads = NumThreads(num_elements, k, batch_size); if (num_threads == 0) { return absl::FailedPreconditionError( "Invalid kernel parameters. This is likely a bug in the " "TopkSpecializer."); } + se::StreamExecutor* executor = stream->parent(); + se::DeviceMemory data_typed(data); + se::DeviceMemory top_elements_typed(top_elements); + se::DeviceMemory top_indices_typed(top_indices); + + TF_ASSIGN_OR_RETURN(void* kernel_symbol, GetKernel(num_elements, k)); + TF_ASSIGN_OR_RETURN( + auto kernel, + (executor + ->CreateTypedKernel, size_t, se::DeviceMemory, + se::DeviceMemory, size_t>( + "topk", kernel_symbol))); + + TF_RETURN_IF_ERROR(stream->ThenLaunch( + se::ThreadDim(num_threads, 1, 1), se::BlockDim(batch_size, 1, 1), + shmem_size, *kernel, data_typed, num_elements, top_elements_typed, + top_indices_typed, k)); - TF_ASSIGN_OR_RETURN(void* kernel, GetKernel(args.num_elements, args.k)); - int blocks_per_grid = args.batch_size; - constexpr size_t max_kv_size = sizeof(uint64_t); - // Allocate shmem assuming we have a full reduction. - int shmem_size = absl::bit_ceil(args.k) * max_kv_size * WAVEFRONT_SIZE; - void* kernel_args[] = {&args.data, &args.num_elements, &args.top_elements, - &args.top_indices, &args.k}; - auto launch_status = gpuLaunchKernel(kernel, blocks_per_grid, num_threads, - kernel_args, shmem_size, args.stream); - if (launch_status != gpuSuccess) { - return absl::InternalError(absl::StrCat("Failed to launch kernel: ", - gpuGetErrorString(launch_status))); - } return absl::OkStatus(); } } // namespace -absl::Status RunTopk(GpuStreamHandle stream, PrimitiveType dtype, void* data, - size_t num_elements, void* top_elements, - uint32_t* top_indices, size_t k, size_t batch_size) { +absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, + se::DeviceMemoryBase data, size_t num_elements, + se::DeviceMemoryBase top_elements, + se::DeviceMemoryBase top_indices, size_t k, + size_t batch_size) { VLOG(2) << "TopK: " << primitive_util::LowercasePrimitiveTypeName(dtype) << ", n: " << num_elements << ", k: " << k << ", bs: " << batch_size; - auto args = TopkArgs(stream, dtype, data, num_elements, top_elements, - top_indices, k, batch_size); switch (dtype) { case PrimitiveType::F32: - return TypedTopK(args.Convert()); + return TypedTopK(stream, data, num_elements, top_elements, + top_indices, k, batch_size); case PrimitiveType::BF16: - return TypedTopK(args.Convert()); + return TypedTopK( + stream, data, num_elements, top_elements, top_indices, k, batch_size); default: return absl::UnimplementedError("GpuTopK not implemented for this dtype"); } diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.h b/third_party/xla/xla/service/gpu/runtime/topk_kernel.h index 6afd56d35fe9d6..5be0d121d4f31b 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.h +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel.h @@ -20,9 +20,11 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -33,9 +35,10 @@ namespace xla::gpu { // - top_indices: [batch_size, k] u32 // Where `top_elements` contains the largest elements of the input, and // `top_indices` their original indices. -absl::Status RunTopk(::tensorflow::se::gpu::GpuStreamHandle stream, - PrimitiveType dtype, void* data, size_t num_elements, - void* top_elements, uint32_t* top_indices, size_t k, +absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, + se::DeviceMemoryBase data, size_t num_elements, + se::DeviceMemoryBase top_elements, + se::DeviceMemoryBase top_indices, size_t k, size_t batch_size); } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc index 73a8e5f0356c28..9bc2ccff274f8e 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc @@ -26,9 +26,14 @@ limitations under the License. #include "absl/log/check.h" #include "absl/random/random.h" #include "absl/strings/substitute.h" +#include "absl/time/time.h" #include "Eigen/Core" // from @eigen_archive #include "xla/service/gpu/runtime/gpu_kernel_helper.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/stream.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -36,45 +41,31 @@ limitations under the License. namespace xla::gpu { namespace { -using ::stream_executor::gpu::GpuStreamHandle; +using se::gpu::GpuStreamHandle; using ::testing::Combine; using ::testing::Values; -#define CUDA_CHECK(s) \ - do { \ - CHECK_EQ(s, gpuSuccess) << gpuGetErrorString(s); \ - } while (0) - -template -T* AllocateGpuBuffer(int num_elements) { - void* buffer; - CUDA_CHECK(gpuMalloc(&buffer, num_elements * sizeof(T))); - return static_cast(buffer); -} - template -std::vector RandomFillRange(void* buffer, int num_elements, T start, T end) { +std::vector RandomVecRange(int num_elements, T start, T end) { std::vector local; local.reserve(num_elements); thread_local absl::BitGen gen; for (int i = 0; i < num_elements; ++i) { local.push_back(absl::Uniform(gen, start, end)); } - CUDA_CHECK(gpuMemcpy(buffer, local.data(), num_elements * sizeof(T), - gpuMemcpyHostToDevice)); return local; } template -std::vector RandomFill(void* buffer, int num_elements) { - return RandomFillRange(buffer, num_elements, static_cast(0), - static_cast(num_elements)); +std::vector RandomVec(int num_elements) { + return RandomVecRange(num_elements, static_cast(0), + static_cast(num_elements)); } template -std::vector RandomFillNegative(void* buffer, int num_elements) { - return RandomFillRange(buffer, num_elements, -static_cast(num_elements), - static_cast(0)); +std::vector RandomVecNegative(int num_elements) { + return RandomVecRange(num_elements, -static_cast(num_elements), + static_cast(0)); } PrimitiveType Get(float) { return PrimitiveType::F32; } @@ -92,62 +83,84 @@ using TopkTest = ::testing::TestWithParam>; // utilities to simplify the test logic. TEST_P(TopkTest, TopKFloat) { using T = float; + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + const auto [n_kb, k, batch_size, offset] = GetParam(); const size_t n = n_kb * 1024 + offset; - T* input_buffer = AllocateGpuBuffer(n * batch_size); - auto source = RandomFill(input_buffer, n * batch_size); - T* output_values = AllocateGpuBuffer(k * batch_size); - auto* output_indices = - static_cast(AllocateGpuBuffer(k * batch_size)); - GpuStreamHandle stream; - CUDA_CHECK(gpuStreamCreate(&stream)); - ASSERT_TRUE(RunTopk(stream, Get(T()), input_buffer, n, output_values, + + se::DeviceMemory input_buffer = + executor->AllocateArray(n * batch_size, 0); + se::DeviceMemory output_values = + executor->AllocateArray(k * batch_size, 0); + se::DeviceMemory output_indices = + executor->AllocateArray(k * batch_size, 0); + + auto source = RandomVec(n * batch_size); + stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); + + ASSERT_TRUE(RunTopk(&stream, Get(T()), input_buffer, n, output_values, output_indices, k, batch_size) .ok()); std::vector got(k); - CUDA_CHECK(gpuStreamSynchronize(stream)); + ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - CUDA_CHECK(gpuMemcpy(got.data(), &output_values[k * i], k * sizeof(T), - gpuMemcpyDeviceToHost)); + stream.ThenMemcpy(got.data(), + executor->GetSubBuffer(&output_values, k * i, k), + k * sizeof(T)); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) - << " k=" << k << ", batch_size=" << batch_size; + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; } - CUDA_CHECK(gpuFree(input_buffer)); - CUDA_CHECK(gpuFree(output_indices)); - CUDA_CHECK(gpuFree(output_values)); } TEST_P(TopkTest, TopKPackedNegative) { using T = float; + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + const auto [n_kb, k, batch_size, offset] = GetParam(); const size_t n = n_kb * 1024 + offset; - T* input_buffer = AllocateGpuBuffer(n * batch_size); - auto source = RandomFillNegative(input_buffer, n * batch_size); - T* output_values = AllocateGpuBuffer(k * batch_size); - auto* output_indices = - static_cast(AllocateGpuBuffer(k * batch_size)); - GpuStreamHandle stream; - CUDA_CHECK(gpuStreamCreate(&stream)); - ASSERT_TRUE(RunTopk(stream, Get(T()), input_buffer, n, output_values, + + se::DeviceMemory input_buffer = + executor->AllocateArray(n * batch_size, 0); + se::DeviceMemory output_values = + executor->AllocateArray(k * batch_size, 0); + se::DeviceMemory output_indices = + executor->AllocateArray(k * batch_size, 0); + + auto source = RandomVecNegative(n * batch_size); + stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); + + ASSERT_TRUE(RunTopk(&stream, Get(T()), input_buffer, n, output_values, output_indices, k, batch_size) .ok()); std::vector got(k); - CUDA_CHECK(gpuStreamSynchronize(stream)); + ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - CUDA_CHECK(gpuMemcpy(got.data(), &output_values[k * i], k * sizeof(T), - gpuMemcpyDeviceToHost)); + stream.ThenMemcpy(got.data(), + executor->GetSubBuffer(&output_values, k * i, k), + k * sizeof(T)); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); slice.resize(k); EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) - << " k=" << k << ", batch_size=" << batch_size; + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; } - CUDA_CHECK(gpuFree(input_buffer)); - CUDA_CHECK(gpuFree(output_indices)); - CUDA_CHECK(gpuFree(output_values)); } INSTANTIATE_TEST_SUITE_P(TopkTests, TopkTest, @@ -167,39 +180,43 @@ INSTANTIATE_TEST_SUITE_P(TopkTests, TopkTest, template void BM_SmallTopk(benchmark::State& state) { using T = float; + size_t k = K; size_t batch_size = state.range(0); size_t n = state.range(1) * 1024; state.SetLabel( absl::Substitute("n=$0Ki k=$1 batch_size=$2", n / 1024, k, batch_size)); - void* input_buffer = AllocateGpuBuffer(n * batch_size); - auto source = RandomFill(input_buffer, n); - void* output_values = AllocateGpuBuffer(k); - auto* output_indices = static_cast(AllocateGpuBuffer(k)); - GpuStreamHandle stream; - CUDA_CHECK(gpuStreamCreate(&stream)); + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + se::DeviceMemory input_buffer = + executor->AllocateArray(n * batch_size, 0); + se::DeviceMemory output_values = executor->AllocateArray(k, 0); + se::DeviceMemory output_indices = + executor->AllocateArray(k, 0); + + auto source = RandomVec(n); + stream.ThenMemcpy(&input_buffer, source.data(), n * sizeof(T)); + for (auto _ : state) { - gpuEvent_t start, stop; - CUDA_CHECK(gpuEventCreate(&start)); - CUDA_CHECK(gpuEventCreate(&stop)); - CUDA_CHECK(gpuEventRecord(start, stream)); - CHECK_OK(RunTopk(stream, Get(T()), input_buffer, n, output_values, + auto timer = se::gpu::GpuTimer::Create(se::gpu::AsGpuStream(&stream)); + CHECK_OK(timer.status()); + CHECK_OK(RunTopk(&stream, Get(T()), input_buffer, n, output_values, output_indices, k, batch_size)); - CUDA_CHECK(gpuGetLastError()); - CUDA_CHECK(gpuEventRecord(stop, stream)); - CUDA_CHECK(gpuEventSynchronize(stop)); - float milliseconds = 0; - CUDA_CHECK(gpuEventElapsedTime(&milliseconds, start, stop)); - state.SetIterationTime(static_cast(milliseconds) / 1000); - CUDA_CHECK(gpuEventDestroy(start)); - CUDA_CHECK(gpuEventDestroy(stop)); + CHECK_OK(stream.BlockHostUntilDone()); + auto timer_duration = timer.value().GetElapsedDuration(); + CHECK_OK(timer_duration.status()); + state.SetIterationTime(ToDoubleMicroseconds(timer_duration.value())); } size_t items_processed = batch_size * n * state.iterations(); state.SetItemsProcessed(items_processed); state.SetBytesProcessed(items_processed * sizeof(T)); - CUDA_CHECK(gpuFree(input_buffer)); - CUDA_CHECK(gpuFree(output_values)); - CUDA_CHECK(gpuFree(output_indices)); } BENCHMARK(BM_SmallTopk<1>)->RangePair(1, 512, 16, 1024)->UseManualTime(); diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index abe4806f96e00d..3021720d859801 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -182,6 +182,12 @@ class Stream { tsl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, const TypedKernel &kernel, Args... args); + // Same as above, with an explicit argument for shared memory size in bytes. + template + tsl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, + int32_t shmem_bytes, + const TypedKernel &kernel, Args... args); + // Create a dependency for this stream's next work on the other stream // completing. Does not take ownership of other, and other must not be // null. @@ -1550,6 +1556,17 @@ inline tsl::Status Stream::ThenLaunch(ThreadDim thread_dims, return ::tsl::OkStatus(); } +template +inline tsl::Status Stream::ThenLaunch(ThreadDim thread_dims, + BlockDim block_dims, int32_t shmem_bytes, + const TypedKernel &kernel, + Args... args) { + auto kernel_args = PackKernelArgs(shmem_bytes, args...); + TF_RETURN_IF_ERROR( + parent_->Launch(this, thread_dims, block_dims, kernel, *kernel_args)); + return ::tsl::OkStatus(); +} + template inline tsl::StatusOr>> Stream::AllocateTemporaryArray(uint64_t element_count) { From 961ca34ed9eb267431adfbf70d1b26c3d98c2a2c Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 29 Nov 2023 17:46:48 -0800 Subject: [PATCH 222/381] [xla:gpu] Add the AOT compilation pipeline for thunk runtime #7360 PiperOrigin-RevId: 586502257 --- third_party/xla/xla/service/gpu/BUILD | 7 ++ .../xla/xla/service/gpu/executable.proto | 7 ++ .../service/gpu/gpu_aot_compilation_test.cc | 53 +++++++++ .../xla/xla/service/gpu/gpu_compiler.cc | 107 +++++++++++++----- .../xla/xla/service/gpu/gpu_compiler.h | 3 + .../mhlo_to_lhlo_with_xla.cc | 3 +- 6 files changed, 150 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b4d2f868fcc914..44e20e53fd9496 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3354,13 +3354,20 @@ xla_cc_test( ":amdgpu_compiler_impl", ]) + [ ":gpu_transfer_manager", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", "//xla/service:compiler", + "//xla/service:executable", "//xla/service:gpu_plugin", "//xla/service:platform_util", "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_headers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/executable.proto b/third_party/xla/xla/service/gpu/executable.proto index 1e4d5d81ba3662..3a9e08a8665116 100644 --- a/third_party/xla/xla/service/gpu/executable.proto +++ b/third_party/xla/xla/service/gpu/executable.proto @@ -37,3 +37,10 @@ message XlaRuntimeGpuExecutableProto { // Constants required by the serialized executable. repeated ConstantInfoProto constants = 5; } + +message CompilationResultProto { + HloModuleProto hlo_module = 1; + BufferAssignmentProto buffer_assignment = 2; + string asm_text = 3; + bytes binary = 4; +} diff --git a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc index f2afc80c08c4ad..db5e3d9436d653 100644 --- a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc @@ -14,12 +14,22 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include +#include #include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/service/compiler.h" +#include "xla/service/executable.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor_pimpl.h" +#include "tsl/platform/statusor.h" #if GOOGLE_CUDA #include "xla/service/gpu/nvptx_compiler.h" @@ -78,6 +88,49 @@ ENTRY main { aot_result->LoadExecutable(compiler, stream_exec)); } +TEST_F(GpuAotCompilationTest, LoadExecutableForThunkRuntime) { + const absl::string_view hlo_string = R"( +HloModule Test + +ENTRY main { + a = f32[100, 200]{1,0} parameter(0) + ROOT b = f32[100, 200]{0,1} copy(a) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + DebugOptions debug_options; + debug_options.set_xla_gpu_enable_xla_runtime_executable(false); + module->mutable_config().set_debug_options(debug_options); + + auto compiler = backend().compiler(); + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::MultiPlatformManager::PlatformWithName(name)); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + // Compile AOT. + auto module_group = std::make_unique(std::move(module)); + AotCompilationOptions aot_options(compiler->PlatformId()); + aot_options.set_executor(stream_exec); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_results, + compiler->CompileAheadOfTime(std::move(module_group), aot_options)); + + // Serialize-deserialize AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + aot_results[0]->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + // Load Executable from AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + aot_result->LoadExecutable(compiler, stream_exec)); +} + TEST_F(GpuAotCompilationTest, AotCompilationWithoutGpuDevice) { const absl::string_view hlo_string = R"( HloModule Test diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 4a5a8f39929025..5db19c289ec71c 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -333,19 +333,38 @@ class GpuAotCompilationResult : public AotCompilationResult { class GpuThunkAotCompilationResult : public AotCompilationResult { public: - // TODO(anlunx): Add SerializeAsString(). + GpuThunkAotCompilationResult(HloModule* hlo_module, + BufferAssignment* buffer_assignment, + std::string_view asm_text, + absl::Span binary) { + *proto_.mutable_hlo_module() = hlo_module->ToProto(); + *proto_.mutable_buffer_assignment() = buffer_assignment->ToProto(); + proto_.set_asm_text(std::string(asm_text)); + proto_.set_binary(binary.data(), binary.size()); + } + + explicit GpuThunkAotCompilationResult(CompilationResultProto proto) + : proto_(proto) {} + + StatusOr SerializeAsString() const override { + return proto_.SerializeAsString(); + } + + static StatusOr> FromString( + const std::string& serialized) { + CompilationResultProto proto; + if (!proto.ParseFromString(serialized)) { + return InternalError( + "Failed to parse serialized GpuThunkAotCompilationResult."); + } + return std::make_unique(proto); + } + StatusOr> LoadExecutable( Compiler* compiler, se::StreamExecutor* stream_exec) override; private: - std::unique_ptr hlo_module_; - std::unique_ptr buffer_assignment_; - std::string asm_text_; - std::vector binary_; - - // We can call LoadExecutable only once because buffer_assignment_ is - // moved to GpuExecutable when LoadExecutable is called. - bool loadable_ = true; + CompilationResultProto proto_; }; } // end anonymous namespace @@ -383,10 +402,22 @@ StatusOr> GpuAotCompilationResult::LoadExecutable( StatusOr> GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, se::StreamExecutor* stream_exec) { - if (!loadable_) { - return InternalError("The AOT compilation result is not loadable."); - } - loadable_ = false; + // Recreate HloModule from proto. + TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, + HloModule::CreateModuleConfigFromProto( + proto_.hlo_module(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(proto_.hlo_module(), hlo_module_config)); + + // Recreate BufferAssignment from proto. + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer_assignment, + BufferAssignment::FromProto(proto_.buffer_assignment(), hlo_module.get(), + compiler->BufferSizeBytesFunction(), + /*can_share_buffer=*/nullptr)); + + std::vector binary(proto_.binary().begin(), proto_.binary().end()); // Build the executable, which should be a thunk sequence. TF_ASSIGN_OR_RETURN( @@ -399,22 +430,28 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, auto mlir_context = std::make_unique(registry); llvm::LLVMContext llvm_context; auto llvm_module = std::make_unique("", llvm_context); - IrEmitterContext ir_emitter_context( - hlo_module_.get(), buffer_assignment_.get(), platform_name, - gpu_device_info, mlir_context.get(), llvm_module.get(), - /*emit_ir_from_hlo=*/true); + auto* gpu_compiler = dynamic_cast(compiler); + if (gpu_compiler == nullptr) { + return InternalError("Compiler is not a GpuCompiler."); + } + llvm_module->setTargetTriple(gpu_compiler->target_triple()); + llvm_module->setDataLayout(gpu_compiler->data_layout()); + IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(), + platform_name, gpu_device_info, + mlir_context.get(), llvm_module.get(), + /*emit_ir_from_hlo=*/true); mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( - mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module_->name()); + mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); std::vector ordered_allocations; absl::flat_hash_map operation_map; - TF_RETURN_IF_ERROR(HloToLhloModule(*buffer_assignment_, *hlo_module_, + TF_RETURN_IF_ERROR(HloToLhloModule(*buffer_assignment, *hlo_module, *mlir_module, &ordered_allocations, &operation_map)); ir_emitter_context.set_allocations(ordered_allocations); auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); auto entry_function = mlir::cast( - mlir_module->lookupSymbol(hlo_module_->entry_computation()->name())); + mlir_module->lookupSymbol(hlo_module->entry_computation()->name())); // TODO(anlunx): EmitLmhloRegion emits fusion kernels. We need to make sure // ptx and cubin already contain emission results and disable kernel emission // here. @@ -429,33 +466,33 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, std::vector constants = std::move(ir_emitter_context.constants()); TF_ASSIGN_OR_RETURN(auto output_info, - GetOutputInfo(*hlo_module_, *buffer_assignment_)); - const Shape& output_shape = hlo_module_->result_shape(); + GetOutputInfo(*hlo_module, *buffer_assignment)); + const Shape& output_shape = hlo_module->result_shape(); std::function buffer_assignment_dumper = [] { return std::string(); }; bool enable_persistent_temp_buffers = - hlo_module_->config() + hlo_module->config() .debug_options() .xla_gpu_enable_persistent_temp_buffers(); int64_t debug_buffer_assignment_show_max = - hlo_module_->config() + hlo_module->config() .debug_options() .xla_debug_buffer_assignment_show_max(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, GpuExecutable::Create(GpuExecutable::Params{ - /*asm_text=*/asm_text_, - /*binary=*/binary_, + /*asm_text=*/proto_.asm_text(), + /*binary=*/binary, /*gpu_version=*/gpu_device_info.gpu_compute_capability(), /*executable=*/std::move(thunk_sequence), /*constants=*/std::move(constants), /*output_info=*/std::move(output_info), - /*module_name=*/std::move(hlo_module_->name()), + /*module_name=*/std::move(hlo_module->name()), /*output_shape=*/std::move(output_shape), /*mlir_allocations=*/std::nullopt, - /*buffer_assignment=*/std::move(buffer_assignment_), + /*buffer_assignment=*/std::move(buffer_assignment), /*enable_persistent_temp_buffers=*/enable_persistent_temp_buffers, /*debug_buffer_assignment_show_max=*/debug_buffer_assignment_show_max, /*debug_module=*/std::unique_ptr(), @@ -1859,6 +1896,14 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, CompileToBackendResult(module.get(), &llvm_context, options.executor(), {options.device_allocator()}, gpu_device_info)); + if (!IsXlaRuntimeExecutableEnabled(module->config())) { + // Create GpuThunkAotCompilationResult if thunk runtime is enabled. + results.emplace_back(std::make_unique( + module.get(), res.compile_module_results.buffer_assignment.get(), + res.backend_result.asm_text, res.backend_result.binary)); + continue; + } + const auto* program = std::get_if( &res.compile_module_results.executable); if (!program) { @@ -2030,7 +2075,11 @@ GpuCompiler::LoadAotCompilationResult( StatusOr> GpuCompiler::LoadAotCompilationResultStatic( const std::string& serialized_aot_result) { - return GpuAotCompilationResult::FromString(serialized_aot_result); + // TODO(anlunx): Remove the code that loads a GpuAotCompilationResult when we + // convert to thunk runtime. + auto result = GpuAotCompilationResult::FromString(serialized_aot_result); + if (result.ok()) return result; + return GpuThunkAotCompilationResult::FromString(serialized_aot_result); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index f112332125adb3..0232b6b27ef8e6 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -102,6 +102,9 @@ class GpuCompiler : public LLVMCompiler { Status RunPostSchedulingPipelines(HloModule* module, int64_t scheduler_mem_limit) const; + std::string target_triple() const { return target_triple_; } + std::string data_layout() const { return data_layout_; } + protected: struct BackendCompileResult { std::string asm_text; diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 86b741ca76b16b..53289bcf357427 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -2522,7 +2522,8 @@ tsl::Status HloToLhloModule( TF_RETURN_IF_ERROR(emitter.Initialize(ordered_allocations)); const xla::HloInstructionSequence* schedule = - assignment.hlo_ordering().SequentialOrder(*computation); + &hlo_module.schedule().sequence(computation); + if (!schedule) { return tsl::errors::Unimplemented( "Missing sequential order for the computation"); From ca6672c79c08aabfd411acb8166933e4879b022e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 17:47:23 -0800 Subject: [PATCH 223/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/8f915f25e8b17d2509bb6c7f199a45f2a5e6736c. PiperOrigin-RevId: 586502386 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 40f4624639e1f4..7a59477c03eb4e 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "1bdc3ca2c91c108235a08d603a629970533649f9" - TFRT_SHA256 = "dd9c04c8907c217cefd8336908fcbc2adff00a3126fbc0f6af4182e22ce87395" + TFRT_COMMIT = "8f915f25e8b17d2509bb6c7f199a45f2a5e6736c" + TFRT_SHA256 = "6d0cc4221d9bb6739bf16a03da482abc348f6143395726595d89e3f12158a0ea" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 40f4624639e1f4..7a59477c03eb4e 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "1bdc3ca2c91c108235a08d603a629970533649f9" - TFRT_SHA256 = "dd9c04c8907c217cefd8336908fcbc2adff00a3126fbc0f6af4182e22ce87395" + TFRT_COMMIT = "8f915f25e8b17d2509bb6c7f199a45f2a5e6736c" + TFRT_SHA256 = "6d0cc4221d9bb6739bf16a03da482abc348f6143395726595d89e3f12158a0ea" tf_http_archive( name = "tf_runtime", From 086f52fa8970fc20ff84c4932a8d290f81ccb482 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 18:15:48 -0800 Subject: [PATCH 224/381] [XLA:GPU] Improve errors if callbacks are not provided to the CollectivePipeliner. Fix an apparently missing callback if --xla_gpu_enable_pipelined_p2p is enabled. PiperOrigin-RevId: 586507418 --- third_party/xla/xla/service/collective_pipeliner.cc | 3 +++ third_party/xla/xla/service/collective_pipeliner.h | 2 +- third_party/xla/xla/service/gpu/gpu_compiler.cc | 7 ++++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 6d5e2f5f2a63ec..57a4c45d0110bf 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -2258,6 +2258,8 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, StatusOr CollectivePipeliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + CHECK(config_.acceptable_formatting); + CHECK(config_.should_process); bool changed = false; std::vector while_loop_instructions; for (HloComputation* computation : module->MakeComputationPostOrder()) { @@ -2303,6 +2305,7 @@ StatusOr CollectivePipeliner::Run( } } if (config_.pipelining_direction == PipeliningDirection::kForward) { + CHECK(config_.reuse_pipelined_op_buffer); TF_RETURN_IF_ERROR(TransformLoopForward( loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h index 01b167dcc45835..8553dd96521931 100644 --- a/third_party/xla/xla/service/collective_pipeliner.h +++ b/third_party/xla/xla/service/collective_pipeliner.h @@ -75,7 +75,7 @@ class CollectivePipeliner : public HloModulePass { bool process_different_sized_ops = false; PipeliningDirection pipelining_direction = PipeliningDirection::kForward; HloPredicate should_process; - // Filter acceptable formatting ops for for forward piplining to discard + // Filter acceptable formatting ops for for forward pipelining to discard // cases that pipeline formatting operations that we don't want to support. HloPredicate acceptable_formatting; // If the pipelined op has same input/output size the we reuse the same diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 5db19c289ec71c..8c5f4b03674b13 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -393,8 +393,8 @@ StatusOr> GpuAotCompilationResult::LoadExecutable( return GpuExecutable::LoadFromObjFile( std::move(hlo_module), xla_runtime_executable.obj_file(), - xla_runtime_executable.mlir_module(), - GetDebugOptionsFromFlags(), xla_runtime_gpu_executable_.gpu_asm_text(), + xla_runtime_executable.mlir_module(), GetDebugOptionsFromFlags(), + xla_runtime_gpu_executable_.gpu_asm_text(), xla_runtime_gpu_executable_.gpu_binary(), std::move(constants), GetGpuVersion(executor), executor); } @@ -1112,7 +1112,8 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, /*process_different_sized_ops=*/true, /*pipelining_direction=*/ CollectivePipeliner::PipeliningDirection::kBackward, - /*should_process=*/may_pipeline_p2p}; + /*should_process=*/may_pipeline_p2p, + /*acceptable_formatting=*/[](const HloInstruction*) { return true; }}; pipeline.AddPass(config); } From 49184b68a61e33a15e1bb5d7afff2a137c766f68 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 18:34:16 -0800 Subject: [PATCH 225/381] Relocates the evaluation output into the core solver. PiperOrigin-RevId: 586510253 --- .../auto_sharding/auto_sharding.cc | 27 +---------------- .../auto_sharding/auto_sharding_solver.cc | 29 +++++++++++++++++-- .../auto_sharding_solver_test.cc | 4 +++ 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b2947dd85c750f..3de8fbeb86ce7d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2657,32 +2657,7 @@ AutoShardingSolverResult CallSolver( PopulateTemporalValues(cost_graph, request); - const AutoShardingSolverResult result = CallORToolsSolver(request); - if (result.status.ok()) { - const AutoShardingEvaluation evaluation = Evaluate(request, result); - LOG(INFO) << "Total Communication Cost: " - << evaluation.total.communication_cost - << " (lower bound: " << evaluation.lower_bound.communication_cost - << ")"; - LOG(INFO) << "Total Computation Cost: " << evaluation.total.computation_cost - << " (lower bound: " << evaluation.lower_bound.computation_cost - << ")"; - LOG(INFO) << "Total Resharding Cost: " << evaluation.total.resharding_cost - << " (lower bound: " << evaluation.lower_bound.resharding_cost - << ")"; - LOG(INFO) << "Total Overbudget Cost: " << evaluation.total.overbudget_cost - << " (lower bound: " << evaluation.lower_bound.overbudget_cost - << ")"; - LOG(INFO) << "Total Makespan Cost: " << evaluation.total.makespan_cost - << " (lower bound: " << evaluation.lower_bound.makespan_cost - << ")"; - LOG(INFO) << "Total Cost: " << evaluation.total.cost() - << " (lower bound: " << evaluation.lower_bound.cost() << ")"; - LOG(INFO) << "Total Departures: " << evaluation.total_departures; - LOG(INFO) << "Total Makespan: " << evaluation.total_makespan; - LOG(INFO) << "Total Violations: " << evaluation.violation_codes.size(); - } - return result; + return CallORToolsSolver(request); } void CheckHloSharding(const HloInstructionSequence& sequence, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 1847af9e87aed9..8b73e6a43bc002 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -525,8 +525,33 @@ AutoShardingSolverResult CallORToolsSolver( << "GB\n" << "Number of ILP constraints: " << solver->NumConstraints() << "\n" << "Module name: " << request.module_name(); - return SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, - *solver); + auto result = SolveAndExtractSolution(request, s, e, overbudget_var, + makespan_var, *solver); + if (result.status.ok()) { + const AutoShardingEvaluation evaluation = Evaluate(request, result); + LOG(INFO) << "Total Communication Cost: " + << evaluation.total.communication_cost + << " (lower bound: " << evaluation.lower_bound.communication_cost + << ")"; + LOG(INFO) << "Total Computation Cost: " << evaluation.total.computation_cost + << " (lower bound: " << evaluation.lower_bound.computation_cost + << ")"; + LOG(INFO) << "Total Resharding Cost: " << evaluation.total.resharding_cost + << " (lower bound: " << evaluation.lower_bound.resharding_cost + << ")"; + LOG(INFO) << "Total Overbudget Cost: " << evaluation.total.overbudget_cost + << " (lower bound: " << evaluation.lower_bound.overbudget_cost + << ")"; + LOG(INFO) << "Total Makespan Cost: " << evaluation.total.makespan_cost + << " (lower bound: " << evaluation.lower_bound.makespan_cost + << ")"; + LOG(INFO) << "Total Cost: " << evaluation.total.cost() + << " (lower bound: " << evaluation.lower_bound.cost() << ")"; + LOG(INFO) << "Total Departures: " << evaluation.total_departures; + LOG(INFO) << "Total Makespan: " << evaluation.total_makespan; + LOG(INFO) << "Total Violations: " << evaluation.violation_codes.size(); + } + return result; } AutoShardingSolverResult SolveAndExtractSolution( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 1c54cd7aae11e8..d2b63f690f3321 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -217,6 +217,10 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { 6000, 6100, 6200, 6300, 7000, 7100, 7200, 7300}}; AddCosts(request.mutable_resharding_costs(), r); + const CostMatrix t = {{50000, 51000, 52000, 53000, + 60000, 61000, 62000, 63000, + 70000, 71000, 72000, 73000}}; + AddCosts(request.mutable_duration_costs(), t); const AutoShardingSolverResult result = CallORToolsSolver(request); From a4506661c474e31197f8519a97dc9b8f92f23c40 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Nov 2023 18:51:58 -0800 Subject: [PATCH 226/381] Lower tensor.from_elements and shape.broadcast ops in ShapeLegalizeToHLO PiperOrigin-RevId: 586512876 --- .../shape_legalize_to_hlo.cc | 89 ++++++++++++++++++- .../Dialect/mhlo/shape_legalize_to_hlo.mlir | 79 +++++++++++++++- 2 files changed, 166 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc index b5b45d35818d04..48a1fc1f722810 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include @@ -23,10 +25,12 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -252,6 +256,45 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern { } }; +struct ConvertShapeBroadcastOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::BroadcastOp op, + PatternRewriter& rewriter) const override { + // Only support broadcasting for two 1D tensors with same size. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = shape1.getType().dyn_cast(); + auto tensorType2 = shape2.getType().dyn_cast(); + if (!tensorType1 || !tensorType2 || tensorType1.getRank() != 1 || + tensorType2.getRank() != 1 || + tensorType1.getDimSize(0) != tensorType2.getDimSize(0)) + return failure(); + + // By definition, broadcasted dims are: + // result[i] = lhs[i] if lhs[i] == rhs[i] + // = lhs[i] if rhs[i] == 1 + // = rhs[i] if lhs[i] == 1 + // + // We assume that there is shape.cstr_broadcastable check done elsewhere to + // make sure the shapes are broadcastable, then we can calculate broadcast + // result simply using MaxOp. In case the shapes are not broadcastable, the + // result extent tensor is undefined according to spec. So this + // implementation is technically correct. + auto broadcasted = + rewriter.create(op->getLoc(), shape1, shape2); + + auto broadcastedIndex = castToIndex(rewriter, op.getLoc(), broadcasted); + if (!broadcastedIndex || + broadcastedIndex.getType() != op.getResult().getType()) + return rewriter.notifyMatchFailure(op, "cast to index failed"); + rewriter.replaceOp(op, broadcastedIndex); + return success(); + } +}; + struct ConvertTensorDimPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp op, @@ -271,6 +314,48 @@ struct ConvertTensorDimPattern : public OpRewritePattern { } }; +struct ConvertTensorFromElementsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::FromElementsOp op, + PatternRewriter& rewriter) const override { + // We only handle 1D tensor with index types. tensor.from_elements spec + // allows the same element type only for all input/output. + auto tensorType = + op.getResult().getType().dyn_cast_or_null(); + if (!tensorType || tensorType.getRank() != 1) { + return failure(); + } + if (!hasIndexStyle(op.getResult())) return failure(); + + SmallVector elementI32x1; + for (size_t i = 0; i < op.getElements().size(); ++i) { + if (auto constIndex = dyn_cast_or_null( + op.getElements()[i].getDefiningOp())) { + elementI32x1.push_back(rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({1}, rewriter.getI32Type()), + static_cast(constIndex.value())))); + } else { + elementI32x1.push_back(rewriter.create( + op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()), + castToI32(rewriter, op->getLoc(), op.getElements()[i]))); + } + } + Value tensorI32 = + rewriter.create(op.getLoc(), elementI32x1, + /*dimension=*/0); + + tensorI32 = hasI32Style(op.getResult()) + ? tensorI32 + : castToIndex(rewriter, op.getLoc(), tensorI32); + if (!tensorI32 || tensorI32.getType() != op.getResult().getType()) + return rewriter.notifyMatchFailure(op, "cast to index failed"); + rewriter.replaceOp(op, tensorI32); + return success(); + } +}; + template struct CastOperandsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -326,7 +411,7 @@ struct ShapeLegalizeToHloPass // Most of these ops are convertible to MHLO, although the representation is // going to be pretty laborious for many of them. Luckily, canonicalization // is able to remove unnecessary cruft. At the moment, this pass is a - // work in progress, so now all of these ops are supported. + // work in progress, so not all of these ops are supported. // // The only problem (and a big problem at that) are the ops involved in // shape constraints: cstr* ops as well as shape.assuming*. Since HLO does @@ -357,8 +442,10 @@ struct ShapeLegalizeToHloPass patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add>(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir index fcdd268c0c4208..f81af2fa052715 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir @@ -89,7 +89,6 @@ func.func @shape_of_ranked_to_shape(%arg0: tensor) -> !shape.shape { func.return %0 : !shape.shape } - // ----- // CHECK-LABEL: func.func @tensor_dim @@ -109,3 +108,81 @@ func.func @tensor_dim_dynamic(%arg0: tensor, %arg1: index) -> index { %dim = tensor.dim %arg0, %arg1 : tensor func.return %dim : index } + +// ----- + +// CHECK-LABEL: func.func @tensor_from_elements +func.func @tensor_from_elements(%arg0: index) -> tensor<2xindex> { + %c0 = arith.constant 0 : index + %0 = tensor.from_elements %arg0, %c0 : tensor<2xindex> + func.return %0 : tensor<2xindex> + // CHECK: %[[ELEMENT1_SCALAR:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor + // CHECK-NEXT: %[[ELEMENT1:.*]] = mhlo.reshape %[[ELEMENT1_SCALAR]] : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[ELEMENT2:.*]] = mhlo.constant dense<0> : tensor<1xi32> + // CHECK-NEXT: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[ELEMENT1]], %[[ELEMENT2]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[CONCAT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONCAT]] : tensor<2xi32> to tensor<2xindex> + // CHECK-NEXT: return %[[CONCAT_INDEX]] : tensor<2xindex> +} + +// ----- + +func.func @tensor_from_elements_i8(%arg0: i8) -> tensor<2xi8> { + %c0 = arith.constant 0 : i8 + // expected-error@+1 {{failed to legalize operation 'tensor.from_elements' that was explicitly marked illegal}} + %0 = tensor.from_elements %arg0, %c0 : tensor<2xi8> + func.return %0 : tensor<2xi8> +} + +// ----- + +func.func @tensor_from_elements_rank2(%arg0: index) -> tensor<2x1xindex> { + %c0 = arith.constant 0 : index + // expected-error@+1 {{failed to legalize operation 'tensor.from_elements' that was explicitly marked illegal}} + %0 = tensor.from_elements %arg0, %c0 : tensor<2x1xindex> + func.return %0 : tensor<2x1xindex> +} + +// ----- + +// CHECK-LABEL: func.func @shape_broadcast +func.func @shape_broadcast(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> tensor<4xindex> { + %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> + func.return %0 : tensor<4xindex> + // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4xindex> to tensor<4xi32> + // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<4xindex> to tensor<4xi32> + // CHECK-NEXT: %[[BROADCAST:.*]] = mhlo.maximum %[[LHS]], %[[RHS]] : tensor<4xi32> + // CHECK-NEXT: %[[BROADCAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[BROADCAST]] : tensor<4xi32> to tensor<4xindex> + // CHECK-NEXT: return %[[BROADCAST_INDEX]] : tensor<4xindex> +} + +// ----- + +func.func @shape_broadcast_result_shape(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> !shape.shape { + // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} + %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<4xindex> -> !shape.shape + func.return %0 : !shape.shape +} + +// ----- + +func.func @shape_broadcast_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) -> !shape.shape { + // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} + %0 = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> !shape.shape + func.return %0 : !shape.shape +} + +// ----- + +func.func @shape_broadcast_different_dims(%arg0: tensor<4xindex>, %arg1: tensor<6xindex>) -> tensor<6xindex> { + // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} + %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<6xindex> -> tensor<6xindex> + func.return %0 : tensor<6xindex> +} + +// ----- + +func.func @shape_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) -> tensor<4xindex> { + // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} + %0 = shape.broadcast %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> + func.return %0 : tensor<4xindex> +} From f85e22b6b574a93ecba7f4beb676b440f1a894bf Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 29 Nov 2023 19:25:18 -0800 Subject: [PATCH 227/381] Internal infrastructure change PiperOrigin-RevId: 586517769 --- third_party/stablehlo/temporary.patch | 124 +----------------- .../xla/third_party/stablehlo/temporary.patch | 124 +----------------- 2 files changed, 4 insertions(+), 244 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 0014e5c9384b16..8abb8b476d8c90 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -156,127 +156,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/experimental/BUILD b/stablehlo/stablehlo/experimental/BUILD ---- stablehlo/stablehlo/experimental/BUILD -+++ stablehlo/stablehlo/experimental/BUILD -@@ -0,0 +1,117 @@ -+load("//third_party/llvm/llvm-project/mlir:tblgen.bzl", "gentbl_cc_library") -+load("//third_party/tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") -+load("//third_party/tensorflow/core/platform:rules_cc.bzl", "cc_library") -+ -+package( -+ default_applicable_licenses = ["//third_party/stablehlo:license"], # copybara:comment -+ default_visibility = ["//learning/brain/mlir:stablehlo_friends"], -+ licenses = ["notice"], -+) -+ -+filegroup( -+ name = "stablehlo_experimental_filegroup", -+ srcs = glob(["**"]), -+) -+ -+cc_library( -+ name = "experimental_base", -+ srcs = [ -+ "dialect/Base.cpp", -+ ], -+ hdrs = [ -+ "dialect/Base.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ ], -+) -+ -+cc_library( -+ name = "experimental_stablehlo_ops", -+ srcs = [ -+ "dialect/StablehloOps.cpp", -+ ], -+ hdrs = [ -+ "dialect/StablehloOps.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_base", -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:FuncDialect", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ "//third_party/llvm/llvm-project/mlir:Support", -+ "//third_party/stablehlo:stablehlo_ops", -+ ], -+) -+ -+gentbl_cc_library( -+ name = "experimental_stablehlo_pass_inc_gen", -+ compatible_with = get_compatible_with_portable(), -+ tbl_outs = [ -+ ( -+ [ -+ "-gen-pass-decls", -+ ], -+ "transforms/Passes.h.inc", -+ ), -+ ], -+ tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen", -+ td_file = "transforms/Passes.td", -+ deps = ["//third_party/llvm/llvm-project/mlir:PassBaseTdFiles"], -+) -+ -+cc_library( -+ name = "experimental_stablehlo_passes", -+ srcs = [ -+ "transforms/StablehloCanonicalizeDynamism.cpp", -+ "transforms/StablehloRefineShapes.cpp", -+ ], -+ hdrs = [ -+ "transforms/Passes.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_stablehlo_ops", -+ ":experimental_stablehlo_pass_inc_gen", -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:FuncDialect", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ "//third_party/llvm/llvm-project/mlir:InferTypeOpInterface", -+ "//third_party/llvm/llvm-project/mlir:Pass", -+ "//third_party/llvm/llvm-project/mlir:Support", -+ "//third_party/llvm/llvm-project/mlir:TransformUtils", -+ "//third_party/llvm/llvm-project/mlir:Transforms", -+ "//third_party/stablehlo:base", -+ "//third_party/stablehlo:chlo_ops", -+ "//third_party/stablehlo:stablehlo_ops", -+ "//third_party/stablehlo:stablehlo_ops_inc_gen", -+ "//third_party/stablehlo:stablehlo_type_inference", -+ ], -+) -+ -+cc_binary( -+ name = "experimental-stablehlo-opt", -+ srcs = [ -+ "tools/StablehloOptMain.cpp", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_stablehlo_passes", -+ "//third_party/llvm/llvm-project/mlir:AllExtensions", -+ "//third_party/llvm/llvm-project/mlir:AllPassesAndDialects", -+ "//third_party/llvm/llvm-project/mlir:MlirOptLib", -+ "//third_party/llvm/llvm-project/mlir:TosaDialect", -+ "//third_party/stablehlo:interpreter_ops", -+ "//third_party/stablehlo:register", -+ "//third_party/stablehlo:stablehlo_passes", -+ "//third_party/stablehlo:test_utils", -+ "//third_party/stablehlo:tosa_passes", -+ ], -+) diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -1291,7 +1170,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.h b/stablehlo diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stablehlo/experimental/tests/BUILD.bazel --- stablehlo/stablehlo/experimental/tests/BUILD.bazel +++ stablehlo/stablehlo/experimental/tests/BUILD.bazel -@@ -0,0 +1,58 @@ +@@ -0,0 +1,59 @@ +# Copyright 2023 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -1338,6 +1217,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stab + "lit.site.cfg.py", + "//:stablehlo-opt", + "//:stablehlo-translate", ++ "//stablehlo/experimental:experimental-stablehlo-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ] + glob(["%s.bc" % src]), diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 6ed8468c4819c5..8abb8b476d8c90 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -156,127 +156,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/experimental/BUILD b/stablehlo/stablehlo/experimental/BUILD ---- stablehlo/stablehlo/experimental/BUILD -+++ stablehlo/stablehlo/experimental/BUILD -@@ -0,0 +1,117 @@ -+load("//third_party/llvm/llvm-project/mlir:tblgen.bzl", "gentbl_cc_library") -+load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") -+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -+ -+package( -+ default_applicable_licenses = ["//third_party/stablehlo:license"], # copybara:comment -+ default_visibility = ["//learning/brain/mlir:stablehlo_friends"], -+ licenses = ["notice"], -+) -+ -+filegroup( -+ name = "stablehlo_experimental_filegroup", -+ srcs = glob(["**"]), -+) -+ -+cc_library( -+ name = "experimental_base", -+ srcs = [ -+ "dialect/Base.cpp", -+ ], -+ hdrs = [ -+ "dialect/Base.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ ], -+) -+ -+cc_library( -+ name = "experimental_stablehlo_ops", -+ srcs = [ -+ "dialect/StablehloOps.cpp", -+ ], -+ hdrs = [ -+ "dialect/StablehloOps.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_base", -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:FuncDialect", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ "//third_party/llvm/llvm-project/mlir:Support", -+ "//third_party/stablehlo:stablehlo_ops", -+ ], -+) -+ -+gentbl_cc_library( -+ name = "experimental_stablehlo_pass_inc_gen", -+ compatible_with = get_compatible_with_portable(), -+ tbl_outs = [ -+ ( -+ [ -+ "-gen-pass-decls", -+ ], -+ "transforms/Passes.h.inc", -+ ), -+ ], -+ tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen", -+ td_file = "transforms/Passes.td", -+ deps = ["//third_party/llvm/llvm-project/mlir:PassBaseTdFiles"], -+) -+ -+cc_library( -+ name = "experimental_stablehlo_passes", -+ srcs = [ -+ "transforms/StablehloCanonicalizeDynamism.cpp", -+ "transforms/StablehloRefineShapes.cpp", -+ ], -+ hdrs = [ -+ "transforms/Passes.h", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_stablehlo_ops", -+ ":experimental_stablehlo_pass_inc_gen", -+ "//third_party/llvm/llvm-project/llvm:Support", -+ "//third_party/llvm/llvm-project/mlir:FuncDialect", -+ "//third_party/llvm/llvm-project/mlir:IR", -+ "//third_party/llvm/llvm-project/mlir:InferTypeOpInterface", -+ "//third_party/llvm/llvm-project/mlir:Pass", -+ "//third_party/llvm/llvm-project/mlir:Support", -+ "//third_party/llvm/llvm-project/mlir:TransformUtils", -+ "//third_party/llvm/llvm-project/mlir:Transforms", -+ "//third_party/stablehlo:base", -+ "//third_party/stablehlo:chlo_ops", -+ "//third_party/stablehlo:stablehlo_ops", -+ "//third_party/stablehlo:stablehlo_ops_inc_gen", -+ "//third_party/stablehlo:stablehlo_type_inference", -+ ], -+) -+ -+cc_binary( -+ name = "experimental-stablehlo-opt", -+ srcs = [ -+ "tools/StablehloOptMain.cpp", -+ ], -+ compatible_with = get_compatible_with_portable(), -+ includes = ["../.."], -+ deps = [ -+ ":experimental_stablehlo_passes", -+ "//third_party/llvm/llvm-project/mlir:AllExtensions", -+ "//third_party/llvm/llvm-project/mlir:AllPassesAndDialects", -+ "//third_party/llvm/llvm-project/mlir:MlirOptLib", -+ "//third_party/llvm/llvm-project/mlir:TosaDialect", -+ "//third_party/stablehlo:interpreter_ops", -+ "//third_party/stablehlo:register", -+ "//third_party/stablehlo:stablehlo_passes", -+ "//third_party/stablehlo:test_utils", -+ "//third_party/stablehlo:tosa_passes", -+ ], -+) diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -1291,7 +1170,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.h b/stablehlo diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stablehlo/experimental/tests/BUILD.bazel --- stablehlo/stablehlo/experimental/tests/BUILD.bazel +++ stablehlo/stablehlo/experimental/tests/BUILD.bazel -@@ -0,0 +1,58 @@ +@@ -0,0 +1,59 @@ +# Copyright 2023 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -1338,6 +1217,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stab + "lit.site.cfg.py", + "//:stablehlo-opt", + "//:stablehlo-translate", ++ "//stablehlo/experimental:experimental-stablehlo-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ] + glob(["%s.bc" % src]), From 7f0aff419f9da840c23a836e204b3e01168b1a3c Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Wed, 29 Nov 2023 19:38:43 -0800 Subject: [PATCH 228/381] [stream_executor][NFC] More doc for While and For command PiperOrigin-RevId: 586519546 --- third_party/xla/xla/stream_executor/command_buffer.h | 7 +++++-- .../xla/stream_executor/cuda/cuda_command_buffer_test.cc | 2 ++ .../stream_executor/cuda/cuda_conditional_kernels.cu.cc | 5 +++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 647d51283f8eab..3ac378acbdf171 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -154,14 +154,17 @@ class CommandBuffer { std::vector branches); // Adds a conditional operation that will execute a command buffer constructed - // by the `body_builder` exactly `num_iteration` times. + // by the `body_builder` exactly `num_iteration` times. This means the + // condition is known at compile time (`num_iteration` < `loop_counter`), and + // does not require a `cond_builder`. tsl::Status For(StreamExecutor* executor, int32_t num_iteration, DeviceMemory loop_counter, Builder body_builder); // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value is `true`. + // `cond_builder`. Will continue while `pred` value (which is continously + // updated by `cond_builder`) is `true`. // // In pseudocode: // diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 1464df79e8e6d1..d3064b4bdde443 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -686,6 +686,8 @@ TEST(CudaCommandBufferTest, ConditionalWhile) { int64_t byte_length = sizeof(int32_t) * length; // Prepare arguments: a=1, b=0, loop_counter=0, pred=false + // Value of `pred` is not important, as it will be updated by `cond_builder` + // below. DeviceMemory pred = executor->AllocateArray(1, 0); DeviceMemory loop_counter = executor->AllocateArray(1, 0); DeviceMemory a = executor->AllocateArray(length, 0); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc index 09475192d460bf..0ce8dbd4a4fa32 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -22,8 +22,9 @@ namespace stream_executor { namespace cuda { namespace { -// In all kernels defined above we set conditional handle value to `1` when we -// want to execute a CUDA graph tied to it, and to `0` otherwise. +// In all kernels defined below we set conditional handle value to `1` when we +// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the +// graph will keep being executed until the conditional handle becomes `0`. #if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) && \ CUDA_VERSION >= 12030 From 314936814144b4cb6790660291ef069c4212b38a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 01:02:04 -0800 Subject: [PATCH 229/381] Update GraphDef version to 1696. PiperOrigin-RevId: 586585209 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index c7699beb5be775..79142b54bc27ec 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1695 // Updated: 2023/11/29 +#define TF_GRAPH_DEF_VERSION 1696 // Updated: 2023/11/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 67a62c4acf2df1f9b03e9b81e87eaa32d866192f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 01:02:09 -0800 Subject: [PATCH 230/381] compat: Update forward compatibility horizon to 2023-11-30 PiperOrigin-RevId: 586585231 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 288f3ffe741104..138d020c136a53 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 5d2c0b15c139178c5c38f3ffa70ce1e1f2e232b3 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 30 Nov 2023 01:52:42 -0800 Subject: [PATCH 231/381] Slightly improve power of 2 rounding in TreeReductionEmitter. When splitting a reduction dimension n into k, n / k, it is better if we can avoid padding to make n divisible by k. So use the power of 2 condition only as a tie breaker. Example HLO: Add { x.1 = f32[] parameter(0) y.1 = f32[] parameter(1) ROOT add.1 = f32[] add(x.1, y.1) } ENTRY reduce.1 { parameter = f32[4,73984,33]{2,1,0} parameter(0) init_value = f32[] constant(0) ROOT reduce = f32[4,33] reduce(parameter, init_value), dimensions={1}, to_apply=Add } We should split into [4, 288, 256, 33] instead of [4, 256, 288, 33] or [4, 272, 272, 33] PiperOrigin-RevId: 586597696 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/reduction_utils.cc | 35 +++++++---- .../xla/xla/service/gpu/reduction_utils.h | 5 ++ .../gpu/tests/tree_reduction_rewriter_test.cc | 40 ++++++------- .../service/gpu/tree_reduction_rewriter.cc | 59 ++++++++++--------- 5 files changed, 81 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 44e20e53fd9496..2e63d3680bf7d4 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1152,6 +1152,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ + ":ir_emission_utils", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/reduction_utils.cc b/third_party/xla/xla/service/gpu/reduction_utils.cc index a8576a5e4f1c3e..153d094c159667 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.cc +++ b/third_party/xla/xla/service/gpu/reduction_utils.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/layout_util.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/util.h" @@ -96,17 +98,26 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { return {1, 128, 1}; } +int64_t ReductionDimensionRaceFreeBound( + const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions) { + Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); + if (reduction_dimensions.is_row_reduction) { + return MinThreadsXRowReduction(hlo_module_config) * reduction_tiling[2]; + } + return WarpSize() * reduction_tiling[1]; +} + static bool IsUnnestedReductionFasterThanElemental( const ReductionDimensions& reduction_dimensions) { - const int kWarpSize = 32; if (reduction_dimensions.is_row_reduction) { // For row reduction, the tile block is 1 x tile_size_x, and we are reducing // along tile_size_x which needs to be large enough to make the tiling // implementation efficient. // For very small reductions with a power-of-two size, we can fit multiple // reductions inside a single warp, which is more efficient than a loop. - return (reduction_dimensions.dimensions[2] >= kWarpSize) || - ((kWarpSize % reduction_dimensions.dimensions[2]) == 0); + return (reduction_dimensions.dimensions[2] >= WarpSize()) || + ((WarpSize() % reduction_dimensions.dimensions[2]) == 0); } // For column reduction, the tile block is tile_size_y x tile_size_x, and we @@ -117,10 +128,10 @@ static bool IsUnnestedReductionFasterThanElemental( // Rule generated by sweeping the search space of small column reductions. bool prefer_elemental_emitter = - (major_size < kWarpSize) || - (major_size < 2 * kWarpSize && minor_size < kWarpSize) || - (major_size < 4 * kWarpSize && minor_size < 8) || - (major_size < 8 * kWarpSize && minor_size < 3); + (major_size < WarpSize()) || + (major_size < 2 * WarpSize() && minor_size < WarpSize()) || + (major_size < 4 * WarpSize() && minor_size < 8) || + (major_size < 8 * WarpSize() && minor_size < 3); return !prefer_elemental_emitter; } @@ -153,18 +164,18 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, const ReductionDimensions& reduction_dimensions) { - const int kWarpSize = 32; - Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); if (reduction_dimensions.is_row_reduction) { return reduction_dimensions.dimensions[2] <= - MinThreadsXRowReduction(hlo_module_config) * - reduction_tiling[2] && + ReductionDimensionRaceFreeBound(hlo_module_config, + reduction_dimensions) && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound(); } // Column reduction. - return reduction_dimensions.dimensions[1] <= kWarpSize * reduction_tiling[1]; + return reduction_dimensions.dimensions[1] <= + ReductionDimensionRaceFreeBound(hlo_module_config, + reduction_dimensions); } ReductionDimensions GetReductionKindAndContiguousComponents( diff --git a/third_party/xla/xla/service/gpu/reduction_utils.h b/third_party/xla/xla/service/gpu/reduction_utils.h index 610200c8884642..6d20c5d4032cff 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.h +++ b/third_party/xla/xla/service/gpu/reduction_utils.h @@ -59,6 +59,11 @@ ReductionDimensions GetReductionKindAndContiguousComponents( // Get tiling per thread for the given reduction in dimensions [D, H, W]. Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); +// How big the reduction dimension can be to be race free. +int64_t ReductionDimensionRaceFreeBound( + const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions); + // Returns whether the given reduction can be safely generated without atomics : // that is, at most one block will write to every output element. bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, diff --git a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc index 4cfd03a38a66ee..d49464b612f713 100644 --- a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc @@ -60,7 +60,7 @@ add { } ENTRY main { - input = f32[50000] parameter(0) + input = f32[50021] parameter(0) zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } @@ -68,9 +68,9 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[pad_0:%[^ ]+]] = f32[50048]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_48 -// CHECK: [[bitcast_3:%[^ ]+]] = f32[128,391]{1,0} bitcast([[pad_0]]) -// CHECK: [[reduce_4:%[^ ]+]] = f32[128]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]] +// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1 +// CHECK: [[bitcast_3:%[^ ]+]] = f32[397,126]{1,0} bitcast([[pad_0]]) +// CHECK: [[reduce_4:%[^ ]+]] = f32[397]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]] // CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]] )"); } @@ -120,9 +120,9 @@ ENTRY main { CheckTreeRewriter(hlo, R"( // CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0) -// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]]) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[391,128]{1,0} bitcast([[input_0]]) // CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0) -// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] +// CHECK: [[reduce_3:%[^ ]+]] = f32[391]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] // CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]] )"); } @@ -269,7 +269,7 @@ add { } ENTRY main { - input = f32[10302,100] parameter(0) + input = f32[10303,100] parameter(0) zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add } @@ -277,11 +277,11 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[input_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0) +// CHECK: [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0) // CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0) -// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_2x0_0 -// CHECK: [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]]) -// CHECK: [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] +// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0 +// CHECK: [[bitcast_1:%[^ ]+]] = f32[161,64,100]{2,1,0} bitcast([[pad_0]]) +// CHECK: [[reduce_3:%[^ ]+]] = f32[161,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] // CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]] )"); } @@ -362,8 +362,8 @@ argmax { } ENTRY main { - input = f32[2,100000] parameter(0) - idxs = u32[2,100000] iota(), iota_dimension=0 + input = f32[2,100003] parameter(0) + idxs = u32[2,100003] iota(), iota_dimension=0 zero = f32[] constant(0) zero_idx = u32[] constant(0) @@ -376,14 +376,14 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[pad_0:%[^ ]+]] = f32[2,100096]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_96 -// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,256,391]{2,1,0} bitcast([[pad_0]]) +// CHECK: [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2 +// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,339,295]{2,1,0} bitcast([[pad_0]]) // CHECK: [[zero_idx_4:%[^ ]+]] = u32[] constant(0) -// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100096]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_96 -// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,256,391]{2,1,0} bitcast([[pad_1_5]]) -// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,256]{1,0}, u32[2,256]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]] -// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,256]{1,0} get-tuple-element([[reduce_8]]), index=0 -// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,256]{1,0} get-tuple-element([[reduce_8]]), index=1 +// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2 +// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,339,295]{2,1,0} bitcast([[pad_1_5]]) +// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,339]{1,0}, u32[2,339]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]] +// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,339]{1,0} get-tuple-element([[reduce_8]]), index=0 +// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,339]{1,0} get-tuple-element([[reduce_8]]), index=1 // CHECK: ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]] )"); } diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc index bd37ce540540fd..b5dd3047fe727c 100644 --- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -115,32 +116,36 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { // We do this by splitting the input shape [a, n, b] into [a, k, n / k, b]. // // We want to choose k to be roughly equal to sqrt(n) so that we process - // "most of" the reduction in the first step. We also want k to be a power - // of 2, so that the GPU kernel doesn't spend all its time doing slow - // integer divmods to compute indices into the shape [a,k,n/k,b]. This - // means we may need to pad n so that n is divisible by k. - // - // Thus we consider two options for k: - // - // k1 = round_up_pow2(sqrt(n)) - // k2 = round_down_pow2(sqrt(n)) - // - // and we choose the value of k that results in the least amount of padding. - int64_t k1 = absl::bit_ceil(static_cast(std::ceil(std::sqrt(n)))); - int64_t k2 = - absl::bit_floor(static_cast(std::floor(std::sqrt(n)))); - int64_t padded_n_k1 = RoundUpTo(n, k1); - int64_t padded_n_k2 = RoundUpTo(n, k2); - - int64_t k; - int64_t padded_n; - if (padded_n_k1 < padded_n_k2) { - k = k1; - padded_n = padded_n_k1; - } else { - k = k2; - padded_n = padded_n_k2; + // "most of" the reduction in the first step. But it is also important that + // we choose a value of k with the least amount of padding we need to add to + // n to make it divisible by k. We search for the best value of n / k + // between sqrt(n)/2 and sqrt(n). If there are several possible values for + // n / k that result in the minimum amount of padding, we also want n / k to + // be a power of 2, so that the GPU kernel doesn't spend all its time doing + // slow integer divmods to compute indices into the shape [a,k,n/k,b]. + // Note that by searching in the range between sqrt(n)/2 and sqrt(n), we + // will have a power of 2 in that range. + uint64_t n_div_k = static_cast(std::floor(std::sqrt(n))); + int64_t race_free_bound = ReductionDimensionRaceFreeBound( + hlo->GetModule()->config(), reduction_dimensions); + if (n_div_k > race_free_bound) { + // This means we need more than one split. It is best to limit the n/k + // dimension to the maximum size that doesn't require further splitting. + // Otherwise we might choose a rather small reduce dimension size for the + // first step (in the worst case, sqrt(race_free_bound + 1)). + n_div_k = race_free_bound; + } + uint64_t minimum_padding = (n_div_k - n % n_div_k) % n_div_k; + uint64_t best_k = (n + minimum_padding) / n_div_k; + for (uint64_t i = n_div_k - 1; i > n_div_k / 2; --i) { + uint64_t padding = (i - n % i) % i; + if (padding < minimum_padding || + (padding == minimum_padding && absl::has_single_bit(i))) { + minimum_padding = padding; + best_k = (n + padding) / i; + } } + uint64_t padded_n = n + minimum_padding; // Pad reduced dimension to the required number of elements. bool no_padding_necessary = n == padded_n; @@ -179,8 +184,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { for (int64_t dim_idx = 0; dim_idx < padded[0]->shape().dimensions_size(); dim_idx++) { if (dim_idx == reduced_input_dimension) { - reshaped_dimensions.push_back(k); - reshaped_dimensions.push_back(padded_n / k); + reshaped_dimensions.push_back(best_k); + reshaped_dimensions.push_back(padded_n / best_k); } else { reshaped_dimensions.push_back(padded[0]->shape().dimensions(dim_idx)); } From 1b83acbb147cf7d17d922985275e6005b31d101e Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 30 Nov 2023 01:54:46 -0800 Subject: [PATCH 232/381] [XLA] [NFC] Minor HTML rendering QoL improvements Cutoff an overly long list of control-predecessors and show total shape size. PiperOrigin-RevId: 586598174 --- .../xla/xla/service/hlo_graph_dumper.cc | 26 ++++++++++++++++--- .../xla/xla/service/hlo_graph_dumper.h | 3 +++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index bcb761531d2a5f..9d58c165844689 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -644,6 +644,11 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { return false; } + if (subcomp->WhileCallInstruction() != nullptr && + !hlo_render_options_.show_while_subcomputations) { + return false; + } + // Show the subcomputation if we're showing any of its members. return absl::c_any_of( subcomp->instructions(), @@ -1392,11 +1397,12 @@ std::string HloDotDumper::GetInstructionNodeExtraInfo( for (const auto& line : instr->ExtraAttributesToString( HloPrintOptions().set_print_subcomputation_mode( HloPrintOptions::PrintSubcomputationMode::kOff))) { - // Some instructions have giant device identifier fields, so truncate their - // length to 128. + // Some instructions have giant device identifier or control-predecessor + // fields, so truncate their length to 128. constexpr int kMaxDeviceIdFieldLen = 128; if ((absl::StartsWith(line, "replica_groups=") || - absl::StartsWith(line, "source_target_pairs=")) && + absl::StartsWith(line, "source_target_pairs=") || + absl::StartsWith(line, "control-predecessors=")) && line.length() > kMaxDeviceIdFieldLen) { lines.push_back(HtmlLikeStringSanitize( StrCat(line.substr(0, kMaxDeviceIdFieldLen - 3), "..."))); @@ -1575,8 +1581,20 @@ NodeFilter MakeNodeRadiusAroundFilter( // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { + // Special logic for handling bitcasts: since sometimes bitcasts are not + // fused, they create a lot of extra nodes in the graph, with exactly + // one input and output. Adding such nodes does not "really" increase + // the size of the graph (since they don't add extra information), and + // stopping the rendering early cuts off important information (you + // almost never want the rendering to be cutoff at the bitcast: you'd + // like to see it's parent). if (!nodes.contains(operand)) { - worklist.push_back({operand, depth + 1}); + int new_depth = (operand->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kBitcast) + ? depth + : depth + 1; + + worklist.push_back({operand, new_depth}); } } } diff --git a/third_party/xla/xla/service/hlo_graph_dumper.h b/third_party/xla/xla/service/hlo_graph_dumper.h index 12b589abb93ce0..1be176c331dc7f 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.h +++ b/third_party/xla/xla/service/hlo_graph_dumper.h @@ -72,6 +72,9 @@ struct HloRenderOptions { // Include the fusion subcomputations in the rendered graph. bool show_fusion_subcomputations = true; + + // Include the while subcomputations in the rendered graph. + bool show_while_subcomputations = true; }; // Renders an HLO module as a human-readable visual graph. From a8d000505e924cac9e8c6bfee544912292957d7e Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 30 Nov 2023 02:02:22 -0800 Subject: [PATCH 233/381] Add support for simplifying reverse. Drive-by fixes: - Fix indexing maps for reverse - Make GetIndexingMapsForEntryComputation harder to misuse - Fix FusionExponentialDuplication test PiperOrigin-RevId: 586599815 --- .../xla/service/gpu/model/tile_analysis.cc | 38 +++++++---- .../service/gpu/model/tile_analysis_test.cc | 64 ++++++++++++++++--- 2 files changed, 81 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 771ec1494e48e6..a4e41693298868 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/tile_analysis.h" +#include #include #include #include @@ -530,11 +531,11 @@ StatusOr ComputeReverseOpIndexing( exprs.push_back(dim_expr); continue; } - auto dim_size = getAffineConstantExpr(output_dim, mlir_context); + auto dim_bound = getAffineConstantExpr(output_dim - 1, mlir_context); auto neg_dim_expr = getAffineBinaryOpExpr( AffineExprKind::Mul, getAffineConstantExpr(-1, mlir_context), dim_expr); exprs.push_back( - getAffineBinaryOpExpr(AffineExprKind::Add, neg_dim_expr, dim_size)); + getAffineBinaryOpExpr(AffineExprKind::Add, neg_dim_expr, dim_bound)); } IndexingMap indexing_map{ @@ -637,6 +638,11 @@ std::string ToStringImpl(const T& value) { return ss.str(); } +int64_t FloorDiv(int64_t dividend, int64_t divisor) { + return dividend / divisor - + (((dividend >= 0) != (divisor >= 0) && dividend % divisor) ? 1 : 0); +} + struct IndexingMapSimplifier { struct Bounds { int64_t lower; @@ -650,7 +656,6 @@ struct IndexingMapSimplifier { switch (expr.getKind()) { case AffineExprKind::Constant: { int64_t value = mlir::cast(expr).getValue(); - CHECK_GE(value, 0); return bounds[expr] = {value, value}; } case AffineExprKind::DimId: { @@ -673,12 +678,15 @@ struct IndexingMapSimplifier { switch (expr.getKind()) { case AffineExprKind::Add: return result = {lhs.lower + rhs.lower, lhs.upper + rhs.upper}; - case AffineExprKind::Mul: - return result = {lhs.lower * rhs.lower, lhs.upper * rhs.upper}; + case AffineExprKind::Mul: { + int64_t a = lhs.lower * rhs.lower; + int64_t b = lhs.upper * rhs.upper; + return result = {std::min(a, b), std::max(a, b)}; + } case AffineExprKind::Mod: { CHECK_EQ(rhs.lower, rhs.upper) << "RHS of mod must be a constant"; int64_t m = rhs.lower; - if (lhs.upper < m) { + if (0 <= lhs.lower && lhs.upper < m) { return result = lhs; } return result = {0, m - 1}; @@ -687,7 +695,9 @@ struct IndexingMapSimplifier { CHECK_EQ(rhs.lower, rhs.upper) << "RHS of floor_div must be a constant"; int64_t d = rhs.lower; - return result = {lhs.lower / d, lhs.upper / d}; + int a = FloorDiv(lhs.lower, d); + int b = FloorDiv(lhs.upper, d); + return result = {std::min(a, b), std::max(a, b)}; } default: // We don't use ceildiv, so we don't support it. @@ -706,7 +716,7 @@ struct IndexingMapSimplifier { auto rhs = BoundsInclusive(mod.getRHS()); // a % b where b is always larger than a? - if (lhs.upper < rhs.lower) return lhs_simplified; + if (0 <= lhs.lower && lhs.upper < rhs.lower) return lhs_simplified; // The logic below assumes we have a constant RHS. if (rhs.lower != rhs.upper) return mod; @@ -744,7 +754,7 @@ struct IndexingMapSimplifier { auto lhs = BoundsInclusive(lhs_simplified); auto rhs = BoundsInclusive(div.getRHS()); - if (lhs.upper < rhs.lower) { + if (0 <= lhs.lower && lhs.upper < rhs.lower) { return getAffineConstantExpr(0, mlir_context); } @@ -752,6 +762,12 @@ struct IndexingMapSimplifier { if (rhs.lower != rhs.upper) return div; int64_t d = rhs.lower; + int64_t a = FloorDiv(lhs.lower, d); + int64_t b = FloorDiv(lhs.upper, d); + if (a == b) { + return getAffineConstantExpr(a, mlir_context); + } + AffineExpr extracted = getAffineConstantExpr(0, mlir_context); auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { if (auto multiplier = GetConstantRhsMultiplier(expr)) { @@ -865,9 +881,9 @@ bool IndexingMap::Simplify(absl::Span dimension_sizes) { return false; } - affine_map = + affine_map = mlir::simplifyAffineMap( AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(), - results, affine_map.getContext()); + results, affine_map.getContext())); return true; } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 7dadcc852974c0..b79d99ab8158f9 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -63,6 +63,13 @@ class TileAnalysisTest : public HloTestBase { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); + for (auto* operand : root->operands()) { + TF_RET_CHECK(operand->opcode() == HloOpcode::kParameter || + operand->opcode() == HloOpcode::kConstant) + << "If there are multiple instructions, they need to be wrapped in a " + "fusion."; + } + return ComputeInstructionIndexing(root, operand_id, &mlir_context_); } mlir::MLIRContext mlir_context_; @@ -202,7 +209,8 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( HloModule test_module - ENTRY entry_computation { + + fused_computation { p0 = f32[4] parameter(0) p1 = f32[4] parameter(1) add0 = f32[4] add(p0, p1) @@ -212,14 +220,27 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { slice2.0 = f32[2] slice(add1), slice={[0:2]} slice2.1 = f32[2] slice(add1), slice={[1:3]} ROOT add2 = f32[2] add(slice2.0, slice2.1) - })")); + } + + ENTRY entry_computation { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + ROOT fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=fused_computation + })")); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))))); + UnorderedElementsAre( + MatchOperandIndexing( + 0, UnorderedElementsAre( + MatchIndexingMap("(d0) -> (d0)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))), + MatchOperandIndexing( + 1, + UnorderedElementsAre( + MatchIndexingMap("(d0) -> (d0)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { @@ -617,17 +638,40 @@ TEST_F(TileAnalysisTest, ReverseOp) { HloModule m ENTRY e { p0 = f32[1, 17, 9, 9] parameter(0) - ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} + ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} } )")); - // TODO(b/313840171): Support simplifying this. + EXPECT_FALSE(input_indexing.Simplify({1, 17, 9, 9})); EXPECT_THAT(input_indexing.operand_indexing_maps, ElementsAre(MatchOperandIndexing( 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)", + "(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)", std::vector{}))))); } +TEST_F(TileAnalysisTest, ReverseReshape) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + HloModule m + fused_computation { + p0 = f32[10, 11] parameter(0) + reverse.0 = f32[10, 11] reverse(p0), dimensions={0, 1} + reshape.0 = f32[110] reshape(reverse.0) + reverse.1 = f32[110] reverse(reshape.0), dimensions={0} + ROOT reshape.1 = f32[10, 11] reshape(reverse.1) + } + ENTRY e { + p0 = f32[10, 11] parameter(0) + ROOT fusion = f32[10, 11] fusion(p0), kind=kLoop, calls=fused_computation + } + )")); + EXPECT_TRUE(input_indexing.Simplify({10, 11})); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(MatchOperandIndexing( + 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); +} + TEST_F(TileAnalysisTest, SliceOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( From f2e5769cf40ae43a16ddc177e8c2dcf3d0e7e3df Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 30 Nov 2023 02:42:11 -0800 Subject: [PATCH 234/381] [TileAnalysis] Add an example for a fused dot. PiperOrigin-RevId: 586610448 --- .../xla/service/gpu/model/tile_analysis.cc | 1 + .../service/gpu/model/tile_analysis_test.cc | 90 +++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index a4e41693298868..7fe45637bf2cfa 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -212,6 +212,7 @@ StatusOr ComputeFusionOpIndexing( for (const auto& operand_indexing : instr_indexing.operand_indexing_maps) { const HloInstruction* producer_instr = instr->operand(operand_indexing.operand_id); + if (producer_instr->IsConstant()) continue; // If the producer is a fusion op parameter, store the result. if (auto parameter = DynCast(producer_instr)) { parameter_indexing_maps[parameter->parameter_number()].insert( diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index b79d99ab8158f9..518c0c9e8f3e2d 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -182,6 +182,96 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { "(d0) -> (d0)", std::vector{}))))); } +TEST_F(TileAnalysisTest, FusionOpWithDot) { + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetIndexingMapsForEntryComputation(R"( + f { + p0 = s8[3,12288,6,128]{3,2,1,0} parameter(0) + bitcast1 = s8[3,6,128,12288]{2,1,3,0} bitcast(p0) + copy1 = s8[3,6,128,12288]{3,2,1,0} copy(bitcast1) + bitcast2 = s8[2304,12288]{1,0} bitcast(copy1) + convert1 = bf16[2304,12288]{1,0} convert(bitcast2) + bitcast3 = bf16[2304,16,768]{2,1,0} bitcast(convert1) + p3 = bf16[16,12288]{1,0} parameter(3) + convert2 = f32[16,12288]{1,0} convert(p3) + p4 = bf16[16,12288]{1,0} parameter(4) + convert3 = f32[16,12288]{1,0} convert(p4) + add1 = f32[16,12288]{1,0} add(convert2, convert3) + p2 = bf16[16]{0} parameter(2) + convert15 = f32[16]{0} convert(p2) + rsqrt = f32[16]{0} rsqrt(convert15) + convert4 = bf16[16]{0} convert(rsqrt) + bcast1 = bf16[16,12288]{1,0} broadcast(convert4), dimensions={0} + convert5 = f32[16,12288]{1,0} convert(bcast1) + multiply1 = f32[16,12288]{1,0} multiply(add1, convert5) + p1 = bf16[12288]{0} parameter(1) + convert6 = f32[12288]{0} convert(p1) + c1 = bf16[] constant(1) + bcast2 = bf16[12288]{0} broadcast(c1), dimensions={} + convert7 = f32[12288]{0} convert(bcast2) + add2 = f32[12288]{0} add(convert6, convert7) + convert8 = bf16[12288]{0} convert(add2) + bcast3 = bf16[16,12288]{1,0} broadcast(convert8), dimensions={1} + convert9 = f32[16,12288]{1,0} convert(bcast3) + multiply2 = f32[16,12288]{1,0} multiply(multiply1, convert9) + convert10 = bf16[16,12288]{1,0} convert(multiply2) + bcast4 = bf16[16,16,768]{2,1,0} bitcast(convert10) + dot = bf16[16,2304,16]{2,1,0} dot(bitcast3, bcast4), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={2} + bcast5 = bf16[16,3,6,128,16]{4,3,2,1,0} bitcast(dot) + copy2 = bf16[16,3,6,128,16]{3,2,4,1,0} copy(bcast5) + convert13 = f32[16,3,6,128,16]{3,2,4,1,0} convert(copy2) + p5 = bf16[3,6,128]{2,1,0} parameter(5) + bcast6 = bf16[3,6,128,16]{2,1,3,0} broadcast(p5), dimensions={0,1,2} + convert11 = f32[3,6,128,16]{2,1,3,0} convert(bcast6) + bcast7 = f32[16,3,6,128,16]{3,2,4,1,0} broadcast(convert11), + dimensions={1,2,3,4} + multiply3 = f32[16,3,6,128,16]{3,2,4,1,0} multiply(convert13, bcast7) + convert12 = bf16[16,3,6,128,16]{3,2,4,1,0} convert(multiply3) + ROOT bcast8 = bf16[16,16,3,1,6,128]{5,4,1,3,2,0} bitcast(convert12) + } + ENTRY e { + p0 = s8[3,12288,6,128]{3,2,1,0} parameter(0) + p1 = bf16[12288]{0} parameter(1) + p2 = bf16[16]{0} parameter(2) + p3 = bf16[16,12288]{1,0} parameter(3) + p4 = bf16[16,12288]{1,0} parameter(4) + p5 = bf16[3,6,128]{2,1,0} parameter(5) + ROOT fusion = bf16[16,16,3,1,6,128]{5,4,1,3,2,0} + fusion(p0, p1, p2, p3, p4, p5), kind=kLoop, calls=f + } + )")); + EXPECT_TRUE(input_indexing.Simplify({16, 16, 3, 1, 6, 128})); + + EXPECT_THAT( + input_indexing.operand_indexing_maps, + UnorderedElementsAre( + MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> " + "(d2 + d3, d0 * 768 + s0, d4, d5)", + std::vector{768}))), + MatchOperandIndexing( + 1, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0)", + std::vector{768}))), + MatchOperandIndexing( + 2, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5) -> (d1)", std::vector{}))), + MatchOperandIndexing( + 3, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", + std::vector{768}))), + MatchOperandIndexing( + 4, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", + std::vector{768}))), + MatchOperandIndexing( + 5, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5) -> (d2 + d3, d4, d5)", + std::vector{}))))); +} + TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, GetIndexingMapsForEntryComputation(R"( From 52c26c6865b04666bd23247f109bf1db73953283 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 30 Nov 2023 03:51:53 -0800 Subject: [PATCH 235/381] [XLA:GPU] Don't look for roots in non-fusion computation. Creating HloFusionAnalysis doesn't make sense for non-fusion computations, but we still do that for legacy reasons. Looking for root in big non-fusion computations can be very expensive, but we also never use those roots anyway. Just disable them to save ~10% compile time. PiperOrigin-RevId: 586625378 --- .../service/gpu/hlo_fusion_analysis_test.cc | 18 +++++++---- .../xla/xla/service/gpu/hlo_traversal.cc | 30 ++++++++++++++++++- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc index 0a02760077547a..9b1bf89afd94c8 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -69,24 +69,30 @@ TEST_F(HloFusionAnalysisTest, ReductionWithMultipleUsers) { ROOT add = f32[] add(p0, p1) } - ENTRY main { + fused_computation { %p0 = f32[1024] parameter(0) %p1 = f32[] parameter(1) %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add %negate = f32[] negate(%reduce) %log = f32[] log(%reduce) ROOT %tuple = (f32[], f32[]) tuple(%negate, %log) + } + + ENTRY main { + %p0 = f32[1024] parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kLoop, calls=fused_computation })") .value(); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), - HloFusionAdaptor::ForComputation(module->entry_computation()), - &device_info)); + auto analysis, HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()), + &device_info)); // This fusion cannot use the reduction emitter because the reduce has two // users. EXPECT_EQ(analysis.GetEmitterFusionKind(), diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.cc b/third_party/xla/xla/service/gpu/hlo_traversal.cc index dc0b1dcb57a10b..ac5286e88ad685 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal.cc @@ -78,6 +78,27 @@ class HloComputationFusion : public HloFusionAdaptor { public: explicit HloComputationFusion(const HloComputation* computation) : computation_(computation) { + // HloFusionAdaptor should only be created for fusion computations, that + // usually have only a few roots, but there is a case when we can it for + // non-fusion computations with thousands of roots. It happens inside + // `FindNonTrivialHero` and it gets very expensive. Calling + // `FindNonTrivialHero` also doesn't make sense on non-fusion computation, + // but `InstructionFusion` and `FusionMerger` depend on this behavoiur in + // `IsProducerConsumerFusible`. + // + // `FindNonTrivialHero` only call `ContainsInstruction` and doesn't use + // information about roots, so we can skip looking for roots as performance + // optimization. + // TODO(shyshkov): Clean this up once priority fusion is fully launched. + if (computation->IsFusionComputation()) { + roots_ = FindRoots(computation); + } + } + + static absl::InlinedVector FindRoots( + const HloComputation* computation) { + absl::InlinedVector roots; + std::function get_roots; absl::flat_hash_set roots_set; get_roots = [&](const HloInstruction* instr) { @@ -88,11 +109,13 @@ class HloComputationFusion : public HloFusionAdaptor { } else { HloInstructionAdaptor wrapped{*instr}; if (roots_set.insert(wrapped).second) { - roots_.push_back(wrapped); + roots.push_back(wrapped); } } }; get_roots(computation->root_instruction()); + + return roots; } bool ContainsInstruction(HloInstructionAdaptor instruction) const override { @@ -100,6 +123,11 @@ class HloComputationFusion : public HloFusionAdaptor { } absl::InlinedVector GetRoots() const override { + CHECK(!roots_.empty()) + << "No roots found in the computation. HloFusionAdaptor was likely " + "created for a non-fusion computation: " + << computation_->ToString(); + return roots_; } From 7ce56ce029ef6536677109c96018521d3d3bc87c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 30 Nov 2023 04:18:04 -0800 Subject: [PATCH 236/381] Fix reduce-reduce check. This time, allow small reductions. Thanks to shyshkov@ for pointing out the mistake. PiperOrigin-RevId: 586631263 --- .../xla/xla/service/gpu/priority_fusion.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 0c0c4c69a58775..7e804b271250ce 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -538,13 +538,21 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, // Avoid fusing reduce into reduce. Our cost model doesn't currently // understand this case due to a lack of tiling analysis. // TODO(b/312200883): Remove this. - auto contains_reduce = [&](const HloInstruction* instr) { - return HloAnyOf({HloInstructionAdaptor{*instr}}, - *HloFusionAdaptor::ForInstruction(instr), [](auto node) { - return node.opcode() == HloOpcode::kReduce; - }); + auto contains_signficant_reduce = [&](const HloInstruction* instr) { + auto fusion = HloFusionAdaptor::ForInstruction(instr); + return HloAnyOf(fusion->GetRoots(), *fusion, [](auto node) { + if (node.opcode() != HloOpcode::kReduce) return false; + + int64_t reduction_size = + ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) / + ShapeUtil::ElementsIn(node.shape()); + + // Small reductions are emitted using the elemental emitter anyway. + return reduction_size >= 16; + }); }; - if (contains_reduce(producer) && contains_reduce(consumer)) { + if (contains_signficant_reduce(producer) && + contains_signficant_reduce(consumer)) { return "both the producer and the consumer contain a reduce"; } From 33f9004e09389a3480e3421c6905601efd3e7f37 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 30 Nov 2023 05:29:35 -0800 Subject: [PATCH 237/381] Integrate LLVM at llvm/llvm-project@511ba45a47d6 Updates LLVM usage to match [511ba45a47d6](https://github.com/llvm/llvm-project/commit/511ba45a47d6) PiperOrigin-RevId: 586645883 --- third_party/llvm/generated.patch | 11 ----------- third_party/llvm/workspace.bzl | 4 ++-- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index ce1937af46e5d5..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,12 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -594,6 +594,7 @@ - name = "__support_bit", - hdrs = ["src/__support/bit.h"], - deps = [ -+ ":__support_cpp_type_traits", - ":__support_macros_attributes", - ], - ) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 62f02cded785c5..163586edceb97a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f688e0901213726feb9b26cedc61919413cbf59c" - LLVM_SHA256 = "b8885c22a9b77f9c91a316b21d71414a7b48dae38513f170da1554002e85b030" + LLVM_COMMIT = "511ba45a47d6f9e48ad364181830c9fb974135b2" + LLVM_SHA256 = "23b4e703adbb219853d3d375379e15beefea13a9062715f31e50140c2bafe540" tf_http_archive( name = name, From 43c17d7a67c0c80c075a5974cd1da05124287017 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 30 Nov 2023 05:47:45 -0800 Subject: [PATCH 238/381] [XLA] New interactive HTML visualization mode, available through hlo-opt Features: - Start exploration at the root of the instruction, move up - Renders client-side, so can handle truly huge HLOs - Click on the instruction name to expand it and interactively move around the visualization To use: $ hlo-opt --stage=html yourfile.hlo PiperOrigin-RevId: 586649078 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/hlo_graph_dumper.cc | 96 ++++++++++++++++--- .../xla/xla/service/hlo_graph_dumper.h | 2 + third_party/xla/xla/tools/hlo_opt/BUILD | 9 ++ third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 7 +- 5 files changed, 99 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 8ed5f82e417593..8503cec2199642 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5511,6 +5511,7 @@ cc_library( "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor", "//xla/stream_executor:dnn", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index 9d58c165844689..575c62b55369e4 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -22,14 +22,17 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" @@ -98,8 +101,9 @@ class NodeFilter { NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} explicit NodeFilter( - std::function filter) - : filter_(std::move(filter)) {} + std::function filter, + std::optional num_rendered = std::nullopt) + : filter_(std::move(filter)), num_rendered_(num_rendered) {} bool Show(const HloInstruction* instr) const { return filter_(instr) != kHideNode; @@ -120,8 +124,12 @@ class NodeFilter { result == kSomeUsersOmitted; } + // Returns an optionally recorded number of nodes which will be rendered. + std::optional GetNumRendered() const { return num_rendered_; } + private: std::function filter_; + std::optional num_rendered_; }; // We arbitrarily set this as the boundary between "large" and "small" @@ -1661,17 +1669,19 @@ NodeFilter MakeNodeRadiusAroundFilter( // Highlight the root node. nodes[root] = kHighlightNode; - return NodeFilter([=](const HloInstruction* instr) { - auto it = nodes.find(instr); - if (it != nodes.end()) { - return it->second; - } - // Show all nodes in subcomputations. - if (instr->parent() != root->parent()) { - return kNormalNode; - } - return kHideNode; - }); + return NodeFilter( + [=](const HloInstruction* instr) { + auto it = nodes.find(instr); + if (it != nodes.end()) { + return it->second; + } + // Show all nodes in subcomputations. + if (instr->parent() != root->parent()) { + return kNormalNode; + } + return kHideNode; + }, + nodes.size()); } // Gets a node filter that includes nodes on all paths from `from` to `to`. If @@ -1927,6 +1937,14 @@ StatusOr WrapFusionExplorer( } document.getElementById('performance_note').innerText = `Rendering took ${(performance.now() - render_start).toFixed(2)}ms`; + + // Change cursor. + let text_nodes = document.getElementsByTagName("text"); + for (var el of text_nodes) { + if (title_to_id.has(el.innerHTML)) { + el.style.cursor = "pointer"; + } + } }; if (renderCache[dot_ptr]) { render_callback(renderCache[dot_ptr]); @@ -1997,6 +2015,20 @@ StatusOr WrapFusionExplorer( renderFrameList(); renderCurrentFrame(); }); + + window.title_to_id = new Map(); + for (let i=0; i < frames.length; i++) { + title_to_id.set(frames[i][1], i); + } + + // Navigate to next elements on click. + document.addEventListener("click", (event) => { + let txt = event.target.innerHTML; + if (title_to_id.has(txt)) { + let id = title_to_id.get(txt); + window.location.hash = `#frame${id}`; + } + }); }); //--> @@ -2109,6 +2141,44 @@ StatusOr RenderGraph(const HloComputation& computation, return WrapDotInFormat(computation, rendered_dot, format); } +StatusOr RenderAllComputationsToHtml(const HloModule& module) { + FusionVisualizerProgress progress; + + std::vector instrs = + module.entry_computation()->MakeInstructionPostOrder(); + absl::c_reverse(instrs); + for (const HloInstruction* instr : instrs) { + if (absl::c_linear_search( + std::vector{HloOpcode::kConstant, + HloOpcode::kGetTupleElement}, + instr->opcode())) { + continue; + } + + HloRenderOptions opts; + opts.show_fusion_subcomputations = true; + opts.show_backend_config = true; + opts.show_while_subcomputations = instr->opcode() == HloOpcode::kWhile; + + // Dynamically adjusts the radius with a magical cutoff of 100. + static constexpr int64_t max_nodes_to_render = 100; + absl::flat_hash_set render_boundary; + + NodeFilter filter = MakeNodeRadiusAroundFilter(instr, 2, render_boundary); + if (filter.GetNumRendered().value_or(1) > max_nodes_to_render) { + filter = MakeNodeRadiusAroundFilter(instr, 1, render_boundary); + } + + std::string dot = + HloDotDumper(module.entry_computation(), instr->name(), + module.config().debug_options(), opts, filter) + .Dump(); + progress.AddState(dot, instr->name(), std::nullopt); + } + + return WrapFusionExplorer(progress, module.name()); +} + StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, HloRenderOptions hlo_render_options, diff --git a/third_party/xla/xla/service/hlo_graph_dumper.h b/third_party/xla/xla/service/hlo_graph_dumper.h index 1be176c331dc7f..fca549b550fdb1 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.h +++ b/third_party/xla/xla/service/hlo_graph_dumper.h @@ -90,6 +90,8 @@ StatusOr RenderGraph(const HloComputation& computation, RenderedGraphFormat format, HloRenderOptions hlo_render_options = {}); +StatusOr RenderAllComputationsToHtml(const HloModule& module); + // Like RenderGraph, but renders only nodes "near" the given node in the graph. // // The number of nodes dumped is controlled by the radius parameter, which diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index 6d516a6d780c37..b64a6de29fbdb8 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -33,6 +33,7 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/service:compiler", + "//xla/service:platform_util", "//xla/stream_executor:platform", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -51,6 +52,7 @@ cc_library( "//xla:types", "//xla/service:compiler", "//xla/service:dump", + "//xla/service:hlo_graph_dumper", "//xla/service:platform_util", "//xla/service/gpu:executable_proto_cc", "//xla/stream_executor/cuda:cuda_platform_id", @@ -118,3 +120,10 @@ filegroup( ], visibility = ["//visibility:public"], ) + +exports_files( + glob([ + "gpu_specs/*.txtpb", + ]), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 9f6b72b8edd03a..55283f6cffec06 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/service/dump.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/hlo_graph_dumper.h" #include "xla/service/platform_util.h" #include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" @@ -100,9 +101,9 @@ struct GpuOptProvider : public OptProvider { TF_ASSIGN_OR_RETURN( std::unique_ptr optimized_module, compiler->RunHloPasses(std::move(module), executor, opts)); - return RenderGraph(optimized_module->name(), *optimized_module, - RenderedGraphFormat::kHtml, - /*show_fusion_subcomputations=*/false); + TF_ASSIGN_OR_RETURN(std::string computations, + RenderAllComputationsToHtml(*optimized_module)); + return computations; } // Unimplemented stage. From 5cdd8ee0a0a31ce9cb799eeda1fbcd4a02c9a939 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 06:04:36 -0800 Subject: [PATCH 239/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/bf11f0d876fef436c5ea018a5b13eb2e89e53b7a. PiperOrigin-RevId: 586652359 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 7a59477c03eb4e..8cdfbfc358b1fb 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "8f915f25e8b17d2509bb6c7f199a45f2a5e6736c" - TFRT_SHA256 = "6d0cc4221d9bb6739bf16a03da482abc348f6143395726595d89e3f12158a0ea" + TFRT_COMMIT = "bf11f0d876fef436c5ea018a5b13eb2e89e53b7a" + TFRT_SHA256 = "f78926aaefd521c80154ff9c2c85ef77c0bb5ee34f5af4f24f3a349d22b41ff8" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 7a59477c03eb4e..8cdfbfc358b1fb 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "8f915f25e8b17d2509bb6c7f199a45f2a5e6736c" - TFRT_SHA256 = "6d0cc4221d9bb6739bf16a03da482abc348f6143395726595d89e3f12158a0ea" + TFRT_COMMIT = "bf11f0d876fef436c5ea018a5b13eb2e89e53b7a" + TFRT_SHA256 = "f78926aaefd521c80154ff9c2c85ef77c0bb5ee34f5af4f24f3a349d22b41ff8" tf_http_archive( name = "tf_runtime", From 882f1e199432b9461f69603365443598383d4c02 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 30 Nov 2023 06:24:32 -0800 Subject: [PATCH 240/381] [xla:gpu] Split CUTLASS gemm kernel compilation into a separate target Restructure build file so that targets compiling CUTLASS templates into device code do not have any dependencies on XLA PiperOrigin-RevId: 586656972 --- third_party/xla/xla/service/gpu/kernels/BUILD | 77 ++++++++++++++----- ...cu.cc => cutlass_gemm_custom_kernel.cu.cc} | 26 ++----- ..._kernel.h => cutlass_gemm_custom_kernel.h} | 10 +-- ....cc => cutlass_gemm_custom_kernel_test.cc} | 2 +- .../gpu/kernels/cutlass_gemm_fusion.cc | 2 +- .../gpu/kernels/cutlass_gemm_kernels.cu.h | 43 +++++++++++ .../cutlass_gemm_kernels_f32xf32_to_f32.cu.cc | 22 ++++++ 7 files changed, 137 insertions(+), 45 deletions(-) rename third_party/xla/xla/service/gpu/kernels/{cutlass_gemm_kernel.cu.cc => cutlass_gemm_custom_kernel.cu.cc} (62%) rename third_party/xla/xla/service/gpu/kernels/{cutlass_gemm_kernel.h => cutlass_gemm_custom_kernel.h} (78%) rename third_party/xla/xla/service/gpu/kernels/{cutlass_gemm_test.cc => cutlass_gemm_custom_kernel_test.cc} (97%) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index d08e19e189af96..cd18016d849b2a 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -82,7 +82,7 @@ cc_library( # ":custom_fusion", # ":custom_fusion_pattern", # ":custom_kernel", -# ":cutlass_gemm_kernel", +# ":cutlass_gemm_custom_kernel", # "@com_google_absl//absl/status", # "//xla:shape_util", # "//xla:status", @@ -114,13 +114,19 @@ cc_library( # ], # ) # +# #===--------------------------------------------------------------------------------------------===# +# # CUTLASS Gemm <-> xla::gpu::kernel::CustomKernel adaptor +# #===--------------------------------------------------------------------------------------------===# +# # cuda_library( -# name = "cutlass_gemm_kernel", -# srcs = ["cutlass_gemm_kernel.cu.cc"], -# hdrs = ["cutlass_gemm_kernel.h"], +# name = "cutlass_gemm_custom_kernel", +# srcs = ["cutlass_gemm_custom_kernel.cu.cc"], +# hdrs = ["cutlass_gemm_custom_kernel.h"], # visibility = ["//visibility:private"], # deps = [ # ":custom_kernel", +# ":cutlass_gemm_kernels", +# ":cutlass_gemm_kernels_header", # ":cutlass_gemm_universal", # "@com_google_absl//absl/status", # "@com_google_absl//absl/strings", @@ -131,25 +137,12 @@ cc_library( # ], # ) # -# cuda_library( -# name = "cutlass_gemm_universal", -# hdrs = ["cutlass_gemm_universal.cu.h"], -# visibility = ["//visibility:private"], -# deps = [ -# "@com_google_absl//absl/status", -# "@com_google_absl//absl/strings", -# "//third_party/gpus/cutlass", -# "//xla:statusor", -# "//xla/stream_executor", -# ], -# ) -# # xla_test( -# name = "cutlass_gemm_test", -# srcs = if_cuda_is_configured(["cutlass_gemm_test.cc"]), +# name = "cutlass_gemm_custom_kernel_test", +# srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_test.cc"]), # backends = ["gpu"], # deps = [ -# ":cutlass_gemm_kernel", +# ":cutlass_gemm_custom_kernel", # "//xla:types", # "//xla:xla_data_proto_cc", # "//xla/stream_executor", @@ -164,4 +157,48 @@ cc_library( # ], # ) # +# #===--------------------------------------------------------------------------------------------===# +# # CUTLASS GemmUniversal-base kernels <-> StreamExecutor adaptor +# #===--------------------------------------------------------------------------------------------===# +# +# cuda_library( +# name = "cutlass_gemm_universal", +# hdrs = ["cutlass_gemm_universal.cu.h"], +# visibility = ["//visibility:private"], +# deps = [ +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", +# "//third_party/gpus/cutlass", +# "//xla:statusor", +# "//xla/stream_executor", +# ], +# ) +# +# #===--------------------------------------------------------------------------------------------===# +# # CUTLASS Gemm kernels implementation +# #===--------------------------------------------------------------------------------------------===# +# +# # We split each individual kernel into a separate targets to compile them all in parallel. We also +# # do not have any dependencies except CUTLASS itself to reduce the number of recompilations. +# +# cuda_library( +# name = "cutlass_gemm_kernels", +# visibility = ["//visibility:private"], +# deps = [":cutlass_gemm_kernels_f32xf32_to_f32"], +# ) +# +# cuda_library( +# name = "cutlass_gemm_kernels_header", +# hdrs = ["cutlass_gemm_kernels.cu.h"], +# visibility = ["//visibility:private"], +# deps = ["//third_party/gpus/cutlass"], +# ) +# +# cuda_library( +# name = "cutlass_gemm_kernels_f32xf32_to_f32", +# srcs = ["cutlass_gemm_kernels_f32xf32_to_f32.cu.cc"], +# visibility = ["//visibility:private"], +# deps = [":cutlass_gemm_kernels_header"], +# ) +# # copybara:uncomment_end(google-only) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc similarity index 62% rename from third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc rename to third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc index a84686a4617b3e..1dc483cc3ac604 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include #include #include "absl/status/status.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h" #include "xla/service/gpu/kernels/cutlass_gemm_universal.cu.h" #include "xla/statusor.h" #include "xla/stream_executor/kernel_spec.h" @@ -27,30 +28,19 @@ limitations under the License. namespace xla::gpu::kernel { -using F32xF32toF32 = - cutlass::gemm::device::GemmUniversal; - -//===----------------------------------------------------------------------===// -// Adaptor from a CUTLASS GemmUniversal to a CustomKernel. -//===----------------------------------------------------------------------===// - template -StatusOr LoadCutlassGemmUniversal(int32_t m, int32_t n, - int32_t k) { +static StatusOr LoadCutlassGemmUniversal(int32_t m, int32_t n, + int32_t k) { using Kernel = typename Gemm::GemmKernel; cutlass::gemm::GemmCoord problem_size = {m, n, k}; - // TODO(ezhulenev): We should generate more descriptive names for custom - // kernels, i.e. include tile and dimensions sizes, dtypes, etc. se::MultiKernelLoaderSpec spec( /*arity=*/1, gemm_universal::ArgsPacking(problem_size)); - spec.AddInProcessSymbol(reinterpret_cast(cutlass::Kernel2), - "cutlass_universal_gemm"); + spec.AddInProcessSymbol(internal::GetCutlassGemmKernel(), + "cutlass_gemm"); - return CustomKernel("cutlass_gemm:f32<-f32xf32", std::move(spec), + return CustomKernel("cutlass_gemm", std::move(spec), gemm_universal::BlockDim(problem_size), gemm_universal::ThreadDim(), sizeof(typename Kernel::SharedStorage)); @@ -62,7 +52,7 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, return absl::InvalidArgumentError( "Currently cutlass gemm kernel supports only F32 data type"); - return LoadCutlassGemmUniversal(m, n, k); + return LoadCutlassGemmUniversal(m, n, k); } } // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h similarity index 78% rename from third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.h rename to third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index 41cf68e5619a3c..364f8b9aba73a6 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ -#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ + +#include #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/statusor.h" @@ -22,11 +24,9 @@ limitations under the License. namespace xla::gpu::kernel { -// A reference implementation GEMM kernel written in CUTLASS based on -// `00_basic_gemm` example. StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, int32_t n, int32_t k); } // namespace xla::gpu::kernel -#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/kernels/cutlass_gemm_test.cc rename to third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 5d73ffd7be2b04..e627cb449b6fb5 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index ac58ba9c2d8cae..7c4854f2d877c6 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_fusion.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/status.h" diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h new file mode 100644 index 00000000000000..f70bb49db4ac72 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNELS_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNELS_CU_H_ + +#include "third_party/gpus/cutlass/include/cutlass/gemm/device/gemm_universal.h" + +namespace xla::gpu::kernel { + +struct CutlassGemmKernels { + using F32xF32toF32 = + cutlass::gemm::device::GemmUniversal; +}; + +namespace internal { + +template +void* GetCutlassGemmKernel() { + return reinterpret_cast(cutlass::Kernel2); +} + +// Extern templates for all supported CUTLASS Gemm kernels. +extern template void* GetCutlassGemmKernel(); + +} // namespace internal +} // namespace xla::gpu::kernel + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNELS_CU_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc new file mode 100644 index 00000000000000..1dc1a9e1aa0fe7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc @@ -0,0 +1,22 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h" + +namespace xla::gpu::kernel::internal { + +template void* GetCutlassGemmKernel(); + +} // namespace xla::gpu::kernel::internal From 0d8a6a4eadc1604955f5a56d850b15e989a8dc7c Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 30 Nov 2023 06:39:00 -0800 Subject: [PATCH 241/381] [XLA:GPU] Tiling propagation: generalize handling of concatenations. Store offsets along the concatenated dimensions in the dimension orders of operands to reflect that they may be incompatible with other uses of the same operands. This allows fusing concatenations which are not the only users of their inputs. PiperOrigin-RevId: 586659844 --- .../service/gpu/gemm_rewriter_triton_test.cc | 10 +++---- .../xla/xla/service/gpu/ir_emitter_triton.cc | 26 ++++++++---------- .../service/gpu/triton_tiling_propagation.cc | 19 ++++++------- .../service/gpu/triton_tiling_propagation.h | 27 +++++++++++++++++++ 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 8913782398cca7..e54bb3497825fa 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -845,7 +845,7 @@ e { EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(1), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, - /*slice_start=*/0, /*sliced_count=*/128, + /*slice_start=*/-1536, /*sliced_count=*/128, /*subfragments=*/ElementsAre(128)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, @@ -856,7 +856,7 @@ e { EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, computation->parameter_instruction(2), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, - /*slice_start=*/0, + /*slice_start=*/-1536 - 128, /*sliced_count=*/256, /*subfragments=*/ElementsAre(256)))); } @@ -923,7 +923,7 @@ e { } TEST_F(GemmRewriterTritonLevel2Test, - TwoConcatenationsOfSameParametersAreNotFused) { + DifferentConcatenationOfSameParametersIsNotFused) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( e { @@ -945,8 +945,8 @@ e { .Run(module.get()) .value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Concatenate(), - m::Parameter())))); + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter(), + m::Parameter(), m::Parameter())))); } } // namespace diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 93f511506eea37..818d4530100aeb 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -1234,18 +1234,16 @@ class MatMulEmitterHelper { CHECK_EQ(bases.size(), hlo->operand_count()); concat_boundaries.reserve(hlo->operand_count() - 1); - int64_t accumulated_size = 0; for (int i = 0; i < hlo->operand_count() - 1; ++i) { - const int64_t operand_size = + const TensorIterationSpec::IterationSpecFragment& fragment = analysis_.IterSpec(side.scope, hlo->operand(i), concat_dim_idx) - ->at(0) - .count; - if (operand_size % properties.block_size != 0) { + ->at(0); + if (fragment.sliced_count % properties.block_size != 0) { return UncompilableMatmul( "Operand is not divisible by the block size."); } - accumulated_size += operand_size; - concat_boundaries.push_back(Cst32(accumulated_size)); + concat_boundaries.push_back( + Cst32(-fragment.slice_start + fragment.sliced_count)); } concat_dim_pid_offset = @@ -1285,10 +1283,8 @@ class MatMulEmitterHelper { specs.push_back( analysis_.IterSpec(side.scope, input, properties.index)); input_strides.push_back(Cst64(specs.back()->at(0).stride)); - input_offsets.push_back(b_.create( - pid_offset, input_offsets.empty() - ? Cst32(0) - : concat_boundaries[input_offsets.size() - 1])); + input_offsets.push_back(b_.create( + pid_offset, Cst32(specs.back()->at(0).slice_start))); input_bounds.push_back(Cst64(specs.back()->at(0).count)); } strides.push_back(EmitMultiSelect(b_, concat_dim_pid_offset, @@ -1300,10 +1296,10 @@ class MatMulEmitterHelper { EmitMultiSelect(b_, pid_offset, concat_boundaries, input_bounds)); } else { block_offsets.push_back(pid_offset); - int64_t count = specs.back()->at(0).count; + int64_t count = specs.front()->at(0).count; if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && properties.index == dims_.out_lhs_noncontracting_dim_idx && - specs.back()->size() == 1 && + specs.front()->size() == 1 && dims_.lhs_noncontracting_split.has_value()) { // Dimension of the output produced by the non-contracting LHS one // is logically split, major part is addressed using pid_batch. @@ -1314,7 +1310,7 @@ class MatMulEmitterHelper { boundary_checks.push_back(bounds.size() - 1); } } - tensor_offsets.push_back(Cst32(specs.back()->at(0).slice_start)); + tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); block_dims.push_back(properties.block_size); dim_order.emplace(dim_order.begin(), dim_order.size()); }; @@ -1396,7 +1392,7 @@ class MatMulEmitterHelper { } Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); } - Value Cst32(int64_t v) { return CreateConst(b_, i32_ty_, v); } + Value Cst32(int32_t v) { return CreateConst(b_, i32_ty_, v); } Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } ImplicitLocOpBuilder& b_; diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index ec45285cb517f8..4db8c75024798e 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -633,6 +633,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } DimOrderMap dst_dim_orders; + int64_t concat_accumulated_size = 0; for (const HloInstruction* dst : GetDestHlos(hlo, direction)) { DimensionOrder& dst_dim_order = dst_dim_orders.insert({dst, DimensionOrder()}).first->second; @@ -685,13 +686,19 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } else if (hlo.opcode() == HloOpcode::kConcatenate) { dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { - dst_logical[i] = src_logical[i]; if (i == hlo.concatenate_dimension()) { if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { return FusionDecision("Unsupported concatenation."); } - dst_logical[i][0]->set_count(dst->shape().dimensions(i)); - dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); + const Fragment& src_fragment = *src_logical[i][0]; + Fragment& dst_fragment = new_fragments.emplace_back( + src_fragment.dst_dim_number(), dst->shape().dimensions(i)); + dst_fragment.set_slice(-concat_accumulated_size, + dst->shape().dimensions(i)); + concat_accumulated_size += dst->shape().dimensions(i); + dst_logical[i].push_back(&dst_fragment); + } else { + dst_logical[i] = src_logical[i]; } } } else if (hlo.opcode() == HloOpcode::kCopy) { @@ -863,12 +870,6 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { return "Unsupported concatenation."; } - if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) { - return operand->user_count() > 1; - })) { - return FusionDecision( - "Concatenation has to be the only user of its inputs."); - } if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { // In the current simple implementation of concatenation the size of // each of its inputs along the concatenated dimension has to be diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h index 08445f73b76080..a24085318612a3 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.h +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.h @@ -34,6 +34,33 @@ limitations under the License. namespace xla { namespace gpu { +// Illustration explaining why slice_start for concatenations is negative: + +// Slice +// ===== +// input +// [--------------------------] +// . . . +// . offset . +// |------> output . +// [--------] +// +// output[x] = input[x + offset] + +// Concatenation +// ============= +// +// input_n +// [......][--------][........] +// . . +// offset . . +// <-------| . +// . . . +// . . output . +// [--------------------------] +// +// output[x] = input_n[x - offset] + class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride` From 45374d5137889cf50c4acfd45163ccd865801116 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Thu, 30 Nov 2023 06:54:23 -0800 Subject: [PATCH 242/381] Add overload of `ModifyGraphWithDelegate` that takes a `TfLiteOpaqueDelegate*`. This is similar to the existing overload of `InterpreterBuilder::AddDelegate`. When building with TFLITE_USE_OPAQUE_DELEGATE, this overload is needed to use ModifyGraphWithDelegate with delegates created using TfLiteOpaqueDelegateCreate. InterpreterBuilder::AddDelegate remains preferred over Interpreter::ModifyGraphWithDelegate, but this new overload is useful because it reduces the burden for converting code to use the opaque delegate APIs. PiperOrigin-RevId: 586663060 --- tensorflow/lite/core/c/BUILD | 27 +++++++++++++++++++ tensorflow/lite/core/c/c_api_test.cc | 4 +++ tensorflow/lite/core/interpreter.h | 1 + .../lite/core/interpreter_experimental.cc | 6 +++++ tensorflow/lite/delegates/delegate_test.cc | 6 +++-- 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 35fcea37f9d62d..70999f0b24cf7a 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -190,6 +190,33 @@ cc_test( ], ) +cc_test( + name = "c_api_test_with_opaque_delegate", + size = "small", + srcs = ["c_api_test.cc"], + copts = tflite_copts(), + data = [ + "//tensorflow/lite:testdata/2_subgraphs.bin", + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + "//tensorflow/lite:testdata/custom_sinh.bin", + ], + local_defines = ["TFLITE_USE_OPAQUE_DELEGATE"], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_types", + ":common", + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core:subgraph", + "//tensorflow/lite/delegates:delegate_test_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "selectively_built_c_api_test", size = "small", diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index abb0083e12578c..189cd9815f8ebf 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -291,6 +291,7 @@ TEST(CApiSimple, TfLiteInterpreterGetTensor) { TfLiteInterpreterDelete(interpreter); } +#if !TFLITE_USE_OPAQUE_DELEGATE TEST(CApiSimple, Delegate) { TfLiteModel* model = TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); @@ -316,6 +317,7 @@ TEST(CApiSimple, Delegate) { EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); TfLiteInterpreterDelete(interpreter); } +#endif TEST(CApiSimple, DelegateExternal_GetExecutionPlan) { TfLiteModel* model = @@ -409,6 +411,7 @@ TEST(CApiSimple, DelegateExternal_MarkSubgraphAsDelegationSkippable) { TfLiteOpaqueDelegateDelete(opaque_delegate); } +#if !TFLITE_USE_OPAQUE_DELEGATE TEST(CApiSimple, DelegateFails) { TfLiteModel* model = TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); @@ -428,6 +431,7 @@ TEST(CApiSimple, DelegateFails) { TfLiteInterpreterOptionsDelete(options); TfLiteModelDelete(model); } +#endif struct DelegateState { bool delegate_prepared; diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index 334de115286d93..ed9d798f34753b 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -580,6 +580,7 @@ class Interpreter { /// 5. kTfLiteError: Unexpected/runtime failure. \n /// \warning This is an experimental API and subject to change. \n TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + TfLiteStatus ModifyGraphWithDelegate(TfLiteOpaqueDelegateStruct* delegate); // Owning handle to a TfLiteDelegate instance. using TfLiteDelegatePtr = diff --git a/tensorflow/lite/core/interpreter_experimental.cc b/tensorflow/lite/core/interpreter_experimental.cc index e04b1d3e7c675d..016d45df977955 100644 --- a/tensorflow/lite/core/interpreter_experimental.cc +++ b/tensorflow/lite/core/interpreter_experimental.cc @@ -84,6 +84,12 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { return ModifyGraphWithDelegateImpl(delegate); } +TfLiteStatus Interpreter::ModifyGraphWithDelegate( + TfLiteOpaqueDelegateStruct* delegate) { + return ModifyGraphWithDelegateImpl( + reinterpret_cast(delegate)); +} + bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); } TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index dbfb3de9f4cfb7..560b2b4c65b940 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -51,7 +51,8 @@ using test_utils::TestTwoDelegates; namespace { TEST_F(TestDelegate, NullDelegate) { - EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(nullptr), + TfLiteOpaqueDelegate* delegate = nullptr; + EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(delegate), kTfLiteDelegateError); } @@ -1488,7 +1489,8 @@ TEST_P(TestFP16Delegation, NonDelegatedInterpreterWorks) { } TEST_F(TestFP16Delegation, NullDelegate) { - EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(nullptr), + TfLiteOpaqueDelegate* delegate = nullptr; + EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(delegate), kTfLiteDelegateError); // Verify that resulting interpreter still works, despite null delegate. ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); From e4a313537fe9dfc0099ff92671656afc3366a90d Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Thu, 30 Nov 2023 08:37:21 -0800 Subject: [PATCH 243/381] Disable map_fusion experiment due to suspected errors. PiperOrigin-RevId: 586685567 --- tensorflow/core/data/dataset_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 8024db152df9d7..fd785c0f51647f 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -1007,7 +1007,7 @@ REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("reduce_array_record_dataset_memory_usage", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<50>, +REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<0>, AllTasks); } // namespace } // namespace data From fb55ba38c3fe79553eac7fc19403fb90420350c1 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 30 Nov 2023 08:49:31 -0800 Subject: [PATCH 244/381] Remove `alwayslink = True` for the target `quantization/stablehlo:passes`. PiperOrigin-RevId: 586688349 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 0c3f41c4596a01..555a53b5011814 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -25,7 +25,6 @@ package( licenses = ["notice"], ) -# TODO(b/264218457): Add quantize and post_quantize passes. cc_library( name = "passes", srcs = [ @@ -104,9 +103,6 @@ cc_library( "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", ], - # Alwayslink is required for registering the MLIR passes. - # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. - alwayslink = True, ) cc_library( From 297e7ec123aa397ac55f7330b42b091258d5ce1f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 30 Nov 2023 09:24:58 -0800 Subject: [PATCH 245/381] [XLA:CPU] Enforce a major-to-minor layout constraint on the TopK custom call. The emitter depends on this layout, but layout assignment doesn't enforce it. This bug was revealed by a change adding an AllGather op that enforced a different layout constraint on one such TopK operator. PiperOrigin-RevId: 586697397 --- .../xla/service/cpu/cpu_layout_assignment.cc | 22 ++++++++++++++----- third_party/xla/xla/service/cpu/ir_emitter.cc | 11 ++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 8b124ddaa60397..dbfda4a5362a68 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -15,7 +15,10 @@ limitations under the License. #include "xla/service/cpu/cpu_layout_assignment.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" #include "xla/map_util.h" @@ -78,12 +81,17 @@ static optional ShouldMakeOperandColumnMajor( return it->second ? operand_idx : nullopt; } -static Shape RowMajorShape(const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; +static Shape RowMajorShape(Shape shape) { + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->IsArray()) { + return; + } + std::vector dimension_order(subshape->dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *subshape->mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + }); + return shape; } static Shape ColMajorShape(const Shape& old_shape) { @@ -103,6 +111,8 @@ static bool OperandsAndResultMustHaveRowMajorLayout( } else if (instr.opcode() == HloOpcode::kDot) { return DotOperandsAndResultMustHaveRowMajorLayout(instr, target_machine_features); + } else if (instr.opcode() == HloOpcode::kCustomCall) { + return instr.custom_call_target() == "TopK"; } return false; } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 18523dec844113..15446b21998770 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2378,13 +2378,16 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { const HloInstruction* input = hlo->operand(0); const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back(); const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2; - TF_RET_CHECK(input->shape().element_type() == F32); + TF_RET_CHECK(input->shape().element_type() == F32) << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(0).layout())); + hlo->shape().tuple_shapes(0).layout())) + << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(1).layout())); + hlo->shape().tuple_shapes(1).layout())) + << hlo->ToString(); TF_RET_CHECK( - LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())); + LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())) + << hlo->ToString(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice, assignment_.GetUniqueSlice(hlo->operand(0), {})); From dc6be3572a059cb98a2e2763920d394375cb0861 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Thu, 30 Nov 2023 09:26:06 -0800 Subject: [PATCH 246/381] [stream_executor][NFC] Remove unused SharedDeviceMemory PiperOrigin-RevId: 586697692 --- .../xla/xla/stream_executor/device_memory.h | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index fe3df0067f337d..43cb0ac5245438 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -167,26 +167,6 @@ class DeviceMemory final : public DeviceMemoryBase { DeviceMemory(void *opaque, uint64_t size) : DeviceMemoryBase(opaque, size) {} }; -// A class to encapsulate the type and size of a dynamic shared memory -// buffer. Because the buffer exists solely on the device and is not copyable -// to the host, memory objects of this type do not maintain buffer pointers -// on the host. -template -class SharedDeviceMemory final : public DeviceMemoryBase { - public: - explicit SharedDeviceMemory(uint64_t elem_count) - : DeviceMemoryBase(nullptr, elem_count * kElemSize) {} - - static constexpr size_t kElemSize = sizeof(ElemT); - - // Returns the number of elements of type ElemT that constitute this - // allocation. - uint64_t ElementCount() const { return size() / kElemSize; } - - // Returns whether this is a single-element allocation. - bool IsScalar() const { return ElementCount() == 1; } -}; - // Host-side representation of packed-and-aligned vector datatypes on the device // side. Since these can appear in device kernel signatures, we support // launching them with these datatypes in launch signatures. From 14855f56cef9510583eb5aa69ef9fa2aafa3873a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 30 Nov 2023 09:59:57 -0800 Subject: [PATCH 247/381] [XLA:CPU] Add a direct implementation of AllGather, rather than lowering AllGather to AllReduce (attempt 2). We didn't handle the "use_global_device_ids" argument correctly in the first attempt. Fix that and add a test. PiperOrigin-RevId: 586707142 --- .../xla/xla/service/collective_ops_utils.cc | 13 ++++ third_party/xla/xla/service/cpu/BUILD | 3 +- .../xla/service/cpu/collectives_interface.h | 9 ++- .../xla/xla/service/cpu/cpu_compiler.cc | 2 - .../xla/service/cpu/cpu_layout_assignment.cc | 10 ++++ .../xla/xla/service/cpu/cpu_runtime.cc | 41 +++++++++++++ third_party/xla/xla/service/cpu/cpu_runtime.h | 7 +++ .../xla/service/cpu/in_process_collectives.cc | 59 +++++++++++++++++++ .../xla/service/cpu/in_process_collectives.h | 4 ++ third_party/xla/xla/service/cpu/ir_emitter.cc | 51 ++++++++++++++++ third_party/xla/xla/service/cpu/ir_emitter.h | 1 + .../xla/xla/service/cpu/simple_orc_jit.cc | 1 + .../xla/xla/tests/collective_ops_test.cc | 28 +++++++++ 13 files changed, 224 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 87d9d9445cdb5e..46d0cc6d36b6b4 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -369,6 +369,11 @@ StatusOr> GetParticipatingDevices( device_assignment.LogicalIdForDevice(device_id)); int current_replica_id = logical_id.replica_id; int current_partition_id = logical_id.computation_id; + TF_RET_CHECK(0 <= current_replica_id && current_replica_id < replica_count) + << current_replica_id << " " << replica_count; + TF_RET_CHECK(0 <= current_partition_id && + current_partition_id < partition_count) + << current_partition_id << " " << partition_count; std::vector participants; switch (group_mode) { @@ -384,6 +389,8 @@ StatusOr> GetParticipatingDevices( // partition. participants.reserve(participating_replicas.size()); for (int replica_id : participating_replicas) { + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; participants.emplace_back( device_assignment(replica_id, current_partition_id)); } @@ -398,6 +405,8 @@ StatusOr> GetParticipatingDevices( partition_count, replica_groups)); participants.reserve(participating_partitions.size()); for (int partition_id : participating_partitions) { + TF_RET_CHECK(0 <= partition_id && partition_id < partition_count) + << partition_id << " " << partition_count; participants.emplace_back( device_assignment(current_replica_id, partition_id)); } @@ -412,6 +421,8 @@ StatusOr> GetParticipatingDevices( replica_groups)); participants.reserve(participating_replicas.size() * partition_count); for (int replica_id : participating_replicas) { + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; for (int partition_id = 0; partition_id < partition_count; ++partition_id) { participants.emplace_back( @@ -441,6 +452,8 @@ StatusOr> GetParticipatingDevices( for (int flattened_id : participating_flattened_ids) { // Map from flattened id back to replica_id, partition_id. int replica_id = flattened_id / partition_count; + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; int partition_id = flattened_id % partition_count; participants.emplace_back(device_assignment(replica_id, partition_id)); } diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index ef51b94dcd80ad..44dbf2f4db410d 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -250,7 +250,6 @@ cc_library( "//xla/runtime:executable", "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", - "//xla/service:all_gather_decomposer", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", "//xla/service:batch_dot_simplification", @@ -1351,7 +1350,9 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features", + "//xla:shape_util", "//xla:util", + "//xla/hlo/ir:hlo", "//xla/service:computation_layout", "//xla/service:layout_assignment", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index bd518db3a780bc..4191df1d831fa2 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -59,9 +59,14 @@ class CollectivesCommunicator { // The all-to-all chunks are passed separately and do not have to be // contiguous in memory. virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffer, - absl::Span output_buffer, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) = 0; + + // Performs an all-gather. + virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index bc59af6706534d..06a62b4f8ca69e 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -115,7 +115,6 @@ limitations under the License. #include "xla/runtime/executable.h" #include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" -#include "xla/service/all_gather_decomposer.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" #include "xla/service/batch_dot_simplification.h" @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index dbfda4a5362a68..eb59a195d572cf 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -21,9 +21,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/map_util.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" +#include "xla/shape_util.h" #include "tsl/platform/errors.h" namespace xla { @@ -136,6 +139,13 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kAllGather) { + // XLA:CPU can only support all-gathers where the gather dimension is the + // most major dimension in the layout. + auto ag = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()), + ag)); } else { for (int64_t operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index a4090d67fc8d9b..912c4db69aacd2 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -142,6 +142,7 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -345,6 +346,34 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void AllGatherImpl(const ExecutableRunOptions* run_options, + int32_t channel_id_present, int32_t use_global_device_ids, + int64_t op_id, const void* replica_groups_str, + int32_t replica_groups_str_size, int64_t buffer_size, + void* source_buffer, void* destination_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + use_global_device_ids, op_id); + + auto it = absl::c_find(rendezvous_key.global_devices, device); + CHECK(it != rendezvous_key.global_devices.end()); + int rank = std::distance(rendezvous_key.global_devices.begin(), it); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size, + source_buffer, destination_buffer, + DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -503,6 +532,18 @@ void __xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions* run_options, destination_buffers); } +void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, + int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int64_t buffer_size, void* source_buffer, + void* destination_buffer) { + return xla::cpu::runtime::AllGatherImpl( + run_options, channel_id_present, use_global_device_ids, op_id, + replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, + destination_buffer); +} void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 361a116e7c300f..06a05cda713592 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -84,6 +84,7 @@ extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; +extern const char* const kAllGatherSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -195,6 +196,12 @@ extern void __xla_cpu_runtime_AllToAll( int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size, void** source_buffers, void** destination_buffers); +extern void __xla_cpu_runtime_AllGather( + const xla::ExecutableRunOptions* run_options, int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + const void* replica_groups_str, int32_t replica_groups_str_size, + int64_t buffer_size, void* source_buffer, void* destination_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index fc13c6c3870377..78eee1aa74f731 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -340,6 +340,44 @@ class CpuAllToAllRendezvous } }; +struct AllGatherParticipantData : ParticipantData { + AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + const void* source_buffer; + void* destination_buffer; + size_t chunk_size; + + std::string ToString() const override { + return absl::StrFormat( + "AllGatherParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_size); + } +}; + +class CpuAllGatherRendezvous + : public Rendezvous { + public: + explicit CpuAllGatherRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const AllGatherParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + char* out = static_cast(p.destination_buffer); + for (int i = 0; i < world_size; ++i, out += p.chunk_size) { + std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); + } + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -349,6 +387,8 @@ struct InProcessCollectivesState { collective_permute_rendezvous_map; RefcountingHashMap all_to_all_rendezvous_map; + RefcountingHashMap + all_gather_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -429,6 +469,25 @@ absl::Status InProcessCollectivesCommunicator::AllToAll( .status(); } +absl::Status InProcessCollectivesCommunicator::AllGather( + const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + AllGatherParticipantData participant(key, rank_); + participant.chunk_size = chunk_bytes; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllGatherRendezvous::SubmitParticipant( + [&] { + return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index fb25fd3528d606..aaedc474fa39b2 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -55,6 +55,10 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { absl::Span output_buffers, absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 15446b21998770..eea576506d99cf 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1344,6 +1344,57 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { return OkStatus(); } +Status IrEmitter::HandleAllGather(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); + + std::string replica_groups = + ReplicaGroupsToString(instruction->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + + const HloInstruction* op = instruction->operand(0); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice in_slice, + assignment_.GetUniqueSlice(op, {})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(instruction, {})); + const Shape& operand_shape = op->shape(); + CHECK(op->shape().IsArray()) + << "Operand to all-gather must be arrays: " << instruction->ToString(); + llvm::Value* output_buffer = EmitBufferPointer(out_slice, operand_shape); + llvm::Value* input_buffer = GetEmittedValueFor(op); + int64_t buffer_size = in_slice.size(); + + bool use_global_device_ids = + Cast(instruction)->use_global_device_ids(); + + EmitCallToFunc( + runtime::kAllGatherSymbolName, + { + /*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32( + static_cast(instruction->channel_id().has_value())), + /*use_global_device_ids=*/ + b_.getInt32(static_cast(use_global_device_ids)), + /*op_id=*/ + b_.getInt64(instruction->channel_id().has_value() + ? *instruction->channel_id() + : instruction->GetModule()->unique_id()), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + /*buffer_size=*/b_.getInt64(buffer_size), + /*source_buffer=*/input_buffer, + /*destination_buffer=*/output_buffer, + }, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); + return OkStatus(); +} + Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 3a194d054cb5fd..a0e31671ab4375 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -134,6 +134,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* instruction) override; Status HandleAllToAll(HloInstruction* instruction) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant) override; diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 8895b4f6451d5a..2e27a7c810869d 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -485,6 +485,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); + REGISTER_CPU_RUNTIME_SYMBOL(AllGather); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 29d7e5781bfa43..d87dae6e86dc7f 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -983,6 +983,34 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) { } } +XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0_UseGlobalDevices) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + id = u32[] replica-id() + id2 = u32[1, 2] broadcast(id), dimensions={} + a0 = u32[1, 2] constant({{10, 15}}) + a1 = u32[1, 2] add(id2, a0) + allgather = u32[2, 2] all-gather(a1), dimensions={0}, use_global_device_ids=true, channel_id=7, replica_groups={{0, 1}} + ROOT out = u32[4] reshape(allgather) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, result); + } +} + XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) { const char* const kModuleStr = R"( HloModule test From cbf5fefcb3d1ddbd60b925547ae2cbb73eb6a2e6 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Thu, 30 Nov 2023 10:06:32 -0800 Subject: [PATCH 248/381] Import openai/triton from GitHub. PiperOrigin-RevId: 586709597 --- third_party/triton/cl584230333.patch | 14 -------------- third_party/triton/workspace.bzl | 5 ++--- .../xla/third_party/triton/cl584230333.patch | 14 -------------- third_party/xla/third_party/triton/workspace.bzl | 5 ++--- 4 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 third_party/triton/cl584230333.patch delete mode 100644 third_party/xla/third_party/triton/cl584230333.patch diff --git a/third_party/triton/cl584230333.patch b/third_party/triton/cl584230333.patch deleted file mode 100644 index d8399eadf8f4d0..00000000000000 --- a/third_party/triton/cl584230333.patch +++ /dev/null @@ -1,14 +0,0 @@ -==== triton/lib/Dialect/Triton/IR/Dialect.cpp#6 - /google/src/cloud/jreiffers/mlir_26a0b277369adc31b162b1cc38b1a712bc10c1a0_1700552908/triton/lib/Dialect/Triton/IR/Dialect.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-11-21 01:58:04.000000000 -0800 -@@ -64,8 +64,7 @@ - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. -- void handleTerminator(Operation *op, -- ArrayRef valuesToRepl) const final { -+ void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 834668f112a38d..08ce69188c40d5 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl584018112" - TRITON_SHA256 = "a0f2461af9fbcf576cef08e0b83ab7a1caa3cfe2041c60b2809cbd495ff14f08" + TRITON_COMMIT = "cl586277651" + TRITON_SHA256 = "4941438a65ce53b1586b193d2f410b2b120ef1d32cd666f55f10055a913574fe" tf_http_archive( name = "triton", @@ -17,6 +17,5 @@ def repo(): patch_file = [ "//third_party/triton:b304456327.patch", "//third_party/triton:cl568176943.patch", - "//third_party/triton:cl584230333.patch", ], ) diff --git a/third_party/xla/third_party/triton/cl584230333.patch b/third_party/xla/third_party/triton/cl584230333.patch deleted file mode 100644 index d8399eadf8f4d0..00000000000000 --- a/third_party/xla/third_party/triton/cl584230333.patch +++ /dev/null @@ -1,14 +0,0 @@ -==== triton/lib/Dialect/Triton/IR/Dialect.cpp#6 - /google/src/cloud/jreiffers/mlir_26a0b277369adc31b162b1cc38b1a712bc10c1a0_1700552908/triton/lib/Dialect/Triton/IR/Dialect.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-11-21 01:58:04.000000000 -0800 -@@ -64,8 +64,7 @@ - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. -- void handleTerminator(Operation *op, -- ArrayRef valuesToRepl) const final { -+ void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 834668f112a38d..08ce69188c40d5 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl584018112" - TRITON_SHA256 = "a0f2461af9fbcf576cef08e0b83ab7a1caa3cfe2041c60b2809cbd495ff14f08" + TRITON_COMMIT = "cl586277651" + TRITON_SHA256 = "4941438a65ce53b1586b193d2f410b2b120ef1d32cd666f55f10055a913574fe" tf_http_archive( name = "triton", @@ -17,6 +17,5 @@ def repo(): patch_file = [ "//third_party/triton:b304456327.patch", "//third_party/triton:cl568176943.patch", - "//third_party/triton:cl584230333.patch", ], ) From 442bf89b296e3409729e19f5df92b78dde1e3377 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 10:26:12 -0800 Subject: [PATCH 249/381] Moves the BuildStrategyAndCost() method into its own file (auto_sharding_strategy.cc) PiperOrigin-RevId: 586717385 --- .../xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding.cc | 809 +--------------- .../auto_sharding/auto_sharding.h | 138 +++ .../auto_sharding/auto_sharding_strategy.cc | 861 ++++++++++++++++++ 4 files changed, 1003 insertions(+), 806 deletions(-) create mode 100644 third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index b871a6012fbb5a..88009f91b60778 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -22,6 +22,7 @@ cc_library( srcs = [ "auto_sharding.cc", "auto_sharding_dot_handler.cc", + "auto_sharding_strategy.cc", ], hdrs = [ "auto_sharding.h", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3de8fbeb86ce7d..f33cf2954a7c12 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -610,7 +610,7 @@ void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, double replicated_penalty, - absl::flat_hash_set operands_to_consider_all_strategies_for = {}) { + absl::flat_hash_set operands_to_consider_all_strategies_for) { HloSharding replicated_strategy = HloSharding::Replicate(); HloSharding output_spec = replicated_strategy; double memory_cost = GetBytes(shape) / output_spec.NumTiles(); @@ -791,7 +791,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, bool only_allow_divisible, const CallGraph& call_graph, int64_t partition_dimensions, - const std::vector& tensor_dims = {}) { + const std::vector& tensor_dims) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForOp(ins, shape, device_mesh, cluster_env, @@ -1261,7 +1261,7 @@ bool ShardingIsConsistent(const HloSharding& partial_sharding, void TrimOrGenerateStrategiesBasedOnExistingSharding( const Shape& output_shape, StrategyGroup* strategy_group, const StrategyMap& strategy_map, - const std::vector instructions, + const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableHashMap>& pretrimmed_strategy_map, @@ -1525,8 +1525,6 @@ void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, } } -// Enumerates sharding strategies for elementwise operators by following -// strategies of an operand of the elementwise op. std::unique_ptr CreateElementwiseOperatorStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, @@ -1585,9 +1583,6 @@ std::unique_ptr CreateElementwiseOperatorStrategies( return strategy_group; } -// Enumerates sharding strategies for reshape operators. The function does so by -// essentially reshaping the sharding of the operand in a manner similar to the -// tensor reshape itself. std::unique_ptr CreateReshapeStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, @@ -1680,804 +1675,6 @@ std::unique_ptr CreateReshapeStrategies( return strategy_group; } -// NOLINTBEGIN(readability/fn_size) -// TODO(zhuohan): Decompose this function into smaller pieces -// Build possible sharding strategies and their costs for all instructions. -StatusOr> -BuildStrategyAndCost(const HloInstructionSequence& sequence, - const HloModule* module, - const absl::flat_hash_map& - instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, - const AliasMap& alias_map, - const ClusterEnvironment& cluster_env, - AutoShardingOption& option, const CallGraph& call_graph, - const HloCostAnalysis& hlo_cost_analysis, - bool trying_multiple_mesh_shapes) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; - StrategyMap strategy_map; - // This map stores all of the trimmed strategies due to user specified - // sharding. The key is the instruction id, the value is the strategies. This - // is useful when the operand is forced to use a user sharding, and the op - // doesn't need to strictly follow it. We restore the trimmed strategies in - // this situation. - StableHashMap> pretrimmed_strategy_map; - StrategyGroups strategy_groups; - AssociativeDotPairs associative_dot_pairs; - - const std::vector& instructions = sequence.instructions(); - - // Add penalty for replicated tensors - double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + - cluster_env.AllReduceCost(1, 1)); - - int64_t max_depth = -1; - for (auto iter : depth_map) { - max_depth = std::max(max_depth, iter.second); - } - - // Register strategies and their costs for each instruction. - for (size_t instruction_id = 0; instruction_id < instructions.size(); - ++instruction_id) { - const HloInstruction* ins = instructions[instruction_id]; - VLOG(2) << "instruction_id = " << instruction_id << ": " - << ToAdaptiveString(ins); - std::unique_ptr strategy_group; - - HloOpcode opcode = ins->opcode(); - - bool only_allow_divisible; - if (IsEntryComputationInputOrOutput(module, ins)) { - // With IsEntryComputationInputOrOutput(module, ins) == true, entry - // computation's root instruction may still be unevenly sharded because it - // usually "follows" other instruction's sharding. If the instruction it - // follows is an intermediate instruction, it may be able to choose - // unevenly sharded strategiyes. Usually if we constraint input's sharding - // strategies, outputs would be constrained as welll, but if outputs are - // still unevely sharded in some cases, we need to fix the implementation - // in auto sharding. - only_allow_divisible = option.only_allow_divisible_input_output; - } else { - only_allow_divisible = option.only_allow_divisible_intermediate; - } - - switch (opcode) { - case HloOpcode::kParameter: - case HloOpcode::kRngBitGenerator: - case HloOpcode::kRng: { - strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters) - .value(); - break; - } - case HloOpcode::kConstant: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - break; - } - case HloOpcode::kScatter: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - // We follow the first operand (the array we're scattering into) - auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); - CHECK(!src_strategy_group->is_tuple); - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = - src_strategy_group->strategies[sid].output_sharding; - std::string name = ToStringSimple(output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - - std::vector> input_shardings_optional( - {output_spec, std::nullopt, std::nullopt}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, output_spec, strategy_map, cluster_env, call_graph, - input_shardings_optional); - - for (const auto& sharding_optional : input_shardings_optional) { - CHECK(sharding_optional.has_value()); - } - - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_cost), input_shardings_optional})); - } - break; - } - case HloOpcode::kGather: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - const HloInstruction* indices = ins->operand(1); - const Shape& shape = ins->shape(); - for (int32_t index_dim = 0; index_dim < indices->shape().rank(); - index_dim++) { - // Shard on indices dimensions that correspond to output dimensions - // TODO(b/220935014) Shard the last dim of output (model dim) with - // AllGather cost and no follow. - if (index_dim == ins->gather_dimension_numbers().index_vector_dim()) { - continue; - } - for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { - // Split only when the tensor shape is divisible by device - // mesh. - if (device_mesh.dim(j) == 1 || - (only_allow_divisible && - !IsDivisible(shape.dimensions(index_dim), - device_mesh.dim(j)))) { - continue; - } - std::string name = absl::StrCat("S", index_dim, " @ ", j); - - HloSharding output_spec = - Tile(shape, {index_dim}, {j}, device_mesh); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - std::optional input_spec = - hlo_sharding_util::ReshapeSharding(shape, indices->shape(), - output_spec); - if (!input_spec.has_value()) { // invalid reshape - continue; - } - std::vector> input_shardings_optional( - {std::nullopt, input_spec}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, output_spec, strategy_map, cluster_env, call_graph, - input_shardings_optional); - - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, - memory_cost, std::move(resharding_cost), - input_shardings_optional})); - } - } - auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = - src_strategy_group->strategies[sid].output_sharding; - auto gather_parallel_dims = - hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); - absl::Span operand_parallel_dims; - if (gather_parallel_dims) { - operand_parallel_dims = absl::MakeConstSpan( - gather_parallel_dims->operand_parallel_dims); - } - HloSharding filtered_operand_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - output_spec, operand_parallel_dims); - auto maybe_from_data = hlo_sharding_util:: - GatherOutputShardingFromOperandOperandPassthroughDimensions( - filtered_operand_sharding, *ins); - if (!maybe_from_data) continue; - std::string name = ToStringSimple(*maybe_from_data); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - GetBytes(ins->shape()) / maybe_from_data->NumTiles(); - std::vector> input_shardings_optional( - {*maybe_from_data, std::nullopt}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, *maybe_from_data, strategy_map, cluster_env, call_graph, - input_shardings_optional); - strategy_group->strategies.push_back(ShardingStrategy( - {name, *maybe_from_data, compute_cost, communication_cost, - memory_cost, std::move(resharding_cost), - input_shardings_optional})); - } - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, - /* operands_to_consider_all_strategies_for */ {0}); - break; - } - case HloOpcode::kBroadcast: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - - if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } else { - EnumerateAllPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - batch_dim_map, only_allow_divisible, call_graph, - /*partitions*/ 2); - if (option.allow_mixed_mesh_shape) { - EnumerateAll1DPartition(ins, ins->shape(), - cluster_env.device_mesh_1d_, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "1d", call_graph); - } - } - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - - break; - } - case HloOpcode::kReshape: { - strategy_group = CreateReshapeStrategies( - instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups); - break; - } - case HloOpcode::kTranspose: - case HloOpcode::kReverse: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - - const HloInstruction* operand = ins->operand(0); - - // Create follow strategies - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; - - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = Undefined(); - auto input_spec = src_strategy_group->strategies[sid].output_sharding; - if (opcode == HloOpcode::kTranspose) { - output_spec = hlo_sharding_util::TransposeSharding( - input_spec, ins->dimensions()); - } else { - output_spec = hlo_sharding_util::ReverseSharding(input_spec, - ins->dimensions()); - } - - std::string name = ToStringSimple(output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - auto resharding_costs = ReshardingCostVector( - src_strategy_group, operand->shape(), input_spec, cluster_env); - strategy_group->strategies.push_back( - ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, - memory_cost, - {resharding_costs}, - {input_spec}})); - } - break; - } - case HloOpcode::kPad: - case HloOpcode::kSlice: - case HloOpcode::kConcatenate: // TODO(zhuohan): revisit concatenate - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - int64_t follow_idx; - switch (opcode) { - // TODO(yuemmawang) Re-evaluate the follow_idx choices for the - // following 3. - case HloOpcode::kPad: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kConcatenate: - // Follow the operand according to the follow heuristics - follow_idx = ChooseOperandToFollow(strategy_map, depth_map, - alias_map, max_depth, ins) - .first; - break; - // The following types are better to follow the first operand. - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - follow_idx = 0; - break; - default: - LOG(FATAL) << "Selecting follow index encounters an unhandled " - "instruction type: " + - ins->ToShortString(); - } - // Create follow strategies - const HloInstruction* operand = ins->operand(follow_idx); - StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; - - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - std::optional output_spec; - HloSharding input_spec = - src_strategy_group->strategies[sid].output_sharding; - - // Find output shardings. - switch (opcode) { - case HloOpcode::kPad: - case HloOpcode::kSlice: - case HloOpcode::kConcatenate: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - output_spec = PropagateDimwiseSharding( - input_spec, operand->shape(), ins->shape()); - break; - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - output_spec = PropagateReduceWindowSharding( - input_spec, operand->shape(), ins->window()); - break; - default: - LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); - } - - // Get a list of input shardings, each corresponds to an operand. - std::vector> input_shardings; - for (int64_t k = 0; k < ins->operand_count(); ++k) { - if (k == follow_idx || - ToString(ins->operand(k)->shape().dimensions()) == - ToString(operand->shape().dimensions())) { - input_shardings.push_back(input_spec); - } else { - input_shardings.push_back(std::nullopt); - } - } - if (!output_spec.has_value()) { - continue; - } - - std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); - std::vector> resharding_costs = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, *output_spec, strategy_map, cluster_env, call_graph, - input_shardings); - - strategy_group->strategies.push_back( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - std::move(resharding_costs), - {input_spec}})); - } - - if (strategy_group->strategies.empty()) { - strategy_group->following = nullptr; - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - } - - break; - } - case HloOpcode::kOptimizationBarrier: { - auto operand_strategies = strategy_map.at(ins->operand(0)).get(); - strategy_group = MaybeFollowInsStrategyGroup( - operand_strategies, ins->shape(), instruction_id, - /* have_memory_cost */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - break; - } - case HloOpcode::kBitcast: { - if (ins->shape() == ins->operand(0)->shape()) { - strategy_group = CreateElementwiseOperatorStrategies( - instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, - associative_dot_pairs); - } else { - strategy_group = CreateReshapeStrategies( - instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups); - } - break; - } - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kRoundNearestEven: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kPopulationCount: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kRsqrt: - case HloOpcode::kLogistic: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kSqrt: - case HloOpcode::kCbrt: - case HloOpcode::kTan: - case HloOpcode::kTanh: - // Binary elementwise operations - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kCompare: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kXor: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - case HloOpcode::kStochasticConvert: - // Ternary elementwise operations. - case HloOpcode::kSelect: - case HloOpcode::kClamp: { - strategy_group = CreateElementwiseOperatorStrategies( - instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, - associative_dot_pairs); - break; - } - case HloOpcode::kReduce: { - auto strategies_status = FollowReduceStrategy( - ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id, - strategy_map, strategy_groups, cluster_env, - option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); - if (strategies_status.ok()) { - strategy_group = std::move(strategies_status.value()); - } else { - return strategies_status.status(); - } - break; - } - case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, - GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); - } - break; - } - case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, - GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); - } - break; - } - case HloOpcode::kRngGetAndUpdateState: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - break; - } - case HloOpcode::kIota: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - if (cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } - if (cluster_env.IsDeviceMesh2D()) { - // Split 2 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 2); - } - if (cluster_env.IsDeviceMesh3D()) { - // Split 3 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 3); - } - if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) { - // Split 1 dim, but for 1d flattened version of the 2d mesh - // For example, when the mesh shape is (2, 4), we add strategies for - // mesh shape (1, 8) here in addition. - EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, " 1d", call_graph); - } - - // Replicate - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty * 5); - - break; - } - case HloOpcode::kTuple: { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->operand_count()); - for (size_t i = 0; i < ins->operand_count(); ++i) { - const HloInstruction* operand = ins->operand(i); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group, operand->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); - } - break; - } - case HloOpcode::kGetTupleElement: { - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), - instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - break; - } - case HloOpcode::kCustomCall: { - auto generate_non_following_strategies = - [&](bool only_replicated, - absl::flat_hash_set - operands_to_consider_all_strategies_for = {}) { - if (ins->shape().IsTuple()) { - if (only_replicated) { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve( - ins->shape().tuple_shapes_size()); - for (size_t i = 0; i < ins->shape().tuple_shapes_size(); - ++i) { - std::unique_ptr child_strategies = - CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), - cluster_env, strategy_map, - child_strategies, replicated_penalty); - strategy_group->childs.push_back( - std::move(child_strategies)); - } - } else { - strategy_group = - CreateAllStrategiesGroup( - ins, ins->shape(), instruction_id, strategy_groups, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) - .value(); - } - } else { - if (only_replicated) { - strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, - strategy_map, strategy_group, - replicated_penalty); - } else { - strategy_group = - CreateAllStrategiesGroup( - ins, ins->shape(), instruction_id, strategy_groups, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) - .value(); - } - } - }; - - if (IsCustomCallMarker(ins)) { - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - } else if (ins->has_sharding()) { - generate_non_following_strategies(false); - } else if (OutputInputSameShapes(ins)) { - auto* partitioner = - GetCustomCallPartitioner(ins->custom_call_target()); - if (partitioner && partitioner->IsCustomCallShardable(ins)) { - // Follows operand 0's strategies if this custom-call op is - // shardable and has the same input and output sizes. - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - } - } else if (IsTopKCustomCall(ins)) { - generate_non_following_strategies(false, {0}); - } else { - // TODO (b/258723035) Handle CustomCall ops for GPUs in a better way. - generate_non_following_strategies(true); - } - break; - } - case HloOpcode::kWhile: { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); - const StrategyGroup* src_strategy_group = - strategy_map.at(ins->operand(0)).get(); - for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { - auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[i].get(), - ins->shape().tuple_shapes().at(i), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); - } - break; - } - case HloOpcode::kConditional: - case HloOpcode::kInfeed: - case HloOpcode::kSort: { - strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - /*create_replicated_strategies*/ true) - .value(); - break; - } - case HloOpcode::kOutfeed: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - break; - } - case HloOpcode::kAfterAll: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - break; - } - default: - LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); - } - RemoveDuplicatedStrategy(strategy_group); - if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { - // Finds the sharding strategy that aligns with the given sharding spec - // Do not merge nodes if this one instruction has annotations. - TrimOrGenerateStrategiesBasedOnExistingSharding( - ins->shape(), strategy_group.get(), strategy_map, instructions, - ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, - option.nd_sharding_iteratively_strict_search_space); - } - if (!strategy_group->is_tuple && strategy_group->following) { - if (!LeafVectorsAreConsistent( - strategy_group->strategies, strategy_group->following->strategies, - /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { - // It confuses the solver if two instructions have different number of - // sharding strategies but share the same ILP variable. The solver - // would run much longer and/or return infeasible solutions. - // So if two strategies' strategiess are inconsistent, we unfollow - // them. - strategy_group->following = nullptr; - } - } else if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { - if (strategy_group->childs.at(i)->following && - !LeafVectorsAreConsistent( - strategy_group->childs.at(i)->strategies, - strategy_group->childs.at(i)->following->strategies, - /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { - strategy_group->childs.at(i)->following = nullptr; - } - } - } - RemoveInvalidShardingsWithShapes( - ins->shape(), strategy_group.get(), - /* instruction_has_user_sharding */ ins->has_sharding()); - - if (instruction_execution_counts.contains(ins)) { - ScaleCostsWithExecutionCounts(strategy_group.get(), - instruction_execution_counts.at(ins)); - } else { - VLOG(5) << "No execution count available for " << ins->name(); - } - XLA_VLOG_LINES(2, - absl::StrCat("strategies:\n", strategy_group->ToString())); - - // Debug options: forcibly set the strategy of some instructions. - if (option.force_strategy) { - std::vector inst_indices = option.force_strategy_inst_indices; - std::vector stra_names = option.force_strategy_stra_names; - CHECK_EQ(inst_indices.size(), stra_names.size()); - auto it = absl::c_find(inst_indices, strategy_group->node_idx); - if (it != inst_indices.end()) { - CHECK(!strategy_group->is_tuple); - std::vector new_strategies; - int64_t idx = it - inst_indices.begin(); - for (const auto& stra : strategy_group->strategies) { - if (stra.name == stra_names[idx]) { - new_strategies.push_back(stra); - } - } - strategy_group->strategies = std::move(new_strategies); - } - } - - // When trying out multiple mesh shapes in the presence of user specified - // sharding (as in - // AutoShardingTest.AutoShardingKeepUserShardingInputOutput), there may be a - // situation when we cannot generate any shardings for an instruction when - // the mesh shape we're trying does not match with the mesh shape used in - // user specified shardings. So we disable the check in that situation. - if (!trying_multiple_mesh_shapes) { - CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) - << ins->ToString() << " does not have any valid strategies."; - } else if (!(strategy_group->is_tuple || - !strategy_group->strategies.empty())) { - return Status(absl::StatusCode::kFailedPrecondition, - "Could not generate any shardings for an instruction due " - "to mismatched mesh shapes."); - } - // Checks the shape of resharding_costs is valid. It will check fail if the - // shape is not as expected. - // CheckReshardingCostsShape(strategies.get()); - CheckMemoryCosts(strategy_group.get(), ins->shape()); - strategy_map[ins] = std::move(strategy_group); - } // end of for loop - - // If gradient accumulation is used, adjust the cost of all-reduce for - // gradient synchronization. - if (option.grad_acc_num_micro_batches > 1) { - // find gradient-computation instructions - std::vector grad_insts = - GetGradientComputationInstructions(instructions); - for (const HloInstruction* inst : grad_insts) { - StrategyGroup* stra_vector = strategy_map[inst].get(); - CHECK(!stra_vector->is_tuple); - - for (auto& stra : stra_vector->strategies) { - if (absl::StrContains(stra.name, "allreduce")) { - stra.communication_cost /= option.grad_acc_num_micro_batches; - } - } - } - } - - return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), - std::move(associative_dot_pairs)); -} - -// NOLINTEND - AutoShardingSolverResult CallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 5af6ad3d35cc8e..66f12dd37ad95a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -19,7 +19,10 @@ limitations under the License. #include #include #include +#include #include +#include +#include #include #include "absl/container/flat_hash_map.h" @@ -37,6 +40,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" #include "xla/shape.h" #include "xla/statusor.h" @@ -223,6 +227,140 @@ AutoShardingSolverResult Solve( void PopulateTemporalValues(const CostGraph& cost_graph, AutoShardingSolverRequest& request); +void AddReplicatedStrategy( + const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, double replicated_penalty, + absl::flat_hash_set operands_to_consider_all_strategies_for = {}); + +void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape); + +// Choose an operand to follow. We choose to follow the operand with the +// highest priority. +std::pair ChooseOperandToFollow( + const StrategyMap& strategy_map, const InstructionDepthMap& depth_map, + const AliasMap& alias_map, int64_t max_depth, const HloInstruction* ins); + +StatusOr> CreateAllStrategiesGroup( + const HloInstruction* ins, const Shape& shape, size_t instruction_id, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, const AutoShardingOption& option, + double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, + const CallGraph& call_graph, bool only_allow_divisible, + bool create_replicated_strategies); + +// Enumerates sharding strategies for elementwise operators by following +// strategies of an operand of the elementwise op. +std::unique_ptr CreateElementwiseOperatorStrategies( + size_t instruction_id, const HloInstruction* ins, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, + StableHashMap>& + pretrimmed_strategy_map, + int64_t max_depth, StrategyGroups& strategy_groups, + AssociativeDotPairs& associative_dot_pairs); + +// Factory functions for StrategyGroup. +std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( + size_t instruction_id, StrategyGroups& strategy_groups); + +// Enumerates sharding strategies for reshape operators. The function does so by +// essentially reshaping the sharding of the operand in a manner similar to the +// tensor reshape itself. +std::unique_ptr CreateReshapeStrategies( + size_t instruction_id, const HloInstruction* ins, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + bool only_allow_divisible, double replicated_penalty, + const InstructionBatchDimMap& batch_dim_map, + const AutoShardingOption& option, StrategyGroups& strategy_groups); + +std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id); + +// Enumerate all 1d partition strategies. +void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, + const Array& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + bool only_allow_divisible, + const std::string& suffix, + const CallGraph& call_graph); + +// Enumerate all partitions recursively +void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, + const Array& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + const InstructionBatchDimMap& batch_dim_map, + bool only_allow_divisible, + const CallGraph& call_graph, + int64_t partition_dimensions, + const std::vector& tensor_dims = {}); + +StatusOr> FollowReduceStrategy( + const HloInstruction* ins, const Shape& output_shape, + const HloInstruction* operand, const HloInstruction* unit, + size_t instruction_id, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + bool allow_mixed_mesh_shape, bool crash_at_error); + +void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + double replicated_penalty); + +std::vector> +GenerateReshardingCostsAndMissingShardingsForAllOperands( + const HloInstruction* ins, const HloSharding& output_sharding, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + const CallGraph& call_graph, + std::vector>& input_shardings); + +bool LeafVectorsAreConsistent(const std::vector& one, + const std::vector& two, + bool is_reshape); + +std::unique_ptr MaybeFollowInsStrategyGroup( + const StrategyGroup* src_strategy_group, const Shape& shape, + size_t instruction_id, bool have_memory_cost, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + StableHashMap>& + pretrimmed_strategy_map); + +void RemoveInvalidShardingsWithShapes(const Shape& shape, + StrategyGroup* strategy_group, + bool instruction_has_user_sharding); + +void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, + int64_t execution_count); + +// Existing shardings refer to the HloSharding field in the given +// HloInstruction. +void TrimOrGenerateStrategiesBasedOnExistingSharding( + const Shape& output_shape, StrategyGroup* strategy_group, + const StrategyMap& strategy_map, + const std::vector& instructions, + const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, + StableHashMap>& + pretrimmed_strategy_map, + const CallGraph& call_graph, bool strict); + +// Build possible sharding strategies and their costs for all instructions. +StatusOr> +BuildStrategyAndCost(const HloInstructionSequence& sequence, + const HloModule* module, + const absl::flat_hash_map& + instruction_execution_counts, + const InstructionDepthMap& depth_map, + const InstructionBatchDimMap& batch_dim_map, + const AliasMap& alias_map, + const ClusterEnvironment& cluster_env, + AutoShardingOption& option, const CallGraph& call_graph, + const HloCostAnalysis& hlo_cost_analysis, + bool trying_multiple_mesh_shapes); + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc new file mode 100644 index 00000000000000..8a2468fd4fb2b9 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -0,0 +1,861 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" +#include "xla/hlo/experimental/auto_sharding/cluster_environment.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/sharding_propagation.h" +#include "xla/shape.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace spmd { + +// NOLINTBEGIN(readability/fn_size) +// TODO(zhuohan): Decompose this function into smaller pieces +StatusOr> +BuildStrategyAndCost(const HloInstructionSequence& sequence, + const HloModule* module, + const absl::flat_hash_map& + instruction_execution_counts, + const InstructionDepthMap& depth_map, + const InstructionBatchDimMap& batch_dim_map, + const AliasMap& alias_map, + const ClusterEnvironment& cluster_env, + AutoShardingOption& option, const CallGraph& call_graph, + const HloCostAnalysis& hlo_cost_analysis, + bool trying_multiple_mesh_shapes) { + const Array& device_mesh = cluster_env.device_mesh_; + const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + StrategyMap strategy_map; + // This map stores all of the trimmed strategies due to user specified + // sharding. The key is the instruction id, the value is the strategies. This + // is useful when the operand is forced to use a user sharding, and the op + // doesn't need to strictly follow it. We restore the trimmed strategies in + // this situation. + StableHashMap> pretrimmed_strategy_map; + StrategyGroups strategy_groups; + AssociativeDotPairs associative_dot_pairs; + + const std::vector& instructions = sequence.instructions(); + + // Add penalty for replicated tensors + double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + + cluster_env.AllReduceCost(1, 1)); + + int64_t max_depth = -1; + for (auto iter : depth_map) { + max_depth = std::max(max_depth, iter.second); + } + + // Register strategies and their costs for each instruction. + for (size_t instruction_id = 0; instruction_id < instructions.size(); + ++instruction_id) { + const HloInstruction* ins = instructions[instruction_id]; + VLOG(2) << "instruction_id = " << instruction_id << ": " + << ToAdaptiveString(ins); + std::unique_ptr strategy_group; + + HloOpcode opcode = ins->opcode(); + + bool only_allow_divisible; + if (IsEntryComputationInputOrOutput(module, ins)) { + // With IsEntryComputationInputOrOutput(module, ins) == true, entry + // computation's root instruction may still be unevenly sharded because it + // usually "follows" other instruction's sharding. If the instruction it + // follows is an intermediate instruction, it may be able to choose + // unevenly sharded strategiyes. Usually if we constraint input's sharding + // strategies, outputs would be constrained as welll, but if outputs are + // still unevely sharded in some cases, we need to fix the implementation + // in auto sharding. + only_allow_divisible = option.only_allow_divisible_input_output; + } else { + only_allow_divisible = option.only_allow_divisible_intermediate; + } + + switch (opcode) { + case HloOpcode::kParameter: + case HloOpcode::kRngBitGenerator: + case HloOpcode::kRng: { + strategy_group = + CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, + strategy_groups, cluster_env, strategy_map, + option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + option.allow_replicated_parameters) + .value(); + break; + } + case HloOpcode::kConstant: { + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } + case HloOpcode::kScatter: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + // We follow the first operand (the array we're scattering into) + auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); + CHECK(!src_strategy_group->is_tuple); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + HloSharding output_spec = + src_strategy_group->strategies[sid].output_sharding; + std::string name = ToStringSimple(output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); + + std::vector> input_shardings_optional( + {output_spec, std::nullopt, std::nullopt}); + std::vector> resharding_cost = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, output_spec, strategy_map, cluster_env, call_graph, + input_shardings_optional); + + for (const auto& sharding_optional : input_shardings_optional) { + CHECK(sharding_optional.has_value()); + } + + strategy_group->strategies.push_back(ShardingStrategy( + {name, output_spec, compute_cost, communication_cost, memory_cost, + std::move(resharding_cost), input_shardings_optional})); + } + break; + } + case HloOpcode::kGather: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + const HloInstruction* indices = ins->operand(1); + const Shape& shape = ins->shape(); + for (int32_t index_dim = 0; index_dim < indices->shape().rank(); + index_dim++) { + // Shard on indices dimensions that correspond to output dimensions + // TODO(b/220935014) Shard the last dim of output (model dim) with + // AllGather cost and no follow. + if (index_dim == ins->gather_dimension_numbers().index_vector_dim()) { + continue; + } + for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { + // Split only when the tensor shape is divisible by device + // mesh. + if (device_mesh.dim(j) == 1 || + (only_allow_divisible && + !IsDivisible(shape.dimensions(index_dim), + device_mesh.dim(j)))) { + continue; + } + std::string name = absl::StrCat("S", index_dim, " @ ", j); + + HloSharding output_spec = + Tile(shape, {index_dim}, {j}, device_mesh); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(shape) / output_spec.NumTiles(); + std::optional input_spec = + hlo_sharding_util::ReshapeSharding(shape, indices->shape(), + output_spec); + if (!input_spec.has_value()) { // invalid reshape + continue; + } + std::vector> input_shardings_optional( + {std::nullopt, input_spec}); + std::vector> resharding_cost = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, output_spec, strategy_map, cluster_env, call_graph, + input_shardings_optional); + + strategy_group->strategies.push_back(ShardingStrategy( + {name, output_spec, compute_cost, communication_cost, + memory_cost, std::move(resharding_cost), + input_shardings_optional})); + } + } + auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + HloSharding output_spec = + src_strategy_group->strategies[sid].output_sharding; + auto gather_parallel_dims = + hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); + absl::Span operand_parallel_dims; + if (gather_parallel_dims) { + operand_parallel_dims = absl::MakeConstSpan( + gather_parallel_dims->operand_parallel_dims); + } + HloSharding filtered_operand_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + output_spec, operand_parallel_dims); + auto maybe_from_data = hlo_sharding_util:: + GatherOutputShardingFromOperandOperandPassthroughDimensions( + filtered_operand_sharding, *ins); + if (!maybe_from_data) continue; + std::string name = ToStringSimple(*maybe_from_data); + double compute_cost = 0, communication_cost = 0; + double memory_cost = + GetBytes(ins->shape()) / maybe_from_data->NumTiles(); + std::vector> input_shardings_optional( + {*maybe_from_data, std::nullopt}); + std::vector> resharding_cost = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, *maybe_from_data, strategy_map, cluster_env, call_graph, + input_shardings_optional); + strategy_group->strategies.push_back(ShardingStrategy( + {name, *maybe_from_data, compute_cost, communication_cost, + memory_cost, std::move(resharding_cost), + input_shardings_optional})); + } + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, + /* operands_to_consider_all_strategies_for */ {0}); + break; + } + case HloOpcode::kBroadcast: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + + if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { + EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, + cluster_env, strategy_map, strategy_group, + only_allow_divisible, "", call_graph); + } else { + EnumerateAllPartition(ins, ins->shape(), cluster_env.device_mesh_, + cluster_env, strategy_map, strategy_group, + batch_dim_map, only_allow_divisible, call_graph, + /*partitions*/ 2); + if (option.allow_mixed_mesh_shape) { + EnumerateAll1DPartition(ins, ins->shape(), + cluster_env.device_mesh_1d_, cluster_env, + strategy_map, strategy_group, + only_allow_divisible, "1d", call_graph); + } + } + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty); + + break; + } + case HloOpcode::kReshape: { + strategy_group = CreateReshapeStrategies( + instruction_id, ins, strategy_map, cluster_env, + only_allow_divisible, replicated_penalty, batch_dim_map, option, + strategy_groups); + break; + } + case HloOpcode::kTranspose: + case HloOpcode::kReverse: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + + const HloInstruction* operand = ins->operand(0); + + // Create follow strategies + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; + + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + HloSharding output_spec = Undefined(); + auto input_spec = src_strategy_group->strategies[sid].output_sharding; + if (opcode == HloOpcode::kTranspose) { + output_spec = hlo_sharding_util::TransposeSharding( + input_spec, ins->dimensions()); + } else { + output_spec = hlo_sharding_util::ReverseSharding(input_spec, + ins->dimensions()); + } + + std::string name = ToStringSimple(output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); + auto resharding_costs = ReshardingCostVector( + src_strategy_group, operand->shape(), input_spec, cluster_env); + strategy_group->strategies.push_back( + ShardingStrategy({name, + output_spec, + compute_cost, + communication_cost, + memory_cost, + {resharding_costs}, + {input_spec}})); + } + break; + } + case HloOpcode::kPad: + case HloOpcode::kSlice: + case HloOpcode::kConcatenate: // TODO(zhuohan): revisit concatenate + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + int64_t follow_idx; + switch (opcode) { + // TODO(yuemmawang) Re-evaluate the follow_idx choices for the + // following 3. + case HloOpcode::kPad: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kConcatenate: + // Follow the operand according to the follow heuristics + follow_idx = ChooseOperandToFollow(strategy_map, depth_map, + alias_map, max_depth, ins) + .first; + break; + // The following types are better to follow the first operand. + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + follow_idx = 0; + break; + default: + LOG(FATAL) << "Selecting follow index encounters an unhandled " + "instruction type: " + + ins->ToShortString(); + } + // Create follow strategies + const HloInstruction* operand = ins->operand(follow_idx); + StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; + + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + std::optional output_spec; + HloSharding input_spec = + src_strategy_group->strategies[sid].output_sharding; + + // Find output shardings. + switch (opcode) { + case HloOpcode::kPad: + case HloOpcode::kSlice: + case HloOpcode::kConcatenate: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + output_spec = PropagateDimwiseSharding( + input_spec, operand->shape(), ins->shape()); + break; + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + output_spec = PropagateReduceWindowSharding( + input_spec, operand->shape(), ins->window()); + break; + default: + LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); + } + + // Get a list of input shardings, each corresponds to an operand. + std::vector> input_shardings; + for (int64_t k = 0; k < ins->operand_count(); ++k) { + if (k == follow_idx || + ToString(ins->operand(k)->shape().dimensions()) == + ToString(operand->shape().dimensions())) { + input_shardings.push_back(input_spec); + } else { + input_shardings.push_back(std::nullopt); + } + } + if (!output_spec.has_value()) { + continue; + } + + std::string name = ToStringSimple(*output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); + std::vector> resharding_costs = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, *output_spec, strategy_map, cluster_env, call_graph, + input_shardings); + + strategy_group->strategies.push_back( + ShardingStrategy({name, + *output_spec, + compute_cost, + communication_cost, + memory_cost, + std::move(resharding_costs), + {input_spec}})); + } + + if (strategy_group->strategies.empty()) { + strategy_group->following = nullptr; + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + } + + break; + } + case HloOpcode::kOptimizationBarrier: { + auto operand_strategies = strategy_map.at(ins->operand(0)).get(); + strategy_group = MaybeFollowInsStrategyGroup( + operand_strategies, ins->shape(), instruction_id, + /* have_memory_cost */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + break; + } + case HloOpcode::kBitcast: { + if (ins->shape() == ins->operand(0)->shape()) { + strategy_group = CreateElementwiseOperatorStrategies( + instruction_id, ins, strategy_map, cluster_env, depth_map, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, + associative_dot_pairs); + } else { + strategy_group = CreateReshapeStrategies( + instruction_id, ins, strategy_map, cluster_env, + only_allow_divisible, replicated_penalty, batch_dim_map, option, + strategy_groups); + } + break; + } + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRoundNearestEven: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTan: + case HloOpcode::kTanh: + // Binary elementwise operations + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kCompare: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kStochasticConvert: + // Ternary elementwise operations. + case HloOpcode::kSelect: + case HloOpcode::kClamp: { + strategy_group = CreateElementwiseOperatorStrategies( + instruction_id, ins, strategy_map, cluster_env, depth_map, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, + associative_dot_pairs); + break; + } + case HloOpcode::kReduce: { + auto strategies_status = FollowReduceStrategy( + ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id, + strategy_map, strategy_groups, cluster_env, + option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); + if (strategies_status.ok()) { + strategy_group = std::move(strategies_status.value()); + } else { + return strategies_status.status(); + } + break; + } + case HloOpcode::kDot: { + TF_RETURN_IF_ERROR(HandleDot( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + cluster_env, batch_dim_map, option, call_graph)); + if (option.allow_replicated_strategy_for_dot_and_conv) { + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, + GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, + sequence, hlo_cost_analysis)); + } + break; + } + case HloOpcode::kConvolution: { + TF_RETURN_IF_ERROR(HandleConv( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + cluster_env, batch_dim_map, option, call_graph)); + if (option.allow_replicated_strategy_for_dot_and_conv) { + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, + GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, + sequence, hlo_cost_analysis)); + } + break; + } + case HloOpcode::kRngGetAndUpdateState: { + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } + case HloOpcode::kIota: { + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + strategy_groups); + if (cluster_env.IsDeviceMesh1D()) { + EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, + strategy_map, strategy_group, + only_allow_divisible, "", call_graph); + } + if (cluster_env.IsDeviceMesh2D()) { + // Split 2 dims + EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, + strategy_map, strategy_group, batch_dim_map, + only_allow_divisible, call_graph, /*parts*/ 2); + } + if (cluster_env.IsDeviceMesh3D()) { + // Split 3 dims + EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, + strategy_map, strategy_group, batch_dim_map, + only_allow_divisible, call_graph, /*parts*/ 3); + } + if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) { + // Split 1 dim, but for 1d flattened version of the 2d mesh + // For example, when the mesh shape is (2, 4), we add strategies for + // mesh shape (1, 8) here in addition. + EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, + cluster_env, strategy_map, strategy_group, + only_allow_divisible, " 1d", call_graph); + } + + // Replicate + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty * 5); + + break; + } + case HloOpcode::kTuple: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->operand_count()); + for (size_t i = 0; i < ins->operand_count(); ++i) { + const HloInstruction* operand = ins->operand(i); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group, operand->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + break; + } + case HloOpcode::kGetTupleElement: { + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), + instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + break; + } + case HloOpcode::kCustomCall: { + auto generate_non_following_strategies = + [&](bool only_replicated, + absl::flat_hash_set + operands_to_consider_all_strategies_for = {}) { + if (ins->shape().IsTuple()) { + if (only_replicated) { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve( + ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); + ++i) { + std::unique_ptr child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), + cluster_env, strategy_map, + child_strategies, replicated_penalty); + strategy_group->childs.push_back( + std::move(child_strategies)); + } + } else { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, true) + .value(); + } + } else { + if (only_replicated) { + strategy_group = CreateLeafStrategyGroup( + instruction_id, ins, strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, + strategy_map, strategy_group, + replicated_penalty); + } else { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, true) + .value(); + } + } + }; + + if (IsCustomCallMarker(ins)) { + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + } else if (ins->has_sharding()) { + generate_non_following_strategies(false); + } else if (OutputInputSameShapes(ins)) { + auto* partitioner = + GetCustomCallPartitioner(ins->custom_call_target()); + if (partitioner && partitioner->IsCustomCallShardable(ins)) { + // Follows operand 0's strategies if this custom-call op is + // shardable and has the same input and output sizes. + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + } + } else if (IsTopKCustomCall(ins)) { + generate_non_following_strategies(false, {0}); + } else { + // TODO (b/258723035) Handle CustomCall ops for GPUs in a better way. + generate_non_following_strategies(true); + } + break; + } + case HloOpcode::kWhile: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + const StrategyGroup* src_strategy_group = + strategy_map.at(ins->operand(0)).get(); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[i].get(), + ins->shape().tuple_shapes().at(i), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + break; + } + case HloOpcode::kConditional: + case HloOpcode::kInfeed: + case HloOpcode::kSort: { + strategy_group = + CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, + strategy_groups, cluster_env, strategy_map, + option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /*create_replicated_strategies*/ true) + .value(); + break; + } + case HloOpcode::kOutfeed: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty); + break; + } + case HloOpcode::kAfterAll: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty); + break; + } + default: + LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); + } + RemoveDuplicatedStrategy(strategy_group); + if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { + // Finds the sharding strategy that aligns with the given sharding spec + // Do not merge nodes if this one instruction has annotations. + TrimOrGenerateStrategiesBasedOnExistingSharding( + ins->shape(), strategy_group.get(), strategy_map, instructions, + ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, + option.nd_sharding_iteratively_strict_search_space); + } + if (!strategy_group->is_tuple && strategy_group->following) { + if (!LeafVectorsAreConsistent( + strategy_group->strategies, strategy_group->following->strategies, + /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { + // It confuses the solver if two instructions have different number of + // sharding strategies but share the same ILP variable. The solver + // would run much longer and/or return infeasible solutions. + // So if two strategies' strategiess are inconsistent, we unfollow + // them. + strategy_group->following = nullptr; + } + } else if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { + if (strategy_group->childs.at(i)->following && + !LeafVectorsAreConsistent( + strategy_group->childs.at(i)->strategies, + strategy_group->childs.at(i)->following->strategies, + /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { + strategy_group->childs.at(i)->following = nullptr; + } + } + } + RemoveInvalidShardingsWithShapes( + ins->shape(), strategy_group.get(), + /* instruction_has_user_sharding */ ins->has_sharding()); + + if (instruction_execution_counts.contains(ins)) { + ScaleCostsWithExecutionCounts(strategy_group.get(), + instruction_execution_counts.at(ins)); + } else { + VLOG(5) << "No execution count available for " << ins->name(); + } + XLA_VLOG_LINES(2, + absl::StrCat("strategies:\n", strategy_group->ToString())); + + // Debug options: forcibly set the strategy of some instructions. + if (option.force_strategy) { + std::vector inst_indices = option.force_strategy_inst_indices; + std::vector stra_names = option.force_strategy_stra_names; + CHECK_EQ(inst_indices.size(), stra_names.size()); + auto it = absl::c_find(inst_indices, strategy_group->node_idx); + if (it != inst_indices.end()) { + CHECK(!strategy_group->is_tuple); + std::vector new_strategies; + int64_t idx = it - inst_indices.begin(); + for (const auto& stra : strategy_group->strategies) { + if (stra.name == stra_names[idx]) { + new_strategies.push_back(stra); + } + } + strategy_group->strategies = std::move(new_strategies); + } + } + + // When trying out multiple mesh shapes in the presence of user specified + // sharding (as in + // AutoShardingTest.AutoShardingKeepUserShardingInputOutput), there may be a + // situation when we cannot generate any shardings for an instruction when + // the mesh shape we're trying does not match with the mesh shape used in + // user specified shardings. So we disable the check in that situation. + if (!trying_multiple_mesh_shapes) { + CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) + << ins->ToString() << " does not have any valid strategies."; + } else if (!(strategy_group->is_tuple || + !strategy_group->strategies.empty())) { + return Status(absl::StatusCode::kFailedPrecondition, + "Could not generate any shardings for an instruction due " + "to mismatched mesh shapes."); + } + // Checks the shape of resharding_costs is valid. It will check fail if the + // shape is not as expected. + // CheckReshardingCostsShape(strategies.get()); + CheckMemoryCosts(strategy_group.get(), ins->shape()); + strategy_map[ins] = std::move(strategy_group); + } // end of for loop + + // If gradient accumulation is used, adjust the cost of all-reduce for + // gradient synchronization. + if (option.grad_acc_num_micro_batches > 1) { + // find gradient-computation instructions + std::vector grad_insts = + GetGradientComputationInstructions(instructions); + for (const HloInstruction* inst : grad_insts) { + StrategyGroup* stra_vector = strategy_map[inst].get(); + CHECK(!stra_vector->is_tuple); + + for (auto& stra : stra_vector->strategies) { + if (absl::StrContains(stra.name, "allreduce")) { + stra.communication_cost /= option.grad_acc_num_micro_batches; + } + } + } + } + + return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), + std::move(associative_dot_pairs)); +} + +// NOLINTEND + +} // namespace spmd +} // namespace xla From d2a94e9ef9e32f1d7e463de76c2a3ba617c07d01 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Thu, 30 Nov 2023 12:51:20 -0800 Subject: [PATCH 250/381] [stream_executor][NFC] Cleanup topk_kernel refactor PiperOrigin-RevId: 586759885 --- third_party/xla/xla/service/gpu/runtime/topk_kernel.cc | 2 -- third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc b/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc index 7a65107ea9e6eb..3bc3dfc585b6b6 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel.cc @@ -38,8 +38,6 @@ namespace xla::gpu { namespace { -using se::gpu::GpuStreamHandle; - size_t NumThreads(size_t n, size_t k, size_t batch_size) { // Estimate number of threads per block that can run concurrently given the // register footprint. diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc index 9bc2ccff274f8e..32d776574cbae9 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" @@ -212,7 +213,7 @@ void BM_SmallTopk(benchmark::State& state) { CHECK_OK(stream.BlockHostUntilDone()); auto timer_duration = timer.value().GetElapsedDuration(); CHECK_OK(timer_duration.status()); - state.SetIterationTime(ToDoubleMicroseconds(timer_duration.value())); + state.SetIterationTime(absl::ToDoubleMicroseconds(timer_duration.value())); } size_t items_processed = batch_size * n * state.iterations(); state.SetItemsProcessed(items_processed); From fec118a8b599d22c0b33caed1bde99be7b153f55 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 30 Nov 2023 12:52:14 -0800 Subject: [PATCH 251/381] [xla:ffi] Split BufferBase into a weakly typed BufferBase and a templated Buffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also added BufferRN aliases for buffers with common ranks. name cpu/op BM_BufferBaseArgX1 6.51ns ± 4% BM_BufferBaseArgX4 12.8ns ±13% BM_BufferArgX1 6.71ns ± 1% BM_BufferArgX4 12.9ns ± 9% BM_TupleOfI32Attrs 45.6ns ± 1% name time/op BM_BufferBaseArgX1 6.51ns ± 4% BM_BufferBaseArgX4 12.8ns ±13% BM_BufferArgX1 6.71ns ± 1% BM_BufferArgX4 12.9ns ± 9% BM_TupleOfI32Attrs 45.6ns ± 1% PiperOrigin-RevId: 586760169 --- third_party/xla/xla/ffi/api/ffi.h | 37 +++++++++-- third_party/xla/xla/ffi/api/ffi_test.cc | 88 +++++++++++++++++++++---- 2 files changed, 108 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index a3e4b1fbbdbde0..f44f6327091b5a 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -116,6 +116,12 @@ class Error { // Arguments //===----------------------------------------------------------------------===// +struct BufferBase { + DataType dtype; + void* data; + Span dimensions; +}; + namespace internal { // A workaround for the fact that a static_assertion can be evaluated @@ -149,11 +155,19 @@ inline constexpr size_t kDynamicRank = std::numeric_limits::max(); } // namespace internal template -struct BufferBase { +struct Buffer { typename internal::PtrType::Type* data; Span dimensions; }; +// clang-format off +template using BufferR0 = Buffer; +template using BufferR1 = Buffer; +template using BufferR2 = Buffer; +template using BufferR3 = Buffer; +template using BufferR4 = Buffer; +// clang-format on + //===----------------------------------------------------------------------===// // Arguments decoding //===----------------------------------------------------------------------===// @@ -165,10 +179,25 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) { } } +template <> +struct ArgDecoding { + XLA_ATTRIBUTE_ALWAYS_INLINE + static std::optional Decode(XLA_FFI_ArgType type, void* arg, + DiagnosticEngine& diagnostic) { + if (type != XLA_FFI_ArgType_BUFFER) { + return diagnostic.Emit("Wrong argument type: expected ") + << XLA_FFI_ArgType_BUFFER << " but got " << type; + } + auto* buf = reinterpret_cast(arg); + return BufferBase{static_cast(buf->dtype), buf->data, + Span(buf->dims, buf->rank)}; + } +}; + template -struct ArgDecoding> { +struct ArgDecoding> { XLA_ATTRIBUTE_ALWAYS_INLINE - static std::optional> Decode( + static std::optional> Decode( XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) { if (type != XLA_FFI_ArgType_BUFFER) { return diagnostic.Emit("Wrong argument type: expected ") @@ -189,7 +218,7 @@ struct ArgDecoding> { return std::nullopt; } } - return BufferBase{data, Span(buf->dims, rank)}; + return Buffer{data, Span(buf->dims, rank)}; } }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 5e081313d59bef..caa5fc5e25971b 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -63,6 +63,24 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::BF16), encoded(DataType::BF16)); } +TEST(FfiTest, BufferBaseArgument) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg().To([&](auto buffer) { + EXPECT_EQ(buffer.data, storage.data()); + EXPECT_EQ(buffer.dimensions.size(), 2); + return Error::Success(); + }); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, BufferArgument) { std::vector storage(4, 0.0f); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); @@ -72,7 +90,7 @@ TEST(FfiTest, BufferArgument) { auto call_frame = builder.Build(); auto handler = - Ffi::Bind().Arg>().To([&](auto buffer) { + Ffi::Bind().Arg>().To([&](auto buffer) { EXPECT_EQ(buffer.data, storage.data()); EXPECT_EQ(buffer.dimensions.size(), 2); return Error::Success(); @@ -86,7 +104,7 @@ TEST(FfiTest, MissingBufferArgument) { CallFrameBuilder builder; auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -102,7 +120,7 @@ TEST(FfiTest, WrongRankBufferArgument) { builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -119,7 +137,7 @@ TEST(FfiTest, WrongTypeBufferArgument) { builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -144,6 +162,51 @@ static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { return builder; } +//===----------------------------------------------------------------------===// +// BM_BufferBaseArgX1 +//===----------------------------------------------------------------------===// + +void BM_BufferBaseArgX1(benchmark::State& state) { + auto call_frame = WithBufferArgs(1).Build(); + + auto handler = Ffi::Bind().Arg().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + return Error::Success(); + }); + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_BufferBaseArgX1); + +//===----------------------------------------------------------------------===// +// BM_BufferBaseArgX4 +//===----------------------------------------------------------------------===// + +void BM_BufferBaseArgX4(benchmark::State& state) { + auto call_frame = WithBufferArgs(4).Build(); + + auto handler = Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .To([](auto b0, auto b1, auto b2, auto b3) { + benchmark::DoNotOptimize(b0); + benchmark::DoNotOptimize(b1); + benchmark::DoNotOptimize(b2); + benchmark::DoNotOptimize(b3); + return Error::Success(); + }); + + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_BufferBaseArgX4); + //===----------------------------------------------------------------------===// // BM_BufferArgX1 //===----------------------------------------------------------------------===// @@ -151,11 +214,10 @@ static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { void BM_BufferArgX1(benchmark::State& state) { auto call_frame = WithBufferArgs(1).Build(); - auto handler = - Ffi::Bind().Arg>().To([](auto buffer) { - benchmark::DoNotOptimize(buffer); - return Error::Success(); - }); + auto handler = Ffi::Bind().Arg>().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + return Error::Success(); + }); for (auto _ : state) { CHECK_OK(Call(*handler, call_frame)); } @@ -171,10 +233,10 @@ void BM_BufferArgX4(benchmark::State& state) { auto call_frame = WithBufferArgs(4).Build(); auto handler = Ffi::Bind() - .Arg>() - .Arg>() - .Arg>() - .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() .To([](auto b0, auto b1, auto b2, auto b3) { benchmark::DoNotOptimize(b0); benchmark::DoNotOptimize(b1); From 14b24032a6b3f8be5297159c9f50b6f3ee1d35bc Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Thu, 30 Nov 2023 13:05:12 -0800 Subject: [PATCH 252/381] [HloValueSemanticsAnalysis] Handle OptimizationBarrier. PiperOrigin-RevId: 586763511 --- .../xla/xla/service/hlo_value_semantics_analysis.cc | 9 +++++++++ .../xla/xla/service/hlo_value_semantics_analysis.h | 1 + 2 files changed, 10 insertions(+) diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index 2e7f98cf5f357f..5720bfd49b6fba 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -1373,4 +1373,13 @@ Status HloValueSemanticsPropagation::HandleDomain(HloInstruction* domain) { return OkStatus(); } +Status HloValueSemanticsPropagation::HandleOptimizationBarrier( + HloInstruction* opt_barrier) { + HloInstruction* opt_barrier_operand = opt_barrier->mutable_operand(0); + const ShapeTree& operand_semantics = + analysis_->GetInstructionSemantics(opt_barrier_operand); + analysis_->DeepCopyHloValueSemantics(opt_barrier, operand_semantics); + return OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index 7037951eddaaf0..8cea61acf0e7ec 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -240,6 +240,7 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { Status HandleAsyncDone(HloInstruction* async_done) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleDomain(HloInstruction* domain) override; + Status HandleOptimizationBarrier(HloInstruction* opt_barrier) override; Status HandleRngBitGenerator(HloInstruction* rng_bit_generator) override; protected: From 2b918e2da507e500dd27baef7037773f09104ca8 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Thu, 30 Nov 2023 13:15:42 -0800 Subject: [PATCH 253/381] No public description PiperOrigin-RevId: 586766512 --- tensorflow/core/data/dataset_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index fd785c0f51647f..3aa6c659b78f00 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -1006,7 +1006,7 @@ REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<50>, REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("reduce_array_record_dataset_memory_usage", - RandomJobSamplePercentage<0>, AllTasks); + RandomJobSamplePercentage<1>, AllTasks); REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<0>, AllTasks); } // namespace From 20c05898b344c0fe17fab13cc355b1e94e3864ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Thu, 30 Nov 2023 13:26:49 -0800 Subject: [PATCH 254/381] [XLA:GPU] Fix fusion parameter limit again and remove hard check I removed the hard check, because this seems to break frequently, but it's not really an error. Also had to fix some tests, because the fix changes the order of parameters in the fusion. Also fixed TensorIterationSpec::IterationSpecFragment::ToString() during the debugging. PiperOrigin-RevId: 586770496 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 15 +++---- .../service/gpu/gemm_rewriter_triton_test.cc | 41 +++++++++++++++---- .../service/gpu/triton_tiling_propagation.cc | 3 +- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index f618411e7f49de..55dc53d87c409d 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -134,8 +134,9 @@ void FuseDotOnly(HloInstruction& hlo, OldToNewHloMap& output_old_to_new_map, // an input. int64_t NumAddedParameters(const HloInstruction& hlo) { // Non-scalar constant is equivalent to a parameter: one input, one output. - if (hlo.opcode() == HloOpcode::kConstant && - !ShapeUtil::IsScalar(hlo.shape())) { + if (hlo.opcode() == HloOpcode::kParameter || + (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape()))) { return 0; } // All other instructions add all own inputs and remove own single output. @@ -153,7 +154,7 @@ void TryToFuseWithInputsRecursively(HloInstruction& root, HloComputation::Builder& builder) { // Instructions at the fusion edge that can either get fused too or // become parameters of the fusion. Used to track the number of parameters. - absl::flat_hash_set inputs; + absl::flat_hash_set inputs = {&root}; // Traverse all connected instructions that could be fused, analyze them and // collect ones that will be fused. absl::flat_hash_set to_fuse_set; @@ -251,10 +252,10 @@ StatusOr FuseDot(HloInstruction& dot, gpu_version, context, old_to_new_map, fusion_inputs, builder); const int new_parameters = fusion_inputs.size() - operand_count_before; - TF_RET_CHECK(new_parameters <= - TritonFusionAnalysis::kMaxParameterPerDotScope) - << "Too many new parameters: " << new_parameters << " > " - << TritonFusionAnalysis::kMaxParameterPerDotScope; + if (new_parameters > TritonFusionAnalysis::kMaxParameterPerDotScope) { + LOG(WARNING) << "Too many new parameters fused: " << new_parameters + << " > " << TritonFusionAnalysis::kMaxParameterPerDotScope; + } return context; }; diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index e54bb3497825fa..7946ca5f7e1087 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -547,6 +547,33 @@ ENTRY e { TritonFusionAnalysis::kMaxParameterPerDotScope + 1); } +TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParametersForConcat) { + static_assert(TritonFusionAnalysis::kMaxParameterPerDotScope == 4, + "We have to update this test."); + // The concat shouldn't overgo the allowed parameter limit. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[3,3]{1,0} parameter(0) + b = f32[3,3]{1,0} parameter(1) + c = f32[3,3]{1,0} parameter(2) + d = f32[3,3]{1,0} parameter(3) + e = f32[3,3]{1,0} parameter(4) + f = f16[3,3]{1,0} parameter(5) + concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0} + convert = f32[3,3]{1,0} convert(f) + ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + TritonFusionAnalysis::kMaxParameterPerDotScope + 1); +} + TEST_F(GemmRewriterTritonLevel2Test, InstructionsReachableFromMultipleOperandsAreHandledCorrectly) { static_assert(TritonFusionAnalysis::kMaxParameterPerDotScope == 4, @@ -827,34 +854,34 @@ e { TritonFusionAnalysis::Execute(*computation)); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(0), 0), + computation->parameter_instruction(1), 0), ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153, /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(0), 1), + computation->parameter_instruction(1), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536, /*slice_start=*/0, /*sliced_count=*/1536, /*subfragments=*/ElementsAre(1536)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(1), 0), + computation->parameter_instruction(2), 0), ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153, /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(1), 1), + computation->parameter_instruction(2), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, /*slice_start=*/-1536, /*sliced_count=*/128, /*subfragments=*/ElementsAre(128)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(2), 0), + computation->parameter_instruction(3), 0), ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153, /*slice_start=*/0, /*sliced_count=*/153, /*subfragments=*/ElementsAre(153)))); EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(2), 1), + computation->parameter_instruction(3), 1), ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, /*slice_start=*/-1536 - 128, /*sliced_count=*/256, @@ -945,7 +972,7 @@ e { .Run(module.get()) .value()); EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Parameter(), + GmockMatch((m::Fusion(m::Parameter(), m::Concatenate(), m::Parameter(), m::Parameter())))); } diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 4db8c75024798e..61c9e0e05ec6cd 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -72,7 +72,8 @@ bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { std::string TensorIterationSpec::IterationSpecFragment::ToString() const { return absl::StrCat("{stride=", stride, ", count=", count, - ", slice_start=", slice_start, ", subfragments=[", + ", slice_start=", slice_start, + ", sliced_count=", sliced_count, ", subfragments=[", absl::StrJoin(subfragments, ", "), "]}"); } From 870c7fa947bdcee0c8e4da1e3be34695faa7073a Mon Sep 17 00:00:00 2001 From: Wilsin Gosti Date: Thu, 30 Nov 2023 13:28:13 -0800 Subject: [PATCH 255/381] #tf-data Add `buffer_output_elements` and `prefetch_input_elements` to `ParallelInterleaveV4` in xprof. PiperOrigin-RevId: 586770929 --- .../core/kernels/data/parallel_interleave_dataset_op.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2ae28298d81f27..41e98622f18e5f 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -226,7 +226,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { {"cycle_length", strings::Printf("%lld", static_cast(cycle_length))}, {"deterministic", - deterministic.IsNondeterministic() ? "false" : "true"}}) { + deterministic.IsNondeterministic() ? "false" : "true"}, + {"buffer_output_elements", + strings::Printf("%lld", + static_cast(buffer_output_elements_))}, + {"prefetch_input_elements", + strings::Printf( + "%lld", static_cast(prefetch_input_elements_))}}) { input_->Ref(); } From d0366fc5d2d2de526db0c093f209f6cf505afcea Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 30 Nov 2023 13:36:31 -0800 Subject: [PATCH 256/381] Import ref_variable.py inline in variable_v1.py to ensure the tf1 variable creation function is always available. PiperOrigin-RevId: 586773267 --- tensorflow/python/ops/BUILD | 27 +++++++++++---------------- tensorflow/python/ops/ref_variable.py | 6 ------ tensorflow/python/ops/variable_v1.py | 21 ++++++++------------- 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 4b1892b922edc9..784a8b306412ad 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3167,19 +3167,27 @@ py_strict_library( ], ) -py_strict_library( +alias( name = "ref_variable", - srcs = ["ref_variable.py"], + actual = ":variable_v1", +) + +py_strict_library( + name = "variable_v1", + srcs = [ + "ref_variable.py", + "variable_v1.py", + ], srcs_version = "PY3", deps = [ ":array_ops", ":array_ops_gen", + ":cond", ":resource_variable_ops", ":resource_variables_toggle", ":state_ops", ":state_ops_gen", ":variable_scope", - ":variable_v1", ":variables", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", @@ -3193,19 +3201,6 @@ py_strict_library( "//tensorflow/python/types:core", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", - ], -) - -py_strict_library( - name = "variable_v1", - srcs = ["variable_v1.py"], - srcs_version = "PY3", - deps = [ - ":cond", - ":state_ops", - ":variable_scope", - ":variables", - "//tensorflow/python/framework:ops", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_should_use", ], diff --git a/tensorflow/python/ops/ref_variable.py b/tensorflow/python/ops/ref_variable.py index 7e51288b48ef9d..241275b44da30f 100644 --- a/tensorflow/python/ops/ref_variable.py +++ b/tensorflow/python/ops/ref_variable.py @@ -97,9 +97,6 @@ def default_variable_creator(next_creator=None, **kwargs): shape=shape) -variable_v1.default_variable_creator = default_variable_creator - - def _to_proto_fn(v, export_scope=None): """Converts Variable and ResourceVariable to VariableDef for collections.""" return v.to_proto(export_scope=export_scope) @@ -1346,6 +1343,3 @@ def _restore_from_tensors(self, restored_tensors): # allowing instances of the class to be used as tensors. tensor_conversion_registry.register_tensor_conversion_function( RefVariable, RefVariable._TensorConversionFunction) # pylint: disable=protected-access - - -variable_v1.set_variable_from_proto_fn(RefVariable) diff --git a/tensorflow/python/ops/variable_v1.py b/tensorflow/python/ops/variable_v1.py index d7d4f0e5daeee9..f3cca80758e5cb 100644 --- a/tensorflow/python/ops/variable_v1.py +++ b/tensorflow/python/ops/variable_v1.py @@ -23,15 +23,6 @@ from tensorflow.python.util.tf_export import tf_export -_variable_from_proto_fn = None - - -def set_variable_from_proto_fn(variable_from_proto_fn): - """Set the variable class that variable proto defs will be converted to.""" - global _variable_from_proto_fn - _variable_from_proto_fn = variable_from_proto_fn - - @tf_export(v1=["is_variable_initialized"]) @tf_should_use.should_use_result def is_variable_initialized(variable): @@ -47,9 +38,12 @@ def is_variable_initialized(variable): return state_ops.is_variable_initialized(variable) -def default_variable_creator(_, **kwds): - del kwds - raise NotImplementedError("ref_variable needs to be imported") +def default_variable_creator(next_creator=None, **kwds): + from tensorflow.python.ops import ref_variable # pylint: disable=g-import-not-at-top + + return ref_variable.default_variable_creator( + next_creator=next_creator, **kwds + ) @tf_export(v1=["Variable"]) @@ -269,7 +263,8 @@ def initialized_value(self): @staticmethod def from_proto(variable_def, import_scope=None): - return _variable_from_proto_fn( + from tensorflow.python.ops import ref_variable # pylint: disable=g-import-not-at-top + return ref_variable.RefVariable( variable_def=variable_def, import_scope=import_scope) @classmethod From eba052ccadbc7395a5c72046b6f22bb045aaebad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 13:39:48 -0800 Subject: [PATCH 257/381] Modifies the return type of `WriteToString` to include whether it's saved as a chunked protobuf or regular protobuf. Callers need this information to load the serialized `std::string` correctly. PiperOrigin-RevId: 586774412 --- .../saved_model/image_format/internal_api.cc | 5 +++-- .../saved_model/image_format/internal_api.h | 8 +++++-- tensorflow/tools/proto_splitter/cc/BUILD | 1 + .../cc/composable_splitter_base.cc | 21 ++++++++++++------- .../cc/composable_splitter_base.h | 7 +++++-- .../cc/composable_splitter_test.cc | 14 +++++++++---- 6 files changed, 39 insertions(+), 17 deletions(-) diff --git a/tensorflow/cc/saved_model/image_format/internal_api.cc b/tensorflow/cc/saved_model/image_format/internal_api.cc index f9eea13682765b..db38d1786e59ea 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.cc +++ b/tensorflow/cc/saved_model/image_format/internal_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/image_format/internal_api.h" #include +#include #include "absl/log/log.h" #include "absl/status/status.h" @@ -105,7 +106,7 @@ absl::Status WriteSavedModel(SavedModel* saved_model_proto, #endif } -absl::StatusOr WriteSavedModelToString( +absl::StatusOr> WriteSavedModelToString( SavedModel* saved_model_proto) { #if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); @@ -119,7 +120,7 @@ absl::StatusOr WriteSavedModelToString( #if !IS_OSS // TODO(b/311769337): Define the function unconditionally after tf oss // dependency is updated to protobuf v22.x. -absl::StatusOr WriteSavedModelToCord( +absl::StatusOr> WriteSavedModelToCord( SavedModel* saved_model_proto) { tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); return splitter.WriteToCord(); diff --git a/tensorflow/cc/saved_model/image_format/internal_api.h b/tensorflow/cc/saved_model/image_format/internal_api.h index 7a14b4d031972f..5c9b13d0f97364 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.h +++ b/tensorflow/cc/saved_model/image_format/internal_api.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CC_SAVED_MODEL_IMAGE_FORMAT_INTERNAL_API_H_ #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -43,10 +44,13 @@ absl::Status ReadSavedModel(const std::string& file_prefix, absl::Status WriteSavedModel(SavedModel* saved_model_proto, const std::string& file_prefix); // Writes the SavedModel proto to std::string -absl::StatusOr WriteSavedModelToString( +// The bool field record whether it's saved as a chunked protobuf (true) or +// regular protobuf (false) +absl::StatusOr> WriteSavedModelToString( SavedModel* saved_model_proto); #if !IS_OSS -absl::StatusOr WriteSavedModelToCord(SavedModel* saved_model_proto); +absl::StatusOr> WriteSavedModelToCord( + SavedModel* saved_model_proto); #endif // See above. The `debug_max_size` argument can be used to the maximum size to diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index e7c60cb3050b6b..1188ed94533864 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -49,6 +49,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@riegeli//riegeli/bytes:cord_writer", "@riegeli//riegeli/bytes:fd_writer", diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index 76f44d9bb8ed21..b02c09c6fa8d62 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #define IS_OSS true @@ -179,7 +181,8 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { return absl::OkStatus(); } -absl::StatusOr ComposableSplitterBase::WriteToString() { +absl::StatusOr> +ComposableSplitterBase::WriteToString() { TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); auto split_results = Split(); @@ -192,6 +195,8 @@ absl::StatusOr ComposableSplitterBase::WriteToString() { // Export regular pb. if (!message_->SerializeToString(&output)) return absl::InvalidArgumentError("Serialization to string failed"); + LOG(INFO) << "Splitter output written to string"; + return std::make_tuple(output, false); } else { // Export Riegeli / chunked file. using WriterType = riegeli::StringWriter<>; @@ -200,13 +205,14 @@ absl::StatusOr ComposableSplitterBase::WriteToString() { TF_RETURN_IF_ERROR(WriteToRecordWriter( writer, chunks, chunked_message, Version())); if (!writer.Close()) return writer.status(); + LOG(INFO) << "Splitter output written to string"; + return std::make_tuple(output, true); } - LOG(INFO) << "Splitter output written to string"; - return output; } #if !IS_OSS -absl::StatusOr ComposableSplitterBase::WriteToCord() { +absl::StatusOr> +ComposableSplitterBase::WriteToCord() { TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); auto split_results = Split(); @@ -219,6 +225,8 @@ absl::StatusOr ComposableSplitterBase::WriteToCord() { // Export regular pb. if (!message_->SerializeToCord(&output)) return absl::InvalidArgumentError("Serialization to absl::Cord failed"); + LOG(INFO) << "Splitter output written to absl::Cord"; + return std::make_tuple(output, false); } else { // Export Riegeli / chunked file. using WriterType = riegeli::CordWriter<>; @@ -227,10 +235,9 @@ absl::StatusOr ComposableSplitterBase::WriteToCord() { TF_RETURN_IF_ERROR(WriteToRecordWriter( writer, chunks, chunked_message, Version())); if (!writer.Close()) return writer.status(); + LOG(INFO) << "Splitter output written to absl::Cord"; + return std::make_tuple(output, true); } - LOG(INFO) << "Splitter output written to absl::Cord"; - - return output; } #endif diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h index 55ced8eb992ef3..a37a3c61ca0a02 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -66,9 +67,11 @@ class ComposableSplitterBase : public Splitter { // attach a `.pb` or `.cpb` (chunked pb) suffix depending on whether the // proto is split. absl::Status Write(std::string file_prefix) override; - absl::StatusOr WriteToString(); + // The bool field record whether it's saved as a chunked protobuf (true) or + // regular protobuf (false). + absl::StatusOr> WriteToString(); #if !IS_OSS - absl::StatusOr WriteToCord(); + absl::StatusOr> WriteToCord(); #endif VersionDef Version() override; diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc index 39930cc04a7ff0..8efdf36caee628 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc @@ -186,8 +186,11 @@ TEST(RepeatedStringSplitterTest, TestWriteToString) { std::vector strings = {"piece-1", "piece-2", "piece-3"}; auto message = SetUpRepeatedString(strings); RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); - - TF_ASSERT_OK_AND_ASSIGN(std::string string_output, splitter.WriteToString()); + auto string_output_results = splitter.WriteToString(); + TF_EXPECT_OK(string_output_results.status()); + std::string string_output = std::get<0>(string_output_results.value()); + bool is_chunked = std::get<1>(string_output_results.value()); + EXPECT_TRUE(is_chunked); // Look for the last chunk, which should contain a ChunkMetadata proto. riegeli::RecordReader> string_reader( std::forward_as_tuple(string_output)); @@ -200,8 +203,11 @@ TEST(RepeatedStringSplitterTest, TestWriteToCord) { std::vector strings = {"piece-1", "piece-2", "piece-3"}; auto message = SetUpRepeatedString(strings); RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); - - TF_ASSERT_OK_AND_ASSIGN(absl::Cord cord_output, splitter.WriteToCord()); + auto cord_output_results = splitter.WriteToCord(); + TF_EXPECT_OK(cord_output_results.status()); + absl::Cord cord_output = std::get<0>(cord_output_results.value()); + bool is_chunked = std::get<1>(cord_output_results.value()); + EXPECT_TRUE(is_chunked); // Look for the last chunk, which should contain a ChunkMetadata proto. riegeli::RecordReader> cord_reader( std::forward_as_tuple(&cord_output)); From 77ffe8b72f88baed697e4ad377cd7963d580db73 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Thu, 30 Nov 2023 13:48:06 -0800 Subject: [PATCH 258/381] #tensorflow enable double-typed gauge cell PiperOrigin-RevId: 586776669 --- tensorflow/core/profiler/utils/tfstreamz_utils.cc | 3 +++ .../tsl/tsl/lib/monitoring/collected_metrics.h | 1 + .../tsl/tsl/lib/monitoring/collection_registry.h | 12 ++++++++++++ .../xla/third_party/tsl/tsl/lib/monitoring/gauge.h | 6 ++++-- .../third_party/tsl/tsl/lib/monitoring/metric_def.h | 13 ++++++++++++- 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc index af957c54843ec7..0b32f5712edba5 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.cc @@ -112,6 +112,9 @@ Status SerializeToXPlane(const std::vector& snapshots, xevent.AddStatValue(*metadata, *xplane.GetOrCreateStatMetadata( point->string_value)); break; + case monitoring::ValueType::kDouble: + xevent.AddStatValue(*metadata, point->double_value); + break; case monitoring::ValueType::kHistogram: xevent.AddStatValue(*metadata, point->histogram_value); break; diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h index 8582594922adf2..ba67299b57a952 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h @@ -90,6 +90,7 @@ struct Point { int64_t int64_value; string string_value; bool bool_value; + double double_value; HistogramProto histogram_value; Percentiles percentiles_value; diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h index d988d2f19f15ad..7af6c87e51f0bb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h @@ -352,6 +352,18 @@ inline void CollectValue(Percentiles value, Point* const point) { point->percentiles_value = std::move(value); } +template <> +inline void CollectValue(double value, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value; +} + +template <> +inline void CollectValue(std::function value_fn, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value_fn(); +} + // Used by the CollectionRegistry class to collect all the values of all the // metrics in the registry. This is an implementation detail of the // CollectionRegistry class, please do not depend on this. diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h index 93cbe9aa928df0..2552013f3d7150 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h @@ -296,8 +296,10 @@ Gauge* Gauge::New( std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double, and string types."); return new Gauge( MetricDef( std::forward(metric_def_args)...)); diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h index f8c21c360a2b09..ab454664691b1e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h @@ -47,7 +47,8 @@ enum class ValueType : int { kHistogram, kString, kBool, - kPercentiles + kPercentiles, + kDouble }; // Everything in the internal namespace is implementation details. Do not depend @@ -97,6 +98,16 @@ inline ValueType GetValueType>() { return ValueType::kBool; } +template <> +inline ValueType GetValueType() { + return ValueType::kDouble; +} + +template <> +inline ValueType GetValueType>() { + return ValueType::kDouble; +} + } // namespace internal // Abstract base class for a metric definition. From 73ac629772b5d0576178abf6eecb1871eccf3580 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 14:00:20 -0800 Subject: [PATCH 259/381] Adds logs statements to report the amount of time spent in the constraint solver. PiperOrigin-RevId: 586779921 --- .../hlo/experimental/auto_sharding/auto_sharding_solver.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 8b73e6a43bc002..e44f8ed7b707ec 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/util.h" @@ -560,7 +561,12 @@ AutoShardingSolverResult SolveAndExtractSolution( const std::vector>& e, const MPVariable* overbudget_var, const MPVariable* makespan_var, MPSolver& solver) { + absl::Time start_time = absl::Now(); auto status = solver.Solve(); + absl::Time end_time = absl::Now(); + auto duration = end_time - start_time; + LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms"; + if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; #ifdef PLATFORM_GOOGLE From 2f0ff66f39bb2fbdeb76e89a9cf9c8db0bf7240c Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 30 Nov 2023 14:51:00 -0800 Subject: [PATCH 260/381] Import resource_variable_ops.py inline in variables.py to ensure the tf2 variable creation function is always available. PiperOrigin-RevId: 586794787 --- tensorflow/python/ops/BUILD | 70 ++++++++----------- .../python/ops/resource_variable_ops.py | 3 - tensorflow/python/ops/variables.py | 8 ++- 3 files changed, 33 insertions(+), 48 deletions(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 784a8b306412ad..e05d71cc3c6657 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -2108,48 +2108,9 @@ py_strict_library( ], ) -py_strict_library( +alias( name = "resource_variable_ops", - srcs = ["resource_variable_ops.py"], - srcs_version = "PY3", - deps = [ - ":array_ops", - ":array_ops_gen", - ":handle_data_util", - ":resource_variable_ops_gen", - ":state_ops", - ":state_ops_gen", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/core/function/trace_type", - "//tensorflow/python/checkpoint:tensor_callable", - "//tensorflow/python/client:pywrap_tf_session", - "//tensorflow/python/compat", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:record", - "//tensorflow/python/eager:tape", - "//tensorflow/python/framework:auto_control_deps_utils", - "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:composite_tensor_gradient", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:cpp_shape_inference_proto_py", - "//tensorflow/python/framework:device", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:indexed_slices", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor", - "//tensorflow/python/framework:tensor_conversion_registry", - "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/saved_model:nested_structure_coder", - "//tensorflow/python/trackable:base", - "//tensorflow/python/types:core", - "//tensorflow/python/util:_pywrap_utils", - "//tensorflow/python/util:compat", - "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:tf_export", - "//third_party/py/numpy", - ], + actual = ":variables", ) py_strict_library( @@ -3140,30 +3101,55 @@ py_strict_library( py_strict_library( name = "variables", - srcs = ["variables.py"], + srcs = [ + "resource_variable_ops.py", + "variables.py", + ], srcs_version = "PY3", deps = [ ":array_ops", + ":array_ops_gen", ":array_ops_stack", ":control_flow_ops", + ":handle_data_util", ":math_ops", ":math_ops_gen", + ":resource_variable_ops_gen", ":state_ops", + ":state_ops_gen", "//tensorflow/core:protos_all_py", + "//tensorflow/core/function/trace_type", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/checkpoint:tensor_callable", + "//tensorflow/python/client:pywrap_tf_session", + "//tensorflow/python/compat", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:record", + "//tensorflow/python/eager:tape", + "//tensorflow/python/framework:auto_control_deps_utils", + "//tensorflow/python/framework:composite_tensor", + "//tensorflow/python/framework:composite_tensor_gradient", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:cpp_shape_inference_proto_py", + "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", + "//tensorflow/python/types:core", "//tensorflow/python/util:_pywrap_utils", + "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:object_identity", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_should_use", "//tensorflow/python/util:traceback_utils", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index d5b776e5997a30..f1ddb2c45f95bf 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -371,9 +371,6 @@ def default_variable_creator_v2(next_creator=None, **kwargs): ) -variables.default_variable_creator_v2 = default_variable_creator_v2 - - class BaseResourceVariable(variables.Variable, core.Tensor): """A python variable from an existing handle.""" diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 5208dd1c8229ae..6dc2a43739123c 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -44,9 +44,11 @@ from tensorflow.python.util.tf_export import tf_export -def default_variable_creator_v2(_, **kwds): - del kwds - raise NotImplementedError("resource_variable_ops needs to be imported") +def default_variable_creator_v2(next_creator=None, **kwds): + from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-import-not-at-top + + return resource_variable_ops.default_variable_creator_v2( + next_creator=next_creator, **kwds) def _make_getter(captured_getter, captured_previous): From 2efd24eaef6a483c6d10361e8c695efc89ccde7f Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 30 Nov 2023 15:36:52 -0800 Subject: [PATCH 261/381] Turn static functions into regular free functions in `lift_quantizable_spots_as_functions.cc`. They are already in an anonymous namespace. PiperOrigin-RevId: 586807258 --- .../stablehlo/passes/lift_quantizable_spots_as_functions.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 383f2430c94eee..447652ca151282 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -42,7 +42,7 @@ namespace { // TODO - b/303543789: Move the helper functions below to a separate util. // Fetches the default or null attribute, used for pattern matching. -static Attribute DefaultOrNullAttr(OpBuilder& builder, Attribute& attr) { +Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { if (!attr) { return builder.getStringAttr(kNullAttributeValue); } @@ -51,7 +51,7 @@ static Attribute DefaultOrNullAttr(OpBuilder& builder, Attribute& attr) { // Checks whether the value of a constant equals the given float, regardless // of the tensor dimension. -static bool FloatValueEquals(const Attribute& attr, double value) { +bool FloatValueEquals(const Attribute& attr, const double value) { auto fp_attr = attr.dyn_cast_or_null(); if (!fp_attr) return false; @@ -101,7 +101,7 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { } // Remove all attr_map attributes. - module_op.walk([&](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); } } // namespace From f8c92f3aba10941705e2346ca7659ea61c9df906 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 15:48:22 -0800 Subject: [PATCH 262/381] Stop throwing exceptions if users pass `profiler_outdir` to `trace_export()` and warn users who don't pass it to `trace_on()`. This change avoids breaking users not throwing exceptions in these functions. However, users will need to pass `profiler_outdir` to `trace_on()` if they want profiling to work. PiperOrigin-RevId: 586810146 --- RELEASE.md | 6 +++++ tensorflow/python/ops/summary_ops_v2.py | 22 +++++++++++++++---- .../api/golden/v2/tensorflow.summary.pbtxt | 2 +- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index f2456bfcb0b1e9..8af503d972f2d4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,6 +9,12 @@ * * +* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be set + if `profiler` arg is set to `True`. + * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. Enabling + the profiler now requires setting `profiler_outdir` in `trace_on`. + + ### Known Caveats * diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index de8209370a9522..761f42885ada59 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -1367,15 +1367,23 @@ def trace_on(graph=True, profiler=False, profiler_outdir=None): # pylint: disab context.context().enable_graph_collection() if profiler: if profiler_outdir is None: - raise ValueError("Argument `profiler_outdir` is not specified.") - context.context().enable_run_metadata() - _profiler.start(profiler_outdir) + # TODO(b/149431324): Change this to throw a ValueError when Tensorflow + # major version advances. (current version is 2.15) + logging.warn( + "No `profiler_outdir` passed to trace_on(). Profiler won't be" + " enabled." + ) + else: + context.context().enable_run_metadata() + _profiler.start(profiler_outdir) _current_trace_context = _TraceContext(graph=graph, profiler=profiler) +# TODO(b/149431324): Delete `profiler_outdir` arg when Tensorflow major version +# advances. (current version is 2.15) @tf_export("summary.trace_export", v1=[]) -def trace_export(name, step=None): +def trace_export(name, step=None, profiler_outdir=None): """Stops and exports the active trace as a Summary and/or profile file. Stops the trace and exports all metadata collected during the trace to the @@ -1386,6 +1394,7 @@ def trace_export(name, step=None): step: Explicit `int64`-castable monotonic step value for this summary. If omitted, this defaults to `tf.summary.experimental.get_step()`, which must not be None. + profiler_outdir: This arg is a no-op. Please set this in trace_on(). Raises: ValueError: if a default writer exists, but no step was provided and @@ -1416,6 +1425,11 @@ def trace_export(name, step=None): run_metadata(name, run_meta, step) if profiler: + if profiler_outdir: + logging.warn( + "Ignoring `profiler_outdir` passed to trace_export(). Please pass it" + " to trace_on() instead." + ) _profiler.stop() trace_off() diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt index 43efa9619844dc..1d36dacaff7eea 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt @@ -54,7 +54,7 @@ tf_module { } member_method { name: "trace_export" - argspec: "args=[\'name\', \'step\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'name\', \'step\', \'profiler_outdir\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "trace_off" From 69c72ef7b7da673da1c0b093ff28007708896be4 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Thu, 30 Nov 2023 16:09:01 -0800 Subject: [PATCH 263/381] Add a way to serialize CompiledMemoryStats. PiperOrigin-RevId: 586816226 --- third_party/xla/xla/pjrt/BUILD | 11 ++++++++- .../xla/xla/pjrt/executable_metadata.proto | 15 ++++++++++++ third_party/xla/xla/pjrt/pjrt_executable.h | 24 ++++++++++++++++++- 3 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/pjrt/executable_metadata.proto diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index dd1e74b9528ce7..4853b646929a25 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -230,6 +230,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":compile_options_proto_cc", + ":executable_metadata_proto_cc", ":execute_options_proto_cc", ":pjrt_common", "//xla:shape_layout", @@ -253,7 +254,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], ) @@ -800,3 +800,12 @@ tf_proto_library( srcs = ["execute_options.proto"], visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "executable_metadata_proto", + srcs = ["executable_metadata.proto"], + protodeps = [ + "//xla/service:hlo_proto", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/pjrt/executable_metadata.proto b/third_party/xla/xla/pjrt/executable_metadata.proto new file mode 100644 index 00000000000000..c5995492d28fdb --- /dev/null +++ b/third_party/xla/xla/pjrt/executable_metadata.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package xla; + +import "xla/service/hlo.proto"; + +// Mirror of xla::CompiledMemoryStats. +message CompiledMemoryStatsProto { + int64 generated_code_size_in_bytes = 1; + int64 argument_size_in_bytes = 2; + int64 output_size_in_bytes = 3; + int64 alias_size_in_bytes = 4; + int64 temp_size_in_bytes = 5; + xla.HloProto hlo_proto = 6; +} diff --git a/third_party/xla/xla/pjrt/pjrt_executable.h b/third_party/xla/xla/pjrt/pjrt_executable.h index 0c603c20f1baf2..2a739484ade108 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.h +++ b/third_party/xla/xla/pjrt/pjrt_executable.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/client/executable_build_options.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/executable_metadata.pb.h" #include "xla/pjrt/execute_options.pb.h" #include "xla/pjrt/pjrt_common.h" #include "xla/service/compiler.h" @@ -44,7 +45,6 @@ limitations under the License. #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" namespace xla { @@ -275,6 +275,28 @@ struct CompiledMemoryStats { std::string serialized_hlo_proto = ""; std::string DebugString() const; + + CompiledMemoryStatsProto ToProto() { + CompiledMemoryStatsProto proto; + proto.set_generated_code_size_in_bytes(generated_code_size_in_bytes); + proto.set_argument_size_in_bytes(argument_size_in_bytes); + proto.set_output_size_in_bytes(output_size_in_bytes); + proto.set_alias_size_in_bytes(alias_size_in_bytes); + proto.set_temp_size_in_bytes(temp_size_in_bytes); + proto.mutable_hlo_proto()->ParseFromString(serialized_hlo_proto); + return proto; + } + + static CompiledMemoryStats FromProto(const CompiledMemoryStatsProto& proto) { + CompiledMemoryStats stats; + stats.generated_code_size_in_bytes = proto.generated_code_size_in_bytes(); + stats.argument_size_in_bytes = proto.argument_size_in_bytes(); + stats.output_size_in_bytes = proto.alias_size_in_bytes(); + stats.alias_size_in_bytes = proto.alias_size_in_bytes(); + stats.temp_size_in_bytes = proto.temp_size_in_bytes(); + stats.serialized_hlo_proto = proto.hlo_proto().SerializeAsString(); + return stats; + } }; class PjRtExecutable { From 595f5ea875648fcc274f874cbd1a5a67ee860c7d Mon Sep 17 00:00:00 2001 From: Grant Jensen Date: Thu, 30 Nov 2023 16:37:56 -0800 Subject: [PATCH 264/381] [tflite] Fix tflite selective_build_script. This file prints out a header file which is piped to ops_to_register.h. The problem lies in that when we import selective_registration_header_lib it prints "Using TensorFlow backend" which also gets piped to ops_to_register, but does not belong. PiperOrigin-RevId: 586824804 --- .../python/tools/print_selective_registration_header.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/tools/print_selective_registration_header.py b/tensorflow/python/tools/print_selective_registration_header.py index 8ae04c137e4eb4..6809ea62f51513 100644 --- a/tensorflow/python/tools/print_selective_registration_header.py +++ b/tensorflow/python/tools/print_selective_registration_header.py @@ -32,10 +32,17 @@ """ import argparse +import contextlib import sys from absl import app -from tensorflow.python.tools import selective_registration_header_lib + +# Import statement prints "Using TensorFlow backend" which gets piped to +# ops_to_register.h. Avoid this printing import statement to /dev/null +with open('/dev/null', 'w') as f, contextlib.redirect_stdout(f): + # pylint: disable=g-import-not-at-top + from tensorflow.python.tools import selective_registration_header_lib + # pylint: enable FLAGS = None From 7378f1e02b804094245638e8e1e05b80b0e71d1d Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Thu, 30 Nov 2023 16:48:01 -0800 Subject: [PATCH 265/381] Add avg time in op detail PiperOrigin-RevId: 586827165 --- tensorflow/core/profiler/convert/op_profile_builder.cc | 5 ++++- tensorflow/core/profiler/protobuf/op_profile.proto | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index 2111ea4f56ac6e..124d4096518f95 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -185,9 +185,12 @@ void PopulateOpMetricsNode( // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts metrics->set_raw_time(op_metrics.time_ps()); metrics->set_raw_flops(op_metrics.flops()); + metrics->set_occurrences(op_metrics.occurrences()); + metrics->set_avg_time_ps( + SafeDivide(op_metrics.time_ps(), op_metrics.occurrences())); // Hack to approximate utilization for INT8/4 convolution HLOs: - // Since MXU BW is 2x/4x for INT8/4, multiply peak BW by the factor detemrined + // Since MXU BW is 2x/4x for INT8/4, multiply peak BW by the factor determined // by the computation size if (GetComputationSize(*node) == 8) { peak_gigaflops_per_second_per_core *= 2; diff --git a/tensorflow/core/profiler/protobuf/op_profile.proto b/tensorflow/core/profiler/protobuf/op_profile.proto index 9c29d1777eb38a..14ce2d203fb16a 100644 --- a/tensorflow/core/profiler/protobuf/op_profile.proto +++ b/tensorflow/core/profiler/protobuf/op_profile.proto @@ -82,6 +82,10 @@ message Metrics { // Total bytes accessed for each memory type. // Index into array using MemBwType enum. repeated double raw_bytes_accessed_array = 15; + // Number of executions. + uint32 occurrences = 16; + // Average "accumlated" time in picoseconds that the operation took. + double avg_time_ps = 17; reserved 1, 3, 4, 13, 14; } From c655e08610f192ba283d440271d5a1f6c1286030 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Thu, 30 Nov 2023 16:56:30 -0800 Subject: [PATCH 266/381] Rework ShardingCallback to be more user-friendly. PiperOrigin-RevId: 586828958 --- tensorflow/python/checkpoint/sharding/BUILD | 4 +- .../checkpoint/sharding/sharding_policies.py | 58 +++--- .../sharding/sharding_policies_test.py | 6 +- .../checkpoint/sharding/sharding_util.py | 150 ++++++++++++--- .../checkpoint/sharding/sharding_util_test.py | 174 ++++++++++-------- 5 files changed, 248 insertions(+), 144 deletions(-) diff --git a/tensorflow/python/checkpoint/sharding/BUILD b/tensorflow/python/checkpoint/sharding/BUILD index 82745bc1df79f3..99b64ec2568dbb 100644 --- a/tensorflow/python/checkpoint/sharding/BUILD +++ b/tensorflow/python/checkpoint/sharding/BUILD @@ -19,6 +19,7 @@ py_strict_library( deps = [ ":sharding_util", "//tensorflow/python/eager:context", + "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", @@ -41,6 +42,7 @@ tf_py_strict_test( "//tensorflow/python/checkpoint", "//tensorflow/python/checkpoint:graph_view", "//tensorflow/python/eager:test", + "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_shape", @@ -79,6 +81,7 @@ tf_py_strict_test( "//tensorflow/python/checkpoint:graph_view", "//tensorflow/python/eager:remote", "//tensorflow/python/eager:test", + "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", @@ -86,7 +89,6 @@ tf_py_strict_test( "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:resource_variable_ops", - "//tensorflow/python/ops:variables", "//tensorflow/python/training:server_lib", "//tensorflow/python/training/saving:saveable_object", ], diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies.py b/tensorflow/python/checkpoint/sharding/sharding_policies.py index 9a512759d6003a..ed9576aa85fa93 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_policies.py +++ b/tensorflow/python/checkpoint/sharding/sharding_policies.py @@ -21,6 +21,7 @@ from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import context +from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib @@ -34,12 +35,11 @@ class ShardByDevicePolicy(sharding_util.ShardingCallback): """Policy that splits tensors into shards based on their device spec.""" - def __init__(self): - super().__init__( - self._device_callback_impl, - "Split tensors into shards based on their device spec.") + @property + def description(self) -> str: + return "Split tensors into shards based on their device spec." - def _device_callback_impl( + def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: @@ -58,7 +58,8 @@ def _device_callback_impl( tensor = shardable_tensor.tensor checkpoint_key = shardable_tensor.checkpoint_key slice_spec = shardable_tensor.slice_spec - device = saveable_object_util.set_cpu0(shardable_tensor.device) + device = device_lib.DeviceSpec.from_string( + saveable_object_util.set_cpu0(shardable_tensor.device.to_string())) (tensors_by_device .setdefault(device, {}) @@ -66,12 +67,6 @@ def _device_callback_impl( return list(tensors_by_device.values()) - def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter - _PartitionAxisAndSize = tuple[int, int] _OffsetAndShape = tuple[Sequence[int], Sequence[int]] @@ -81,11 +76,12 @@ class MaxShardSizePolicy(sharding_util.ShardingCallback): """Policy that splits tensors into shards with a max shard size.""" def __init__(self, max_shard_size: int): - super().__init__( - self._max_shard_size_callback_impl, - "Split tensors into shards with a max shard size.") self.max_shard_size = max_shard_size + @property + def description(self) -> str: + return "Split tensors into shards with a max shard size." + def _get_next_partition( self, shard_size_remaining: int, @@ -172,13 +168,14 @@ def _add_partition( for i in range(root_tensor_shape.rank)] slice_shape[min_axis] = part_size slice_size_in_bytes = int(math.prod(slice_shape)) * dtype_size - tensor_slice = array_ops.slice( - root_tensor, begin=slice_offset, size=slice_shape) + with ops.device(root_shardable_tensor.device): + tensor_slice = array_ops.slice( + root_tensor, begin=slice_offset, size=slice_shape) slice_spec = variables.Variable.SaveSliceInfo( full_name=checkpoint_key, full_shape=root_tensor_shape, var_offset=slice_offset, - var_shape=slice_shape).spec + var_shape=slice_shape).spec.strip() remaining_size = shard_size_remaining if slice_size_in_bytes > max_shard_size: logging.warning("Slice %s of tensor %s is a scalar of size %s bytes and " @@ -206,16 +203,13 @@ def _add_partition( return (remaining_size, (slice_offset, slice_shape)) - def _max_shard_size_callback_impl( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor], - max_shard_size: int + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSliceDict]: """Callback to split tensors into shards with a max shard size. Args: shardable_tensors: A list of ShardableTensors. - max_shard_size: Max size in bytes allowed for a checkpoint shard. Returns: List of shard dicts containing tensors. @@ -224,7 +218,7 @@ def _max_shard_size_callback_impl( tensors_by_shard = [] large_scalars = [] - shard_size_remaining = max_shard_size + shard_size_remaining = self.max_shard_size for shardable_tensor in shardable_tensors: root_tensor = shardable_tensor.tensor root_shape = shardable_tensor.shape @@ -249,13 +243,13 @@ def _max_shard_size_callback_impl( total_size = sum(sizes) dtype_size = max(sizes) - if (total_size > max_shard_size and + if (total_size > self.max_shard_size and (root_shape.rank is None or root_shape.rank == 0)): logging.warning("Tensor %s is a scalar of size %s bytes and cannot be " "partitioned into a shard of max shard size %s bytes. " "It will be added as an individual shard that exceeds " "the max shard size.", - checkpoint_key, total_size, max_shard_size) + checkpoint_key, total_size, self.max_shard_size) large_scalars.append( {checkpoint_key: {shardable_tensor.slice_spec: root_tensor}}) continue @@ -279,7 +273,7 @@ def _max_shard_size_callback_impl( working_tensor_offset=working_tensor_var_offset, part_axis_and_size=part_axis_and_size, shard_size_remaining=shard_size_remaining, - max_shard_size=max_shard_size, + max_shard_size=self.max_shard_size, tensors_by_shard=tensors_by_shard, large_scalars=large_scalars) @@ -287,7 +281,7 @@ def _max_shard_size_callback_impl( # Tensor partition couldn't fit in remaining shard space. Try again # with the next full shard. tensors_by_shard.append({}) - shard_size_remaining = max_shard_size + shard_size_remaining = self.max_shard_size else: working_tensor = array_ops.slice( root_tensor, begin=remaining_offset, size=remaining_shape) @@ -301,7 +295,7 @@ def _max_shard_size_callback_impl( full_name=checkpoint_key, full_shape=root_shape, var_offset=working_tensor_var_offset, - var_shape=working_tensor_shape).spec + var_shape=working_tensor_shape).spec.strip() if not tensors_by_shard: tensors_by_shard.append({}) (tensors_by_shard[-1] @@ -310,9 +304,3 @@ def _max_shard_size_callback_impl( shard_size_remaining -= working_tensor_size return tensors_by_shard + large_scalars - - def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] - ) -> Sequence[sharding_util.TensorSliceDict]: - return self.callback(shardable_tensors, self.max_shard_size) # pylint: disable=no-value-for-parameter diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies_test.py b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py index 9009b3ae2eba44..e437d9170b1f82 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_policies_test.py +++ b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py @@ -22,6 +22,7 @@ from tensorflow.python.checkpoint.sharding import sharding_policies from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import test +from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -55,12 +56,15 @@ def _get_shardable_tensors(self, root): dtype=tensor_save_spec.dtype, device=tensor_save_spec.device) save_spec_tensor = tensor_save_spec.tensor + device = (device_lib.DeviceSpec.from_string(tensor_save_spec.device) + if isinstance(tensor_save_spec.device, str) + else tensor_save_spec.device) shardable_tensors.append( sharding_util.ShardableTensor( _tensor_save_spec=tensor_save_spec, tensor=save_spec_tensor, dtype=tensor_save_spec.dtype, - device=tensor_save_spec.device, + device=device, name=tensor_save_spec.name, shape=save_spec_tensor.shape, slice_spec=slice_spec, diff --git a/tensorflow/python/checkpoint/sharding/sharding_util.py b/tensorflow/python/checkpoint/sharding/sharding_util.py index ef96a044598508..89df27672fc31a 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_util.py +++ b/tensorflow/python/checkpoint/sharding/sharding_util.py @@ -14,8 +14,10 @@ # ============================================================================== """Data structures and utilities for checkpoint sharding.""" +import abc import dataclasses -from typing import Callable, Mapping, Sequence +import inspect +from typing import Hashable, MutableMapping, Sequence from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import dtypes @@ -27,13 +29,19 @@ from tensorflow.python.training.saving import saveable_object -TensorSlice = Mapping[tensor_spec.TensorSpec, tensor_lib.Tensor] -TensorSliceDict = Mapping[str, TensorSlice] +TensorSlice = MutableMapping[tensor_spec.TensorSpec, tensor_lib.Tensor] +TensorSliceDict = MutableMapping[str, TensorSlice] @dataclasses.dataclass(frozen=True) class ShardableTensor: - """Tensor wrapper containing data necessary for sharding.""" + """Tensor wrapper containing data necessary for sharding. + + The tensor representation used as inputs to pre-made and custom + `tf.train.experiemental.ShardingCallback`s, which can be specified using the + `experimental_sharding_callback` option in `tf.train.CheckpointOptions`. + + """ _tensor_save_spec: saveable_object.SaveSpec tensor: tensor_lib.Tensor dtype: dtypes.DType @@ -47,22 +55,102 @@ class ShardableTensor: def __hash__(self) -> int: return hash((self.name, self.dtype, str(self.device), self.checkpoint_key)) + def __repr__(self) -> str: + return (f"\n{self.__class__.__name__}:\n" + f" _tensor_save_spec={self._tensor_save_spec!r}\n" + f" tensor={self.tensor!r}\n" + f" dtype={self.dtype!r}\n" + f" device={self.device!r}\n" + f" name={self.name!r}\n" + f" shape={self.shape!r}\n" + f" slice_spec={self.slice_spec!r}\n" + f" checkpoint_key={self.checkpoint_key!r}\n" + f" trackable={self.trackable!r}") -@dataclasses.dataclass(frozen=True) -class ShardingCallback: - """Checkpoint sharding callback function, along with a text description.""" - callback: Callable[ - [Sequence[ShardableTensor], ...], - Sequence[Mapping[ - str, Mapping[tensor_spec.TensorSpec, saveable_object.SaveSpec]]]] + +class ShardingCallback(abc.ABC): + """Checkpoint sharding callback function, along with a text description. + + A callback function wrapper that will be executed to determine how tensors + will be split into shards when the saver writes the checkpoint shards to disk. + + When calling the callback, it takes a list of + `tf.train.experimental.ShardableTensor`s as input (as well as any kwargs + defined by the `tf.train.experimental.ShardingCallback` subclass), and outputs + a `tensors_by_shard` dict. + + There are a few restrictions to keep in mind when creating a custom callback: + - Tensors must not be removed from the checkpoint. + - Tensors must not be reshaped. + - Tensor dtypes must not change. + - Tensors within a shard must belong to the same task. + Validation checks will be performed after the callback function is executed to + ensure these restrictions aren't violated. + + Here's an example of a simple custom callback: + + ``` + # Place all tensors in a single shard. + class AllInOnePolicy(tf.train.experimental.ShardingCallback): + @property + def description(self): + return "Place all tensors in a single shard." + + def __call__(self, shardable_tensors): + tensors = {} + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor_save_spec.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + + tensors.set_default(checkpoint_key, {})[slice_spec] = tensor + return [tensors] + + ckpt.save( + "path", + options=tf.train.CheckpointOptions( + experimental_sharding_callback=AllInOnePolicy())) + ``` + + The `description` attribute is used to identify the callback and to aid + debugging during saving and restoration. + + To take in kwargs, simply define the constructor and pass them in: + + ``` + class ParameterPolicy(tf.train.experimental.ShardingCallback): + def __init__(self, custom_param): + self.custom_param = custom_param + ... + + ckpt.save( + "path", + options=tf.train.CheckpointOptions( + experimental_sharding_callback=ParameterPolicy(custom_param=...))) + ``` + + """ description: str + @property + @abc.abstractmethod + def description(self) -> str: + pass + + @abc.abstractmethod + def __call__( + self, shardable_tensors: Sequence[ShardableTensor] + ) -> Sequence[TensorSlice]: + pass + def __hash__(self) -> int: - if hasattr(self.callback, "__name__"): - callback_hash = hash((self.callback.__module__, self.callback.__name__)) - else: - callback_hash = id(self.callback) - return hash((callback_hash, self.description)) + hash_val = hash(self.description) + for attr_name, attr_val in inspect.getmembers(self): + if not inspect.ismethod(attr_val) and not inspect.isfunction(attr_val): + hash_val ^= hash(attr_name) + if isinstance(attr_val, Hashable): + hash_val ^= hash(attr_val) + return hash_val def validate_shards( @@ -71,21 +159,21 @@ def validate_shards( callback_description: str ) -> None: """Validates shards generated by the sharding_callback.""" - unseen_tensor_dict = { - (shardable_tensor.slice_spec, - shardable_tensor.checkpoint_key): shardable_tensor.tensor - for shardable_tensor in shardable_tensors - if shardable_tensor.tensor is not None} + unseen_tensor_dict = {} + for shardable_tensor in shardable_tensors: + unseen_tensor_dict.setdefault( + shardable_tensor.checkpoint_key, {} + )[shardable_tensor.slice_spec] = shardable_tensor.tensor seen_tensor_set = set() for shard_tensors in shards: task_tensor = None for checkpoint_key, tensor_slice_dict in shard_tensors.items(): for slice_spec, shard_tensor in tensor_slice_dict.items(): - shard_tensor_id = (slice_spec, checkpoint_key) + slice_spec = slice_spec.strip() # Validate uniqueness. - if shard_tensor_id in seen_tensor_set: + if (checkpoint_key, slice_spec) in seen_tensor_set: raise RuntimeError( "After executing the checkpoint sharding callback, multiple " "tensors with the same checkpoint key and slice spec were " @@ -95,7 +183,7 @@ def validate_shards( f" slice_spec: {slice_spec}\n") # Validate no added tensors. - if shard_tensor_id not in unseen_tensor_dict: + if checkpoint_key not in unseen_tensor_dict: raise RuntimeError( "After executing the checkpoint sharding callback, a tensor " "not originally in the object graph was found in the " @@ -105,7 +193,7 @@ def validate_shards( f" slice_spec: {slice_spec}\n") # Validate no shape change. - target_shape = unseen_tensor_dict[shard_tensor_id].shape + target_shape = unseen_tensor_dict[checkpoint_key][slice_spec].shape if shard_tensor.shape != target_shape: raise RuntimeError( "After executing the checkpoint sharding callback, a tensor " @@ -117,7 +205,7 @@ def validate_shards( f" new tensor_shape: {shard_tensor.shape}\n") # Validate no dtype change. - target_dtype = unseen_tensor_dict[shard_tensor_id].dtype + target_dtype = unseen_tensor_dict[checkpoint_key][slice_spec].dtype if shard_tensor.dtype != target_dtype: raise RuntimeError( "After executing the checkpoint sharding callback, a tensor " @@ -137,7 +225,7 @@ def validate_shards( else: task1 = device_lib.DeviceSpec.from_string(task_tensor.device).task task2 = device_lib.DeviceSpec.from_string(shard_tensor.device).task - if task1 != task2: + if task1 is not None and task2 is not None and task1 != task2: raise RuntimeError( "After executing the checkpoint sharding callback, tensors " "with different tasks were found in the same shard:\n" @@ -151,13 +239,15 @@ def validate_shards( f" slice_spec: {slice_spec}\n" f" task: {task2}\n") - del unseen_tensor_dict[shard_tensor_id] - seen_tensor_set.add(shard_tensor_id) + del unseen_tensor_dict[checkpoint_key][slice_spec] + if not unseen_tensor_dict[checkpoint_key]: + del unseen_tensor_dict[checkpoint_key] + seen_tensor_set.add((checkpoint_key, slice_spec)) # validate no tensor removal if unseen_tensor_dict: tensors_info = "" - for slice_spec, ckpt_key in unseen_tensor_dict: + for ckpt_key, slice_spec in unseen_tensor_dict.items(): tensors_info += " tensor:\n" tensors_info += f" checkpoint_key: {ckpt_key}\n" tensors_info += f" slice_spec: {slice_spec}\n" diff --git a/tensorflow/python/checkpoint/sharding/sharding_util_test.py b/tensorflow/python/checkpoint/sharding/sharding_util_test.py index 3bf92da2e94c2b..117a5e32102571 100644 --- a/tensorflow/python/checkpoint/sharding/sharding_util_test.py +++ b/tensorflow/python/checkpoint/sharding/sharding_util_test.py @@ -23,6 +23,7 @@ from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import remote from tensorflow.python.eager import test +from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib @@ -30,7 +31,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables from tensorflow.python.training import server_lib from tensorflow.python.training.saving import saveable_object @@ -58,19 +58,51 @@ def _get_shardable_tensors(self, root): dtype=tensor_save_spec.dtype, device=tensor_save_spec.device) save_spec_tensor = tensor_save_spec.tensor + device = (device_lib.DeviceSpec.from_string(tensor_save_spec.device) + if isinstance(tensor_save_spec.device, str) + else tensor_save_spec.device) shardable_tensors.append( sharding_util.ShardableTensor( _tensor_save_spec=tensor_save_spec, tensor=save_spec_tensor, dtype=tensor_save_spec.dtype, - device=tensor_save_spec.device, + device=device, name=tensor_save_spec.name, shape=save_spec_tensor.shape, - slice_spec=slice_spec, + slice_spec=slice_spec.strip(), checkpoint_key=checkpoint_key, trackable=obj)) return shardable_tensors + def test_hash_ShardingCallback(self): + class BlankCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSlice]: + pass + + self.assertEqual(hash(BlankCallback()), hash(BlankCallback())) + + class ValueCallback(sharding_util.ShardingCallback): + def __init__(self, val): + self.val = val + + @property + def description(self): + return "value callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSlice]: + pass + + self.assertEqual(hash(ValueCallback(1)), hash(ValueCallback(1))) + self.assertNotEqual(hash(ValueCallback(1)), hash(ValueCallback(2))) + def test_validate_shards_correct(self): root = module.Module() with ops.device("cpu:0"): @@ -119,23 +151,21 @@ def test_validate_shards_duplicate_tensor(self): root.v1 = v1 class DuplicateTensorCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(shardable_tensors): - tensor = shardable_tensors[0].tensor - checkpoint_key = shardable_tensors[0].checkpoint_key - slice_spec = shardable_tensors[0].slice_spec - shards = [ - {checkpoint_key: {slice_spec: tensor}}, - {checkpoint_key: {slice_spec: tensor}} - ] - return shards - super().__init__(sharding_callback_impl, "duplicate tensor callback") + @property + def description(self): + return "duplicate tensor callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + tensor = shardable_tensors[0].tensor + checkpoint_key = shardable_tensors[0].checkpoint_key + slice_spec = shardable_tensors[0].slice_spec + shards = [ + {checkpoint_key: {slice_spec: tensor}}, + {checkpoint_key: {slice_spec: tensor}} + ] + return shards sharding_callback = DuplicateTensorCallback() shardable_tensors = self._get_shardable_tensors(root) @@ -154,19 +184,17 @@ def test_validate_shards_added_tensor(self): root.v0 = v0 class AddedTensorCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(_): - checkpoint_key = "ADDED_TENSOR_ABC123" - slice_spec = variables.Variable.SaveSliceInfo() - tensor = tensor_lib.Tensor() - return [{checkpoint_key: {slice_spec: tensor}}] - super().__init__(sharding_callback_impl, "added tensor callback") + @property + def description(self): + return "added tensor callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + checkpoint_key = "ADDED_TENSOR_ABC123" + slice_spec = "" + tensor = tensor_lib.Tensor() + return [{checkpoint_key: {slice_spec: tensor}}] sharding_callback = AddedTensorCallback() shardable_tensors = self._get_shardable_tensors(root) @@ -184,24 +212,22 @@ def test_validate_shards_shape_change(self): root.v0 = v0 class ShapeChangeCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(shardable_tensors): - shards = [] - for shardable_tensor in shardable_tensors: - tensor = shardable_tensor.tensor - checkpoint_key = shardable_tensor.checkpoint_key - slice_spec = shardable_tensor.slice_spec - if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": - tensor = array_ops.transpose(tensor) - shards.append({checkpoint_key: {slice_spec: tensor}}) - return shards - super().__init__(sharding_callback_impl, "shape change callback") + @property + def description(self): + return "shape change callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + shards = [] + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": + tensor = array_ops.transpose(tensor) + shards.append({checkpoint_key: {slice_spec: tensor}}) + return shards sharding_callback = ShapeChangeCallback() shardable_tensors = self._get_shardable_tensors(root) @@ -219,24 +245,22 @@ def test_validate_shards_dtype_change(self): root.v0 = v0 class DtypeChangeCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(shardable_tensors): - shards = [] - for shardable_tensor in shardable_tensors: - tensor = shardable_tensor.tensor - checkpoint_key = shardable_tensor.checkpoint_key - slice_spec = shardable_tensor.slice_spec - if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": - tensor = math_ops.cast(tensor, dtype=dtypes.int32) - shards.append({checkpoint_key: {slice_spec: tensor}}) - return shards - super().__init__(sharding_callback_impl, "dtype change callback") + @property + def description(self): + return "dtype change callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + shards = [] + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": + tensor = math_ops.cast(tensor, dtype=dtypes.int32) + shards.append({checkpoint_key: {slice_spec: tensor}}) + return shards sharding_callback = DtypeChangeCallback() shardable_tensors = self._get_shardable_tensors(root) @@ -262,22 +286,20 @@ def test_validate_shards_different_tasks(self): root.v1 = v1 class DifferentTasksCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(shardable_tensors): - shard = {} - for shardable_tensor in shardable_tensors: - tensor = shardable_tensor.tensor - checkpoint_key = shardable_tensor.checkpoint_key - slice_spec = shardable_tensor.slice_spec - shard.setdefault(checkpoint_key, {})[slice_spec] = tensor - return [shard] - super().__init__(sharding_callback_impl, "different tasks callback") + @property + def description(self): + return "different tasks callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + shard = {} + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + shard.setdefault(checkpoint_key, {})[slice_spec] = tensor + return [shard] sharding_callback = DifferentTasksCallback() shardable_tensors = self._get_shardable_tensors(root) @@ -295,16 +317,14 @@ def test_validate_shards_tensor_removal(self): root.v0 = v0 class TensorRemovalCallback(sharding_util.ShardingCallback): - def __init__(self): - def sharding_callback_impl(_): - return [] - super().__init__(sharding_callback_impl, "tensor removal callback") + @property + def description(self): + return "tensor removal callback" def __call__( - self, - shardable_tensors: Sequence[sharding_util.ShardableTensor] + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSlice]: - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter + return [] sharding_callback = TensorRemovalCallback() shardable_tensors = self._get_shardable_tensors(root) From cff48a3f8b0eb9aa6d5cc1bf5411fe9801d1978b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2023 17:54:37 -0800 Subject: [PATCH 267/381] Add comments to the CallORToolsSolver to specify what fields of the request object the problem formulation variables correspond to. PiperOrigin-RevId: 586842444 --- .../auto_sharding/auto_sharding_solver.cc | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index e44f8ed7b707ec..38865d3454bdeb 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -155,7 +155,12 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { return minimum_memory_budget_required_estimate; } -// We formulate the auto sharding process as the following ILP problem: +// Taking an auto-sharding problem (`request`) as an input, calls the OR tools +// CP-SAT solver and results a solution to the input problem. +// +// We formulate the auto sharding process as the following ILP problem +// (correspondences to the fields of the request parameter are specified in +// parenthesis): // Variables: // s[i]: Sharding strategy one-hot vector. // dim(s[i]) == # sharding strategies of the i-th XLA op @@ -163,19 +168,22 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { // e[i, j]: Strategy one-hot vector of edge i -> j. // dim(e[i, j]) == dim(s[i]) * dim(s[j]) // Constants: -// N: Number of total XLA ops -// M: Memory budget -// E: Edge set {(i, j)} -// L[t]: Index of live instructions at time t -// c[i]: Computation cost vector of instruction i +// N: Number of total XLA ops (request.num_nodes) +// M: Memory budget (request.memory_budget) +// E: Edge set {(i, j)} (request.edges) +// L[t]: Index of live instructions at time t (request.live) +// c[i]: Computation cost vector of instruction i (request.computation_costs) // d[i]: Communication cost vector of instruction i -// m[i]: Memory cost vector of instruction i +// (request.communication_costs) +// m[i]: Memory cost vector of instruction i (request.memory_costs) // dim(c[i]) == dim(d[i]) == dim(m[i]) == dim(s[i]) // r[i, j]: The resharding cost vector of edge i -> j +// (request.resharding_costs) // dim(e[i, j]) == dim(r[i, j]) -// A: Alias set {(i, j)} +// A: Alias set {(i, j)} (request.aliases) // v[i, j]: v[i, j](p, q) == 1 if strategy p is different than q, otherwise // v[i, j](p, q) == 0 +// (request.value_costs) // dim(e[i, j]) == dim(v[i, j]) // Problem: // Minimize sum_{0 <= i < N} s[i]^T * (c[i] + d[i]) @@ -200,6 +208,16 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { // Serialize parameters of the ILP problem as numpy arrays and call the python // solver. +// Beyond what is described, note the following: +// 1. We also enforce that certain HLO ops have the same sharding as some other +// HLO ops (think elementwise ops, for example). This information stored in +// request.s_follow, where if s_follow[i] >= 0, then instruction i is forced +// the share same sharding as s_follow[i]. +// 2. If request.overbudget_coeff is present, we turn the hard memory budget +// constraint into a soft constraint instead. +// 3. If request.makespan_coeff is present, the objective additionally includes +// a makespan term. This is experimental and turned off by default. +// 4. request.max_departures is used only for debugging and can be ignored AutoShardingSolverResult CallORToolsSolver( const AutoShardingSolverRequest& request) { size_t num_edges = request.edges_size(); From c810fe7b8ad63488e7e11b63c37030909f64f712 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 30 Nov 2023 19:42:47 -0800 Subject: [PATCH 268/381] [xla:gpu] Add support for control flow command buffer commands to Thunk PiperOrigin-RevId: 586861177 --- .../gpu/runtime3/command_buffer_cmd.cc | 174 +++++++++++- .../service/gpu/runtime3/command_buffer_cmd.h | 96 ++++++- .../gpu/runtime3/command_buffer_thunk_test.cc | 251 ++++++++++++++++++ .../xla/xla/stream_executor/command_buffer.h | 2 +- 4 files changed, 509 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 854f0bd790e404..c9c6b505971486 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -42,6 +43,27 @@ limitations under the License. namespace xla::gpu { +// Creates condition command buffer builder from a cmd sequence. +static se::CommandBuffer::Builder ConditionBuilder( + CommandBufferCmdSequence* commands, + const CommandBufferCmd::RecordParams* params) { + return [=](se::CommandBuffer* command_buffer) { + return commands->Record(*params, command_buffer, + CommandBufferCmdSequence::RecordMode::kConditional); + }; +} + +// Creates condition command buffer builders from a span of cmd sequences. +static std::vector ConditionBuilders( + absl::Span commands, + const CommandBufferCmd::RecordParams* params) { + std::vector builders; + for (CommandBufferCmdSequence& cmd : commands) { + builders.push_back(ConditionBuilder(&cmd, params)); + } + return builders; +} + //===----------------------------------------------------------------------===// // CommandBufferCmdSequence //===----------------------------------------------------------------------===// @@ -186,12 +208,13 @@ CommandBufferCmd::Slices MemcpyDeviceToDeviceCmd::slices() { // IfCmd //===----------------------------------------------------------------------===// -IfCmd::IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds) - : pred_(pred), then_cmds_(std::move(then_cmds)) {} +IfCmd::IfCmd(BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands) + : pred_(pred), then_commands_(std::move(then_commands)) {} Status IfCmd::Initialize(se::StreamExecutor* executor, ExecutableSource source) { - return then_cmds_.Initialize(executor, source); + return then_commands_.Initialize(executor, source); } Status IfCmd::Record(const RecordParams& params, @@ -199,17 +222,146 @@ Status IfCmd::Record(const RecordParams& params, se::DeviceMemoryBase pred = params.buffer_allocations->GetDeviceAddress(pred_); - return command_buffer->If( - params.executor, se::DeviceMemory(pred), - [&](se::CommandBuffer* then_cmd_buffer) { - return then_cmds_.Record( - params, then_cmd_buffer, - CommandBufferCmdSequence::RecordMode::kConditional); - }); + return command_buffer->If(params.executor, se::DeviceMemory(pred), + ConditionBuilder(&then_commands_, ¶ms)); } CommandBufferCmd::Slices IfCmd::slices() { - auto& slices = then_cmds_.slices(); + absl::flat_hash_set slices = {pred_}; + slices.insert(then_commands_.slices().begin(), then_commands_.slices().end()); + return {slices.begin(), slices.end()}; +} + +//===----------------------------------------------------------------------===// +// IfElseCmd +//===----------------------------------------------------------------------===// + +IfElseCmd::IfElseCmd(BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands, + CommandBufferCmdSequence else_commands) + : pred_(pred), + then_commands_(std::move(then_commands)), + else_commands_(std::move(else_commands)) {} + +Status IfElseCmd::Initialize(se::StreamExecutor* executor, + ExecutableSource source) { + TF_RETURN_IF_ERROR(then_commands_.Initialize(executor, source)); + TF_RETURN_IF_ERROR(else_commands_.Initialize(executor, source)); + return OkStatus(); +} + +Status IfElseCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + params.buffer_allocations->GetDeviceAddress(pred_); + + return command_buffer->IfElse(params.executor, se::DeviceMemory(pred), + ConditionBuilder(&then_commands_, ¶ms), + ConditionBuilder(&else_commands_, ¶ms)); +} + +CommandBufferCmd::Slices IfElseCmd::slices() { + absl::flat_hash_set slices = {pred_}; + slices.insert(then_commands_.slices().begin(), then_commands_.slices().end()); + slices.insert(else_commands_.slices().begin(), else_commands_.slices().end()); + return {slices.begin(), slices.end()}; +} + +//===----------------------------------------------------------------------===// +// CaseCmd +//===----------------------------------------------------------------------===// + +CaseCmd::CaseCmd(BufferAllocation::Slice index, + std::vector branches_commands) + : index_(index), branches_commands_(std::move(branches_commands)) {} + +Status CaseCmd::Initialize(se::StreamExecutor* executor, + ExecutableSource source) { + for (auto& branch : branches_commands_) { + TF_RETURN_IF_ERROR(branch.Initialize(executor, source)); + } + return OkStatus(); +} + +Status CaseCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase index = + params.buffer_allocations->GetDeviceAddress(index_); + + return command_buffer->Case( + params.executor, se::DeviceMemory(index), + ConditionBuilders(absl::MakeSpan(branches_commands_), ¶ms)); +} + +CommandBufferCmd::Slices CaseCmd::slices() { + absl::flat_hash_set slices = {index_}; + for (auto& branch : branches_commands_) { + slices.insert(branch.slices().begin(), branch.slices().end()); + } + return {slices.begin(), slices.end()}; +} + +//===----------------------------------------------------------------------===// +// ForCmd +//===----------------------------------------------------------------------===// + +ForCmd::ForCmd(int32_t num_iterations, BufferAllocation::Slice loop_counter, + CommandBufferCmdSequence body_commands) + : num_iterations_(num_iterations), + loop_counter_(loop_counter), + body_commands_(std::move(body_commands)) {} + +Status ForCmd::Initialize(se::StreamExecutor* executor, + ExecutableSource source) { + return body_commands_.Initialize(executor, source); +} + +Status ForCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase loop_counter = + params.buffer_allocations->GetDeviceAddress(loop_counter_); + + return command_buffer->For(params.executor, num_iterations_, + se::DeviceMemory(loop_counter), + ConditionBuilder(&body_commands_, ¶ms)); +} + +CommandBufferCmd::Slices ForCmd::slices() { + absl::flat_hash_set slices = {loop_counter_}; + slices.insert(body_commands_.slices().begin(), body_commands_.slices().end()); + return {slices.begin(), slices.end()}; +} + +//===----------------------------------------------------------------------===// +// WhileCmd +//===----------------------------------------------------------------------===// + +WhileCmd::WhileCmd(BufferAllocation::Slice pred, + CommandBufferCmdSequence cond_commands, + CommandBufferCmdSequence body_commands) + : pred_(pred), + cond_commands_(std::move(cond_commands)), + body_commands_(std::move(body_commands)) {} + +Status WhileCmd::Initialize(se::StreamExecutor* executor, + ExecutableSource source) { + return body_commands_.Initialize(executor, source); +} + +Status WhileCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + params.buffer_allocations->GetDeviceAddress(pred_); + + return command_buffer->While(params.executor, se::DeviceMemory(pred), + ConditionBuilder(&cond_commands_, ¶ms), + ConditionBuilder(&body_commands_, ¶ms)); +} + +CommandBufferCmd::Slices WhileCmd::slices() { + absl::flat_hash_set slices = {pred_}; + slices.insert(cond_commands_.slices().begin(), cond_commands_.slices().end()); + slices.insert(body_commands_.slices().begin(), body_commands_.slices().end()); return {slices.begin(), slices.end()}; } diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 699c9a43283fb4..31be131338ecbe 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -197,7 +197,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { class IfCmd : public CommandBufferCmd { public: - IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds); + IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_commands); Status Initialize(se::StreamExecutor* executor, ExecutableSource source) override; @@ -209,7 +209,99 @@ class IfCmd : public CommandBufferCmd { private: BufferAllocation::Slice pred_; - CommandBufferCmdSequence then_cmds_; + CommandBufferCmdSequence then_commands_; +}; + +//===----------------------------------------------------------------------===// +// IfElseCmd +//===----------------------------------------------------------------------===// + +class IfElseCmd : public CommandBufferCmd { + public: + IfElseCmd(BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands, + CommandBufferCmdSequence else_commands); + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource source) override; + + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence then_commands_; + CommandBufferCmdSequence else_commands_; +}; + +//===----------------------------------------------------------------------===// +// CaseCmd +//===----------------------------------------------------------------------===// + +class CaseCmd : public CommandBufferCmd { + public: + CaseCmd(BufferAllocation::Slice index, + std::vector branches_commands); + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource source) override; + + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation::Slice index_; + std::vector branches_commands_; +}; + +//===----------------------------------------------------------------------===// +// ForCmd +//===----------------------------------------------------------------------===// + +class ForCmd : public CommandBufferCmd { + public: + ForCmd(int32_t num_iterations, BufferAllocation::Slice loop_counter, + CommandBufferCmdSequence body_commands); + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource source) override; + + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + int32_t num_iterations_; + BufferAllocation::Slice loop_counter_; + CommandBufferCmdSequence body_commands_; +}; + +//===----------------------------------------------------------------------===// +// WhileCmd +//===----------------------------------------------------------------------===// + +class WhileCmd : public CommandBufferCmd { + public: + WhileCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence cond_commands, + CommandBufferCmdSequence body_commands); + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource source) override; + + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence cond_commands_; + CommandBufferCmdSequence body_commands_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index f976cdecdbf36b..94ec906b84441d 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -84,6 +84,7 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Execute command buffer thunk and verify that it copied the memory. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::vector dst(4, 0); @@ -96,6 +97,7 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -151,6 +153,7 @@ TEST(CommandBufferThunkTest, AllocateCmd) { // Execute command buffer thunk and verify that it copied the memory. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::vector dst(4, 0); @@ -204,6 +207,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::vector dst(4, 0); @@ -220,6 +224,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -232,6 +237,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -302,6 +308,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Execute command buffer thunk and verify that it executed a GEMM. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `out` data back to host. std::vector dst(6, 0); @@ -319,6 +326,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `updated_out` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -331,6 +339,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `updated_out` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -394,6 +403,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::vector dst(4, 0); @@ -417,6 +427,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -433,6 +444,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -501,6 +513,7 @@ TEST(CommandBufferThunkTest, IfCmd) { // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `b` data back to host. std::vector dst(4, 0); @@ -517,6 +530,7 @@ TEST(CommandBufferThunkTest, IfCmd) { // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); // Copy `c` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -525,4 +539,241 @@ TEST(CommandBufferThunkTest, IfCmd) { ASSERT_EQ(dst, std::vector(4, 42 + 42)); } +TEST(CommandBufferThunkTest, IfElseCmd) { + se::StreamExecutor* executor = CudaExecutor(); + if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: pred=true, a=42, b=0 + se::DeviceMemory pred = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + stream.ThenMemcpy(&pred, &kTrue, 1); + stream.ThenMemset32(&a, 42, byte_length); + stream.ThenMemZero(&b, byte_length); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_p(&alloc_p, 0, 1); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for `then` & `else` branches. + CommandBufferCmdSequence then_commands; + CommandBufferCmdSequence else_commands; + + { // Then: b = a + a + auto args = {slice_a, slice_a, slice_b}; + then_commands.Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + { // Else: b = b + b + auto args = {slice_b, slice_b, slice_b}; + else_commands.Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(slice_p, std::move(then_commands), + std::move(else_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + CommandBufferCmd::ExecutableSource source = { + /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize(executor, source)); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Change branch to `else` and check that it updated the `b` buffer. + constexpr bool kFalse = false; + stream.ThenMemcpy(&pred, &kFalse, 1); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), b, byte_length); + ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); +} + +TEST(CommandBufferThunkTest, CaseCmd) { + se::StreamExecutor* executor = CudaExecutor(); + if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: index=0, a=42, b=0 + se::DeviceMemory index = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&index, 0, sizeof(int32_t)); + stream.ThenMemset32(&a, 42, byte_length); + stream.ThenMemZero(&b, byte_length); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_i(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_i(&alloc_i, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for branches. + std::vector branches(2); + + { // Case 0: b = a + a + auto args = {slice_a, slice_a, slice_b}; + branches[0].Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + { // Case 1: b = b + b + auto args = {slice_b, slice_b, slice_b}; + branches[1].Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(slice_i, std::move(branches)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({index, a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + CommandBufferCmd::ExecutableSource source = { + /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize(executor, source)); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Change `index` to `1` and check that it updated the `b` buffer. + stream.ThenMemset32(&index, 1, sizeof(int32_t)); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + stream.ThenMemcpy(dst.data(), b, byte_length); + ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); +} + +TEST(CommandBufferThunkTest, ForCmd) { + se::StreamExecutor* executor = CudaExecutor(); + if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: loop_cnt=0, a=1, b=0 + se::DeviceMemory loop_cnt = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&loop_cnt, 0, sizeof(int32_t)); + stream.ThenMemset32(&a, 1, byte_length); + stream.ThenMemZero(&b, byte_length); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_cnt(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_cnt(&alloc_cnt, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_b, slice_b}; // b = a + b + + // Prepare commands sequence for loop `body`. + CommandBufferCmdSequence body_commands; + body_commands.Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(/*num_iterations=*/10, slice_cnt, + std::move(body_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({loop_cnt, a, b}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + CommandBufferCmd::ExecutableSource source = { + /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize(executor, source)); + + // Execute command buffer thunk and verify that it added the value 10 times. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + + ASSERT_EQ(dst, std::vector(4, 10)); +} + +TEST(CommandBufferThunkTest, WhileCmd) { + // TODO(ezhulenev): Find a way to test WhileCmd: add a test only TraceCmd that + // could allow us trace custom kernels to update while loop iterations. Or + // maybe add a CustomLaunchCmd and wrap loop update into custom kernel. +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 3ac378acbdf171..2867664854b513 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -163,7 +163,7 @@ class CommandBuffer { // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value (which is continously + // `cond_builder`. Will continue while `pred` value (which is continuously // updated by `cond_builder`) is `true`. // // In pseudocode: From 3a029b19c9c156cd68cab671b5ce95bde839f15e Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Thu, 30 Nov 2023 23:28:43 -0800 Subject: [PATCH 269/381] [HloValueSemanticsAnalysis] Use node_hash_map instead since the current implementation requires pointer stability. PiperOrigin-RevId: 586907653 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/hlo_value_semantics_analysis.cc | 2 +- third_party/xla/xla/service/hlo_value_semantics_analysis.h | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 8503cec2199642..2b4fdcc3699224 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4405,6 +4405,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index 5720bfd49b6fba..e74d1e8ea7238e 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -314,7 +314,7 @@ Status EinsumDepthAnalysis::HandleGetTupleElement( if (operand_depth.IsLeaf(shape_index)) { ShapeIndex output_index = shape_index; output_index.pop_front(); - *depth_ptr = std::max(*depth_ptr, depth_tree.element(output_index)); + *depth_ptr = MergeDepth(*depth_ptr, depth_tree.element(output_index)); } }); return OkStatus(); diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index 8cea61acf0e7ec..895f0583390601 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -51,7 +52,7 @@ class HloPreOrderDFS { }; using EinsumDepthMap = - absl::flat_hash_map>; + absl::node_hash_map>; // The einsum depth is the length of the einsum dependency chain. And we // distinguish instructions that are used by root and that are not used by @@ -145,7 +146,7 @@ class HloValueSemantics { }; using HloValueSemanticsMap = - absl::flat_hash_map>; class HloValueSemanticsPropagation; From 86838744e3f072003b66c25654e4b4c5c7f7b269 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Thu, 30 Nov 2023 23:46:05 -0800 Subject: [PATCH 270/381] PR #7106: Pre-prepare profiler annotations for kernels/thunks and modules Imported from GitHub PR https://github.com/openxla/xla/pull/7106 There are two parts to this change: - `tsl::profiler::ScopedAnnotation` and friends can now accept annotations that are custom structs as well as plain `std::string` annotations. This allows optimisations when used to emit NVTX ranges, such as using [registered strings](https://nvidia.github.io/NVTX/doxygen/group___s_t_r_i_n_g___r_e_g_i_s_t_r_a_t_i_o_n.html). - Annotations for HLO module executions and individual kernel/thunk launches within are now prepared in advance, rather than on the fly. This makes it reasonable to generate more complex annotations, and this is done using the `op_name` metadata: ![Screenshot 2023-11-14 at 20 54 21](https://github.com/openxla/xla/assets/6459623/06821dbd-935d-4158-8cee-c1194e22a183) Prior to this change, the annotation title was simply `Thunk:#hlo_op=wrapped_slice.91#`. Also add annotations on some compilation stages, improving the profiling experience by making it obvious which time is spent autotuning kernels: ![Screenshot 2023-11-14 at 20 55 48](https://github.com/openxla/xla/assets/6459623/c1d3a644-894d-4bb2-aa64-c8aac1bec369) Copybara import of the project: -- d984c41a461bed177c6f6f3c8e1535b6c5545d6f by Olli Lupton : Pre-prepare NVTX annotations for kernels/modules. Generate improved names based on op_name metadata. Placeholder for future expansion of these pre-prepared annotation structs. -- 5d191bcf72b0659d2af15275891230265bcd7582 by Olli Lupton : Fix for non-NVTX builds. -- fc3a72aab089e79d3ef9b59dfaaf0b87ed542fb2 by Olli Lupton : Address explicit code review -- 427c0dea40facb91efcc231b840e39fd96e2f6bc by Olli Lupton : Apply rules from code review elsewhere in diff -- f92c4769257d941b2dd6fe70e6ab162cc74a397c by Olli Lupton : Fix test, do not throw exceptions Merging this change closes #7106 PiperOrigin-RevId: 586911081 --- .../third_party/tsl/tsl/profiler/lib/BUILD | 8 +- .../tsl/tsl/profiler/lib/nvtx_utils.h | 50 ++-- .../tsl/tsl/profiler/lib/scoped_annotation.h | 27 +-- .../profiler/lib/scoped_annotation_stack.h | 19 +- third_party/xla/xla/service/BUILD | 4 + third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/gpu_executable.cc | 50 ++-- .../xla/xla/service/gpu/gpu_executable.h | 3 + third_party/xla/xla/service/gpu/runtime/BUILD | 15 +- .../xla/xla/service/gpu/runtime/annotation.cc | 229 ++++++++++++++++++ .../xla/xla/service/gpu/runtime/annotation.h | 59 +++++ .../xla/xla/service/gpu/runtime/tracing.cc | 21 +- .../xla/xla/service/gpu/runtime/tracing.h | 4 + .../xla/xla/service/hlo_pass_pipeline.cc | 19 ++ third_party/xla/xla/service/service.cc | 8 + 15 files changed, 451 insertions(+), 66 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/runtime/annotation.cc create mode 100644 third_party/xla/xla/service/gpu/runtime/annotation.h diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index 70fe322adcda52..c23d63f5f4eddd 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -3,6 +3,10 @@ load("//tsl/platform:build_config_root.bzl", "if_static") load("//tsl:tsl.default.bzl", "filegroup") load("//tsl:tsl.bzl", "if_not_android", "set_external_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load( + "//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", @@ -252,8 +256,8 @@ cc_library( "//tsl/platform:macros", "//tsl/platform:types", "@com_google_absl//absl/strings", - ] + if_not_android([ - "//tsl/profiler/backends/cpu:annotation_stack", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", # NVTX headers ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index 416d8293784551..e3eaaa08af79e8 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -24,18 +24,17 @@ limitations under the License. #if GOOGLE_CUDA #include "nvtx3/nvToolsExt.h" +#else +// Some typedef to help build without NVTX. +typedef void* nvtxEventAttributes_t; +typedef void* nvtxDomainHandle_t; +typedef void* nvtxStringHandle_t; #endif namespace tsl { namespace profiler { namespace nvtx { -// Some typedef to help build without NVTX. -#if !GOOGLE_CUDA -typedef void* nvtxEventAttributes_t; -typedef void* nvtxDomainHandle_t; -#endif - // A helper function that return the domains to use if NVTX profiling // is enabled. inline std::optional GetNVTXDomain() { @@ -65,15 +64,38 @@ inline bool RangesEnabled() { #endif } -// Note: The memory backing msg must persist until the result of this function -// has been consumed by an NVTX API. -inline void MakeAttributes(const char* msg, nvtxEventAttributes_t* result) { - *result = {0}; +// Two types of NVTX range annotation are supported, the older/simpler option +// is to use std::string and have the NVTX implementation copy a C-style +// string every time. The other option is to pass a struct implementing two +// methods: +// +// std::string_view Title() const; +// nvtxStringHandle_t NvtxRegisteredTitle() const; +// +// in which case NvtxRegisteredTitle() will be used when starting NVTX ranges, +// avoiding this string copy. +// The Title() method is needed because AnnotationStack::PushAnnotation(...) is +// the backend for some annotations when NVTX is not enabled, and it does not +// recognise registered strings. has_annotation_api_v +// distinguishes between the two types of annotation. +template +inline constexpr bool has_annotation_api_v = + !std::is_same_v; + +template +void RangePush(nvtxDomainHandle_t domain, const AnnotationType& annotation) { #if GOOGLE_CUDA - result->version = NVTX_VERSION; - result->size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - result->messageType = NVTX_MESSAGE_TYPE_ASCII; - result->message.ascii = msg; + nvtxEventAttributes_t attrs{}; + attrs.version = NVTX_VERSION; + attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + if constexpr (has_annotation_api_v>) { + attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; + attrs.message.registered = annotation.NvtxRegisteredTitle(); + } else { + attrs.messageType = NVTX_MESSAGE_TYPE_ASCII; + attrs.message.ascii = annotation.c_str(); + } + ::nvtxDomainRangePushEx(domain, &attrs); #endif } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index 643d7045428605..f047fafc4ebe3a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -53,10 +53,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), std::string{name}); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -74,9 +71,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -91,9 +86,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -109,15 +102,17 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - auto name = name_generator(); - old_length_ = AnnotationStack::PushAnnotation(name); + auto annotation = name_generator(); + if constexpr (tsl::profiler::nvtx::has_annotation_api_v< + std::decay_t>) { + old_length_ = AnnotationStack::PushAnnotation(annotation.Title()); + } else { + old_length_ = AnnotationStack::PushAnnotation(std::move(annotation)); + } } #endif } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h index f4e538f127c9bb..db46f7c99135e4 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h @@ -55,10 +55,7 @@ class ScopedAnnotationStack { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -83,15 +80,17 @@ class ScopedAnnotationStack { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); } else // NOLINT #endif if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - return AnnotationStack::PushAnnotation(name_generator()); + auto annotation = name_generator(); + if constexpr (tsl::profiler::nvtx::has_annotation_api_v< + std::decay_t>) { + return AnnotationStack::PushAnnotation(annotation.Title()); + } else { + return AnnotationStack::PushAnnotation(std::move(annotation)); + } } #endif return kInvalidActivity; diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2b4fdcc3699224..b1df56793cdf59 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1116,6 +1116,7 @@ cc_library( name = "service", srcs = ["service.cc"], hdrs = ["service.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":allocation_tracker", @@ -1163,6 +1164,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], alwayslink = 1, ) @@ -5100,6 +5102,7 @@ cc_library( hdrs = [ "hlo_pass_pipeline.h", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":compilation_stats", @@ -5119,6 +5122,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], ) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2e63d3680bf7d4..6b86f8ab1ba88a 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1018,6 +1018,7 @@ cc_library( "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/runtime:executable", "//xla/service/gpu/runtime:support", + "//xla/service/gpu/runtime:tracing", "//xla/service/gpu/runtime3:custom_call_thunk", "//xla/service/gpu/runtime3:fft_thunk", "//xla/stream_executor", diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 05283bbdc2b8a5..c9a757922e3356 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" #include "xla/service/gpu/runtime/executable.h" +#include "xla/service/gpu/runtime/tracing.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_parser.h" @@ -143,6 +144,9 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 16]) = tsl::EnvTime::NowNanos(); *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif + if (has_module()) { + annotation_info_.emplace(module()); + } if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->RegisterModule(shared_module(), buffer_assignment_->ToProto()); @@ -227,15 +231,6 @@ Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, tsl::profiler::TraceMeLevel::kInfo); - ScopedAnnotationAlways annotation([&] { - std::string module_id_str; - if (module_id >= 0) { - module_id_str = absl::StrFormat(",program_id=%d", module_id); - } - return absl::StrFormat("XlaModule:#hlo_module=%s%s#", module_name, - module_id_str); - }); - for (const std::unique_ptr& thunk : thunk_sequence) { // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the @@ -514,15 +509,6 @@ static Status ExecuteXlaRuntime(const std::string& module_name, [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, tsl::profiler::TraceMeLevel::kInfo); - ScopedAnnotationAlways annotation([&] { - std::string module_id_str; - if (module_id >= 0) { - module_id_str = absl::StrFormat(",program_id=%d", module_id); - } - return absl::StrFormat("XlaModule:#hlo_module=%s%s#", module_name, - module_id_str); - }); - auto executed = gpu_runtime_executable.Execute( run_options, asm_text, binary, buffer_allocations, gpu_lock, temp_buffer); if (!executed.ok()) return executed; @@ -755,6 +741,24 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( return std::move(result); } +namespace { +struct ModuleAnnotationManager { + ModuleAnnotationManager(const std::optional& annotations) { + if (annotations.has_value()) { + m_old_annotations = SetCurrentModuleAnnotations(&(*annotations)); + } + } + ~ModuleAnnotationManager() { + if (m_old_annotations.has_value()) { + SetCurrentModuleAnnotations(*m_old_annotations); + } + } + + private: + std::optional m_old_annotations; +}; +} // namespace + Status GpuExecutable::ExecuteThunksOrXlaRuntime( const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, bool block_host_until_done, @@ -768,6 +772,15 @@ Status GpuExecutable::ExecuteThunksOrXlaRuntime( unique_id = module().unique_id(); } + ScopedAnnotationAlways annotation([&]() -> ModuleAnnotation { + if (annotation_info_) { + return annotation_info_->top_level; + } else { + return {module_name_, unique_id}; + } + }); + ModuleAnnotationManager set_current_kernel_annotations{annotation_info_}; + if (thunks_) { se::StreamExecutor* executor = run_options->stream()->parent(); Thunk::ExecutableSource executable_source = {text_, binary_}; @@ -952,6 +965,7 @@ GpuExecutable::GpuExecutable( output_info_(std::move(output_info)), enable_debug_info_manager_(true) { if (has_module()) { + annotation_info_.emplace(module()); XlaDebugInfoManager::Get()->RegisterModule(shared_module(), BufferAssignmentProto()); } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index 060583867ed965..67c9eb45a158d1 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" +#include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/runtime/executable.h" #include "xla/service/gpu/thunk.h" #include "xla/service/hlo_execution_profile.h" @@ -307,6 +308,8 @@ class GpuExecutable : public Executable { // This object is also used for dumping debug info. std::unique_ptr buffer_assignment_; + std::optional annotation_info_; + bool enable_persistent_temp_buffers_ = false; absl::Mutex persistent_temp_buffers_mu_; diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 57548606ea9e31..ac1ebe27e21435 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -761,20 +761,27 @@ cc_library( cc_library( name = "tracing", - srcs = ["tracing.cc"], - hdrs = ["tracing.h"], + srcs = [ + "annotation.cc", + "tracing.cc", + ], + hdrs = [ + "annotation.h", + "tracing.h", + ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":support", - "//xla/runtime:custom_call", + "//xla/hlo/ir:hlo", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", "//xla/runtime:tracing", "//xla/runtime:type_id", - "@com_google_absl//absl/status", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/profiler/lib:nvtx_utils", "@local_tsl//tsl/profiler/lib:scoped_annotation_stack", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.cc b/third_party/xla/xla/service/gpu/runtime/annotation.cc new file mode 100644 index 00000000000000..8b3f56ec00b8e0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/annotation.cc @@ -0,0 +1,229 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/annotation.h" + +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" + +namespace xla::gpu { + +namespace { +nvtxStringHandle_t RegisterString(const char* str) { +#if GOOGLE_CUDA + auto domain = tsl::profiler::nvtx::GetNVTXDomain(); + if (!domain) { + // NVTX not enabled, so don't bother registering strings with it + return {}; + } + std::string buffer{}; + constexpr auto max_length = 65330; + if (auto const length = std::strlen(str); length >= max_length) { + // nvbugs 4340868 + std::string_view suffix{"\n[truncated]\n"}; + buffer.reserve(max_length); + buffer.assign(str, str + length - suffix.size()); + buffer.append(suffix); + str = buffer.c_str(); + } + return nvtxDomainRegisterStringA(*domain, str); +#else + return {}; +#endif +} + +template +Status VisitInstAndCalledButNotOperands(Visitor& visitor, + const HloInstruction& inst) { + // Visit the given instruction, and the things it calls, but not its operands. + TF_RETURN_IF_ERROR(visitor.DefaultAction(&inst)); + for (const HloComputation* called : inst.called_computations()) { + const HloInstruction* const root = called->root_instruction(); + TF_RETURN_IF_ERROR(root->Accept(&visitor, false /* call_finish_visit */, + true /* ignore_control_predecessors */, + true /* cross_computation */)); + } + return OkStatus(); +} + +// Split `a` and `b` by `delim` into two lists of possibly-empty tokens, then +// rejoin the first N of those lists that match by `delim`. Note: it is +// unspecified which argument the return value points into. +std::string_view LongestPrefix(std::string_view a, std::string_view b, + char delim = '/') { + if (a.size() > b.size()) a.swap(b); // allow assumption that b is longer + for (auto start_a = a.begin(), iter_a = start_a, start_b = b.begin(), + iter_b = start_b; + ; ++iter_a, ++iter_b) { + if (iter_a == a.end() && (iter_b == b.end() || *iter_b == delim)) { + // reached both ends without finding a mismatch, or reached the end of `a` + // and not `b` but it was the end of the chunk in `b` + return a; + } + if (*iter_a != *iter_b) { + // mismatch in this chunk + return {a.begin(), + static_cast(std::distance(a.begin(), start_a))}; + } + if (*iter_a == delim) { + // end of this chunk, start the next one + start_a = iter_a; + start_b = iter_b; + } + } +} + +// Find the longest prefix among instructions' op_name metadata +// Chunk this by delimiting slashes, i.e. given a/b/cat and a/b/cabbage, the +// longest prefix is a/b not a/b/ca +class OpNamePrefixVisitor : public ConstDfsHloVisitorWithDefault { + public: + Status DefaultAction(const HloInstruction* inst) final { + auto const& op_name = inst->metadata().op_name(); + if (!op_name.empty()) { + prefix = prefix ? LongestPrefix(*prefix, op_name) : op_name; + } + return OkStatus(); + } + std::string_view longest_op_name_prefix() const { + return prefix.value_or(std::string_view{}); + } + + private: + std::optional prefix{}; +}; + +std::string_view GetLongestOpNamePrefix(const HloModule& mod) { + // In the presence of (at least) debug callbacks, calling Accept on the root + // instruction of the module may not reach all instructions in the module. + OpNamePrefixVisitor visitor{}; + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (!visitor.DefaultAction(inst).ok()) { + return {}; + } + } + } + return visitor.longest_op_name_prefix(); +} + +std::string_view GetLongestOpNamePrefix(const HloInstruction& inst) { + OpNamePrefixVisitor visitor{}; + if (!VisitInstAndCalledButNotOperands(visitor, inst).ok()) { + return {}; + } + return visitor.longest_op_name_prefix(); +} + +std::string MakeTitle(const HloModule& mod, std::string_view longest_prefix) { + if (longest_prefix.empty()) { + return absl::StrFormat("XlaModule:#hlo_module=%s,program_id=%d#", + mod.name(), mod.unique_id()); + } + return absl::StrFormat("XlaModule:#prefix=%s,hlo_module=%s,program_id=%d#", + longest_prefix, mod.name(), mod.unique_id()); +} +} // namespace + +ModuleAnnotation::ModuleAnnotation(std::string module_name_, int module_id_) + : longest_prefix{}, + title_str{ + module_id_ >= 0 + ? absl::StrFormat("XlaModule:#hlo_module=%s,program_id=%d", + module_name_, module_id_) + : absl::StrFormat("XlaModule:#hlo_module=%s", module_name_)}, + title{RegisterString(title_str.c_str())} {} + +ModuleAnnotation::ModuleAnnotation(const HloModule& mod) + : longest_prefix{GetLongestOpNamePrefix(mod)}, + title_str{MakeTitle(mod, longest_prefix)}, + title{RegisterString(title_str.c_str())} {} + +std::string_view ModuleAnnotation::longest_op_name_prefix() const { + return longest_prefix; +} + +std::string_view ModuleAnnotation::Title() const { return title_str; } + +nvtxStringHandle_t ModuleAnnotation::NvtxRegisteredTitle() const { + return title; +} + +namespace { +std::string MakeKernelName(std::string_view prefix, + const HloInstruction& inst) { + // Sometimes an instruction doesn't have metadata, but the computations that + // it calls do have metadata. Consider all of those metadata op_name entries + // and attach the longest prefix to this launch. + std::string_view op_name = GetLongestOpNamePrefix(inst); + if (op_name.empty()) { + return absl::StrFormat("Thunk:#hlo_op=%s#", inst.name()); + } else if (op_name.substr(0, prefix.size()) != prefix) { + // the op_name we got for this instruction does not start with the prefix + // that we thought was common to all instructions in the module + return absl::StrFormat("Thunk:#name=%s,hlo_op=%s#", op_name, inst.name()); + } else { + // remove the prefix that's in the parent module annotation + auto short_name = op_name.substr(prefix.size()); + // remove the leading / if there is one (prefix might be an empty string) + if (!short_name.empty() && short_name.front() == '/') { + short_name = short_name.substr(1); + } + return absl::StrFormat("Thunk:#name=%s,hlo_op=%s#", short_name, + inst.name()); + } +} +} // namespace + +KernelAnnotation::KernelAnnotation(const ModuleAnnotation& module_annotation, + const HloInstruction& inst) + : title_str{MakeKernelName(module_annotation.longest_op_name_prefix(), + inst)}, + title{RegisterString(title_str.c_str())} {} + +std::string_view KernelAnnotation::Title() const { return title_str; } + +nvtxStringHandle_t KernelAnnotation::NvtxRegisteredTitle() const { + return title; +} + +ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} { + // loop through `mod` and populate `kernels` (string -> KernelAnnotation map) + // with the information we want to attach to individual kernels. + for (const HloComputation* computation : + mod.computations()) { // top-level blocks in the module + for (const HloInstruction* inst : + computation->instructions()) { // statements within block + // working assumption: only custom calls and fusions end up with NVTX + // ranges named after them. bad assumption [at least partially]: cuda + // graph launches are not handled correctly + switch (inst->opcode()) { + case HloOpcode::kCustomCall: + case HloOpcode::kFusion: { + // e.g. inst.name is "fusion.6", inst.opcode is "kFusion" and called + // is ["fused_computation.5"], in which case the content of + // "fused_computation.5" ends up under an NVTX range called + // "fusion.6". We want to construct a useful annotation for that NVTX + // range based on the content of `inst`, including `called` etc. + // FIXME: using try_emplace here was sensitive to + // https://github.com/abseil/abseil-cpp/issues/388. + kernels.insert({inst->name(), {top_level, *inst}}); + } break; + default: + break; + } + } + } +} +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.h b/third_party/xla/xla/service/gpu/runtime/annotation.h new file mode 100644 index 00000000000000..c454d06e74e109 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/annotation.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ +#define XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_module.h" +#include "tsl/profiler/lib/nvtx_utils.h" + +namespace xla::gpu { +// Prepared information for the top level NVTX/profiler range covering an +// HloModule +struct ModuleAnnotation { + ModuleAnnotation(std::string module_name, int module_id); + ModuleAnnotation(const HloModule& mod); + std::string_view longest_op_name_prefix() const; + nvtxStringHandle_t NvtxRegisteredTitle() const; + std::string_view Title() const; + + private: + std::string longest_prefix; + std::string title_str; + nvtxStringHandle_t title{}; +}; + +// Prepared information for a kernel/thunk/fusion/... within an HloModule +struct KernelAnnotation { + KernelAnnotation(const ModuleAnnotation& module_annotaion, + const HloInstruction& inst); + nvtxStringHandle_t NvtxRegisteredTitle() const; + std::string_view Title() const; + + private: + std::string title_str; + nvtxStringHandle_t title{}; +}; +// Parsed/prepared information for an HloModule that gets propagated to NVTX +// ranges/profilers/... at execution time. +struct ModuleAnnotations { + ModuleAnnotations(const HloModule&); + ModuleAnnotation top_level; + absl::flat_hash_map kernels{}; +}; +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/tracing.cc b/third_party/xla/xla/service/gpu/runtime/tracing.cc index 49767a85500ba9..e5d986db4b6250 100644 --- a/third_party/xla/xla/service/gpu/runtime/tracing.cc +++ b/third_party/xla/xla/service/gpu/runtime/tracing.cc @@ -44,11 +44,23 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry) { // Tracing custom calls implementation. //===----------------------------------------------------------------------===// +namespace { +thread_local const ModuleAnnotations* current_annotations{}; +} + static absl::StatusOr ActivityStart(runtime::HloTrace annotation) { SetCurrentTracingScope(annotation.hlo_op); + if (current_annotations) { + // We know which HloModule we belong to, and may have pre-prepared + // annotation structs ready to use + const auto iter = current_annotations->kernels.find(annotation.hlo_op); + if (iter != current_annotations->kernels.end()) { + // Have a pre-prepared annotation, use it + return ScopedAnnotationStack::ActivityStart([&] { return iter->second; }); + } + } return ScopedAnnotationStack::ActivityStart([&] { - // We use the same tracing annotation scheme as the ThunkSequence (see - // implementation of `GetThunkInfo` in `ir_emitter_unnested.cc`). + // We use the same tracing annotation scheme as the ThunkSequence. return absl::StrFormat("Thunk:#hlo_op=%s#", annotation.hlo_op); }); } @@ -73,5 +85,10 @@ void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry) { registry.Register("xla.trace.activity_end", End); } +const ModuleAnnotations* SetCurrentModuleAnnotations( + const ModuleAnnotations* annotations) { + return std::exchange(current_annotations, annotations); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/tracing.h b/third_party/xla/xla/service/gpu/runtime/tracing.h index 7f5efe48accac4..7446411d035010 100644 --- a/third_party/xla/xla/service/gpu/runtime/tracing.h +++ b/third_party/xla/xla/service/gpu/runtime/tracing.h @@ -20,6 +20,7 @@ limitations under the License. #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/type_id.h" +#include "xla/service/gpu/runtime/annotation.h" namespace xla { namespace gpu { @@ -28,6 +29,9 @@ void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry); void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry); +const ModuleAnnotations* SetCurrentModuleAnnotations( + const ModuleAnnotations* annotations); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/hlo_pass_pipeline.cc b/third_party/xla/xla/service/hlo_pass_pipeline.cc index b795e65faf4b42..52fdca45649009 100644 --- a/third_party/xla/xla/service/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/service/hlo_pass_pipeline.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla { @@ -146,6 +147,18 @@ Status HloPassPipeline::RunInvariantCheckers( return OkStatus(); } +namespace { +std::string UniqueId(const HloModule& mod) { + return std::to_string(mod.unique_id()); +} +std::string UniqueId(const HloModuleGroup& group) { + return absl::StrJoin(group.modules(), "-", + [](std::string* out, const HloModule* mod) { + out->append(std::to_string(mod->unique_id())); + }); +} +} // namespace + template StatusOr HloPassPipeline::RunPassesInternal( HloT* hlo, const DebugOptions& debug_options, @@ -157,6 +170,10 @@ StatusOr HloPassPipeline::RunPassesInternal( static constexpr absl::string_view kPipelineStart = "pipeline-start"; static constexpr absl::string_view kPipelineEnd = "pipeline-end"; std::string pipeline_name = std::string(name()); + tsl::profiler::ScopedAnnotation annotation{[&] { + return absl::StrFormat("XlaPassPipeline:#name=%s,module=%s,program_id=%s#", + pipeline_name, hlo->name(), UniqueId(*hlo)); + }}; TF_RETURN_IF_ERROR( RunInvariantCheckers(hlo, kPipelineStart, execution_threads)); @@ -176,6 +193,8 @@ StatusOr HloPassPipeline::RunPassesInternal( HloPassInterface* pass = passes[i]; XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name())); std::string pass_name = std::string(pass->name()); + tsl::profiler::ScopedAnnotation annotation{ + [&] { return "XlaPass:" + pass_name; }}; VLOG(1) << " HLO pass " << pass_name; VLOG(2) << " Module hash " << absl::HashOf(*hlo); if (!pass->IsPassPipeline()) { diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 48b98e197190d5..caff663a1e19e1 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla { namespace { @@ -750,6 +751,10 @@ StatusOr> Service::BuildExecutable( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name()); + tsl::profiler::ScopedAnnotation annotation{[&] { + return absl::StrCat("XlaCompile:#module=", module_proto.name(), "#"); + }}; + TF_ASSIGN_OR_RETURN( std::unique_ptr module, CreateModuleFromProto(module_proto, *module_config, run_backend_only)); @@ -770,6 +775,9 @@ StatusOr> Service::BuildExecutable( std::move(module), executor, options)); } + tsl::profiler::ScopedAnnotation backend_annotation{[&] { + return absl::StrCat("XlaCompileBackend:#module=", module_proto.name(), "#"); + }}; TF_ASSIGN_OR_RETURN( std::unique_ptr executable, backend->compiler()->RunBackend(std::move(module), executor, options)); From cb792e6ec0f05e3eeee751a8d049b94f83534e8e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 30 Nov 2023 23:52:38 -0800 Subject: [PATCH 271/381] [stream_executor] Put CUDA C++ kernels under if_cuda guard PiperOrigin-RevId: 586912338 --- .../xla/xla/stream_executor/cuda/BUILD | 12 +++-- .../cuda/cuda_conditional_kernels.cc | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index dc91bfeb6eef1a..d292bc57ace60b 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1,6 +1,6 @@ load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") load( "//xla:xla.bzl", "xla_cc_test", @@ -459,13 +459,19 @@ cuda_library( cuda_library( name = "cuda_conditional_kernels", - srcs = if_cuda_is_configured(["cuda_conditional_kernels.cu.cc"]), + srcs = if_cuda( + ["cuda_conditional_kernels.cu.cc"], + ["cuda_conditional_kernels.cc"], + ), local_defines = select({ ":graph_conditional_enabled": ["STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1"], "//conditions:default": [], }), visibility = ["//visibility:public"], - deps = ["@local_config_cuda//cuda:cuda_headers"], + deps = [ + "@com_google_absl//absl/log", + "@local_config_cuda//cuda:cuda_headers", + ], ) xla_test( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc new file mode 100644 index 00000000000000..4bf38d89c5ba2f --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cc @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/log/log.h" + +namespace stream_executor::gpu { + +void* GetSetIfConditionKernel() { + LOG(ERROR) << "XLA compiled without --config=cuda"; + return nullptr; +} + +void* GetSetIfElseConditionKernel() { + LOG(ERROR) << "XLA compiled without --config=cuda"; + return nullptr; +} + +void* GetSetCaseConditionKernel() { + LOG(ERROR) << "XLA compiled without --config=cuda"; + return nullptr; +} + +void* GetSetForConditionKernel() { + LOG(ERROR) << "XLA compiled without --config=cuda"; + return nullptr; +} + +void* GetSetWhileConditionKernel() { + LOG(ERROR) << "XLA compiled without --config=cuda"; + return nullptr; +} + +} // namespace stream_executor::gpu From f39fcefbd23200133529563e4a5115ab6df3eacd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 01:02:05 -0800 Subject: [PATCH 272/381] Update GraphDef version to 1697. PiperOrigin-RevId: 586926506 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 79142b54bc27ec..bc01dd0998da44 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1696 // Updated: 2023/11/30 +#define TF_GRAPH_DEF_VERSION 1697 // Updated: 2023/12/1 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From db58265d6d7f5df6ab1bc7851dad8cf78e0b2d65 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 01:02:05 -0800 Subject: [PATCH 273/381] compat: Update forward compatibility horizon to 2023-12-01 PiperOrigin-RevId: 586926512 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 138d020c136a53..1ee91ed50267a4 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 30) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 1) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 54af26eed209700cb429e202bdd95ee06e8f8310 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Fri, 1 Dec 2023 01:39:37 -0800 Subject: [PATCH 274/381] Refactor fused quantization patterns into quantize pass. Current Quantize{DotGeneral, Convolution}OpPattern matching takes place after post_quantize, which may be misleading. Instead, factor out patterns in a separate file and handle in quantize.cc. MLIR tests for fused patterns are left in quantize_composite_functions.mlir for module-level evaluation of functions. PiperOrigin-RevId: 586934993 --- .../mlir/quantization/stablehlo/BUILD | 17 +- .../stablehlo/passes/quantization_patterns.cc | 432 ++++++++++++++++++ ...tion_pattern.h => quantization_patterns.h} | 21 +- .../quantization/stablehlo/passes/quantize.cc | 5 +- .../passes/quantize_composite_functions.cc | 365 --------------- .../stablehlo/tests/quantize.mlir | 27 ++ .../tests/quantize_composite_functions.mlir | 92 ++-- 7 files changed, 528 insertions(+), 431 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc rename tensorflow/compiler/mlir/quantization/stablehlo/passes/{quantization_pattern.h => quantization_patterns.h} (97%) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 555a53b5011814..b8558350855c19 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -51,7 +51,7 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":quantization_options_proto_cc", - ":quantization_pattern", + ":quantization_patterns", ":stablehlo_passes_inc_gen", ":stablehlo_type_utils", ":uniform_quantized_types", @@ -106,27 +106,38 @@ cc_library( ) cc_library( - name = "quantization_pattern", + name = "quantization_patterns", + srcs = ["passes/quantization_patterns.cc"], hdrs = [ - "passes/quantization_pattern.h", + "passes/quantization_patterns.h", ], compatible_with = get_compatible_with_portable(), deps = [ ":bridge_passes", + ":uniform_quantized_types", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc new file mode 100644 index 00000000000000..603e1ff771a204 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -0,0 +1,432 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define DEBUG_TYPE "populate-quantization-patterns" + +namespace mlir::quant::stablehlo { + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; +using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::UniformQuantizeOp; + +constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kEntryFuncAttrName = "_entry_function"; + +// Returns true if `type` is a TensorType with quantized elements. +bool IsQuantizedTensorType(const Type type) { + return type.isa() && + type.cast().getElementType().isa(); +} + +// Returns true if an op has adjacent bias or activation that can be fused +// together into the quantization function. +// TODO: b/307620428 - Consider using matchAndRewrite to check and apply +// patterns at the same time. Also add check for fusible activation or +// fusible patterns with dynamic shape. +bool HasFusibleQuantizationPattern(Operation& op) { + if (isa(op.getNextNode())) { + return true; + } + return false; +} + +// Returns dynamically broadcasted user op of an input op. Returns null if +// the op is used multiple times or the user op is not dynamically broadcasted. +// Dynamic shapes usually has the following pattern. In the example below, +// the input operand would be stablehlo.gemm_style op, and return value would +// be stablehlo.add op. +// +// ``` +// %2 = stablehlo.gemm_style(%0, %1) +// %3 = shape.shape_of %2 +// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 +// %5 = stablehlo.add %2, %4 +// ``` +Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { + if (!op.hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Target op is used multiple times and will not be checked " + "for dynamic shape case.\n"); + return nullptr; + } + Operation& shapeof_op = *op.getNextNode(); + if (!isa(shapeof_op)) { + return nullptr; + } + Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); + if (!isa(broadcast_in_dims_op)) { + return nullptr; + } + return broadcast_in_dims_op.getNextNode(); +} + +// Checks if all inputs and outputs are quantized. +bool HasQuantizedOperandOrOutput(Operation& call_op) { + SmallVector arg_types; + for (const Value arg : call_op.getOperands()) { + arg_types.push_back(arg.getType()); + } + + SmallVector output_types; + for (const Value output : call_op.getResults()) { + output_types.push_back(output.getType()); + } + + return absl::c_all_of(arg_types, IsQuantizedTensorType) && + absl::c_all_of(output_types, IsQuantizedTensorType); +} + +// Gets the corresponding quantized function name from the given function name. +// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" +std::string GetQuantizedFunctionName(const StringRef func_name) { + return Twine(kQuantizedFuncPrefix) + .concat(func_name.rsplit(kCompositeFuncPrefix).second) + .str(); +} + +// Returns true if `xla_call_module_op` is quantized. To be considered +// quantized, it should meet three conditions: +// 1. At least one of the inputs or outputs should be a uniform quantized type. +// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. +// 3. It should also have the `kEntryFuncAttrName` attribute, which points to +// the function that `xla_call_module_op` represents. +bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { + return HasQuantizedOperandOrOutput(*xla_call_module_op) && + xla_call_module_op->hasAttr(kQuantTraitAttrName) && + xla_call_module_op->hasAttr(kEntryFuncAttrName); +} + +// Returns the entry function, i.e. the callee of `xla_call_module_op`. +func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, + SymbolTable symbol_table) { + const auto entry_function_symbol_ref = + xla_call_module_op->getAttrOfType(kEntryFuncAttrName); + + return dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); +} + +// Replaces the function type of `entry_func_op` to a quantized one, matching +// the input and output types of `xla_call_module_op`. +void SetQuantizedFunctionType(PatternRewriter& rewriter, + func::FuncOp& entry_func_op, + TF::XlaCallModuleOp xla_call_module_op) { + SmallVector arg_types; + SmallVector arg_locs; + for (const Value arg : xla_call_module_op.getArgs()) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + + SmallVector output_types; + for (const Value output : xla_call_module_op.getOutput()) { + output_types.push_back(output.getType()); + } + + entry_func_op.setFunctionType( + rewriter.getFunctionType(arg_types, output_types)); + + // Replace argument types and locs. + Block& entry = entry_func_op->getRegion(0).front(); + for (auto [arg, arg_type, arg_loc] : + llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { + arg.setType(arg_type); + arg.setLoc(arg_loc); + } +} + +// Creates a UniformQuantize op and sets it as return op. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + +// An interface representing patterns that quantizes an entry function's body. +// The entry function's signatures should have already been quantized at the +// point of rewriting. +class EntryFuncBodyQuantizationPattern { + public: + virtual ~EntryFuncBodyQuantizationPattern() = default; + + // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At + // this point `entry_func_op`'s signature has not been reset with quantized + // types. + virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + + // Rewrites the `entry_func_op`'s body. + virtual void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const = 0; +}; + +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + // function must have input, filter, and optionally bias. + auto& operations = entry_func_op.getBody().front().getOperations(); + if (operations.size() != 2 && operations.size() != 3) { + return failure(); + } + if (!isa(operations.front())) { + return failure(); + } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { + LLVM_DEBUG(llvm::dbgs() + << "Currently gemm style ops quantization only supports static " + " shapes.\n"); + return failure(); + } else if (!isa( + operations.front().getResult(0).getType())) { + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + // Update the output type of the gemm_style op. + GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + const double filter_scale = + getElementTypeOrSelf(filter_type).cast().getScale(); + const double result_scale = input_scale * filter_scale; + + // Define the intermediate output type, which is an i32 quantized type. + // This is intermediate because the final output type of the entry_func_op + // should be an i8 quantized type. + const UniformQuantizedType gemm_style_quantized_element_type = + CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), + *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + Value gemm_style_op_result = gemm_style_op->getResult(0); + auto gemm_style_op_result_type = + gemm_style_op_result.getType().cast(); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + const TensorType new_gemm_style_op_result_type = + gemm_style_op_result_type.cloneWith(gemm_style_shape, + gemm_style_quantized_element_type); + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation& next_op = *gemm_style_op->getNextNode(); + // If an op is used multiple times, do not apply quantization of fused + // patterns to prevent removal of dependee ops. + const bool should_quantize_without_fusion = + HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && + !gemm_style_op->hasOneUse(); + + // TODO: b/307620428 - Add support for dynamic shapes. + if (should_quantize_without_fusion || !isa(next_op)) { + // no bias + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + return; + } + // bias fusion + Value bias_op = next_op.getOperand(1); + Value add_op_result = next_op.getResult(0); + const auto add_op_result_type = + add_op_result.getType().cast(); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, gemm_style_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + +// Quantizes the entry function's body containing a `DotGeneralOp`. +class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeDotGeneralOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Converts `entry_func_op` to be quantized according to the respective +// inputs and outputs of `xla_call_module_op` that are possibly quantized. It +// signature (type) is reset to match that of `xla_call_module_op`. +// `entry_func_body_quantization_pattern` rewrites the function's body, based on +// the new signature. +void QuantizeEntryFuncOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); + + body_rewrite_pattern.rewrite(entry_func_op, rewriter); + + // Rename the function to be clear that the function has been quantized. + const std::string quantized_function_name = + GetQuantizedFunctionName(entry_func_op.getSymName()); + entry_func_op.setSymName(quantized_function_name); +} + +// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee +// is expected to remain unquantized (thus having a signature mismatch), and it +// is also quantized accordingly. +void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + ModuleOp module_op = xla_call_module_op->getParentOfType(); + SymbolTable symbol_table(module_op); + + func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); + QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, + body_rewrite_pattern); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.setInsertionPoint(xla_call_module_op); + rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, + xla_call_module_op.getArgs()); +} + +// Pattern that mainly does two things: +// +// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. +// 2. Quantizes the callee function. +// +// The inputs of this pattern assumes an invalid IR, where even if a +// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) +// not only replaces the input and output tensor types into quantized ones, but +// also rewrites the body with a quantized equivalent. +// +// `FuncBodyRewritePatternT` defines how a function body is quantized and +// rewritten. +template >> +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) + : OpRewritePattern(&ctx) {} + + LogicalResult match(TF::XlaCallModuleOp op) const override { + ModuleOp module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + // Ignore unquantized ops. + if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + + func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); + if (!entry_func_op) { + op->emitError("Failed to find a valid entry function."); + return failure(); + } + + return FuncBodyRewritePatternT().match(entry_func_op); + } + + void rewrite(TF::XlaCallModuleOp xla_call_module_op, + PatternRewriter& rewriter) const override { + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + *rewriter.getContext(), rewriter, xla_call_module_op, + FuncBodyRewritePatternT()); + } +}; + +} // namespace + +// TODO: b/307620428 - Increase fused op coverage for static range quantization. +void PopulateFusedGemmStylePatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + patterns.add, + XlaCallModuleOpToCallOp>(ctx); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h similarity index 97% rename from tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 3922374e402353..79daa9ce8b48b8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ #include #include @@ -90,6 +90,7 @@ class StableHloQuantizationPattern : public RewritePattern { : RewritePattern(RootOpT::getOperationName(), 300, context), quant_params_(quant_params) {} + private: LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override { llvm::SmallVector quantizing_ops; @@ -157,8 +158,8 @@ class StableHloQuantizationPattern : public RewritePattern { // Blocklist op is checked in advance for non-dynamic range quantization // case. if (!quant_params_.quant_spec.weight_quantization && - (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != - ops_blocklist.end())) { + (ops_blocklist.contains( + quantizing_op->getName().getStringRef().str()))) { return failure(); } @@ -261,9 +262,6 @@ class StableHloQuantizationPattern : public RewritePattern { return success(); } - private: - QuantPassSpec quant_params_; - // Checks whether the operation is connnected with a quantized composite // function. If not, the same-scale op will not be quantized. This decision is // based on the current assumption that the performance gain of the same-scale @@ -367,8 +365,15 @@ class StableHloQuantizationPattern : public RewritePattern { } return has_quantized_types; } + + QuantPassSpec quant_params_; }; +// Gemm Style Op: glossary/gemm. +// Populates conversion patterns to unfuse batch normalization operations. +void PopulateFusedGemmStylePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + } // namespace mlir::quant::stablehlo -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index fffa0ad782c8d9..cb587416a45bff 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" namespace mlir::quant::stablehlo { @@ -130,6 +130,9 @@ void QuantizePass::runOnOperation() { patterns.add( &ctx, quant_params); + // Support quantization for fused patterns containing Gemm Style ops. + PopulateFusedGemmStylePatterns(ctx, patterns); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index dd558a08bc642c..aef2f7bb9f3cd9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -86,362 +86,6 @@ class QuantizeCompositeFunctionsPass void runOnOperation() override; }; -// Returns true if `type` is a TensorType with quantized elements. -bool IsQuantizedTensorType(const Type type) { - return type.isa() && - type.cast().getElementType().isa(); -} - -// Returns true if an op has adjacent bias or activation that can be fused -// together into the quantization function. -// TODO: b/307620428 - Consider using matchAndRewrite to check and apply -// patterns at the same time. Also add check for fusible activation or -// fusible patterns with dynamic shape. -bool HasFusibleQuantizationPattern(Operation& op) { - if (isa(op.getNextNode())) { - return true; - } - return false; -} - -// Returns dynamically broadcasted user op of an input op. Returns null if -// the op is used multiple times or the user op is not dynamically broadcasted. -// Dynamic shapes usually has the following pattern. In the example below, -// the input operand would be stablehlo.gemm_style op, and return value would -// be stablehlo.add op. -// -// ``` -// %2 = stablehlo.gemm_style(%0, %1) -// %3 = shape.shape_of %2 -// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 -// %5 = stablehlo.add %2, %4 -// ``` -Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { - if (!op.hasOneUse()) { - LLVM_DEBUG(llvm::dbgs() - << "Target op is used multiple times and will not be checked " - "for dynamic shape case.\n"); - return nullptr; - } - Operation& shapeof_op = *op.getNextNode(); - if (!isa(shapeof_op)) { - return nullptr; - } - Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); - if (!isa(broadcast_in_dims_op)) { - return nullptr; - } - return broadcast_in_dims_op.getNextNode(); -} - -// Checks if all inputs and outputs are quantized. -bool HasQuantizedOperandOrOutput(Operation& call_op) { - SmallVector arg_types; - for (const Value arg : call_op.getOperands()) { - arg_types.push_back(arg.getType()); - } - - SmallVector output_types; - for (const Value output : call_op.getResults()) { - output_types.push_back(output.getType()); - } - - return absl::c_all_of(arg_types, IsQuantizedTensorType) && - absl::c_all_of(output_types, IsQuantizedTensorType); -} - -// Get the corresponding quantized function name from the given function name. -// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" -std::string GetQuantizedFunctionName(const StringRef func_name) { - return Twine(kQuantizedFuncPrefix) - .concat(func_name.rsplit(kCompositeFuncPrefix).second) - .str(); -} - -// Returns true if `xla_call_module_op` is quantized. To be considered -// quantized, it should meet three conditions: -// 1. At least one of the inputs or outputs should be a uniform quantized type. -// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. -// 3. It should also have the `kEntryFuncAttrName` attribute, which points to -// the function that `xla_call_module_op` represents. -bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { - return HasQuantizedOperandOrOutput(*xla_call_module_op) && - xla_call_module_op->hasAttr(kQuantTraitAttrName) && - xla_call_module_op->hasAttr(kEntryFuncAttrName); -} - -// Returns the entry function, i.e. the callee of `xla_call_module_op`. -func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, - SymbolTable symbol_table) { - const auto entry_function_symbol_ref = - xla_call_module_op->getAttrOfType(kEntryFuncAttrName); - - return dyn_cast_or_null( - symbol_table.lookup(entry_function_symbol_ref.getValue())); -} - -// Replaces the function type of `entry_func_op` to a quantized one, matching -// the input and output types of `xla_call_module_op`. -void SetQuantizedFunctionType(PatternRewriter& rewriter, - func::FuncOp entry_func_op, - TF::XlaCallModuleOp xla_call_module_op) { - SmallVector arg_types; - SmallVector arg_locs; - for (const Value arg : xla_call_module_op.getArgs()) { - arg_types.push_back(arg.getType()); - arg_locs.push_back(arg.getLoc()); - } - - SmallVector output_types; - for (const Value output : xla_call_module_op.getOutput()) { - output_types.push_back(output.getType()); - } - - entry_func_op.setFunctionType( - rewriter.getFunctionType(arg_types, output_types)); - - // Replace argument types and locs. - Block& entry = entry_func_op->getRegion(0).front(); - for (auto [arg, arg_type, arg_loc] : - llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { - arg.setType(arg_type); - arg.setLoc(arg_loc); - } -} - -// Creates a UniformQuantize op and sets it as return op. -void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, - func::FuncOp entry_func_op, - const Type func_result_type) { - // Add i32 -> i8 requantization. - UniformQuantizeOp uniform_quant_op = rewriter.create( - op.getLoc(), func_result_type, op.getResults()); - cast(entry_func_op.getBody().front().getTerminator()) - .setOperand(0, uniform_quant_op); -} - -// An interface representing patterns that quantizes an entry function's body. -// The entry function's signatures should have already been quantized at the -// point of rewriting. -class EntryFuncBodyQuantizationPattern { - public: - virtual ~EntryFuncBodyQuantizationPattern() = default; - - // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At - // this point `entry_func_op`'s signature has not been reset with quantized - // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; - - // Rewrites the `entry_func_op`'s body. - virtual void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const = 0; -}; - -// Gemm Style Op: glossary/gemm. -template -// Match for all gemm_style op and check for possible fusions. -LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { - // function must have input, filter, and optionally bias. - auto& operations = entry_func_op.getBody().front().getOperations(); - if (operations.size() != 2 && operations.size() != 3) { - return failure(); - } - if (!isa(operations.front())) { - return failure(); - } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { - LLVM_DEBUG(llvm::dbgs() - << "Currently gemm style ops quantization only supports static " - " shapes.\n"); - return failure(); - } else if (!isa( - operations.front().getResult(0).getType())) { - return failure(); - } - return success(); -} - -// Gemm Style Op: glossary/gemm. -template -void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { - // Update the output type of the gemm_style op. - GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); - - const Type input_type = entry_func_op.getArgumentTypes()[0]; - const Type filter_type = entry_func_op.getArgumentTypes()[1]; - const Type func_result_type = entry_func_op.getResultTypes()[0]; - - const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); - const double filter_scale = - getElementTypeOrSelf(filter_type).cast().getScale(); - const double result_scale = input_scale * filter_scale; - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType gemm_style_quantized_element_type = - CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), - *rewriter.getContext(), result_scale, - /*zero_point=*/0); - - Value gemm_style_op_result = gemm_style_op->getResult(0); - auto gemm_style_op_result_type = - gemm_style_op_result.getType().cast(); - const ArrayRef gemm_style_shape = - gemm_style_op_result_type.getShape(); - - const TensorType new_gemm_style_op_result_type = - gemm_style_op_result_type.cloneWith(gemm_style_shape, - gemm_style_quantized_element_type); - gemm_style_op_result.setType(new_gemm_style_op_result_type); - - rewriter.setInsertionPointAfter(gemm_style_op); - - Operation& next_op = *gemm_style_op->getNextNode(); - // If an op is used multiple times, do not apply quantization of fused - // patterns to prevent removal of dependee ops. - const bool should_quantize_without_fusion = - HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && - !gemm_style_op->hasOneUse(); - - // TODO: b/307620428 - Add support for dynamic shapes. - if (should_quantize_without_fusion || !isa(next_op)) { - // no bias - CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, - func_result_type); - return; - } - // bias fusion - Value bias_op = next_op.getOperand(1); - Value add_op_result = next_op.getResult(0); - const auto add_op_result_type = - add_op_result.getType().cast(); - const ArrayRef add_op_shape = add_op_result_type.getShape(); - // For quantized bias add case, lhs, rhs, and result have the same types. - const TensorType new_add_op_result_type = add_op_result_type.cloneWith( - add_op_shape, gemm_style_quantized_element_type); - add_op_result.setType(new_add_op_result_type); - - AddOp bias_add_op = - rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); - - CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, - func_result_type); -} - -// Quantizes the entry function's body containing a `DotGeneralOp`. -class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeDotGeneralOpPattern() = default; - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); - } -}; - -// Quantizes the entry function's body containing a `ConvolutionOp`. -class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeConvolutionOpPattern() = default; - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); - } -}; - -// Converts `entry_func_op` to be quantized according to the respective -// inputs and outputs of `xla_call_module_op` that are possibly quantized. It -// signature (type) is reset to match that of `xla_call_module_op`. -// `entry_func_body_quantization_pattern` rewrites the function's body, based on -// the new signature. -void QuantizeEntryFuncOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); - - body_rewrite_pattern.rewrite(entry_func_op, rewriter); - - // Rename the function to be clear that the function has been quantized. - const std::string quantized_function_name = - GetQuantizedFunctionName(entry_func_op.getSymName()); - entry_func_op.setSymName(quantized_function_name); -} - -// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee -// is expected to remain unquantized (thus having a signature mismatch), and it -// is also quantized accordingly. -void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - ModuleOp module_op = xla_call_module_op->getParentOfType(); - SymbolTable symbol_table(module_op); - - func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); - QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, - body_rewrite_pattern); - - // Replace the XlaCallModuleOp with a new CallOp. - rewriter.setInsertionPoint(xla_call_module_op); - rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, - xla_call_module_op.getArgs()); -} - -// Pattern that mainly does two things: -// -// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. -// 2. Quantizes the callee function. -// -// The inputs of this pattern assumes an invalid IR, where even if a -// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) -// not only replaces the input and output tensor types into quantized ones, but -// also rewrites the body with a quantized equivalent. -// -// `FuncBodyRewritePatternT` defines how a function body is quantized and -// rewritten. -template >> -class XlaCallModuleOpToCallOp : public OpRewritePattern { - public: - explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) - : OpRewritePattern(&ctx) {} - - LogicalResult match(TF::XlaCallModuleOp op) const override { - ModuleOp module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); - - // Ignore unquantized ops. - if (!IsQuantizedXlaCallModuleOp(op)) return failure(); - - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); - if (!entry_func_op) { - op->emitError("Failed to find a valid entry function."); - return failure(); - } - - return FuncBodyRewritePatternT().match(entry_func_op); - } - - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT()); - } -}; - void QuantizeCompositeFunctionsPass::runOnOperation() { MLIRContext& ctx = getContext(); @@ -464,15 +108,6 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { !pm_run_status.ok()) { signalPassFailure(); } - - // TODO - b/307839649: Move this as a separate pass. - RewritePatternSet patterns(&ctx); - patterns.add, - XlaCallModuleOpToCallOp>(ctx); - - if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { - signalPassFailure(); - } } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir index d1bfea7a236448..e794dded354da9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir @@ -1,5 +1,8 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize -verify-each=false | FileCheck %s +// Tests for PopulateFusedGemmStylePatterns are handled in +// quantize_composite_functions for module-level evaluation of functions. + // CHECK-LABEL: quantize_simple_xla_call_module func.func private @quantize_simple_xla_call_module(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { %0 = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> @@ -40,3 +43,27 @@ func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf3 // CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform> // CHECK: %[[DCAST_0:.*]] = "quantfork.dcast"(%[[XLACALLMODULE_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> () + +// ----- + +// Tests for emitting an error when there is no corresponding entry +// function to quantize (@composite_dot_general_fn). + +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %2 = "quantfork.dcast"(%1) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %3 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} +// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir index d38083150ba8be..649b0fece15d81 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir @@ -1,6 +1,9 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ // RUN: -stablehlo-quantize-composite-functions | FileCheck %s + +// Tests that basic dot_general is properly quantized. + module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. @@ -39,7 +42,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused bias pattern is properly quantized. +// Tests that fused pattern for dot_general + bias is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -78,7 +81,8 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused bias pattern with dynamic shape is not quantized. +// Tests that fused pattern for dot_general + bias with dynamic shape is +// not quantized. // TODO: b/307620428 - Add support for fused bias with dynamic shapes. module attributes {tf_saved_model.semantics} { @@ -106,57 +110,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests error when there are no corresponding entry function to quantize -// (@composite_dot_general_fn). - -module attributes {tf_saved_model.semantics} { -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> -// expected-error @+2 {{Failed to find a valid entry function}} -// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } -} - -// ----- - -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. - -module attributes {tf_saved_model.semantics} { - func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. - -// CHECK-LABEL: func.func private @not_quantized_without_stats -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] - - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -// Check that the composite_dot_general_fn is untouched. - -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] -// CHECK: return %[[DOT_GENERAL]] -} - -// ----- - -// Test basic convolution is quantized. +// Tests that basic convolution is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -196,7 +150,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused bias pattern is properly quantized. +// Tests that fused pattern for convolution + bias is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -231,3 +185,33 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> } + +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK-LABEL: func.func private @not_quantized_without_stats +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Check that the composite_dot_general_fn is untouched. + +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] +// CHECK: return %[[DOT_GENERAL]] +} From 75571c610e62b34139da28d14fc549fab0f7cb5d Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Fri, 1 Dec 2023 02:17:11 -0800 Subject: [PATCH 275/381] [xla:gpu] Add a flag to tell ir_emitter to not emit LLVM kernels #7360 PiperOrigin-RevId: 586944336 --- .../xla/service/gpu/compile_module_to_llvm_ir.cc | 2 +- third_party/xla/xla/service/gpu/fusions/BUILD | 3 +++ .../xla/xla/service/gpu/fusions/fusion_emitter.cc | 14 +++++++++++--- third_party/xla/xla/service/gpu/gpu_compiler.cc | 3 ++- .../xla/xla/service/gpu/ir_emitter_context.h | 9 +++++++-- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 2a64ef43f41d57..f1eb2414cc0dc9 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -373,7 +373,7 @@ StatusOr CompileModuleToLlvmIr( IrEmitterContext ir_emitter_context( hlo_module, results.buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), results.llvm_module.get(), - emit_from_hlo); + emit_from_hlo, /*emit_kernels=*/true); std::vector allocations; if (emit_from_hlo) { diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index bbd978a153dfb0..9fac85b5c5efb2 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -35,6 +35,7 @@ cc_library( hdrs = ["fusion_emitter.h"], visibility = ["//visibility:public"], deps = [ + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/mlir_hlo:lhlo", "//xla/service:elemental_ir_emitter", @@ -48,6 +49,8 @@ cc_library( "//xla/service/gpu:thunk", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 4d505db7981869..fa2e8bf496f204 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -15,12 +15,15 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include +#include #include #include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Argument.h" @@ -33,6 +36,7 @@ limitations under the License. #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -203,9 +207,13 @@ StatusOr KernelFusionEmitterBase::Emit( ir_emitter_context, suggested_kernel_name, kernel_arguments.args(), fusion.operand_count(), launch_dims, builder); - TF_RETURN_IF_ERROR(EmitKernel(ir_emitter_context, elemental_emitter, - fusion, launch_dims, std::move(inputs), - std::move(outputs), builder, i)); + if (ir_emitter_context.emit_kernels()) { + TF_RETURN_IF_ERROR(EmitKernel( + ir_emitter_context, elemental_emitter, fusion, launch_dims, + std::move(inputs), std::move(outputs), builder, i)); + } else { + VLOG(3) << "Skipped kernel compilation: " << suggested_kernel_name; + } // TODO(jreiffers): Return shmem_bytes from EmitKernel when // converting the Triton emitters to this infrastructure. return KernelReuseCache::Entry{kernel->getName().str(), launch_dims, diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 8c5f4b03674b13..e6699a7c872540 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -439,7 +439,8 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), - /*emit_ir_from_hlo=*/true); + /*emit_ir_from_hlo=*/true, + /*emit_kernels=*/false); mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); std::vector ordered_allocations; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h index 90614c715ee828..4ec71b8b0d61f9 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h @@ -43,14 +43,15 @@ class IrEmitterContext { std::string platform_name, const se::DeviceDescription& gpu_device_info, mlir::MLIRContext* mlir_context, llvm::Module* llvm_module, - bool emit_ir_from_hlo) + bool emit_ir_from_hlo, bool emit_kernels) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), mlir_context_(mlir_context), llvm_module_(llvm_module), - emit_ir_from_hlo_(emit_ir_from_hlo) {} + emit_ir_from_hlo_(emit_ir_from_hlo), + emit_kernels_(emit_kernels) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; IrEmitterContext& operator=(const IrEmitterContext&) = delete; @@ -102,6 +103,7 @@ class IrEmitterContext { } bool emit_ir_from_hlo() const { return emit_ir_from_hlo_; } + bool emit_kernels() const { return emit_kernels_; } private: const HloModule* hlo_module_; @@ -119,6 +121,9 @@ class IrEmitterContext { NameUniquer name_uniquer_; std::vector constants_; const bool emit_ir_from_hlo_; + + // We should not emit kernels when loading thunks from a compilation result. + const bool emit_kernels_; }; } // namespace gpu From f38c8ca0c8de4ac30095059cb1ed8d5b08722541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Fri, 1 Dec 2023 03:06:35 -0800 Subject: [PATCH 276/381] [XLA:GPU] Handle further propagation from a trivial sized tensor gracefully I just ran into this problem when testing my new "separated" fusion. PiperOrigin-RevId: 586954613 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../gpu/triton_fusion_analysis_test.cc | 33 +++++++++++++++++++ .../service/gpu/triton_tiling_propagation.cc | 6 ++++ 3 files changed, 40 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 6b86f8ab1ba88a..f7c6f2ed56da87 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1345,6 +1345,7 @@ xla_cc_test( deps = [ ":gemm_rewriter_triton", ":triton_fusion_analysis", + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index d2f67580234f60..17ac3b47cafcd3 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gemm_rewriter_triton.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -568,6 +569,38 @@ ENTRY e { /*subfragments=*/ElementsAre(30)))); } +TEST_F(TritonDotAnalysisTest, + HandlesFurtherPropagationFromTrivialSizedTensorGracefully) { + // We could probably support this better, just checking to avoid a crash for + // now. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm_r { + a = f32[3,3]{1,0} parameter(0) + constant = f32[1,1]{1,0} constant({ {0} }) + broadcast = f32[1,1]{1,0} broadcast(constant), dimensions={0,1} + reshape = f32[] reshape(broadcast) + broadcast2 = f32[3,3]{1,0} broadcast(reshape), dimensions={} + ROOT dot = f32[3,3]{1,0} dot(a, broadcast2), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + a = f32[3,3]{1,0} parameter(0) + ROOT dot = f32[3,3]{1,0} fusion(a), kind=kCustom, calls=triton_gemm_r, + backend_config={kind: "__triton_gemm"} +} +)")); + + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + + StatusOr analysis = + TritonFusionAnalysis::Execute(*dot_computation); + // It can fail but shouldn't crash. + (void)analysis; +} + using TritonSoftmaxAnalysisTest = HloTestBase; TEST_F(TritonSoftmaxAnalysisTest, DegenerateBatchDimensionIsSupported) { diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 61c9e0e05ec6cd..b0f27751fa651d 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -611,6 +611,12 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( // full dimensions and matching by total size. std::vector> src_physical; src_physical.reserve(src.shape().rank()); + if (src_fragments_order.size() < src.shape().rank()) { + // It's not supported currently to further propagate dimensions after + // reaching a trivial sized tensor. We could probably support it, but now we + // just prevent crashing here. + return FusionDecision("Cannot propagate further from trivial sized tensor"); + } auto src_fragment_it = src_fragments_order.begin(); for (int64_t dim_index : src.shape().layout().minor_to_major()) { const int64_t dim_size = src.shape().dimensions(dim_index); From 9983fe43696060e66b1e31553bdd1c19c74d3131 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Dec 2023 04:07:04 -0800 Subject: [PATCH 277/381] [XLA:CPU] Add a direct implementation of ReduceScatter, instead of lowering ReduceScatter to AllReduce+DynamicSlice. Second attempt: this time we handle use_global_device_ids. PiperOrigin-RevId: 586966387 --- third_party/xla/xla/service/cpu/BUILD | 4 +- .../xla/service/cpu/collectives_interface.h | 6 + .../xla/xla/service/cpu/cpu_compiler.cc | 2 - .../xla/service/cpu/cpu_layout_assignment.cc | 7 + .../xla/xla/service/cpu/cpu_runtime.cc | 79 ++++- third_party/xla/xla/service/cpu/cpu_runtime.h | 8 + .../xla/service/cpu/in_process_collectives.cc | 280 +++++++++++++----- .../xla/service/cpu/in_process_collectives.h | 6 + third_party/xla/xla/service/cpu/ir_emitter.cc | 103 +++++-- .../xla/xla/service/cpu/simple_orc_jit.cc | 1 + 10 files changed, 389 insertions(+), 107 deletions(-) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 44dbf2f4db410d..6f782102a843b6 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -305,7 +305,6 @@ cc_library( "//xla/service:optimization_barrier_expander", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", - "//xla/service:reduce_scatter_decomposer", "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:result_caster", @@ -902,6 +901,7 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -911,6 +911,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h index 4191df1d831fa2..e2c8190c81b986 100644 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ b/third_party/xla/xla/service/cpu/collectives_interface.h @@ -67,6 +67,12 @@ class CollectivesCommunicator { virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, void* output_buffer, absl::Duration timeout) = 0; + + // Performs a reduce-scatter + virtual absl::Status ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 06a62b4f8ca69e..63eae07969e96a 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -191,7 +191,6 @@ limitations under the License. #include "xla/service/optimization_barrier_expander.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" -#include "xla/service/reduce_scatter_decomposer.h" #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/result_caster.h" @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index eb59a195d572cf..44eb78c7e3eb03 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -139,6 +139,13 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kReduceScatter) { + // XLA:CPU can only support reduce-scatter where the scatter dimension + // is the most major dimension in the layout. + auto ars = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), + ars)); } else if (instruction->opcode() == HloOpcode::kAllGather) { // XLA:CPU can only support all-gathers where the gather dimension is the // most major dimension in the layout. diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 912c4db69aacd2..cd77c7555361d8 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -29,6 +29,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -46,6 +50,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -143,6 +148,8 @@ extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; +extern const char* const kReduceScatterSymbolName = + "__xla_cpu_runtime_ReduceScatter"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -315,6 +322,19 @@ CollectivesInterface* GetInProcessCollectivesImpl() { absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } +absl::StatusOr RankInGlobalDevices( + absl::Span devices, GlobalDeviceId device) { + auto it = absl::c_find(devices, device); + if (it == devices.end()) { + return InvalidArgument( + "Device %d not present in global devices %s.", device.value(), + absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })); + } + return std::distance(devices.begin(), it); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, int32_t channel_id_present, int64_t op_id, @@ -331,9 +351,7 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -361,9 +379,7 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, use_global_device_ids, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -374,6 +390,36 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void ReduceScatterImpl(const ExecutableRunOptions* run_options, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, + int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + use_global_device_ids, op_id); + + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + + CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->ReduceScatter( + rendezvous_key, static_cast(reduction_kind), + static_cast(element_type), chunk_elems, input_buffer, + output_buffer, DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -399,9 +445,7 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -450,9 +494,7 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); CollectivesInterface* collectives = GetInProcessCollectivesImpl(); @@ -544,6 +586,19 @@ void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, destination_buffer); } + +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, int64_t chunk_elems, + void* input_buffer, void* output_buffer) { + return xla::cpu::runtime::ReduceScatterImpl( + run_options, replica_groups_str, replica_groups_str_size, + channel_id_present, use_global_device_ids, op_id, reduction_kind, + element_type, chunk_elems, input_buffer, output_buffer); +} + void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 06a05cda713592..171590b49d27a0 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -85,6 +85,7 @@ extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; extern const char* const kAllGatherSymbolName; +extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; // All symbol names for XLA CPU runtime functions need to start with this @@ -202,6 +203,13 @@ extern void __xla_cpu_runtime_AllGather( const void* replica_groups_str, int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, void* destination_buffer); +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, int64_t chunk_elems, + void* input_buffer, void* output_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc index 78eee1aa74f731..ed30082be82e8e 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ b/third_party/xla/xla/service/cpu/in_process_collectives.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { @@ -93,7 +94,7 @@ template constexpr bool always_false_v = false; template -void Reduce(absl::Span acc, absl::Span const> inputs) { +void ReduceHelper(absl::Span acc, absl::Span inputs) { // TODO(penporn): make sure this gets vectorized. if constexpr (reduction_kind == ReductionKind::SUM) { for (size_t j = 0; j < inputs.size(); ++j) { @@ -124,6 +125,49 @@ void Reduce(absl::Span acc, absl::Span const> inputs) { } } +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + class CpuAllReduceRendezvous : public Rendezvous { public: @@ -146,110 +190,86 @@ class CpuAllReduceRendezvous return nullptr; } + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + switch (me.primitive_type) { case S8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case PRED: case U8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C128: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; default: return absl::UnimplementedError("Unexpected datatype"); } - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; + // All-gather the reduced chunks. for (const auto& p : participants_) { if (p->local_rank != me.local_rank) { - std::memcpy( - reinterpret_cast(p->destination_data) + chunk_offset, - reinterpret_cast(me.destination_data) + chunk_offset, - chunk_bytes); + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); } } return nullptr; } - - template - absl::Status DoAllReduce(const AllReduceParticipantData& me, - int64_t start_elem, int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(me.reduction_kind); - T* acc = reinterpret_cast(me.destination_data); - for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { - acc[i] = initial_value; - } - - absl::Span out_chunk = absl::MakeSpan( - reinterpret_cast(me.destination_data) + start_elem, num_elems); - std::vector> inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(absl::Span( - reinterpret_cast(p->source_data) + start_elem, num_elems)); - } - switch (me.reduction_kind) { - case ReductionKind::SUM: - Reduce(out_chunk, inputs); - break; - case ReductionKind::PRODUCT: - Reduce(out_chunk, inputs); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); - } }; struct CollectivePermuteParticipantData : ParticipantData { @@ -378,6 +398,109 @@ class CpuAllGatherRendezvous } }; +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + switch (me.element_type) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + default: + return absl::UnimplementedError("Unexpected datatype"); + } + + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -389,6 +512,8 @@ struct InProcessCollectivesState { all_to_all_rendezvous_map; RefcountingHashMap all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -488,6 +613,27 @@ absl::Status InProcessCollectivesCommunicator::AllGather( .status(); } +absl::Status InProcessCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = element_type; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = chunk_elems; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h index aaedc474fa39b2..f80baf38c4ebdc 100644 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ b/third_party/xla/xla/service/cpu/in_process_collectives.h @@ -59,6 +59,12 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { const void* input_buffer, void* output_buffer, absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index eea576506d99cf..5fcaa5592ea6a6 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -1169,35 +1169,36 @@ Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { return OkStatus(); } +// Data types supported by ReduceScatter and AllReduce. +static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { + // TODO(cheshire): Fix duplication wrt. cpu_runtime + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); - bool is_datatype_supported = [&] { - // TODO(cheshire): Fix duplication wrt. cpu_runtime - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } - }(); - - if (!is_datatype_supported) { + if (!DataTypeIsSupportedByReduceScatter(datatype)) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } @@ -1285,7 +1286,59 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { - return Unimplemented("ReduceScatter is not implemented on CPU."); + CHECK_EQ(rs->operand_count(), 1); + PrimitiveType datatype = rs->operand(0)->shape().element_type(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); + + if (!DataTypeIsSupportedByReduceScatter(datatype)) { + return Unimplemented("ReduceScatter for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + if (!MatchReductionComputation(rs->to_apply()).has_value()) { + return Unimplemented("ReduceScatter for computation '%s' is not supported", + rs->to_apply()->ToString()); + } + + std::string replica_groups = ReplicaGroupsToString(rs->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + Shape shape = rs->operand(0)->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(rs->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(rs, {})); + llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); + llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + + bool use_global_device_ids = + Cast(rs)->use_global_device_ids(); + + EmitCallToFunc( + runtime::kReduceScatterSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + + /*channel_id_present=*/ + b_.getInt32(static_cast(rs->channel_id().has_value())), + /*use_global_device_ids=*/ + b_.getInt32(static_cast(use_global_device_ids)), + /*op_id=*/ + b_.getInt64(rs->channel_id().has_value() ? *rs->channel_id() + : rs->GetModule()->unique_id()), + /*reduction_kind=*/ + b_.getInt32( + static_cast(*MatchReductionComputation(rs->to_apply()))), + /*element_type=*/ + b_.getInt32(static_cast(datatype)), + /*shape=*/b_.getInt64(ShapeUtil::ElementsIn(rs->shape())), + /*input_buffer=*/input_buffer, + /*output_buffer=*/output_buffer}, + b_.getVoidTy()); + + return OkStatus(); } Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 2e27a7c810869d..0cc07a27246f77 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -486,6 +486,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(AllGather); + REGISTER_CPU_RUNTIME_SYMBOL(ReduceScatter); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); From 564bedf6ce927ef16dff72b6c0b6f17dfe08afe8 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 1 Dec 2023 05:17:22 -0800 Subject: [PATCH 278/381] Fix layering violation in ir_emitter_unnested. The different fusion types are an implementation detail of .../fusion, and the higher level shouldn't depend on that. Also, reimplementing the logic of fusions.cc in ir_emitter_unnested is rather brittle (i.e., it's easy to change just one copy of the logic). PiperOrigin-RevId: 586979735 --- third_party/xla/xla/service/gpu/BUILD | 4 - third_party/xla/xla/service/gpu/fusions/BUILD | 3 +- .../xla/xla/service/gpu/fusions/fusions.cc | 73 +++++++++----- .../xla/xla/service/gpu/fusions/fusions.h | 5 +- .../xla/xla/service/gpu/ir_emission_utils.h | 2 - .../xla/service/gpu/ir_emitter_unnested.cc | 95 ++++--------------- .../xla/xla/service/gpu/ir_emitter_unnested.h | 9 +- 7 files changed, 77 insertions(+), 114 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index f7c6f2ed56da87..eddf9a3a0d6c15 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -323,12 +323,8 @@ cc_library( "//xla/service:name_uniquer", "//xla/service/gpu/fusions", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions:input_slices", - "//xla/service/gpu/fusions:loop", - "//xla/service/gpu/fusions:reduction", "//xla/service/gpu/fusions:thunk_util", "//xla/service/gpu/fusions:tiling_util", - "//xla/service/gpu/fusions:transpose", "//xla/service/gpu/kernels:custom_fusion", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/runtime3:command_buffer_cmd", diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 9fac85b5c5efb2..b0fdb57e7cd049 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -77,10 +77,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", - "//xla/service:elemental_ir_emitter", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter_context", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", ], diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index a24a5d55c28254..4de451b27ea142 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/types/span.h" #include "mlir/IR/Value.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/copy.h" @@ -35,21 +37,35 @@ limitations under the License. namespace xla { namespace gpu { +namespace { -bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { - bool seen_instruction = false; - for (mlir::Operation& instr : fusion.getRegion().front()) { - if (mlir::isa(&instr)) { - continue; - } - if (seen_instruction) return false; - seen_instruction = true; +bool IsParameterOrGteOfParameter(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kParameter) { + return true; + } + if (instr->opcode() == HloOpcode::kGetTupleElement) { + return IsParameterOrGteOfParameter(instr->operand(0)); } - return seen_instruction; + return false; +} + +bool IsSingleInstructionFusion(const HloFusionAnalysis& analysis) { + return analysis.fusion_roots().size() == 1 && + absl::c_all_of(analysis.fusion_roots()[0]->operands(), + IsParameterOrGteOfParameter); +} + +bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) { + return absl::c_all_of( + analysis.fusion_roots(), [](const HloInstruction* root) { + return root->opcode() == HloOpcode::kDynamicUpdateSlice || + (root->opcode() == HloOpcode::kBitcast && + root->operand(0)->opcode() == HloOpcode::kDynamicUpdateSlice); + }); } +} // namespace + std::optional> GetFusionEmitter( HloFusionAnalysis& analysis, absl::Span allocations, @@ -58,23 +74,28 @@ std::optional> GetFusionEmitter( case HloFusionAnalysis::EmitterFusionKind::kInputSlices: return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { - if (!allocations.empty() && fusion_op != nullptr) { - bool is_single = IsSingleInstructionFusion(fusion_op); - if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - fusion_op, allocations)) { + bool is_single = IsSingleInstructionFusion(analysis); + if (!is_single && IsDynamicUpdateSliceFusion(analysis)) { + if (allocations.empty() || fusion_op == nullptr) { + return std::nullopt; + } + if (CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op, + allocations)) { return std::make_unique(analysis); } - if (is_single && analysis.fusion_roots().size() == 1 && - analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { - mlir::Value operand = GetHloOperands(fusion_op).front(); - mlir::Value output = GetHloOutputs(fusion_op).front(); - Shape operand_shape = GetShape(operand); - Shape output_shape = GetShape(output); - if (LayoutUtil::Equal(operand_shape.layout(), - output_shape.layout()) && - GetAllocationSlice(operand, allocations).ok()) { - return std::make_unique(operand, output); - } + } + if (is_single && analysis.fusion_roots().size() == 1 && + analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { + if (!fusion_op) { + return std::nullopt; + } + mlir::Value operand = GetHloOperands(fusion_op).front(); + mlir::Value output = GetHloOutputs(fusion_op).front(); + Shape operand_shape = GetShape(operand); + Shape output_shape = GetShape(output); + if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && + GetAllocationSlice(operand, allocations).ok()) { + return std::make_unique(operand, output); } } return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h index 82fc63400b411c..d1ead648edc925 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.h +++ b/third_party/xla/xla/service/gpu/fusions/fusions.h @@ -29,8 +29,9 @@ namespace gpu { // Returns the emitter for the given fusion. Returns nullopt if the fusion // type is not yet supported. -// `allocations` may be empty and `fusion_op` may be nullptr if buffer -// assignment didn't run yet. +// `allocations` may be empty and `fusion_op` may be nullptr if no LMHLO ops are +// available. In this case, this function will return nullopt if it cannot +// detect whether a loop fusion can be optimized. std::optional> GetFusionEmitter( HloFusionAnalysis& analysis, absl::Span allocations, diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index f3349762868c9c..ade59584899c43 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -127,8 +127,6 @@ StatusOr GetAllocationSlice( mlir::Value v, absl::Span allocations, std::string* constant_name = nullptr); -bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion); - bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( mlir::lmhlo::FusionOp fusion, absl::Span allocations); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 9e25ad8bddd53e..76bb3b0e17e2db 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -100,11 +100,7 @@ limitations under the License. #include "xla/service/gpu/fused_mha_thunk.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/fusions/input_slices.h" -#include "xla/service/gpu/fusions/loop.h" -#include "xla/service/gpu/fusions/reduction.h" #include "xla/service/gpu/fusions/thunk_util.h" -#include "xla/service/gpu/fusions/transpose.h" #include "xla/service/gpu/gemm_thunk.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" @@ -2161,68 +2157,30 @@ StatusOr IrEmitterUnnested::EmitTritonFusion( #endif // GOOGLE_CUDA -// Check if the fusion instruction should be emitted as an in place dynamic -// update slice or a memcpy fusion. The logic is copied from GetFusionEmitter. -bool IsSpecializedLoopFusion( - mlir::Operation* op, absl::Span allocations, - HloFusionAnalysis& analysis) { - auto fusion_op = mlir::cast(op); - if (!allocations.empty() && fusion_op != nullptr) { - bool is_single = IsSingleInstructionFusion(fusion_op); - if (!is_single && - CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op, allocations)) { - return true; - } - if (is_single && analysis.fusion_roots().size() == 1 && - analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { - mlir::Value operand = GetHloOperands(fusion_op).front(); - mlir::Value output = GetHloOutputs(fusion_op).front(); - Shape operand_shape = GetShape(operand); - Shape output_shape = GetShape(output); - if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && - GetAllocationSlice(operand, allocations).ok()) { - return true; - } - } - } - return false; -} - -StatusOr IrEmitterUnnested::GetFusionEmissionResult( - const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis) { +Status IrEmitterUnnested::EmitFusion( + const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis, + mlir::Operation* op, + const absl::flat_hash_map& + hlo_for_lmhlo) { FusionEmissionResult emission_result; switch (fusion_analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: { - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kLoop: { - // TODO(anlunx): Support MemcpyFusion and InPlaceDymaicUpdateSlice. - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: + case HloFusionAnalysis::EmitterFusionKind::kLoop: + case HloFusionAnalysis::EmitterFusionKind::kTranspose: case HloFusionAnalysis::EmitterFusionKind::kReduction: { - auto emitter = std::make_unique(fusion_analysis); + auto emitter = GetFusionEmitter(fusion_analysis, {}, nullptr); + // TODO(anlunx): Support MemcpyFusion and InPlaceDynamicUpdateSlice and + // remove this fallback. + if (!emitter) { + TF_RET_CHECK(op) + << "Fusion should have been handled by GetFusionEmitter, fallback " + "disabled because no lmhlo op is available."; + return EmitFusion(op, hlo_for_lmhlo); + } TF_ASSIGN_OR_RETURN( emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); + (*emitter)->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, + *instr, kernel_reuse_cache_, &b_)); break; } case HloFusionAnalysis::EmitterFusionKind::kTriton: { @@ -2254,13 +2212,6 @@ StatusOr IrEmitterUnnested::GetFusionEmissionResult( break; } - return emission_result; -} - -Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, - HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(FusionEmissionResult emission_result, - GetFusionEmissionResult(instr, fusion_analysis)); for (auto& thunk : emission_result.thunks) { AddThunkToThunkSequence(std::move(thunk)); } @@ -3500,11 +3451,7 @@ Status IrEmitterUnnested::EmitOp( ir_emitter_context_->gpu_device_info(); TF_ASSIGN_OR_RETURN(auto fusion_analysis, HloFusionAnalysis::Create(instr, &device_info)); - // TODO(anlunx): Add support for emitting specialized kLoops. - if (!IsSpecializedLoopFusion(op, ir_emitter_context_->allocations(), - fusion_analysis)) { - return EmitFusion(instr, fusion_analysis); - } + return EmitFusion(instr, fusion_analysis, op, hlo_for_lmhlo); } return EmitFusion(op, hlo_for_lmhlo); @@ -3645,7 +3592,7 @@ Status IrEmitterUnnested::EmitHloInstruction(const HloInstruction* instr) { ir_emitter_context_->gpu_device_info(); TF_ASSIGN_OR_RETURN(auto fusion_analysis, HloFusionAnalysis::Create(fusion, &device_info)); - TF_RETURN_IF_ERROR(EmitFusion(fusion, fusion_analysis)); + TF_RETURN_IF_ERROR(EmitFusion(fusion, fusion_analysis, nullptr, {})); return OkStatus(); } // We don't need to emit thunks for these operations because their semantics diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 6e9fa5e6f74183..edab1055255e70 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -163,14 +163,15 @@ class IrEmitterUnnested : public IrEmitter { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM Status EmitCustomCallThunk(mlir::Operation* op); Status EmitFftThunk(mlir::Operation* op); - StatusOr GetFusionEmissionResult( - const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis); Status EmitFusion( mlir::Operation* op, const absl::flat_hash_map& hlo_for_lmhlo); - Status EmitFusion(const HloFusionInstruction* instr, - HloFusionAnalysis& fusion_analysis); + Status EmitFusion( + const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis, + mlir::Operation* op, + const absl::flat_hash_map& + hlo_for_lmhlo); Status EmitSelectAndScatter( mlir::Operation* op, const absl::flat_hash_map& From 842398272faeb549afcfc8c3687399cf47359885 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 1 Dec 2023 06:22:37 -0800 Subject: [PATCH 279/381] [XLA] [NFC] Generifying hlo-opt to allow easier extension for more platforms Adds a sample CPU implementation, also extends the GPU support to AMD. PiperOrigin-RevId: 586991292 --- .../xla/xla/service/hlo_graph_dumper.cc | 2 +- third_party/xla/xla/tools/hlo_opt/BUILD | 26 +++- third_party/xla/xla/tools/hlo_opt/cpu_hlo.hlo | 12 ++ third_party/xla/xla/tools/hlo_opt/cpu_opt.cc | 37 ++++++ third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 93 ++++---------- third_party/xla/xla/tools/hlo_opt/opt_lib.cc | 117 ++++++++++++++++-- third_party/xla/xla/tools/hlo_opt/opt_lib.h | 36 ++++-- third_party/xla/xla/tools/hlo_opt/opt_main.cc | 14 +-- 8 files changed, 242 insertions(+), 95 deletions(-) create mode 100644 third_party/xla/xla/tools/hlo_opt/cpu_hlo.hlo create mode 100644 third_party/xla/xla/tools/hlo_opt/cpu_opt.cc diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index 575c62b55369e4..c5215b14c77e64 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -1595,7 +1595,7 @@ NodeFilter MakeNodeRadiusAroundFilter( // the size of the graph (since they don't add extra information), and // stopping the rendering early cuts off important information (you // almost never want the rendering to be cutoff at the bitcast: you'd - // like to see it's parent). + // like to see its parent). if (!nodes.contains(operand)) { int new_depth = (operand->opcode() == HloOpcode::kBitcast || instr->opcode() == HloOpcode::kBitcast) diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index b64a6de29fbdb8..f41550dbc00164 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -8,7 +8,6 @@ load( "@local_tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") # hlo-opt tool. load( @@ -33,10 +32,15 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:hlo_graph_dumper", "//xla/service:platform_util", + "//xla/stream_executor", "//xla/stream_executor:platform", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", ], ) @@ -52,12 +56,14 @@ cc_library( "//xla:types", "//xla/service:compiler", "//xla/service:dump", + "//xla/service:executable", "//xla/service:hlo_graph_dumper", "//xla/service:platform_util", "//xla/service/gpu:executable_proto_cc", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/platform", "@com_google_absl//absl/container:flat_hash_map", + "@local_tsl//tsl/platform:statusor", ] + if_gpu_is_configured([ "//xla/service:gpu_plugin", "//xla/service/gpu:gpu_executable", @@ -67,12 +73,28 @@ cc_library( alwayslink = True, # Initializer needs to run. ) +cc_library( + name = "cpu_opt", + testonly = True, + srcs = ["cpu_opt.cc"], + visibility = ["//visibility:public"], + deps = [ + ":opt_lib", + "//xla/service:cpu_plugin", + "//xla/service:hlo_graph_dumper", + "//xla/stream_executor/host:host_platform", + "//xla/stream_executor/platform", + ], + alwayslink = True, # Initializer needs to run. +) + cc_library( name = "opt_main", testonly = True, srcs = ["opt_main.cc"], visibility = ["//visibility:public"], deps = [ + "cpu_opt", ":opt_lib", "//xla:debug_options_flags", "//xla:status", @@ -100,7 +122,7 @@ cc_library( ) glob_lit_tests( - name = "gpu_opt_tests", + name = "hlo_opt_tests", data = [":test_utilities"], default_tags = tf_cuda_tests_tags() + [ ], diff --git a/third_party/xla/xla/tools/hlo_opt/cpu_hlo.hlo b/third_party/xla/xla/tools/hlo_opt/cpu_hlo.hlo new file mode 100644 index 00000000000000..1c9b2a81a98946 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/cpu_hlo.hlo @@ -0,0 +1,12 @@ +// RUN: hlo-opt %s --platform=cpu --stage=hlo | FileCheck %s + +HloModule module + +ENTRY computation { +// CHECK: outer_dimension_partitions + p = f32[5000,6000]{1,0} parameter(0) + e = f32[5000,6000]{1,0} sqrt(p) + c = f32[6000,5000] transpose(p), dimensions={1,0} + r = f32[300,20,5000] reshape(c) + ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) +} diff --git a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc new file mode 100644 index 00000000000000..949b25515f86a1 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/platform/initialize.h" +#include "xla/tools/hlo_opt/opt_lib.h" + +namespace xla { + +namespace { + +class CpuOptProvider : public OptProvider { + public: + std::string GetPlatformName() override { return "cpu"; } +}; + +} // namespace +} // namespace xla + +REGISTER_MODULE_INITIALIZER(cpu_opt_provider, { + xla::OptProvider::RegisterForPlatform( + "cpu", std::make_unique()); +}); diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 55283f6cffec06..d5cb121e4c12d4 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -13,105 +13,64 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include +#include #include "absl/container/flat_hash_map.h" #include "xla/debug_options_flags.h" #include "xla/service/compiler.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_executable.h" -#include "xla/service/hlo_graph_dumper.h" #include "xla/service/platform_util.h" #include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/tools/hlo_opt/opt_lib.h" #include "xla/types.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -// TODO(cheshire): Switch CUDA/ROCM -static auto kGpuPlatformId = se::cuda::kCudaPlatformId; - -static StatusOr> ToGpuExecutable( - std::unique_ptr module, Compiler* compiler, - se::StreamExecutor* executor, const Compiler::CompileOptions& opts) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - DebugOptions d = optimized_module->config().debug_options(); - d.set_xla_embed_ir_in_executable(true); - optimized_module->mutable_config().set_debug_options(d); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - compiler->RunBackend(std::move(optimized_module), executor, opts)); - return executable; -} - -struct GpuOptProvider : public OptProvider { +class GpuOptProvider : public OptProvider { + public: StatusOr> GenerateStage( std::unique_ptr module, absl::string_view s) override { - TF_ASSIGN_OR_RETURN( - se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(kGpuPlatformId)); - - TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(platform)); - DebugOptions debug_opts = GetDebugOptionsFromFlags(); - - Compiler::CompileOptions opts; - - se::StreamExecutor* executor = nullptr; - if (debug_opts.xla_gpu_target_config_filename().empty()) { - TF_ASSIGN_OR_RETURN(std::vector stream_executors, - PlatformUtil::GetStreamExecutors( - platform, /*allowed_devices=*/std::nullopt)); - executor = stream_executors[0]; - } - - if (s == "hlo") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - return optimized_module->ToString(); - } else if (s == "llvm") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + if (s == "llvm") { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get()) ->ir_module_string(); } else if (s == "ptx") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get())->text(); } else if (s == "buffer-assignment") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get()) ->buffer_assignment() ->ToVerboseString(9999); - } else if (s == "html") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - TF_ASSIGN_OR_RETURN(std::string computations, - RenderAllComputationsToHtml(*optimized_module)); - return computations; + } else { + // Delegate to base class. + TF_ASSIGN_OR_RETURN(std::optional out, + OptProvider::GenerateStage(std::move(module), s)); + return out; } - - // Unimplemented stage. - return std::nullopt; } - std::vector SupportedStages() override { - return {"hlo", "llvm", "ptx", "buffer-assignment", "html"}; + std::string GetPlatformName() override { return "gpu"; } + + std::set SupportedStages() override { + std::set supported = OptProvider::SupportedStages(); + supported.insert({"ptx", "llvm", "buffer-assignment"}); + return supported; } }; @@ -120,5 +79,5 @@ struct GpuOptProvider : public OptProvider { REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { xla::OptProvider::RegisterForPlatform( - xla::kGpuPlatformId, std::make_unique()); + "gpu", std::make_unique()); }); diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc index 3bf19f43004f2d..62c93a0836a73d 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc @@ -15,13 +15,33 @@ limitations under the License. #include "xla/tools/hlo_opt/opt_lib.h" +#include +#include +#include +#include +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/compiler.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/platform_util.h" +#include "xla/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/types.h" +#include "tsl/platform/statusor.h" namespace xla { using ProviderMap = - absl::flat_hash_map>; + absl::flat_hash_map>; static absl::Mutex provider_mu(absl::kConstInit); static ProviderMap& GetProviderMap() { @@ -30,22 +50,103 @@ static ProviderMap& GetProviderMap() { } /*static*/ void OptProvider::RegisterForPlatform( - se::Platform::Id platform, - std::unique_ptr translate_provider) { + std::string platform, std::unique_ptr translate_provider) { absl::MutexLock l(&provider_mu); CHECK(!GetProviderMap().contains(platform)); - GetProviderMap()[platform] = std::move(translate_provider); + StatusOr canonical_name = + xla::PlatformUtil::CanonicalPlatformName(platform); + CHECK_OK(canonical_name); + GetProviderMap()[*canonical_name] = std::move(translate_provider); } -/*static*/ OptProvider* OptProvider::ProviderForPlatform( - se::Platform::Id platform) { +/*static*/ StatusOr OptProvider::ProviderForPlatform( + std::string platform) { absl::MutexLock l(&provider_mu); - auto it = GetProviderMap().find(platform); + + TF_ASSIGN_OR_RETURN(std::string canonical_name, + xla::PlatformUtil::CanonicalPlatformName(platform)); + auto it = GetProviderMap().find(canonical_name); if (it == GetProviderMap().end()) { - return nullptr; + return absl::UnimplementedError(absl::StrCat( + "Provider not found for platform ", platform, "; canonical expansion: ", + canonical_name, "; supported platforms are: ", + absl::StrJoin(GetProviderMap(), ", ", + [&](std::string* s, const auto& p) { + absl::StrAppend(s, p.first); + }))); } return it->second.get(); } +StatusOr OptProvider::GetExecutor() { + DebugOptions debug_opts = GetDebugOptionsFromFlags(); + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); + if (debug_opts.xla_gpu_target_config_filename().empty()) { + TF_ASSIGN_OR_RETURN(std::vector stream_executors, + PlatformUtil::GetStreamExecutors( + platform, /*allowed_devices=*/std::nullopt)); + return stream_executors[0]; + } + return nullptr; +} + +StatusOr> OptProvider::GenerateStage( + std::unique_ptr module, absl::string_view stage) { + if (stage == "hlo") { + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + return optimized_module->ToString(); + } else if (stage == "html") { + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + TF_ASSIGN_OR_RETURN(std::string cmps, + RenderAllComputationsToHtml(*optimized_module)); + return cmps; + } + + return std::nullopt; +} + +StatusOr OptProvider::GetCompiler() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); + + TF_ASSIGN_OR_RETURN(Compiler * compiler, Compiler::GetForPlatform(platform)); + return compiler; +} + +StatusOr> OptProvider::GetOptimizedHlo( + std::unique_ptr input_module) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + + DebugOptions debug_opts = GetDebugOptionsFromFlags(); + Compiler::CompileOptions opts; + TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimized_module, + compiler->RunHloPasses(std::move(input_module), executor, opts)); + + DebugOptions d = optimized_module->config().debug_options(); + d.set_xla_embed_ir_in_executable(true); + optimized_module->mutable_config().set_debug_options(d); + return optimized_module; +} + +StatusOr> OptProvider::GetExecutable( + std::unique_ptr input_module) { + Compiler::CompileOptions opts; + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(input_module))); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + compiler->RunBackend(std::move(optimized_module), executor, opts)); + return executable; +} + +std::set OptProvider::SupportedStages() { return {"hlo", "html"}; } + } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.h b/third_party/xla/xla/tools/hlo_opt/opt_lib.h index 29ae5ba73c96d0..819423b50d31ef 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.h +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.h @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include #include -#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" +#include "xla/service/executable.h" #include "xla/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/types.h" @@ -31,21 +32,42 @@ limitations under the License. namespace xla { // Platform-specific provider of `hlo_translate` functionality. -struct OptProvider { +class OptProvider { + public: // Generates textual output for a given stage on a given platform, returns // empty optional if the stage is not supported. virtual StatusOr> GenerateStage( - std::unique_ptr module, absl::string_view stage) = 0; + std::unique_ptr module, absl::string_view stage); virtual ~OptProvider() = default; - virtual std::vector SupportedStages() = 0; + // Returns a set of stages supported by the opt provider. + virtual std::set SupportedStages(); + // Registers a given provider for a given platform. static void RegisterForPlatform( - se::Platform::Id platform, - std::unique_ptr translate_provider); + std::string platform, std::unique_ptr translate_provider); - static OptProvider* ProviderForPlatform(se::Platform::Id platform); + // Gets a provider for a given platform. + static StatusOr ProviderForPlatform(std::string platform); + + protected: + // Returns platform name associated with the provider. + virtual std::string GetPlatformName() = 0; + + // Returns a stream executor for the provider (could be nullptr). + virtual StatusOr GetExecutor(); + + // Generates executable from a given input module. + StatusOr> GetExecutable( + std::unique_ptr input_module); + + // Generates optimized HLO. + StatusOr> GetOptimizedHlo( + std::unique_ptr input_module); + + // Gets a compiler associated with the provider. + virtual StatusOr GetCompiler(); }; } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/tools/hlo_opt/opt_main.cc index f402fe866bc30b..1274a167281998 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_main.cc +++ b/third_party/xla/xla/tools/hlo_opt/opt_main.cc @@ -51,14 +51,14 @@ You can also pass in debug option flags for the HloModule. Usage: - bazel run opt -- --platform=[CUDA|CPU|Interpreter|...] path/to/hlo_module + bazel run opt -- --platform=[gpu|cpu|...] path/to/hlo_module )"; struct HloOptConfig { // Optional flags. bool help{false}; bool split_input_file{false}; - std::string platform{"cuda"}; + std::string platform{"gpu"}; std::string input_file{""}; std::string input_format{""}; std::string output_file{"-"}; @@ -109,14 +109,8 @@ StatusOr> GetModule(const HloOptConfig& opts, StatusOr TranslateToStage(int argc, char** argv, const HloOptConfig& opts) { - se::Platform* platform = - xla::PlatformUtil::GetPlatform(opts.platform).value(); - - OptProvider* provider = OptProvider::ProviderForPlatform(platform->id()); - if (provider == nullptr) { - return absl::UnimplementedError( - absl::StrCat("Provider not found for platform: ", platform->Name())); - } + TF_ASSIGN_OR_RETURN(OptProvider * provider, + OptProvider::ProviderForPlatform(opts.platform)); if (opts.list_stages) { return absl::StrJoin(provider->SupportedStages(), "\n"); From 51dd5af918c81d89eccfd6a4509138b562833869 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 07:24:19 -0800 Subject: [PATCH 280/381] Re-enable layering_check for target. PiperOrigin-RevId: 587004230 --- tensorflow/core/common_runtime/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 69822584487fcd..5c0bf4f382285d 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -461,12 +461,13 @@ cc_library( srcs = ["collective_param_resolver_local.cc"], hdrs = ["collective_param_resolver_local.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device_mgr", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) From 0367df242f49f414604662d17c0b9cbd56599281 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 1 Dec 2023 07:28:11 -0800 Subject: [PATCH 281/381] [XLA] Add pass to move cheap fusible computations into while loops to enable fusion. PiperOrigin-RevId: 587005038 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 48 ++++ third_party/xla/xla/hlo/ir/hlo_instruction.h | 3 + third_party/xla/xla/service/BUILD | 33 +++ third_party/xla/xla/service/defuser.cc | 54 +--- .../xla/service/while_loop_fusible_sinking.cc | 232 ++++++++++++++++++ .../xla/service/while_loop_fusible_sinking.h | 82 +++++++ .../while_loop_fusible_sinking_test.cc | 114 +++++++++ 7 files changed, 513 insertions(+), 53 deletions(-) create mode 100644 third_party/xla/xla/service/while_loop_fusible_sinking.cc create mode 100644 third_party/xla/xla/service/while_loop_fusible_sinking.h create mode 100644 third_party/xla/xla/service/while_loop_fusible_sinking_test.cc diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 2ffc3116853298..5370489b0df6e3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2862,6 +2862,54 @@ Status HloInstruction::ReplaceOperandWithDifferentShape( return OkStatus(); } +// Copy all the instructions in the given fusion instruction into the fusion +// instruction's parent computation and replace the use of the fusion +// instruction with the copy of the fusion expression root. +Status HloInstruction::Defuse() { + if (opcode() != HloOpcode::kFusion) { + return OkStatus(); + } + VLOG(2) << "Defusing instruction: " << ToString(); + + HloComputation* fused_computation = fused_instructions_computation(); + + // A map from fused instruction to its defused clone. + absl::flat_hash_map + defused_instructions; + // Initialize map to contain the fusion instruction parameters mapping + // to the operands of the fusion instruction. + for (int64_t i = 0; i < operand_count(); ++i) { + defused_instructions[fused_computation->parameter_instruction(i)] = + mutable_operand(i); + } + + // Create a clone of each instruction of the fused computation in the same + // computation as the fusion instruction itself. + // TODO(b/68227302): Moving instruction to new computation rather than + // cloning and deleting. + for (HloInstruction* fused_instruction : + fused_computation->MakeInstructionPostOrder()) { + if (fused_instruction->opcode() == HloOpcode::kParameter) { + continue; + } + std::vector new_operands; + for (HloInstruction* operand : fused_instruction->operands()) { + new_operands.push_back(defused_instructions.at(operand)); + } + HloInstruction* defused_instruction = + parent()->AddInstruction(fused_instruction->CloneWithNewOperands( + fused_instruction->shape(), new_operands)); + defused_instructions[fused_instruction] = defused_instruction; + } + + TF_RETURN_IF_ERROR( + ReplaceAllUsesWith(defused_instructions.at(fused_expression_root()))); + + HloModule* module = GetModule(); + TF_RETURN_IF_ERROR(parent()->RemoveInstruction(this)); + return module->RemoveEmbeddedComputation(fused_computation); +} + Status HloInstruction::ReplaceUsesWith(absl::Span users, HloInstruction* new_producer) { TF_RET_CHECK( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 3de91ccd73a3a3..21970c34bb1585 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1436,6 +1436,9 @@ class HloInstruction { Status ReplaceOperandWithDifferentShape(int64_t operand_num, HloInstruction* new_operand); + // Decomposes fusion back to individual parts. + Status Defuse(); + // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b1df56793cdf59..5917512e839f1a 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6116,6 +6116,39 @@ xla_cc_test( ], ) +cc_library( + name = "while_loop_fusible_sinking", + srcs = ["while_loop_fusible_sinking.cc"], + hdrs = ["while_loop_fusible_sinking.h"], + visibility = ["//visibility:public"], + deps = [ + ":call_graph", + ":hlo_pass", + ":while_util", + "//xla:literal_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "while_loop_fusible_sinking_test", + srcs = ["while_loop_fusible_sinking_test.cc"], + deps = [ + ":while_loop_fusible_sinking", + "//xla:test", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/lib/core:status_test_util", + ], +) + cc_library( name = "despecializer", srcs = ["despecializer.cc"], diff --git a/third_party/xla/xla/service/defuser.cc b/third_party/xla/xla/service/defuser.cc index 77ba8ebe989198..e2833d1af91b5b 100644 --- a/third_party/xla/xla/service/defuser.cc +++ b/third_party/xla/xla/service/defuser.cc @@ -36,58 +36,6 @@ limitations under the License. namespace xla { -namespace { - -// Copy all the instructions in the given fusion instruction into the fusion -// instruction's parent computation and replace the use of the fusion -// instruction with the copy of the fusion expression root. -Status Defuse(HloInstruction* fusion_instruction) { - VLOG(2) << "Defusing instruction: " << fusion_instruction->ToString(); - - HloComputation* fused_computation = - fusion_instruction->fused_instructions_computation(); - - // A map from fused instruction to its defused clone. - absl::flat_hash_map - defused_instructions; - // Initialize map to contain the fusion instruction parameters mapping - // to the operands of the fusion instruction. - for (int64_t i = 0; i < fusion_instruction->operand_count(); ++i) { - defused_instructions[fused_computation->parameter_instruction(i)] = - fusion_instruction->mutable_operand(i); - } - - // Create a clone of each instruction of the fused computation in the same - // computation as the fusion instruction itself. - // TODO(b/68227302): Moving instruction to new computation rather than - // cloning and deleting. - for (HloInstruction* fused_instruction : - fused_computation->MakeInstructionPostOrder()) { - if (fused_instruction->opcode() == HloOpcode::kParameter) { - continue; - } - std::vector new_operands; - for (HloInstruction* operand : fused_instruction->operands()) { - new_operands.push_back(defused_instructions.at(operand)); - } - HloInstruction* defused_instruction = - fusion_instruction->parent()->AddInstruction( - fused_instruction->CloneWithNewOperands(fused_instruction->shape(), - new_operands)); - defused_instructions[fused_instruction] = defused_instruction; - } - - TF_RETURN_IF_ERROR(fusion_instruction->ReplaceAllUsesWith( - defused_instructions.at(fusion_instruction->fused_expression_root()))); - - HloModule* module = fusion_instruction->GetModule(); - TF_RETURN_IF_ERROR( - fusion_instruction->parent()->RemoveInstruction(fusion_instruction)); - return module->RemoveEmbeddedComputation(fused_computation); -} - -} // namespace - StatusOr Defuser::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -102,7 +50,7 @@ StatusOr Defuser::Run( TF_RET_CHECK(call_graph_node.caller_callsites().size() == 1); HloInstruction* fusion_instruction = call_graph_node.caller_callsites()[0].instruction(); - TF_RETURN_IF_ERROR(Defuse(fusion_instruction)); + TF_RETURN_IF_ERROR(fusion_instruction->Defuse()); changed = true; } return OkStatus(); diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc new file mode 100644 index 00000000000000..f659f3c9a02c16 --- /dev/null +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -0,0 +1,232 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_fusible_sinking.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/call_graph.h" +#include "xla/service/while_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla { + +HloInstruction* WhileLoopFusibleSinking::GetSinkableFusion( + HloInstruction* while_operand) { + std::vector worklist; + worklist.push_back(while_operand); + HloInstruction* fusion = nullptr; + auto fuse = [&](HloInstruction* instr) -> bool { + if (!instr->IsFusible()) { + return false; + } + if (!fusion) { + fusion = instr->AddInstruction(instr->CreateFusion( + instr->shape(), HloInstruction::FusionKind::kLoop, instr)); + return true; + } + // The instruction has already been visited, just skip it. + if (!fusion->IsUserOf(instr)) { + return false; + } + fusion->FuseInstruction(instr); + return true; + }; + std::vector new_operands; + while (!worklist.empty()) { + HloInstruction* to_process = worklist.back(); + worklist.pop_back(); + if (to_process->IsElementwise() && fuse(to_process)) { + for (auto* op : to_process->operands()) { + worklist.push_back(op); + } + continue; + } + switch (to_process->opcode()) { + case HloOpcode::kBroadcast: { + HloInstruction* op = to_process->mutable_operand(0); + if (fuse(to_process) && (op->opcode() == HloOpcode::kConstant || + op->opcode() == HloOpcode::kIota)) { + fuse(op); + } + break; + } + case HloOpcode::kConstant: + case HloOpcode::kIota: { + fuse(to_process); + break; + } + case HloOpcode::kReshape: + case HloOpcode::kTranspose: { + HloInstruction* op = to_process->mutable_operand(0); + if (fuse(to_process)) { + worklist.push_back(op); + } + break; + } + default: + if (fusion) { + fusion->parent()->RemoveInstruction(fusion).IgnoreError(); + } + return nullptr; + } + } + LOG(ERROR) << fusion->fused_instructions_computation()->ToString(); + return fusion; +} + +StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( + HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); + HloComputation* while_body = while_instr->while_body(); + + // Don't try to mutate unflattened while loop computations. + if (call_graph_->GetNode(while_cond).callers().size() > 1 || + call_graph_->GetNode(while_body).callers().size() > 1) { + return false; + } + HloInstruction* init_value = while_instr->mutable_operand(0); + if (init_value->opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64_t index = invariant_body_gte->tuple_index(); + HloInstruction* invariant_value = init_value->mutable_operand(index); + + if (init_value->IsRoot() || init_value->user_count() > 1) { + init_value = init_value->AddInstruction(init_value->Clone()); + TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(0, init_value)); + } + // Original value should be a fusible subgraph. + HloInstruction* fusion = GetSinkableFusion(invariant_value); + if (fusion == nullptr) { + continue; + } + changed = true; + auto uses = while_instr->users(); + if (fusion->operand_count() > 0 && + (while_instr->IsRoot() || + absl::c_any_of(uses, [&](HloInstruction* use) { + return use->opcode() != HloOpcode::kGetTupleElement; + }))) { + std::vector gtes(init_value->operand_count()); + for (int64_t i = 0; i < gtes.size(); ++i) { + gtes[i] = while_instr->AddInstruction( + HloInstruction::CreateGetTupleElement(while_instr, i)); + } + HloInstruction* tuple = + while_instr->AddInstruction(HloInstruction::CreateTuple(gtes)); + if (while_instr->IsRoot()) { + while_instr->parent()->set_root_instruction(tuple); + } + if (!uses.empty()) { + TF_RETURN_IF_ERROR(while_instr->ReplaceUsesWith(uses, tuple)); + } + } + for (auto use : while_instr->users()) { + if (use->opcode() == HloOpcode::kGetTupleElement && + use->tuple_index() == index) { + TF_RETURN_IF_ERROR( + while_instr->parent()->ReplaceInstruction(use, invariant_value)); + } + } + + HloInstruction* root = while_body->root_instruction(); + HloInstruction* parameter = while_body->parameter_instruction(0); + std::vector tuple_indices(fusion->operand_count()); + int64_t next_index = init_value->operand_count(); + std::vector new_operands(fusion->operand_count()); + for (int64_t i = 0; i < fusion->operand_count(); ++i) { + init_value->AppendOperand(fusion->mutable_operand(i)); + parameter->mutable_shape()->mutable_tuple_shapes()->push_back( + fusion->mutable_operand(i)->shape()); + new_operands[i] = root->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, next_index++)); + root->AppendOperand(new_operands[i]); + } + *(init_value->mutable_shape()) = parameter->shape(); + *(while_instr->mutable_shape()) = parameter->shape(); + *(while_cond->parameter_instruction(0)->mutable_shape()) = + parameter->shape(); + *(root->mutable_shape()) = parameter->shape(); + auto cloned_fusion = while_body->AddInstruction( + fusion->CloneWithNewOperands(fusion->shape(), new_operands)); + TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstruction(invariant_body_gte, cloned_fusion)); + TF_RETURN_IF_ERROR(cloned_fusion->Defuse()); + } + + return changed; +} + +StatusOr WhileLoopFusibleSinking::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + auto call_graph = CallGraph::Build(module, execution_threads); + call_graph_ = call_graph.get(); + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { + // Right now we don't particularly care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(fusible, ...), body=while_0_body, ...) + // } + // + // This will let us sink the fusible into the outer while first and then + // into the inner while in a single run of this pass. + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + HloPredicateIsOp); + } + + for (HloInstruction* while_instr : while_instrs) { + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingFusiblesIntoWhileLoop(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.h b/third_party/xla/xla/service/while_loop_fusible_sinking.h new file mode 100644 index 00000000000000..0f6be3a40d51a7 --- /dev/null +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.h @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ +#define XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be fusibles into the while +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like fusible folding. +// +// state = (..., fusible_graph, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., fusbile_graph, ..., fusible_graph_operands) +// while (pred(state)) { +// (..., v, ...) = state +// use(fusibile_graph) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +class WhileLoopFusibleSinking : public HloModulePass { + public: + WhileLoopFusibleSinking() = default; + + ~WhileLoopFusibleSinking() override = default; + + absl::string_view name() const override { + return "while-loop-fusible-sinking"; + } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Sink a fusible subgraph into a while loop. + StatusOr TrySinkingFusiblesIntoWhileLoop(HloInstruction* while_instr); + + // Creates a loop fusion instruction containing the computation to move into + // the while loop to avoid conflicts with actual instruction fusion, the loop + // fusion will be defused. + HloInstruction* GetSinkableFusion(HloInstruction* while_operand); + + CallGraph* call_graph_; +}; +} // namespace xla + +#endif // XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc new file mode 100644 index 00000000000000..2aa5b623536e4c --- /dev/null +++ b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_fusible_sinking.h" + +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; +using WhileLoopFusibleSinkingTest = HloTestBase; + +TEST_F(WhileLoopFusibleSinkingTest, SinkOneFusible) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] parameter(0) + const_1 = f32[2] iota(), iota_dimension=0 + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Iota()), _)); +} + +TEST_F(WhileLoopFusibleSinkingTest, SinkMask) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[5,7],f32[5,7]) parameter(0) + p_body.0 = get-tuple-element(p_body), index=0 + p_body.1 = get-tuple-element(p_body), index=1 + + add.0 = add(p_body.0, p_body.1) + ROOT root = tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[5,7],f32[5,7]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[5,7] parameter(0) + p = f32[5] parameter(1) + a = f32[5,7] iota(), iota_dimension=0 + b = f32[5,7] iota(), iota_dimension=1 + c = add(a, b) + d = f32[5,7] broadcast(p), dimensions={0} + mask = multiply(c,d) + while_init = tuple(const_0, mask) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Multiply(op::Add(op::Iota(), op::Iota()), + op::Broadcast())), + _, _)); +} + +} // namespace +} // namespace xla From 74db590cc6a4171cc449f72c377479af9456fee3 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 1 Dec 2023 08:25:37 -0800 Subject: [PATCH 282/381] [XLA] Remove accidental log PiperOrigin-RevId: 587017053 --- third_party/xla/xla/service/while_loop_fusible_sinking.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index f659f3c9a02c16..37d5f4d229fa27 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -91,7 +91,6 @@ HloInstruction* WhileLoopFusibleSinking::GetSinkableFusion( return nullptr; } } - LOG(ERROR) << fusion->fused_instructions_computation()->ToString(); return fusion; } From 8e41c9133f84f2e62e046103d2881c47170beffd Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Fri, 1 Dec 2023 09:09:30 -0800 Subject: [PATCH 283/381] Refactor fused quantization patterns into quantize pass. Current Quantize{DotGeneral, Convolution}OpPattern matching takes place after post_quantize, which may be misleading. Instead, factor out patterns in a separate file and handle in quantize.cc. MLIR tests for fused patterns are left in quantize_composite_functions.mlir for module-level evaluation of functions. PiperOrigin-RevId: 587027352 --- .../mlir/quantization/stablehlo/BUILD | 17 +- ...tion_patterns.h => quantization_pattern.h} | 21 +- .../stablehlo/passes/quantization_patterns.cc | 432 ------------------ .../quantization/stablehlo/passes/quantize.cc | 5 +- .../passes/quantize_composite_functions.cc | 365 +++++++++++++++ .../stablehlo/tests/quantize.mlir | 27 -- .../tests/quantize_composite_functions.mlir | 92 ++-- 7 files changed, 431 insertions(+), 528 deletions(-) rename tensorflow/compiler/mlir/quantization/stablehlo/passes/{quantization_patterns.h => quantization_pattern.h} (97%) delete mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index b8558350855c19..555a53b5011814 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -51,7 +51,7 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":quantization_options_proto_cc", - ":quantization_patterns", + ":quantization_pattern", ":stablehlo_passes_inc_gen", ":stablehlo_type_utils", ":uniform_quantized_types", @@ -106,38 +106,27 @@ cc_library( ) cc_library( - name = "quantization_patterns", - srcs = ["passes/quantization_patterns.cc"], + name = "quantization_pattern", hdrs = [ - "passes/quantization_patterns.h", + "passes/quantization_pattern.h", ], compatible_with = get_compatible_with_portable(), deps = [ ":bridge_passes", - ":uniform_quantized_types", - "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h similarity index 97% rename from tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h index 79daa9ce8b48b8..3922374e402353 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ #include #include @@ -90,7 +90,6 @@ class StableHloQuantizationPattern : public RewritePattern { : RewritePattern(RootOpT::getOperationName(), 300, context), quant_params_(quant_params) {} - private: LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override { llvm::SmallVector quantizing_ops; @@ -158,8 +157,8 @@ class StableHloQuantizationPattern : public RewritePattern { // Blocklist op is checked in advance for non-dynamic range quantization // case. if (!quant_params_.quant_spec.weight_quantization && - (ops_blocklist.contains( - quantizing_op->getName().getStringRef().str()))) { + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { return failure(); } @@ -262,6 +261,9 @@ class StableHloQuantizationPattern : public RewritePattern { return success(); } + private: + QuantPassSpec quant_params_; + // Checks whether the operation is connnected with a quantized composite // function. If not, the same-scale op will not be quantized. This decision is // based on the current assumption that the performance gain of the same-scale @@ -365,15 +367,8 @@ class StableHloQuantizationPattern : public RewritePattern { } return has_quantized_types; } - - QuantPassSpec quant_params_; }; -// Gemm Style Op: glossary/gemm. -// Populates conversion patterns to unfuse batch normalization operations. -void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns); - } // namespace mlir::quant::stablehlo -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc deleted file mode 100644 index 603e1ff771a204..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ /dev/null @@ -1,432 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" - -#define DEBUG_TYPE "populate-quantization-patterns" - -namespace mlir::quant::stablehlo { - -namespace { - -using ::mlir::stablehlo::AddOp; -using ::mlir::stablehlo::ConvolutionOp; -using ::mlir::stablehlo::DotGeneralOp; -using ::mlir::stablehlo::DynamicBroadcastInDimOp; -using ::mlir::stablehlo::UniformQuantizeOp; - -constexpr StringRef kCompositeFuncPrefix = "composite_"; -constexpr StringRef kQuantizedFuncPrefix = "quantized_"; -constexpr StringRef kEntryFuncAttrName = "_entry_function"; - -// Returns true if `type` is a TensorType with quantized elements. -bool IsQuantizedTensorType(const Type type) { - return type.isa() && - type.cast().getElementType().isa(); -} - -// Returns true if an op has adjacent bias or activation that can be fused -// together into the quantization function. -// TODO: b/307620428 - Consider using matchAndRewrite to check and apply -// patterns at the same time. Also add check for fusible activation or -// fusible patterns with dynamic shape. -bool HasFusibleQuantizationPattern(Operation& op) { - if (isa(op.getNextNode())) { - return true; - } - return false; -} - -// Returns dynamically broadcasted user op of an input op. Returns null if -// the op is used multiple times or the user op is not dynamically broadcasted. -// Dynamic shapes usually has the following pattern. In the example below, -// the input operand would be stablehlo.gemm_style op, and return value would -// be stablehlo.add op. -// -// ``` -// %2 = stablehlo.gemm_style(%0, %1) -// %3 = shape.shape_of %2 -// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 -// %5 = stablehlo.add %2, %4 -// ``` -Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { - if (!op.hasOneUse()) { - LLVM_DEBUG(llvm::dbgs() - << "Target op is used multiple times and will not be checked " - "for dynamic shape case.\n"); - return nullptr; - } - Operation& shapeof_op = *op.getNextNode(); - if (!isa(shapeof_op)) { - return nullptr; - } - Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); - if (!isa(broadcast_in_dims_op)) { - return nullptr; - } - return broadcast_in_dims_op.getNextNode(); -} - -// Checks if all inputs and outputs are quantized. -bool HasQuantizedOperandOrOutput(Operation& call_op) { - SmallVector arg_types; - for (const Value arg : call_op.getOperands()) { - arg_types.push_back(arg.getType()); - } - - SmallVector output_types; - for (const Value output : call_op.getResults()) { - output_types.push_back(output.getType()); - } - - return absl::c_all_of(arg_types, IsQuantizedTensorType) && - absl::c_all_of(output_types, IsQuantizedTensorType); -} - -// Gets the corresponding quantized function name from the given function name. -// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" -std::string GetQuantizedFunctionName(const StringRef func_name) { - return Twine(kQuantizedFuncPrefix) - .concat(func_name.rsplit(kCompositeFuncPrefix).second) - .str(); -} - -// Returns true if `xla_call_module_op` is quantized. To be considered -// quantized, it should meet three conditions: -// 1. At least one of the inputs or outputs should be a uniform quantized type. -// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. -// 3. It should also have the `kEntryFuncAttrName` attribute, which points to -// the function that `xla_call_module_op` represents. -bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { - return HasQuantizedOperandOrOutput(*xla_call_module_op) && - xla_call_module_op->hasAttr(kQuantTraitAttrName) && - xla_call_module_op->hasAttr(kEntryFuncAttrName); -} - -// Returns the entry function, i.e. the callee of `xla_call_module_op`. -func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, - SymbolTable symbol_table) { - const auto entry_function_symbol_ref = - xla_call_module_op->getAttrOfType(kEntryFuncAttrName); - - return dyn_cast_or_null( - symbol_table.lookup(entry_function_symbol_ref.getValue())); -} - -// Replaces the function type of `entry_func_op` to a quantized one, matching -// the input and output types of `xla_call_module_op`. -void SetQuantizedFunctionType(PatternRewriter& rewriter, - func::FuncOp& entry_func_op, - TF::XlaCallModuleOp xla_call_module_op) { - SmallVector arg_types; - SmallVector arg_locs; - for (const Value arg : xla_call_module_op.getArgs()) { - arg_types.push_back(arg.getType()); - arg_locs.push_back(arg.getLoc()); - } - - SmallVector output_types; - for (const Value output : xla_call_module_op.getOutput()) { - output_types.push_back(output.getType()); - } - - entry_func_op.setFunctionType( - rewriter.getFunctionType(arg_types, output_types)); - - // Replace argument types and locs. - Block& entry = entry_func_op->getRegion(0).front(); - for (auto [arg, arg_type, arg_loc] : - llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { - arg.setType(arg_type); - arg.setLoc(arg_loc); - } -} - -// Creates a UniformQuantize op and sets it as return op. -void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, - func::FuncOp entry_func_op, - const Type func_result_type) { - // Add i32 -> i8 requantization. - UniformQuantizeOp uniform_quant_op = rewriter.create( - op.getLoc(), func_result_type, op.getResults()); - cast(entry_func_op.getBody().front().getTerminator()) - .setOperand(0, uniform_quant_op); -} - -// An interface representing patterns that quantizes an entry function's body. -// The entry function's signatures should have already been quantized at the -// point of rewriting. -class EntryFuncBodyQuantizationPattern { - public: - virtual ~EntryFuncBodyQuantizationPattern() = default; - - // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At - // this point `entry_func_op`'s signature has not been reset with quantized - // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; - - // Rewrites the `entry_func_op`'s body. - virtual void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const = 0; -}; - -// Gemm Style Op: glossary/gemm. -template -// Match for all gemm_style op and check for possible fusions. -LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { - // function must have input, filter, and optionally bias. - auto& operations = entry_func_op.getBody().front().getOperations(); - if (operations.size() != 2 && operations.size() != 3) { - return failure(); - } - if (!isa(operations.front())) { - return failure(); - } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { - LLVM_DEBUG(llvm::dbgs() - << "Currently gemm style ops quantization only supports static " - " shapes.\n"); - return failure(); - } else if (!isa( - operations.front().getResult(0).getType())) { - return failure(); - } - return success(); -} - -// Gemm Style Op: glossary/gemm. -template -void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { - // Update the output type of the gemm_style op. - GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); - - const Type input_type = entry_func_op.getArgumentTypes()[0]; - const Type filter_type = entry_func_op.getArgumentTypes()[1]; - const Type func_result_type = entry_func_op.getResultTypes()[0]; - - const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); - const double filter_scale = - getElementTypeOrSelf(filter_type).cast().getScale(); - const double result_scale = input_scale * filter_scale; - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType gemm_style_quantized_element_type = - CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), - *rewriter.getContext(), result_scale, - /*zero_point=*/0); - - Value gemm_style_op_result = gemm_style_op->getResult(0); - auto gemm_style_op_result_type = - gemm_style_op_result.getType().cast(); - const ArrayRef gemm_style_shape = - gemm_style_op_result_type.getShape(); - - const TensorType new_gemm_style_op_result_type = - gemm_style_op_result_type.cloneWith(gemm_style_shape, - gemm_style_quantized_element_type); - gemm_style_op_result.setType(new_gemm_style_op_result_type); - - rewriter.setInsertionPointAfter(gemm_style_op); - - Operation& next_op = *gemm_style_op->getNextNode(); - // If an op is used multiple times, do not apply quantization of fused - // patterns to prevent removal of dependee ops. - const bool should_quantize_without_fusion = - HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && - !gemm_style_op->hasOneUse(); - - // TODO: b/307620428 - Add support for dynamic shapes. - if (should_quantize_without_fusion || !isa(next_op)) { - // no bias - CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, - func_result_type); - return; - } - // bias fusion - Value bias_op = next_op.getOperand(1); - Value add_op_result = next_op.getResult(0); - const auto add_op_result_type = - add_op_result.getType().cast(); - const ArrayRef add_op_shape = add_op_result_type.getShape(); - // For quantized bias add case, lhs, rhs, and result have the same types. - const TensorType new_add_op_result_type = add_op_result_type.cloneWith( - add_op_shape, gemm_style_quantized_element_type); - add_op_result.setType(new_add_op_result_type); - - AddOp bias_add_op = - rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); - - CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, - func_result_type); -} - -// Quantizes the entry function's body containing a `DotGeneralOp`. -class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeDotGeneralOpPattern() = default; - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); - } -}; - -// Quantizes the entry function's body containing a `ConvolutionOp`. -class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeConvolutionOpPattern() = default; - - LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); - } - - void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); - } -}; - -// Converts `entry_func_op` to be quantized according to the respective -// inputs and outputs of `xla_call_module_op` that are possibly quantized. It -// signature (type) is reset to match that of `xla_call_module_op`. -// `entry_func_body_quantization_pattern` rewrites the function's body, based on -// the new signature. -void QuantizeEntryFuncOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); - - body_rewrite_pattern.rewrite(entry_func_op, rewriter); - - // Rename the function to be clear that the function has been quantized. - const std::string quantized_function_name = - GetQuantizedFunctionName(entry_func_op.getSymName()); - entry_func_op.setSymName(quantized_function_name); -} - -// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee -// is expected to remain unquantized (thus having a signature mismatch), and it -// is also quantized accordingly. -void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - ModuleOp module_op = xla_call_module_op->getParentOfType(); - SymbolTable symbol_table(module_op); - - func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); - QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, - body_rewrite_pattern); - - // Replace the XlaCallModuleOp with a new CallOp. - rewriter.setInsertionPoint(xla_call_module_op); - rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, - xla_call_module_op.getArgs()); -} - -// Pattern that mainly does two things: -// -// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. -// 2. Quantizes the callee function. -// -// The inputs of this pattern assumes an invalid IR, where even if a -// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) -// not only replaces the input and output tensor types into quantized ones, but -// also rewrites the body with a quantized equivalent. -// -// `FuncBodyRewritePatternT` defines how a function body is quantized and -// rewritten. -template >> -class XlaCallModuleOpToCallOp : public OpRewritePattern { - public: - explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) - : OpRewritePattern(&ctx) {} - - LogicalResult match(TF::XlaCallModuleOp op) const override { - ModuleOp module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); - - // Ignore unquantized ops. - if (!IsQuantizedXlaCallModuleOp(op)) return failure(); - - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); - if (!entry_func_op) { - op->emitError("Failed to find a valid entry function."); - return failure(); - } - - return FuncBodyRewritePatternT().match(entry_func_op); - } - - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT()); - } -}; - -} // namespace - -// TODO: b/307620428 - Increase fused op coverage for static range quantization. -void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>(ctx); -} - -} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index cb587416a45bff..fffa0ad782c8d9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h" namespace mlir::quant::stablehlo { @@ -130,9 +130,6 @@ void QuantizePass::runOnOperation() { patterns.add( &ctx, quant_params); - // Support quantization for fused patterns containing Gemm Style ops. - PopulateFusedGemmStylePatterns(ctx, patterns); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index aef2f7bb9f3cd9..dd558a08bc642c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -86,6 +86,362 @@ class QuantizeCompositeFunctionsPass void runOnOperation() override; }; +// Returns true if `type` is a TensorType with quantized elements. +bool IsQuantizedTensorType(const Type type) { + return type.isa() && + type.cast().getElementType().isa(); +} + +// Returns true if an op has adjacent bias or activation that can be fused +// together into the quantization function. +// TODO: b/307620428 - Consider using matchAndRewrite to check and apply +// patterns at the same time. Also add check for fusible activation or +// fusible patterns with dynamic shape. +bool HasFusibleQuantizationPattern(Operation& op) { + if (isa(op.getNextNode())) { + return true; + } + return false; +} + +// Returns dynamically broadcasted user op of an input op. Returns null if +// the op is used multiple times or the user op is not dynamically broadcasted. +// Dynamic shapes usually has the following pattern. In the example below, +// the input operand would be stablehlo.gemm_style op, and return value would +// be stablehlo.add op. +// +// ``` +// %2 = stablehlo.gemm_style(%0, %1) +// %3 = shape.shape_of %2 +// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 +// %5 = stablehlo.add %2, %4 +// ``` +Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { + if (!op.hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Target op is used multiple times and will not be checked " + "for dynamic shape case.\n"); + return nullptr; + } + Operation& shapeof_op = *op.getNextNode(); + if (!isa(shapeof_op)) { + return nullptr; + } + Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); + if (!isa(broadcast_in_dims_op)) { + return nullptr; + } + return broadcast_in_dims_op.getNextNode(); +} + +// Checks if all inputs and outputs are quantized. +bool HasQuantizedOperandOrOutput(Operation& call_op) { + SmallVector arg_types; + for (const Value arg : call_op.getOperands()) { + arg_types.push_back(arg.getType()); + } + + SmallVector output_types; + for (const Value output : call_op.getResults()) { + output_types.push_back(output.getType()); + } + + return absl::c_all_of(arg_types, IsQuantizedTensorType) && + absl::c_all_of(output_types, IsQuantizedTensorType); +} + +// Get the corresponding quantized function name from the given function name. +// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" +std::string GetQuantizedFunctionName(const StringRef func_name) { + return Twine(kQuantizedFuncPrefix) + .concat(func_name.rsplit(kCompositeFuncPrefix).second) + .str(); +} + +// Returns true if `xla_call_module_op` is quantized. To be considered +// quantized, it should meet three conditions: +// 1. At least one of the inputs or outputs should be a uniform quantized type. +// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. +// 3. It should also have the `kEntryFuncAttrName` attribute, which points to +// the function that `xla_call_module_op` represents. +bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { + return HasQuantizedOperandOrOutput(*xla_call_module_op) && + xla_call_module_op->hasAttr(kQuantTraitAttrName) && + xla_call_module_op->hasAttr(kEntryFuncAttrName); +} + +// Returns the entry function, i.e. the callee of `xla_call_module_op`. +func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, + SymbolTable symbol_table) { + const auto entry_function_symbol_ref = + xla_call_module_op->getAttrOfType(kEntryFuncAttrName); + + return dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); +} + +// Replaces the function type of `entry_func_op` to a quantized one, matching +// the input and output types of `xla_call_module_op`. +void SetQuantizedFunctionType(PatternRewriter& rewriter, + func::FuncOp entry_func_op, + TF::XlaCallModuleOp xla_call_module_op) { + SmallVector arg_types; + SmallVector arg_locs; + for (const Value arg : xla_call_module_op.getArgs()) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + + SmallVector output_types; + for (const Value output : xla_call_module_op.getOutput()) { + output_types.push_back(output.getType()); + } + + entry_func_op.setFunctionType( + rewriter.getFunctionType(arg_types, output_types)); + + // Replace argument types and locs. + Block& entry = entry_func_op->getRegion(0).front(); + for (auto [arg, arg_type, arg_loc] : + llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { + arg.setType(arg_type); + arg.setLoc(arg_loc); + } +} + +// Creates a UniformQuantize op and sets it as return op. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + +// An interface representing patterns that quantizes an entry function's body. +// The entry function's signatures should have already been quantized at the +// point of rewriting. +class EntryFuncBodyQuantizationPattern { + public: + virtual ~EntryFuncBodyQuantizationPattern() = default; + + // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At + // this point `entry_func_op`'s signature has not been reset with quantized + // types. + virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + + // Rewrites the `entry_func_op`'s body. + virtual void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const = 0; +}; + +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + // function must have input, filter, and optionally bias. + auto& operations = entry_func_op.getBody().front().getOperations(); + if (operations.size() != 2 && operations.size() != 3) { + return failure(); + } + if (!isa(operations.front())) { + return failure(); + } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { + LLVM_DEBUG(llvm::dbgs() + << "Currently gemm style ops quantization only supports static " + " shapes.\n"); + return failure(); + } else if (!isa( + operations.front().getResult(0).getType())) { + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + // Update the output type of the gemm_style op. + GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + const double filter_scale = + getElementTypeOrSelf(filter_type).cast().getScale(); + const double result_scale = input_scale * filter_scale; + + // Define the intermediate output type, which is an i32 quantized type. + // This is intermediate because the final output type of the entry_func_op + // should be an i8 quantized type. + const UniformQuantizedType gemm_style_quantized_element_type = + CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), + *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + Value gemm_style_op_result = gemm_style_op->getResult(0); + auto gemm_style_op_result_type = + gemm_style_op_result.getType().cast(); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + const TensorType new_gemm_style_op_result_type = + gemm_style_op_result_type.cloneWith(gemm_style_shape, + gemm_style_quantized_element_type); + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation& next_op = *gemm_style_op->getNextNode(); + // If an op is used multiple times, do not apply quantization of fused + // patterns to prevent removal of dependee ops. + const bool should_quantize_without_fusion = + HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && + !gemm_style_op->hasOneUse(); + + // TODO: b/307620428 - Add support for dynamic shapes. + if (should_quantize_without_fusion || !isa(next_op)) { + // no bias + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + return; + } + // bias fusion + Value bias_op = next_op.getOperand(1); + Value add_op_result = next_op.getResult(0); + const auto add_op_result_type = + add_op_result.getType().cast(); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, gemm_style_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + +// Quantizes the entry function's body containing a `DotGeneralOp`. +class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeDotGeneralOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Converts `entry_func_op` to be quantized according to the respective +// inputs and outputs of `xla_call_module_op` that are possibly quantized. It +// signature (type) is reset to match that of `xla_call_module_op`. +// `entry_func_body_quantization_pattern` rewrites the function's body, based on +// the new signature. +void QuantizeEntryFuncOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); + + body_rewrite_pattern.rewrite(entry_func_op, rewriter); + + // Rename the function to be clear that the function has been quantized. + const std::string quantized_function_name = + GetQuantizedFunctionName(entry_func_op.getSymName()); + entry_func_op.setSymName(quantized_function_name); +} + +// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee +// is expected to remain unquantized (thus having a signature mismatch), and it +// is also quantized accordingly. +void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + ModuleOp module_op = xla_call_module_op->getParentOfType(); + SymbolTable symbol_table(module_op); + + func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); + QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, + body_rewrite_pattern); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.setInsertionPoint(xla_call_module_op); + rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, + xla_call_module_op.getArgs()); +} + +// Pattern that mainly does two things: +// +// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. +// 2. Quantizes the callee function. +// +// The inputs of this pattern assumes an invalid IR, where even if a +// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) +// not only replaces the input and output tensor types into quantized ones, but +// also rewrites the body with a quantized equivalent. +// +// `FuncBodyRewritePatternT` defines how a function body is quantized and +// rewritten. +template >> +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) + : OpRewritePattern(&ctx) {} + + LogicalResult match(TF::XlaCallModuleOp op) const override { + ModuleOp module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + // Ignore unquantized ops. + if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + + func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); + if (!entry_func_op) { + op->emitError("Failed to find a valid entry function."); + return failure(); + } + + return FuncBodyRewritePatternT().match(entry_func_op); + } + + void rewrite(TF::XlaCallModuleOp xla_call_module_op, + PatternRewriter& rewriter) const override { + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + *rewriter.getContext(), rewriter, xla_call_module_op, + FuncBodyRewritePatternT()); + } +}; + void QuantizeCompositeFunctionsPass::runOnOperation() { MLIRContext& ctx = getContext(); @@ -108,6 +464,15 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { !pm_run_status.ok()) { signalPassFailure(); } + + // TODO - b/307839649: Move this as a separate pass. + RewritePatternSet patterns(&ctx); + patterns.add, + XlaCallModuleOpToCallOp>(ctx); + + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir index e794dded354da9..d1bfea7a236448 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir @@ -1,8 +1,5 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize -verify-each=false | FileCheck %s -// Tests for PopulateFusedGemmStylePatterns are handled in -// quantize_composite_functions for module-level evaluation of functions. - // CHECK-LABEL: quantize_simple_xla_call_module func.func private @quantize_simple_xla_call_module(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { %0 = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> @@ -43,27 +40,3 @@ func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf3 // CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform> // CHECK: %[[DCAST_0:.*]] = "quantfork.dcast"(%[[XLACALLMODULE_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> () - -// ----- - -// Tests for emitting an error when there is no corresponding entry -// function to quantize (@composite_dot_general_fn). - -module attributes {tf_saved_model.semantics} { -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf32> - %1 = "quantfork.qcast"(%0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> - %2 = "quantfork.dcast"(%1) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> - %3 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// expected-error @+2 {{Failed to find a valid entry function}} -// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> - %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %7 : tensor<1x3xf32> - } -} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir index 649b0fece15d81..d38083150ba8be 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir @@ -1,9 +1,6 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ // RUN: -stablehlo-quantize-composite-functions | FileCheck %s - -// Tests that basic dot_general is properly quantized. - module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. @@ -42,7 +39,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused pattern for dot_general + bias is properly quantized. +// Tests that fused bias pattern is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -81,8 +78,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused pattern for dot_general + bias with dynamic shape is -// not quantized. +// Tests that fused bias pattern with dynamic shape is not quantized. // TODO: b/307620428 - Add support for fused bias with dynamic shapes. module attributes {tf_saved_model.semantics} { @@ -110,7 +106,57 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that basic convolution is properly quantized. +// Tests error when there are no corresponding entry function to quantize +// (@composite_dot_general_fn). + +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} +// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +} + +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK-LABEL: func.func private @not_quantized_without_stats +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Check that the composite_dot_general_fn is untouched. + +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] +// CHECK: return %[[DOT_GENERAL]] +} + +// ----- + +// Test basic convolution is quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -150,7 +196,7 @@ module attributes {tf_saved_model.semantics} { // ----- -// Tests that fused pattern for convolution + bias is properly quantized. +// Tests that fused bias pattern is properly quantized. module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. @@ -185,33 +231,3 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> } - -// ----- - -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. - -module attributes {tf_saved_model.semantics} { - func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. - -// CHECK-LABEL: func.func private @not_quantized_without_stats -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] - - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -// Check that the composite_dot_general_fn is untouched. - -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] -// CHECK: return %[[DOT_GENERAL]] -} From 1ccabf27d88ab867af916f2551697c4f644ca2ea Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 1 Dec 2023 09:30:35 -0800 Subject: [PATCH 284/381] Relax coalescing and input fusion heuristics. - Only tigger input fusion -> loop fusion heuristic for reductions, since that's where we observed issues. - Allow transposes of broadcast values. The model doesn't currently handle this case well. This entire logic is soon going to be replaced by a tiling analysis. PiperOrigin-RevId: 587033168 --- .../gpu/model/gpu_performance_model.cc | 29 +++++++++++++++++-- .../xla/xla/service/gpu/priority_fusion.cc | 10 +++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 7bd454134cebb4..9160e1d67a222c 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -275,8 +275,33 @@ bool IsReadCoalesced(const std::optional& fusion_analysis, // Transposing minor dimension breaks coalescing. if (analyzed_kind_or_reduction != HloFusionAnalysis::EmitterFusionKind::kTranspose) { - if (TransposesMinorDimension(producer)) return false; - if (consumer && TransposesMinorDimension(consumer)) return false; + auto is_broadcast = [&](const HloInstruction* instr) { + while (true) { + if (instr->opcode() == HloOpcode::kBroadcast) return true; + if (instr->operand_count() != 1) return false; + if (instr->opcode() != HloOpcode::kBitcast && !instr->IsElementwise()) { + return false; + } + instr = instr->operand(0); + } + }; + + auto is_bad_transpose = [&](const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kFusion) { + for (auto* instr : instr->fused_instructions()) { + // Hack: we allow transposes of broadcasts. + if (TransposesMinorDimension(instr) && + !is_broadcast(instr->operand(0))) { + return true; + } + } + return false; + } + return TransposesMinorDimension(instr); + }; + + if (is_bad_transpose(producer)) return false; + if (consumer && is_bad_transpose(consumer)) return false; } // Fusing two row reductions breaks coalescing. diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 7e804b271250ce..55205823e2754d 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -560,12 +560,16 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, // switch it to the loop emitter. This often occurs during epilog fusion for // reductions, which suffer from limited emitter support. // TODO(b/312686229): Cost model should handle this. - auto analysis_fused = - AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); + const auto& analysis_fused = fusion_analysis_cache_.Get(*producer, *consumer); if (producer->IsInputFusion() && analysis_fused && analysis_fused->GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kLoop) { - return "fusion into output of an input fusion would create a loop fusion"; + const auto& analysis = fusion_analysis_cache_.Get(*producer); + if (!analysis || analysis->GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kReduction) { + return "fusion into output of a reduce fusion would create a loop " + "fusion"; + } } // Avoid cases where we'd create a fusion that hit limitations in ptxas. From 52b8b8ace4a93626d15faf5f2e2588ff67f87532 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 09:52:10 -0800 Subject: [PATCH 285/381] [xla:gpu] Add support for fusing dynamic-update-slice into CUTLASS gemms PiperOrigin-RevId: 587038608 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/custom_fusion_rewriter.cc | 6 +- third_party/xla/xla/service/gpu/kernels/BUILD | 23 +++- .../xla/service/gpu/kernels/cutlass_gemm.h | 94 +++++++++++++ .../kernels/cutlass_gemm_custom_kernel.cu.cc | 29 +++-- .../gpu/kernels/cutlass_gemm_custom_kernel.h | 9 +- .../cutlass_gemm_custom_kernel_test.cc | 11 +- .../gpu/kernels/cutlass_gemm_fusion.cc | 123 ++++++++++++++++-- .../service/gpu/kernels/cutlass_gemm_fusion.h | 6 + .../gpu/kernels/cutlass_gemm_fusion_test.cc | 115 +++++++++++++++- ...niversal.cu.h => cutlass_gemm_kernel.cu.h} | 42 ++++-- .../gpu/kernels/cutlass_gemm_kernels.cu.h | 34 ++++- .../cutlass_gemm_kernels_f32xf32_to_f32.cu.cc | 4 +- third_party/xla/xla/tests/hlo_test_base.cc | 23 ++++ third_party/xla/xla/tests/hlo_test_base.h | 7 + 15 files changed, 473 insertions(+), 54 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h rename third_party/xla/xla/service/gpu/kernels/{cutlass_gemm_universal.cu.h => cutlass_gemm_kernel.cu.h} (82%) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index eddf9a3a0d6c15..2a2004472ba197 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2854,6 +2854,7 @@ cc_library( "//xla/service:hlo_pass", "//xla/service/gpu/kernels:custom_fusion_library", "//xla/service/gpu/kernels:custom_fusion_pattern", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc index 666b816650c58c..037446811d52aa 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -66,7 +67,10 @@ static StatusOr> GetPatternCaptures( // Collect instructions captured by a matched pattern. for (HloInstruction* instr : match.instructions) { for (HloInstruction* operand : instr->operands()) { - if (!instructions_set.contains(operand)) captures.push_back(operand); + if (!instructions_set.contains(operand) && + absl::c_find(captures, operand) == captures.end()) { + captures.push_back(operand); + } } } diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index cd18016d849b2a..7b4742e54f549b 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -82,6 +82,7 @@ cc_library( # ":custom_fusion", # ":custom_fusion_pattern", # ":custom_kernel", +# ":cutlass_gemm", # ":cutlass_gemm_custom_kernel", # "@com_google_absl//absl/status", # "//xla:shape_util", @@ -105,8 +106,12 @@ cc_library( # ":custom_fusion_pattern", # ":cutlass_gemm_fusion", # "@com_google_absl//absl/strings", +# "//xla:array", +# "//xla:array2d", +# "//xla:array3d", # "//xla:debug_options_flags", # "//xla:error_spec", +# "//xla:literal_util", # "//xla/service/gpu:custom_fusion_rewriter", # "//xla/tests:hlo_test_base", # "@local_tsl//tsl/platform:test", @@ -125,9 +130,10 @@ cc_library( # visibility = ["//visibility:private"], # deps = [ # ":custom_kernel", +# ":cutlass_gemm", +# ":cutlass_gemm_kernel", # ":cutlass_gemm_kernels", # ":cutlass_gemm_kernels_header", -# ":cutlass_gemm_universal", # "@com_google_absl//absl/status", # "@com_google_absl//absl/strings", # "//third_party/gpus/cutlass", @@ -161,11 +167,17 @@ cc_library( # # CUTLASS GemmUniversal-base kernels <-> StreamExecutor adaptor # #===--------------------------------------------------------------------------------------------===# # +# cc_library( +# name = "cutlass_gemm", +# hdrs = ["cutlass_gemm.h"], +# ) +# # cuda_library( -# name = "cutlass_gemm_universal", -# hdrs = ["cutlass_gemm_universal.cu.h"], +# name = "cutlass_gemm_kernel", +# hdrs = ["cutlass_gemm_kernel.cu.h"], # visibility = ["//visibility:private"], # deps = [ +# ":cutlass_gemm", # "@com_google_absl//absl/status", # "@com_google_absl//absl/strings", # "//third_party/gpus/cutlass", @@ -191,7 +203,10 @@ cc_library( # name = "cutlass_gemm_kernels_header", # hdrs = ["cutlass_gemm_kernels.cu.h"], # visibility = ["//visibility:private"], -# deps = ["//third_party/gpus/cutlass"], +# deps = [ +# ":cutlass_gemm", +# "//third_party/gpus/cutlass", +# ], # ) # # cuda_library( diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h new file mode 100644 index 00000000000000..17603f3ac0d324 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ + +#include +#include + +namespace xla::gpu::kernel::gemm_universal { + +// Indices of a custom fusion parameters corresponding to Gemm kernel arguments. +// +// Example: +// se::KernelArgsDeviceMemoryArray args = ... +// void* lhs = args->device_memory_ptr(indices.lhs); +// +// Custom fusion instruction can have parameters in arbitrary order, and we need +// a mapping from a custom kernel argument to the fusion instruction parameter. +struct ArgsIndices { + int64_t lhs; + int64_t rhs; + int64_t out; +}; + +// Following structs encode how a custom kernel arguments packing and a custom +// CUTLASS kernel itself can find dynamic-slice offsets at run time. +// +// Example: CUTLASS gemm with a dynamic-update-slice +// +// cutlass_gemm { +// p0 = f32[2,2]{1,0} parameter(0) +// p1 = f32[2,2,2]{2,1,0} parameter(1) +// p2 = s32[] parameter(2) <--- major dim offset +// p3 = s32[] parameter(3) <--- minor dims offset +// dot = f32[2,2]{1,0} dot(p0, p0) +// ... +// ROOT r = f32[2,2,2]{2,1,0} dynamic-update-slice(p1, ..., p2, p3, p3) +// } +// +// In this example `p2` parameter defines a dynamic slice offset along the +// major dimension (0-th dimension for a row major layout). In practice +// parameters can be passed to fusions in arbitrary order, and when we pack +// custom kernel arguments into device kernel parameters we need to know +// how to find correct device pointers in the list of fusion arguments. +// +// For this example: +// +// DynamicSliceIndices::out = 2 +// DynamicSliceParams::out = +// +// `DynamicSliceIndices` used in the host-code to fetch device memory pointers +// from arguments and pass it as `DynamicSliceParams` to a device kernel. +// +// Example: +// se::KernelArgsDeviceMemoryArray args = ... +// void* out_ptr = args->device_memory_ptr(*slice_indices.out); +// +// DynamicSliceParams params { // this struct passed to a kernel +// out_ptr, // kernel loads offset value from this pointer +// ... +// }; +// + +// TODO(ezhulenev): Support dynamic slices along all dimensions, today we assume +// that we can slice only along the leading dimension (batch). + +// Indices of a custom fusion parameters corresponding to dynamic slice offsets. +struct DynamicSliceIndices { + // Index of a dynamic slice offset along the major dimension. + std::optional out; +}; + +// Pointers to buffers (s32[] buffers in HLO) holding dynamic slice offsets. +struct DynamicSliceParams { + // Dynamic slice offset along the major dimension. + std::optional out; +}; + +} // namespace xla::gpu::kernel::gemm_universal + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc index 1dc483cc3ac604..19ea111789330a 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc @@ -20,39 +20,42 @@ limitations under the License. #include "absl/status/status.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h" #include "xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h" -#include "xla/service/gpu/kernels/cutlass_gemm_universal.cu.h" #include "xla/statusor.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/xla_data.pb.h" -namespace xla::gpu::kernel { +namespace xla::gpu::kernel::gemm_universal { template -static StatusOr LoadCutlassGemmUniversal(int32_t m, int32_t n, - int32_t k) { +static StatusOr LoadCutlassGemmUniversal( + int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, + const DynamicSliceIndices& slices) { using Kernel = typename Gemm::GemmKernel; cutlass::gemm::GemmCoord problem_size = {m, n, k}; - se::MultiKernelLoaderSpec spec( - /*arity=*/1, gemm_universal::ArgsPacking(problem_size)); - spec.AddInProcessSymbol(internal::GetCutlassGemmKernel(), - "cutlass_gemm"); + auto packing = ArgsPacking(problem_size, indices, slices); + + se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); + spec.AddInProcessSymbol(GetCutlassGemmKernel(), "cutlass_gemm"); return CustomKernel("cutlass_gemm", std::move(spec), - gemm_universal::BlockDim(problem_size), - gemm_universal::ThreadDim(), + BlockDim(problem_size), ThreadDim(), sizeof(typename Kernel::SharedStorage)); } StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, - int32_t n, int32_t k) { + int32_t n, int32_t k, + const ArgsIndices& indices, + const DynamicSliceIndices& slices) { if (dtype != PrimitiveType::F32) return absl::InvalidArgumentError( "Currently cutlass gemm kernel supports only F32 data type"); - return LoadCutlassGemmUniversal(m, n, k); + return LoadCutlassGemmUniversal( + m, n, k, indices, slices); } -} // namespace xla::gpu::kernel +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index 364f8b9aba73a6..345d5973c5720b 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -19,14 +19,17 @@ limitations under the License. #include #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" #include "xla/statusor.h" #include "xla/xla_data.pb.h" -namespace xla::gpu::kernel { +namespace xla::gpu::kernel::gemm_universal { StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, - int32_t n, int32_t k); + int32_t n, int32_t k, + const ArgsIndices& indices, + const DynamicSliceIndices& slices); -} // namespace xla::gpu::kernel +} // namespace xla::gpu::kernel::gemm_universal #endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index e627cb449b6fb5..f577c2b431c6fc 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" + #include #include #include -#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" @@ -27,7 +28,7 @@ limitations under the License. #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" -namespace xla::gpu::kernel { +namespace xla::gpu::kernel::gemm_universal { TEST(CutlassGemmKernelTest, SimpleGemm) { se::Platform* platform = @@ -41,7 +42,9 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { se::Kernel gemm(executor); // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. - auto custom_kernel = GetCutlassGemmKernel(PrimitiveType::F32, 4, 4, 4); + auto custom_kernel = + GetCutlassGemmKernel(PrimitiveType::F32, 4, 4, 4, + /*indices=*/{0, 1, 2}, /*slices=*/{}); TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); int64_t length = 4 * 4; @@ -75,4 +78,4 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { ASSERT_EQ(dst, expected); } -} // namespace xla::gpu::kernel +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 7c4854f2d877c6..9bc6c6d62ee867 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_fusion.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" #include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla::gpu { @@ -54,6 +54,18 @@ struct GemmWithUpcast { HloInstruction* lhs_upcast = nullptr; // HLO convert instr HloInstruction* rhs_upcast = nullptr; // HLO convert instr }; + +// Pattern for matching GEMM with surrounding dynamic-slice/update-slice. +struct GemmWithDynamicSlice { + explicit GemmWithDynamicSlice(HloDynamicUpdateSliceInstruction* update_slice) + : update_slice(update_slice) {} + + std::vector Instrs() { return {dot, bitcast, update_slice}; } + + HloInstruction* dot = nullptr; + HloInstruction* bitcast = nullptr; // result bitcast + HloInstruction* update_slice = nullptr; // update result slice +}; } // namespace // Returns OK if dot instruction is a simple 2D row-major gemm. @@ -97,23 +109,39 @@ static Status MatchSimpleGemm(HloDotInstruction* dot, PrimitiveType dtype) { static StatusOr MatchGemmWithUpcast(HloDotInstruction* dot) { TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); - GemmWithUpcast matched(dot); + GemmWithUpcast match(dot); // C <- convert(A) * B if (Match(const_cast(dot->operand(0)), - m::Convert(&matched.lhs_upcast, m::Op()))) { - return matched; + m::Convert(&match.lhs_upcast, m::Op()))) { + return match; } // C <- A * convert(B) if (Match(const_cast(dot->operand(1)), - m::Convert(&matched.rhs_upcast, m::Op()))) { - return matched; + m::Convert(&match.rhs_upcast, m::Op()))) { + return match; } return absl::InternalError("unsupported gemm with upcasing"); } +// Returns matched GEMM with result used to update a slice. +static StatusOr MatchGemmWithDynamicUpdateSlice( + HloDynamicUpdateSliceInstruction* update_slice) { + GemmWithDynamicSlice match(update_slice); + + if (!Match( + const_cast(update_slice->operand(1)), + m::Bitcast(&match.bitcast, m::Dot(&match.dot, m::Op(), m::Op())))) { + return absl::InternalError("failed to match update slice instr"); + } + + TF_RETURN_IF_ERROR(MatchRowMajorGemm(Cast(match.dot))); + + return match; +} + //===----------------------------------------------------------------------===// // Cutlass Gemm Patterns //===----------------------------------------------------------------------===// @@ -131,6 +159,21 @@ std::optional CutlassGemmPattern::TryMatch( return Match{config, {instr}}; } +std::optional +CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( + HloInstruction* instr) const { + auto* update_slice = DynCast(instr); + if (!update_slice) return std::nullopt; + + auto matched = MatchGemmWithDynamicUpdateSlice(update_slice); + if (!matched.ok()) return std::nullopt; + + CustomFusionConfig config; + config.set_name("cutlass_gemm_with_dynamic_update_slice"); + + return Match{config, matched->Instrs()}; +} + std::optional CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { auto* dot = DynCast(instr); @@ -167,15 +210,24 @@ class CutlassGemmFusion : public CustomFusion { auto dtype = dot->shape().element_type(); - auto& lhs_shape = dot->operand(0)->shape(); - auto& rhs_shape = dot->operand(1)->shape(); + auto* lhs = Cast(dot->operand(0)); + auto* rhs = Cast(dot->operand(1)); + + // Mapping from fusion arguments to gemm kernel arguments. + kernel::gemm_universal::ArgsIndices indices = { + lhs->parameter_number(), rhs->parameter_number(), + computation->num_parameters()}; + + auto& lhs_shape = lhs->shape(); + auto& rhs_shape = rhs->shape(); size_t m = lhs_shape.dimensions(0); size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); TF_ASSIGN_OR_RETURN(auto kernel, - kernel::GetCutlassGemmKernel(dtype, m, n, k)); + kernel::gemm_universal::GetCutlassGemmKernel( + dtype, m, n, k, indices, /*slices=*/{})); return std::vector{std::move(kernel)}; } }; @@ -207,10 +259,61 @@ class CutlassGemmWithUpcastFusion : public CustomFusion { } }; +class CutlassGemmWithDynamicUpdateSliceFusion : public CustomFusion { + public: + StatusOr> LoadKernels( + const HloComputation* computation) const final { + auto* dus = DynCast( + computation->root_instruction()); + if (dus == nullptr) { + return absl::InternalError( + "cutlass_gemm_with_dynamic_update_slice requires ROOT operation to " + "be a dynamic update slice"); + } + + TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithDynamicUpdateSlice(dus)); + TF_RETURN_IF_ERROR(MatchSimpleGemm(Cast(matched.dot), + PrimitiveType::F32)); + + auto dtype = matched.dot->shape().element_type(); + + auto* lhs = Cast(matched.dot->operand(0)); + auto* rhs = Cast(matched.dot->operand(1)); + auto* out = Cast(matched.update_slice->operand(0)); + + // Mapping from fusion arguments to gemm kernel arguments. + kernel::gemm_universal::ArgsIndices args_indices = { + lhs->parameter_number(), rhs->parameter_number(), + out->parameter_number()}; + + // Mapping to a buffer that holds output slice offset. + auto* offset = + Cast(matched.update_slice->operand(2)); + + kernel::gemm_universal::DynamicSliceIndices slices; + slices.out = offset->parameter_number(); + + auto& lhs_shape = lhs->shape(); + auto& rhs_shape = rhs->shape(); + + size_t m = lhs_shape.dimensions(0); + size_t k = lhs_shape.dimensions(1); + size_t n = rhs_shape.dimensions(1); + + TF_ASSIGN_OR_RETURN(auto kernel, + kernel::gemm_universal::GetCutlassGemmKernel( + dtype, m, n, k, args_indices, slices)); + return std::vector{std::move(kernel)}; + } +}; + } // namespace xla::gpu -XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmPattern); +XLA_REGISTER_CUSTOM_FUSION_PATTERN( + ::xla::gpu::CutlassGemmWithDynamicUpdateSlicePattern); XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", ::xla::gpu::CutlassGemmFusion); XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm_with_upcast", ::xla::gpu::CutlassGemmWithUpcastFusion); +XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm_with_dynamic_update_slice", + ::xla::gpu::CutlassGemmWithDynamicUpdateSliceFusion); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h index f448b2d0a4d915..42c92520a2789a 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h @@ -29,6 +29,12 @@ class CutlassGemmPattern : public CustomFusionPattern { std::optional TryMatch(HloInstruction* instr) const override; }; +// Pattern matches simple row-major gemms with dynamic-update-slice. +class CutlassGemmWithDynamicUpdateSlicePattern : public CustomFusionPattern { + public: + std::optional TryMatch(HloInstruction* instr) const override; +}; + // Pattern matches mixed dtype gemms when one of the operands is upcasted to an // accumulator (output) dtype, i.e. BF16 <= BF16 x S8. class CutlassGemmWithUpcastPattern : public CustomFusionPattern { diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 541ba5c569b088..bf26adb5f43303 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -15,10 +15,15 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" +#include #include +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/debug_options_flags.h" #include "xla/error_spec.h" +#include "xla/literal_util.h" #include "xla/service/gpu/custom_fusion_rewriter.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/tests/hlo_test_base.h" @@ -115,6 +120,55 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f32[2,2,2], p1: f32[2,2], i: s32[]) -> f32[2,2,2] { + %p0 = f32[2,2,2]{2,1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + %bc = f32[1,2,2]{2,1,0} bitcast(%dot) + + ROOT %r = f32[2,2,2]{2,1,0} dynamic-update-slice(%p0, %bc, %i, %i, %i) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2,2]{2,1,0} parameter + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[CAST:%[^ ]+]] = f32[1,2,2]{2,1,0} bitcast([[DOT]]) + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[2,2,2]{2,1,0} dynamic-update-slice( + ; CHECK: [[P1]], [[CAST]], [[P2]], [[P2]], [[P2]] + ; CHECK: ) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[2,2,2]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: } + )"; + + CustomFusionPatternRegistry patterns; + patterns.Emplace(); + + CustomFusionRewriter pass(&patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + //===----------------------------------------------------------------------===// // Run And Compare Tests //===----------------------------------------------------------------------===// @@ -170,7 +224,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1), custom_call_target="__cublas$gemm", backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} - ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element((bf16[16,8]{1,0}, s8[0]{0}) gemm), index=0 + ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element(gemm), index=0 })"; const char* hlo_text_custom_fusion = R"( @@ -195,4 +249,63 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { error_spec, /*run_hlo_passes=*/false)); } +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + + gemm.tuple = (f32[2,2]{1,0}, s8[0]{0}) custom-call(p1, p1), + custom_call_target="__cublas$gemm", + backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + gemm = f32[2,2]{1,0} get-tuple-element(gemm.tuple), index=0 + cast = f32[1,2,2]{2,1,0} bitcast(gemm) + + ROOT r = f32[2,2,2]{2,1,0} dynamic-update-slice(p0, cast, p2, p3, p3) + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm { + p0.1 = f32[2,2]{1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + dot.1 = f32[2,2]{1,0} dot(p0.1, p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bc.1 = f32[1,2,2]{2,1,0} bitcast(dot.1) + ROOT r.1 = f32[2,2,2]{2,1,0} dynamic-update-slice(p1.1, bc.1, p2, p3, p3) + } + + ENTRY e { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT _ = f32[2,2,2]{2,1,0} fusion(p1, p0, p2, p3), kind=kCustom, + calls=%cutlass_gemm, + backend_config={"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}} + })"; + + Array3D p0_arr(2, 2, 2); + Array2D p1_arr({{0.0, 1.0}, {2.0, 3.0}}); + Array p2_arr({}, 1); + Array p3_arr({}, 0); + + auto p0 = LiteralUtil::CreateFromArray(p0_arr); + auto p1 = LiteralUtil::CreateFromArray(p1_arr); + auto p2 = LiteralUtil::CreateFromArray(p2_arr); + auto p3 = LiteralUtil::CreateFromArray(p3_arr); + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + {&p0, &p1, &p2, &p3}, error_spec, + /*run_hlo_passes=*/false)); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h similarity index 82% rename from third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h rename to third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index d5758dbac810a9..1721c731e05670 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ -#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_CU_H_ #include #include @@ -28,6 +28,7 @@ limitations under the License. #include "third_party/gpus/cutlass/include/cutlass/gemm/gemm_enumerated_types.h" #include "third_party/gpus/cutlass/include/cutlass/gemm_coord.h" #include "third_party/gpus/cutlass/include/cutlass/layout/matrix.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" #include "xla/statusor.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -115,22 +116,31 @@ int64_t LdC(const cutlass::gemm::GemmCoord &problem_size) { using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; template -auto *DevicePtr(const se::KernelArgsDeviceMemoryArray *args) { - const void *opaque = args->device_memory_ptr(index); - +auto *ArgPtr(const se::KernelArgsDeviceMemoryArray *args, + const ArgsIndices &indices) { if constexpr (index == 0) { + const void *opaque = args->device_memory_ptr(indices.lhs); return static_cast(const_cast(opaque)); } else if constexpr (index == 1) { + const void *opaque = args->device_memory_ptr(indices.rhs); return static_cast(const_cast(opaque)); } else if constexpr (index == 2) { + const void *opaque = args->device_memory_ptr(indices.out); return static_cast(const_cast(opaque)); } else { static_assert(sizeof(Gemm) == 0, "illegal Gemm argument index"); } } +int32_t *SlicePtr(const se::KernelArgsDeviceMemoryArray *args, int64_t index) { + const void *opaque = args->device_memory_ptr(index); + return static_cast(const_cast(opaque)); +} + template -KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size) { +KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, + const ArgsIndices &indices, + const DynamicSliceIndices &slices) { using Arguments = typename Gemm::Arguments; using Kernel = typename Gemm::GemmKernel; using Params = typename Kernel::Params; @@ -156,9 +166,9 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size) { auto ldb = LdB(problem_size); auto ldc = LdC(problem_size); - auto ptr_a = DevicePtr(mem_args); - auto ptr_b = DevicePtr(mem_args); - auto ptr_c = DevicePtr(mem_args); + auto ptr_a = ArgPtr(mem_args, indices); + auto ptr_b = ArgPtr(mem_args, indices); + auto ptr_c = ArgPtr(mem_args, indices); auto mode = cutlass::gemm::GemmUniversalMode::kGemm; float alpha = 1.0, beta = 0.0; @@ -174,12 +184,22 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size) { // TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate // correct kernel occupancy using GpuRuntime. + + // Convert CUTLASS operation arguments to a device kernel parameters. Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10); - return se::PackKernelArgs(args.number_of_shared_bytes(), params); + // Optionally set up dynamic slice parameters to allow kernel adjust buffer + // pointers passed via `params`. + DynamicSliceParams slice_params; + if (slices.out.has_value()) { + slice_params.out = SlicePtr(mem_args, *slices.out); + } + + return se::PackKernelArgs( + args.number_of_shared_bytes(), params, slice_params); }; } } // namespace xla::gpu::kernel::gemm_universal -#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_CU_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h index f70bb49db4ac72..c2471f1ceb39d3 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h @@ -17,8 +17,9 @@ limitations under the License. #define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNELS_CU_H_ #include "third_party/gpus/cutlass/include/cutlass/gemm/device/gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" -namespace xla::gpu::kernel { +namespace xla::gpu::kernel::gemm_universal { struct CutlassGemmKernels { using F32xF32toF32 = @@ -27,17 +28,40 @@ struct CutlassGemmKernels { float, cutlass::layout::RowMajor>; }; -namespace internal { +// This entry point is based on `cutlass::Kernel2` template with an extra +// parameter to pass dynamic slices. +template +__global__ void Kernel(typename Gemm::Params params, + gemm_universal::DynamicSliceParams slices) { + extern __shared__ int SharedStorageBase[]; + typename Gemm::SharedStorage* shared_storage = + reinterpret_cast(SharedStorageBase); + + // Update output pointers to account for dynamic offsets. + if (slices.out.has_value()) { + auto m = params.problem_size.m(); + auto n = params.problem_size.n(); + + int32_t out_offset = **slices.out; + + char* ptr_c = reinterpret_cast(params.ptr_C); + char* ptr_d = reinterpret_cast(params.ptr_D); + + params.ptr_C = ptr_c + 4 * out_offset * (m * n); + params.ptr_D = ptr_d + 4 * out_offset * (m * n); + } + + Gemm::invoke(params, *shared_storage); +} template void* GetCutlassGemmKernel() { - return reinterpret_cast(cutlass::Kernel2); + return reinterpret_cast(Kernel); } // Extern templates for all supported CUTLASS Gemm kernels. extern template void* GetCutlassGemmKernel(); -} // namespace internal -} // namespace xla::gpu::kernel +} // namespace xla::gpu::kernel::gemm_universal #endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNELS_CU_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc index 1dc1a9e1aa0fe7..b4da3d222a6370 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h" -namespace xla::gpu::kernel::internal { +namespace xla::gpu::kernel::gemm_universal { template void* GetCutlassGemmKernel(); -} // namespace xla::gpu::kernel::internal +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index e9358e88f53be6..fc947372215ec1 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -591,6 +591,29 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( run_hlo_passes); } +::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + const absl::Span arguments, + const std::optional& error, bool run_hlo_passes) { + auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); + if (!module_0_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0_or_status.status().ToString(); + } + + auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1); + if (!module_1_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1_or_status.status().ToString(); + } + return RunAndCompareTwoModules(std::move(module_0_or_status).value(), + std::move(module_1_or_status).value(), + arguments, error, run_hlo_passes); +} + ::testing::AssertionResult HloTestBase::Run( string_view hlo_string, bool run_hlo_passes, ExecutionProfile* profile, const tsl::protobuf::Message* backend_config) { diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index e543e9fe7ca4b9..118fa8713637be 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -313,6 +313,13 @@ class HloTestBase : public ManifestCheckingTest { absl::string_view hlo_string_module_1, const std::optional& error, bool run_hlo_passes = true); + // Same as above but requires explicit arguments. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + absl::Span arguments, + const std::optional& error, bool run_hlo_passes = true); + // Executes an hlo module with fake inputs on multiple replicas. [[nodiscard]] ::testing::AssertionResult RunReplicated( const absl::string_view hlo_string, bool run_hlo_passes = true, From 57375f35cd85ae3a5276c3baab18983e214e8288 Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Fri, 1 Dec 2023 10:10:57 -0800 Subject: [PATCH 286/381] [HloValueSemanticsAnalysis] Fix EinsumDepthAnalysis::HandleWhile for handling instructions that are only used inside the loop body. PiperOrigin-RevId: 587043652 --- .../service/hlo_value_semantics_analysis.cc | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index e74d1e8ea7238e..af50110b6d6d84 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -371,7 +371,7 @@ Status EinsumDepthAnalysis::HandleCustomCall(HloInstruction* custom_call) { Status EinsumDepthAnalysis::HandleWhile(HloInstruction* xla_while) { auto depth_iter = einsum_depth_map_.find(xla_while); CHECK(depth_iter != einsum_depth_map_.end()); - const ShapeTree depth_tree = depth_iter->second; + const ShapeTree& depth_tree = depth_iter->second; int max_depth = GetMaxDepth(depth_tree); HloComputation* condition_computation = xla_while->while_condition(); HloInstruction* condition_root = condition_computation->root_instruction(); @@ -379,31 +379,36 @@ Status EinsumDepthAnalysis::HandleWhile(HloInstruction* xla_while) { TF_RETURN_IF_ERROR(HandleCalledComputation( *condition_computation, condition_depth, xla_while->operands())); HloComputation* body_computation = xla_while->while_body(); - TF_RETURN_IF_ERROR(HandleCalledComputation(*body_computation, depth_tree, - xla_while->operands())); - // Elements of while loop outputs may only be used within the while loop. - // Set the depth of the while body outputs to have the max of their original - // depth and their corresponding operand depth if their original depth was - // negative. Then recompute while loop instruction depths. - auto body_depth_iter = + bool run_depth_propagation_on_body = true; + const ShapeTree* root_depth_ptr = &depth_tree; + auto root_depth_iter = GetOrCreateDepthTree(body_computation->root_instruction()); - ShapeTree& body_depth = body_depth_iter->second; - // Note: while body computations have a single parameter. See - // ShapeVerifier::HandleWhile. - HloInstruction* operand = body_computation->parameter_instruction(0); - auto operand_depth = GetOrCreateDepthTree(operand)->second; - body_depth.ForEachMutableElement( - [&body_depth, &operand_depth](const ShapeIndex& shape_index, - int* depth_ptr) { - if (body_depth.IsLeaf(shape_index)) { - if (body_depth.element(shape_index) < 0 && + ShapeTree& root_depth = root_depth_iter->second; + while (run_depth_propagation_on_body) { + run_depth_propagation_on_body = false; + TF_RETURN_IF_ERROR(HandleCalledComputation( + *body_computation, *root_depth_ptr, xla_while->operands())); + // Elements of while loop outputs may only be used within the while loop. + // If such elements exist, we set its root depth to it operand depth. Then + // recompute while loop instruction depths. + HloInstruction* operand = body_computation->parameter_instruction(0); + const ShapeTree& operand_depth = GetOrCreateDepthTree(operand)->second; + + root_depth.ForEachMutableElement( + [&run_depth_propagation_on_body, &root_depth, &operand_depth]( + const ShapeIndex& shape_index, int* depth_ptr) { + if (!root_depth.IsLeaf(shape_index)) { + return; + } + if (root_depth.element(shape_index) < 0 && operand_depth.element(shape_index) >= 0) { - *depth_ptr = 0; + *depth_ptr = operand_depth.element(shape_index); + run_depth_propagation_on_body = true; } - } - }); - return HandleCalledComputation(*body_computation, body_depth, - xla_while->operands()); + }); + root_depth_ptr = &root_depth; + } + return OkStatus(); } Status EinsumDepthAnalysis::HandleConditional(HloInstruction* conditional) { From 7c8072fb020bd770363a060ccaff5922956369ce Mon Sep 17 00:00:00 2001 From: Om Thakkar Date: Fri, 1 Dec 2023 10:24:55 -0800 Subject: [PATCH 287/381] adding a unit test for Conv + BiasAdd + Add + fusion --- .../core/grappler/optimizers/remapper_test.cc | 253 +++++++++++------- 1 file changed, 157 insertions(+), 96 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index d3a6652589381b..76c7098361d6f2 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -888,6 +888,163 @@ TEST_F(RemapperFuseConvWithBiasAndActivation, Conv3D_BF16) { RunTest<3, DT_BFLOAT16>(); } +class RemapperFuseConvWithBiasAndAddActivation : public RemapperTest { + public: + template + void RunTest() { + if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; + using ::tensorflow::ops::Placeholder; + + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = Placeholder::Shape({128}); + auto add_shape = ops::Placeholder::Shape({8, 32, 32, 128}); + + auto input_t = GenerateRandomTensor({8, 32, 32, 3}); + auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); + auto bias_t = GenerateRandomTensor({128}); + auto add_t = GenerateRandomTensor({8, 32, 32, 128}); + + float leakyrelu_alpha = 0.5; + + std::vector strides = {1, 1, 1, 1}; + + if (dim == 3) { + input_shape = Placeholder::Shape({8, 4, 32, 32, 3}); + filter_shape = Placeholder::Shape({1, 1, 1, 3, 128}); + bias_shape = Placeholder::Shape({128}); + add_shape = ops::Placeholder::Shape({8, 4, 32, 32, 128}); + strides = {1, 1, 1, 1, 1}; + + input_t = GenerateRandomTensor({8, 4, 32, 32, 3}); + filter_t = GenerateRandomTensor({1, 1, 1, 3, 128}); + bias_t = GenerateRandomTensor({128}); + add_t = GenerateRandomTensor({8, 4, 32, 32, 128}); + } + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); + auto input_add = + Placeholder(s.WithOpName("input_add"), DT_FLOAT, add_shape); + + if (dim == 2) { + auto conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); + auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); + + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); + + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); + return ops::Identity(fetch, + ops::internal::LeakyRelu(activate, add, attr)); + } + + return ops::Identity(fetch, bias); + }(); + } else if (dim == 3) { + auto conv = + ops::Conv3D(s.WithOpName("conv"), input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); + auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); + + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); + + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); + return ops::Identity(fetch, + ops::internal::LeakyRelu(activate, add, attr)); + } + + return ops::Identity(fetch, bias); + }(); + } + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, + {"filter", filter_t}, + {"bias", bias_t}, + {"input_add", add_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "activation") { + if (dim == 2) { + EXPECT_EQ(node.op(), "_FusedConv2D"); + } else if (dim == 3) { + EXPECT_EQ(node.op(), "_FusedConv3D"); + } + ASSERT_GE(node.input_size(), 3); + EXPECT_EQ(node.input(0), "input"); + EXPECT_EQ(node.input(1), "filter"); + + EXPECT_EQ(node.attr().at("num_args").i(), 2); + EXPECT_EQ(node.input(2), "bias"); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + ASSERT_EQ(fused_ops.size(), 3); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("Add", fused_ops[1]); + EXPECT_EQ(activation, fused_ops[2]); + found++; + } + } + EXPECT_EQ(found, 1); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + } + } +}; + +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv2D_F32) { + RunTest<2, DT_FLOAT>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv3D_F32) { + RunTest<3, DT_FLOAT>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv2D_BF16) { + RunTest<2, DT_BFLOAT16>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv3D_BF16) { + RunTest<3, DT_BFLOAT16>(); +} + class RemapperFuseConvWithSqueezeAndBias : public RemapperTest { public: template @@ -2255,102 +2412,6 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAdd) { test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); } -TEST_F(RemapperTest, FuseConv3DWithBiasAndAddActivation) { - if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; - using ::tensorflow::ops::Placeholder; - - for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto input_shape = Placeholder::Shape({8, 4, 32, 32, 3}); - auto filter_shape = Placeholder::Shape({1, 1, 1, 3, 128}); - auto bias_shape = Placeholder::Shape({128}); - auto add_shape = ops::Placeholder::Shape({8, 4, 32, 32, 128}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - auto input_add = - Placeholder(s.WithOpName("input_add"), DT_FLOAT, add_shape); - - float leakyrelu_alpha = 0.5; - - std::vector strides = {1, 1, 1, 1, 1}; - auto conv = - ops::Conv3D(s.WithOpName("conv"), input, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); - - ops::Identity fetch = [&]() -> ops::Identity { - auto activate = s.WithOpName("activation"); - auto fetch = s.WithOpName("fetch"); - - if (activation == "Relu") { - return ops::Identity(fetch, ops::Relu(activate, add)); - } else if (activation == "Relu6") { - return ops::Identity(fetch, ops::Relu6(activate, add)); - } else if (activation == "Elu") { - return ops::Identity(fetch, ops::Elu(activate, add)); - } else if (activation == "LeakyRelu") { - auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); - return ops::Identity(fetch, - ops::internal::LeakyRelu(activate, add, attr)); - } - - return ops::Identity(fetch, bias); - }(); - - auto input_t = GenerateRandomTensor({8, 4, 32, 32, 3}); - auto filter_t = GenerateRandomTensor({1, 1, 1, 3, 128}); - auto bias_t = GenerateRandomTensor({128}); - auto add_t = GenerateRandomTensor({8, 4, 32, 32, 128}); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, - {"filter", filter_t}, - {"bias", bias_t}, - {"input_add", add_t}}; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } - - Remapper optimizer(RewriterConfig::AGGRESSIVE); - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "activation") { - EXPECT_EQ(node.op(), "_FusedConv3D"); - ASSERT_GE(node.input_size(), 3); - EXPECT_EQ(node.input(0), "input"); - EXPECT_EQ(node.input(1), "filter"); - - EXPECT_EQ(node.attr().at("num_args").i(), 2); - EXPECT_EQ(node.input(2), "bias"); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - ASSERT_EQ(fused_ops.size(), 3); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Add", fused_ops[1]); - EXPECT_EQ(activation, fused_ops[2]); - found++; - } - } - EXPECT_EQ(found, 1); - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - ASSERT_EQ(tensors_expected.size(), 1); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); - } -} - // Conv2D + Add {6,} + Conv2D + Biasadd fusion. TEST_F(RemapperTest, FuseConv2DWithSemanticAdd) { if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL."; From d5736156b7c5fa663f26ba94f064ebbe0597f9b7 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 1 Dec 2023 10:22:40 -0800 Subject: [PATCH 288/381] [stream_executor] Add GetSlice to DeviceMemory* Getting slice of DeviceMemory is a common operation, let's have an implementation in DeviceMemoryBase and DeviceMemory. PiperOrigin-RevId: 587047674 --- third_party/xla/xla/stream_executor/BUILD | 5 +++- .../xla/xla/stream_executor/device_memory.h | 23 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 66c87288ee136d..5e802a3f8bb8c1 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -179,7 +179,10 @@ cc_library( name = "device_memory", hdrs = ["device_memory.h"], visibility = ["//visibility:public"], - deps = ["//xla/stream_executor/platform"], + deps = [ + "//xla/stream_executor/platform", + "@com_google_absl//absl/log:check", + ], ) # TODO(ezhulenev): Merge this target into `stream_executor`. diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index 43cb0ac5245438..e4f3643402f037 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "xla/stream_executor/platform/port.h" namespace stream_executor { @@ -89,6 +90,18 @@ class DeviceMemoryBase { return opaque() == other.opaque() && size() == other.size(); } + // Creates a memory region (slice) inside another allocated memory region. + // Offset and size are in bytes. + DeviceMemoryBase GetByteSlice(uint64_t offset_bytes, uint64_t size_bytes) { + DCHECK(offset_bytes + size_bytes <= size_) + << "requested slice allocation (offset + size) is greater " + << "than parent allocation size: (" << offset_bytes << " + " + << size_bytes << ") vs. (" << size_ << ")"; + + return DeviceMemoryBase( + reinterpret_cast(opaque_) + offset_bytes, size_bytes); + } + protected: friend class StreamExecutor; @@ -139,13 +152,21 @@ class DeviceMemory final : public DeviceMemoryBase { // Returns whether this is a single-element allocation. bool IsScalar() const { return ElementCount() == 1; } - // Create a typed area of DeviceMemory with a given opaque pointer and the + // Creates a typed area of DeviceMemory with a given opaque pointer and the // quantity of bytes in the allocation. This function is broken out to // distinguish bytes from an element count. static DeviceMemory MakeFromByteSize(void *opaque, uint64_t bytes) { return DeviceMemory(opaque, bytes); } + // Creates a memory region (slice) inside another allocated memory region. + // Offset and size are specified in terms of ElemT elements. + DeviceMemory GetSlice(uint64_t element_offset, + uint64_t element_count) { + return DeviceMemory(GetByteSlice(sizeof(ElemT) * element_offset, + sizeof(ElemT) * element_count)); + } + // Resets the DeviceMemory data, in MakeFromByteSize fashion. // This simply clobbers the prior values. void ResetFromByteSize(void *opaque, uint64_t bytes) { From 19c7beb663edc1d1cbe0f9f4dd020917e5f55e76 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Fri, 1 Dec 2023 10:24:52 -0800 Subject: [PATCH 289/381] Do MLIR verification after each pass. PiperOrigin-RevId: 587048389 --- tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc | 1 + tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc | 1 + tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc | 1 + tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc | 1 + 4 files changed, 4 insertions(+) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index 2f8469ee3f6f69..bb27edab8aa88a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -103,6 +103,7 @@ tensorflow::Status RunTFXLABridge( } PassManager bridge(module.getContext()); + bridge.enableVerifier(); ::tensorflow::applyTensorflowAndCLOptions(bridge); // Populate a passmanager with the list of passes that implement the bridge. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc index 236282f625e20a..9d0b884ebbe85d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc @@ -127,6 +127,7 @@ tensorflow::Status ExportFromTensorflowDialectToExecutor( ModuleOp module, llvm::StringRef module_name) { PassManager tf_to_executor(module.getContext()); ::tensorflow::applyTensorflowAndCLOptions(tf_to_executor); + tf_to_executor.enableVerifier(); AddTfDialectToExecutorPasses(tf_to_executor); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index e0cb8685552e3f..289d4d0faec78e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -73,6 +73,7 @@ tensorflow::Status RunTFXLABridge( } PassManager bridge(module.getContext()); + bridge.enableVerifier(); ::tensorflow::applyTensorflowAndCLOptions(bridge); // Populate a passmanager with the list of passes that implement the bridge. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc index 69f1c0e20a5e1b..455a59d6607c49 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc @@ -126,6 +126,7 @@ tensorflow::Status ExportFromTensorflowDialectToExecutor( ModuleOp module, llvm::StringRef module_name) { PassManager tf_to_executor(module.getContext()); ::tensorflow::applyTensorflowAndCLOptions(tf_to_executor); + tf_to_executor.enableVerifier(); AddTfDialectToExecutorPasses(tf_to_executor); if (VLOG_IS_ON(1) || From 230bc28e5cd94f62e5ce2ce95e2ba726795fbe78 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Fri, 1 Dec 2023 11:04:18 -0800 Subject: [PATCH 290/381] Integrate LLVM at llvm/llvm-project@668865789620 Updates LLVM usage to match [668865789620](https://github.com/llvm/llvm-project/commit/668865789620) PiperOrigin-RevId: 587060157 --- third_party/llvm/generated.patch | 267 +++++++++++++++++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- 2 files changed, 269 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..16455e0725570e 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,268 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h +--- a/llvm/include/llvm/Target/TargetMachine.h ++++ b/llvm/include/llvm/Target/TargetMachine.h +@@ -239,7 +239,7 @@ + void setCodeModel(CodeModel::Model CM) { CMModel = CM; } + + void setLargeDataThreshold(uint64_t LDT) { LargeDataThreshold = LDT; } +- bool isLargeGlobalObject(const GlobalObject *GO) const; ++ bool isLargeData(const GlobalVariable *GV) const; + + bool isPositionIndependent() const; + +diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp +--- a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp ++++ b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp +@@ -616,7 +616,7 @@ + /// DataSections. + static StringRef getSectionPrefixForGlobal(SectionKind Kind, bool IsLarge) { + if (Kind.isText()) +- return IsLarge ? ".ltext" : ".text"; ++ return ".text"; + if (Kind.isReadOnly()) + return IsLarge ? ".lrodata" : ".rodata"; + if (Kind.isBSS()) +@@ -650,7 +650,10 @@ + Name = ".rodata.cst"; + Name += utostr(EntrySize); + } else { +- Name = getSectionPrefixForGlobal(Kind, TM.isLargeGlobalObject(GO)); ++ bool IsLarge = false; ++ if (auto *GV = dyn_cast(GO)) ++ IsLarge = TM.isLargeData(GV); ++ Name = getSectionPrefixForGlobal(Kind, IsLarge); + } + + bool HasPrefix = false; +@@ -770,8 +773,12 @@ + Group = C->getName(); + IsComdat = C->getSelectionKind() == Comdat::Any; + } +- if (TM.isLargeGlobalObject(GO)) +- Flags |= ELF::SHF_X86_64_LARGE; ++ if (auto *GV = dyn_cast(GO)) { ++ if (TM.isLargeData(GV)) { ++ assert(TM.getTargetTriple().getArch() == Triple::x86_64); ++ Flags |= ELF::SHF_X86_64_LARGE; ++ } ++ } + return {Group, IsComdat, Flags}; + } + +diff -ruN --strip-trailing-cr a/llvm/lib/Target/TargetMachine.cpp b/llvm/lib/Target/TargetMachine.cpp +--- a/llvm/lib/Target/TargetMachine.cpp ++++ b/llvm/lib/Target/TargetMachine.cpp +@@ -39,21 +39,13 @@ + + TargetMachine::~TargetMachine() = default; + +-bool TargetMachine::isLargeGlobalObject(const GlobalObject *GO) const { +- if (getTargetTriple().getArch() != Triple::x86_64) ++bool TargetMachine::isLargeData(const GlobalVariable *GV) const { ++ if (getTargetTriple().getArch() != Triple::x86_64 || GV->isThreadLocal()) + return false; + + if (getCodeModel() != CodeModel::Medium && getCodeModel() != CodeModel::Large) + return false; + +- if (isa(GO)) +- return getCodeModel() == CodeModel::Large; +- +- auto *GV = cast(GO); +- +- if (GV->isThreadLocal()) +- return false; +- + // Allowing large metadata sections in the presence of an explicit section is + // useful, even if GCC does not allow them. However, we should not mark + // certain well-known prefixes as large, because it would make the whole +diff -ruN --strip-trailing-cr a/llvm/lib/Target/X86/X86Subtarget.cpp b/llvm/lib/Target/X86/X86Subtarget.cpp +--- a/llvm/lib/Target/X86/X86Subtarget.cpp ++++ b/llvm/lib/Target/X86/X86Subtarget.cpp +@@ -83,20 +83,32 @@ + if (is64Bit()) { + // 64-bit ELF PIC local references may use GOTOFF relocations. + if (isTargetELF()) { +- CodeModel::Model CM = TM.getCodeModel(); +- assert(CM != CodeModel::Tiny && +- "Tiny codesize model not supported on X86"); +- // In the large code model, even referencing a global under the large data +- // threshold which is considered "small", we need to use GOTOFF. +- if (CM == CodeModel::Large) ++ switch (TM.getCodeModel()) { ++ // 64-bit small code model is simple: All rip-relative. ++ case CodeModel::Tiny: ++ llvm_unreachable("Tiny codesize model not supported on X86"); ++ case CodeModel::Small: ++ case CodeModel::Kernel: ++ return X86II::MO_NO_FLAG; ++ ++ // The large PIC code model uses GOTOFF. ++ case CodeModel::Large: + return X86II::MO_GOTOFF; +- // Large objects use GOTOFF, otherwise use RIP-rel access. +- if (auto *GO = dyn_cast_or_null(GV)) +- return TM.isLargeGlobalObject(GO) ? X86II::MO_GOTOFF +- : X86II::MO_NO_FLAG; +- // For non-GlobalObjects, the small and medium code models treat them as +- // accessible with a RIP-rel access. +- return X86II::MO_NO_FLAG; ++ ++ // Medium is a hybrid: RIP-rel for code and non-large data, GOTOFF for ++ // remaining DSO local data. ++ case CodeModel::Medium: ++ // Constant pool and jump table handling pass a nullptr to this ++ // function so we need to use isa_and_nonnull. ++ if (isa_and_nonnull(GV)) ++ return X86II::MO_NO_FLAG; // All code is RIP-relative ++ if (auto *GVar = dyn_cast_or_null(GV)) { ++ if (TM.isLargeData(GVar)) ++ return X86II::MO_GOTOFF; ++ } ++ return X86II::MO_NO_FLAG; // Local symbols use GOTOFF. ++ } ++ llvm_unreachable("invalid code model"); + } + + // Otherwise, this is either a RIP-relative reference or a 64-bit movabsq, +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/code-model-elf.ll b/llvm/test/CodeGen/X86/code-model-elf.ll +--- a/llvm/test/CodeGen/X86/code-model-elf.ll ++++ b/llvm/test/CodeGen/X86/code-model-elf.ll +@@ -9,7 +9,6 @@ + ; RUN: llc -verify-machineinstrs < %s -relocation-model=pic -code-model=medium -large-data-threshold=1000 | FileCheck %s --check-prefix=CHECK --check-prefix=MEDIUM-SMALL-DATA-PIC + ; RUN: llc -verify-machineinstrs < %s -relocation-model=pic -code-model=medium | FileCheck %s --check-prefix=CHECK --check-prefix=MEDIUM-PIC + ; RUN: llc -verify-machineinstrs < %s -relocation-model=pic -code-model=large | FileCheck %s --check-prefix=CHECK --check-prefix=LARGE-PIC +-; RUN: llc -verify-machineinstrs < %s -relocation-model=pic -code-model=large -large-data-threshold=1000 | FileCheck %s --check-prefix=CHECK --check-prefix=LARGE-PIC + + ; Generated from this C source: + ; +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/code-model-elf-text-sections.ll b/llvm/test/CodeGen/X86/code-model-elf-text-sections.ll +--- a/llvm/test/CodeGen/X86/code-model-elf-text-sections.ll ++++ b/llvm/test/CodeGen/X86/code-model-elf-text-sections.ll +@@ -1,25 +0,0 @@ +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=small -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=SMALL +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=medium -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=SMALL +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=large -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=LARGE +- +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=small -function-sections -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=SMALL-DS +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=medium -function-sections -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=SMALL-DS +-; RUN: llc < %s -relocation-model=pic -filetype=obj -code-model=large -function-sections -o %t +-; RUN: llvm-readelf -S %t | FileCheck %s --check-prefix=LARGE-DS +- +-; SMALL: .text {{.*}} AX {{.*}} +-; SMALL-DS: .text.func {{.*}} AX {{.*}} +-; LARGE: .ltext {{.*}} AXl {{.*}} +-; LARGE-DS: .ltext.func {{.*}} AXl {{.*}} +- +-target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +-target triple = "x86_64--linux" +- +-define void @func() { +- ret void +-} +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/pcsections.ll b/llvm/test/CodeGen/X86/pcsections.ll +--- a/llvm/test/CodeGen/X86/pcsections.ll ++++ b/llvm/test/CodeGen/X86/pcsections.ll +@@ -19,12 +19,12 @@ + ; CHECK: # %bb.0: # %entry + ; CHECK-NEXT: retq + ; CHECK-NEXT: .Lfunc_end0: +-; CHECK: .section section_no_aux,"awo",@progbits,.{{l?}}text ++; CHECK: .section section_no_aux,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base0: + ; DEFCM-NEXT: .long .Lfunc_begin0-.Lpcsection_base0 + ; LARGE-NEXT: .quad .Lfunc_begin0-.Lpcsection_base0 + ; CHECK-NEXT: .long .Lfunc_end0-.Lfunc_begin0 +-; CHECK-NEXT: .{{l?}}text ++; CHECK-NEXT: .text + entry: + ret void + } +@@ -35,7 +35,7 @@ + ; CHECK: # %bb.0: # %entry + ; CHECK-NEXT: retq + ; CHECK-NEXT: .Lfunc_end1: +-; CHECK: .section section_aux,"awo",@progbits,.{{l?}}text ++; CHECK: .section section_aux,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base1: + ; DEFCM-NEXT: .long .Lfunc_begin1-.Lpcsection_base1 + ; LARGE-NEXT: .quad .Lfunc_begin1-.Lpcsection_base1 +@@ -43,7 +43,7 @@ + ; CHECK-NEXT: .long 10 + ; CHECK-NEXT: .long 20 + ; CHECK-NEXT: .long 30 +-; CHECK-NEXT: .{{l?}}text ++; CHECK-NEXT: .text + entry: + ret void + } +@@ -56,22 +56,22 @@ + ; CHECK-NEXT: movq + ; CHECK-NEXT: retq + ; CHECK-NEXT: .Lfunc_end2: +-; CHECK: .section section_no_aux,"awo",@progbits,.{{l?}}text ++; CHECK: .section section_no_aux,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base2: + ; DEFCM-NEXT: .long .Lfunc_begin2-.Lpcsection_base2 + ; LARGE-NEXT: .quad .Lfunc_begin2-.Lpcsection_base2 + ; CHECK-NEXT: .long .Lfunc_end2-.Lfunc_begin2 +-; CHECK-NEXT: .section section_aux_42,"awo",@progbits,.{{l?}}text ++; CHECK-NEXT: .section section_aux_42,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base3: + ; DEFCM-NEXT: .long .Lpcsection0-.Lpcsection_base3 + ; LARGE-NEXT: .quad .Lpcsection0-.Lpcsection_base3 + ; CHECK-NEXT: .long 42 +-; CHECK-NEXT: .section section_aux_21264,"awo",@progbits,.{{l?}}text ++; CHECK-NEXT: .section section_aux_21264,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base4: + ; DEFCM-NEXT: .long .Lpcsection0-.Lpcsection_base4 + ; LARGE-NEXT: .quad .Lpcsection0-.Lpcsection_base4 + ; CHECK-NEXT: .long 21264 +-; CHECK-NEXT: .{{l?}}text ++; CHECK-NEXT: .text + entry: + %0 = load i64, ptr @bar, align 8, !pcsections !2 + ret i64 %0 +@@ -79,7 +79,7 @@ + + define void @multiple_uleb128() !pcsections !6 { + ; CHECK-LABEL: multiple_uleb128: +-; CHECK: .section section_aux,"awo",@progbits,.{{l?}}text ++; CHECK: .section section_aux,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base5: + ; DEFCM-NEXT: .long .Lfunc_begin3-.Lpcsection_base5 + ; LARGE-NEXT: .quad .Lfunc_begin3-.Lpcsection_base5 +@@ -87,13 +87,13 @@ + ; CHECK-NEXT: .byte 42 + ; CHECK-NEXT: .ascii "\345\216&" + ; CHECK-NEXT: .byte 255 +-; CHECK-NEXT: .section section_aux_21264,"awo",@progbits,.{{l?}}text ++; CHECK-NEXT: .section section_aux_21264,"awo",@progbits,.text + ; CHECK-NEXT: .Lpcsection_base6: + ; DEFCM-NEXT: .long .Lfunc_begin3-.Lpcsection_base6 + ; LARGE-NEXT: .quad .Lfunc_begin3-.Lpcsection_base6 + ; CHECK-NEXT: .long .Lfunc_end3-.Lfunc_begin3 + ; CHECK-NEXT: .long 21264 +-; CHECK-NEXT: .{{l?}}text ++; CHECK-NEXT: .text + entry: + ret void + } +diff -ruN --strip-trailing-cr a/llvm/test/ExecutionEngine/OrcLazy/debug-objects-elf-minimal.ll b/llvm/test/ExecutionEngine/OrcLazy/debug-objects-elf-minimal.ll +--- a/llvm/test/ExecutionEngine/OrcLazy/debug-objects-elf-minimal.ll ++++ b/llvm/test/ExecutionEngine/OrcLazy/debug-objects-elf-minimal.ll +@@ -44,7 +44,7 @@ + ; RUN: --generate=__dump_jit_debug_objects %s | llvm-objdump --section-headers - | \ + ; RUN: FileCheck --check-prefix=CHECK_LOAD_ADDR %s + ; +-; CHECK_LOAD_ADDR-NOT: {{[0-9]*}} .ltext {{.*}} 0000000000000000 TEXT ++; CHECK_LOAD_ADDR-NOT: {{[0-9]*}} .text {{.*}} 0000000000000000 TEXT + + target triple = "x86_64-unknown-unknown-elf" + diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 163586edceb97a..b65109291359b8 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "511ba45a47d6f9e48ad364181830c9fb974135b2" - LLVM_SHA256 = "23b4e703adbb219853d3d375379e15beefea13a9062715f31e50140c2bafe540" + LLVM_COMMIT = "668865789620f390fbad4d7093ed8ca6eb932c31" + LLVM_SHA256 = "8d7cbbe492a17656c09af1e79b802303f11cb47d64768760b70d52f11ed4d9da" tf_http_archive( name = name, From f4022868c6286f83be8fcbdbe2fa9ac15b1381ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 11:10:59 -0800 Subject: [PATCH 291/381] Set XlaCallModuleOp Attribute jax.uses_shape_polymorphism=true PiperOrigin-RevId: 587062867 --- ...lehlo_ops_in_main_function_with_xla_call_module_ops.cc | 8 ++++++++ ...hlo_ops_in_main_function_with_xla_call_module_ops.mlir | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 460cc8dd15f60e..edcfde78eecbb1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -45,6 +45,8 @@ namespace mlir::quant::stablehlo { namespace { constexpr StringRef kQuantizeTargetOpAttr = "tf_quant.composite_function"; +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; // Default version number for native serialization. constexpr int64_t kDefaultVersion = 9; @@ -187,6 +189,12 @@ void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, /*disabled_checks=*/empty_array_attr); xla_call_module_op->setAttr(TF::kStablehloEntryFunctionAttrName, SymbolRefAttr::get(stablehlo_func_op)); + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + xla_call_module_op->setAttr( + kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); for (auto [original_output_value, xla_call_module_op_result_value] : llvm::zip_equal(outputs, xla_call_module_op->getResults())) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 036f0709611bf2..745d44282c9e0f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -91,7 +91,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %5 : tensor<1x1024xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) // CHECK: return %[[IDENTITY]] // CHECK } @@ -117,7 +117,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %3 : tensor<1x3xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0} + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) From f1dcd211e0f10c7124fe3e9dfec77279aa682e0e Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Fri, 1 Dec 2023 11:21:20 -0800 Subject: [PATCH 292/381] Factor out `EnableDebugging` for `pywrap_` files in quantization. PiperOrigin-RevId: 587066173 --- .../mlir/quantization/stablehlo/cc/BUILD | 16 ++++ .../quantization/stablehlo/cc/debugger.cc | 73 +++++++++++++++++++ .../mlir/quantization/stablehlo/cc/debugger.h | 50 +++++++++++++ .../mlir/quantization/stablehlo/python/BUILD | 7 +- .../stablehlo/python/pywrap_quantization.cc | 73 +++---------------- .../mlir/quantization/tensorflow/python/BUILD | 5 +- .../python/pywrap_quantize_model.cc | 67 ++--------------- 7 files changed, 164 insertions(+), 127 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 511757314e2da7..67e3eb9ca58515 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -73,3 +73,19 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", ], ) + +cc_library( + name = "debugger", + srcs = ["debugger.cc"], + hdrs = ["debugger.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc new file mode 100644 index 00000000000000..4588d5f00a7523 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::NodeDef; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::DebuggerOptions; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; + +} // namespace + +void EnableDebugging( + ExportedModel& exported_model, const DebuggerOptions& debugger_options, + const PyFunctionLibrary& py_function_library, + const absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& signature_def_map) { + // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by + // default to avoid logging data during calibration. + MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["enabled"].set_b(true); + } + }); + + if (debugger_options.debugger_type() == + DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { + // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. + // TODO: b/296916287 - Create a separate function for saving unquantized + // dump model. + py_function_library.SaveExportedModel( + debugger_options.unquantized_dump_model_path(), exported_model, + src_saved_model_path, tags, signature_def_map); + + // Update the `DumpTensor` ops' file name in `graph_def`. + MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["file_name"].set_s( + "quantized_tensor_data.pb"); + } + }); + } +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h new file mode 100644 index 00000000000000..6bb427ecbdf1fd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace stablehlo::quantization { + +// Enables debugging on `exported_model` by updating the `DumpTensor` ops. +// +// Saves the current model to `debugger_options.unquantized_dump_model_path()` +// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because +// in whole-model debugging mode the `DumpTensor` ops for the unquantized +// tensors are only inserted in the unquantized model whereas `DumpTensor` ops +// for the quantized tensors are only inserted in the quantized model. Both +// models are required to be able to dump both quantized and unquantized tensors +// and compare them offline. +void EnableDebugging( + tensorflow::quantization::ExportedModel& exported_model, + const tensorflow::quantization::DebuggerOptions& debugger_options, + const tensorflow::quantization::PyFunctionLibrary& py_function_library, + absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& + signature_def_map); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 24a8248e750c56..199b0c1f89a964 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -86,7 +86,7 @@ tf_python_pybind_extension( srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], deps = [ - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -95,9 +95,10 @@ tf_python_pybind_extension( "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:env", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:import_status_module", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index ca4518462779e5..e8765c2620e141 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -17,8 +17,11 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 @@ -26,7 +29,7 @@ limitations under the License. #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" @@ -36,74 +39,18 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tsl/platform/env.h" namespace py = pybind11; namespace { -using ::stablehlo::quantization::MutateNodeDefs; +using ::stablehlo::quantization::EnableDebugging; using ::stablehlo::quantization::io::CreateTmpDir; -using ::tensorflow::FunctionDef; -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; using ::tensorflow::SignatureDef; -using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; -// TODO: b/312371048 - Factor out this function to a separate file. -// Enables debugging on `exported_model` by updating the `DumpTensor` ops. -// -// Saves the current model to `debugger_options.unquantized_dump_model_path()` -// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because -// in whole-model debugging mode the `DumpTensor` ops for the unquantized -// tensors are only inserted in the unquantized model whereas `DumpTensor` ops -// for the quantized tensors are only inserted in the quantized model. Both -// models are required to be able to dump both quantized and unquantized tensors -// and compare them offline. -ExportedModel EnableDebugging( - const ExportedModel& exported_model, - const DebuggerOptions& debugger_options, - const PyFunctionLibrary& py_function_library, - const absl::string_view src_saved_model_path, - const std::unordered_set& tags, - const absl::flat_hash_map& signature_def_map) { - ExportedModel debugger_enabled_exported_model = exported_model; - - // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by - // default to avoid logging data during calibration. - MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), - [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["enabled"].set_b(true); - } - }); - - if (debugger_options.debugger_type() == - DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { - // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. - // TODO: b/296916287 - Create a separate function for saving unquantized - // dump model. - py_function_library.SaveExportedModel( - debugger_options.unquantized_dump_model_path(), - debugger_enabled_exported_model, src_saved_model_path, tags, - signature_def_map); - - // Update the `DumpTensor` ops' file name in `graph_def`. - MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), - [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["file_name"].set_s( - "quantized_tensor_data.pb"); - } - }); - } - - return debugger_enabled_exported_model; -} - } // namespace PYBIND11_MODULE(pywrap_quantization, m) { @@ -161,10 +108,10 @@ PYBIND11_MODULE(pywrap_quantization, m) { representative_dataset); if (quantization_options.has_debugger_options()) { - calibrated_exported_model = EnableDebugging( - calibrated_exported_model, - quantization_options.debugger_options(), py_function_library, - src_saved_model_path, tags, signature_def_map); + EnableDebugging(calibrated_exported_model, + quantization_options.debugger_options(), + py_function_library, src_saved_model_path, tags, + signature_def_map); } const absl::StatusOr calibrated_saved_model_path = diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index d3c86b151e7912..71bf9b4dc34af0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -165,9 +165,11 @@ cc_library( cc_library( name = "py_function_lib", hdrs = ["py_function_lib.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", @@ -203,13 +205,12 @@ tf_python_pybind_extension( deps = [ ":py_function_lib", ":type_casters", - "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", - "//tensorflow/python/lib/core:pybind11_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 5f9c9254fe1940..db97a95f0aecd3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -29,7 +29,7 @@ limitations under the License. #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf -#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" @@ -38,15 +38,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" -#include "tensorflow/python/lib/core/pybind11_lib.h" + +namespace py = pybind11; namespace { -using ::stablehlo::quantization::MutateNodeDefs; +using ::stablehlo::quantization::EnableDebugging; using ::stablehlo::quantization::io::CreateTmpDir; -using ::tensorflow::NodeDef; using ::tensorflow::SignatureDef; -using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; @@ -56,56 +55,6 @@ using ::tensorflow::quantization::QuantizePtqModelPreCalibration; using ::tensorflow::quantization::QuantizeQatModel; using ::tensorflow::quantization::QuantizeWeightOnly; -// Enables debugging on `exported_model` by updating the `DumpTensor` ops. -// -// Saves the current model to `debugger_options.unquantized_dump_model_path()` -// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because -// in whole-model debugging mode the `DumpTensor` ops for the unquantized -// tensors are only inserted in the unquantized model whereas `DumpTensor` ops -// for the quantized tensors are only inserted in the quantized model. Both -// models are required to be able to dump both quantized and unquantized tensors -// and compare them offline. -ExportedModel EnableDebugging( - const ExportedModel& exported_model, - const DebuggerOptions& debugger_options, - const PyFunctionLibrary& py_function_library, - const absl::string_view src_saved_model_path, - const std::unordered_set& tags, - const absl::flat_hash_map& signature_def_map) { - ExportedModel debugger_enabled_exported_model = exported_model; - - // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by - // default to avoid logging data during calibration. - MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), - [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["enabled"].set_b(true); - } - }); - - if (debugger_options.debugger_type() == - DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { - // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. - // TODO: b/296916287 - Create a separate function for saving unquantized - // dump model. - py_function_library.SaveExportedModel( - debugger_options.unquantized_dump_model_path(), - debugger_enabled_exported_model, src_saved_model_path, tags, - signature_def_map); - - // Update the `DumpTensor` ops' file name in `graph_def`. - MutateNodeDefs(*debugger_enabled_exported_model.mutable_graph_def(), - [](NodeDef& node_def) { - if (node_def.op() == "DumpTensor") { - (*node_def.mutable_attr())["file_name"].set_s( - "quantized_tensor_data.pb"); - } - }); - } - - return debugger_enabled_exported_model; -} - } // namespace PYBIND11_MODULE(pywrap_quantize_model, m) { @@ -305,10 +254,10 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { representative_dataset); if (quantization_options.has_debugger_options()) { - calibrated_exported_model = EnableDebugging( - calibrated_exported_model, - quantization_options.debugger_options(), py_function_library, - src_saved_model_path, tags, signature_def_map); + EnableDebugging(calibrated_exported_model, + quantization_options.debugger_options(), + py_function_library, src_saved_model_path, tags, + signature_def_map); } const absl::StatusOr calibrated_saved_model_path = From aa28566305528e21741465dd883d3f8116893571 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 11:39:11 -0800 Subject: [PATCH 293/381] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/86ccd9b2b59fd0de694123d63ba1b19bf9690b37. PiperOrigin-RevId: 587071230 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8cdfbfc358b1fb..aa6254002ee6d1 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "bf11f0d876fef436c5ea018a5b13eb2e89e53b7a" - TFRT_SHA256 = "f78926aaefd521c80154ff9c2c85ef77c0bb5ee34f5af4f24f3a349d22b41ff8" + TFRT_COMMIT = "86ccd9b2b59fd0de694123d63ba1b19bf9690b37" + TFRT_SHA256 = "7827819ed2713be7e4aacb952ac16b205c96d827f87953e4d359f7322cb6af64" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 8cdfbfc358b1fb..aa6254002ee6d1 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "bf11f0d876fef436c5ea018a5b13eb2e89e53b7a" - TFRT_SHA256 = "f78926aaefd521c80154ff9c2c85ef77c0bb5ee34f5af4f24f3a349d22b41ff8" + TFRT_COMMIT = "86ccd9b2b59fd0de694123d63ba1b19bf9690b37" + TFRT_SHA256 = "7827819ed2713be7e4aacb952ac16b205c96d827f87953e4d359f7322cb6af64" tf_http_archive( name = "tf_runtime", From 0d321588ac405ad085745c5e685903ca422e3fab Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 11:54:44 -0800 Subject: [PATCH 294/381] [xla:gpu] Add Async HLO Ops documentation PiperOrigin-RevId: 587075805 --- third_party/xla/docs/async_ops.md | 121 ++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 third_party/xla/docs/async_ops.md diff --git a/third_party/xla/docs/async_ops.md b/third_party/xla/docs/async_ops.md new file mode 100644 index 00000000000000..889272eecc4411 --- /dev/null +++ b/third_party/xla/docs/async_ops.md @@ -0,0 +1,121 @@ +# Async HLO Instructions + +1. Adding async operations to HLO is cumbersome (i.e. `all-reduce-start` and + `all-reduce-done`). +2. The start and done split may be inadequate for some of the asynchronous use + cases. + +To target the first shortcoming, we propose to introduce one last set of new +asynchronous opcodes: `kAsyncStart`, `kAsyncUpdate`, and `kAsyncDone`. The idea +is to create a generic asynchronous opcode that can wrap any HLO instruction. +The actual operation that will be performed asynchronously will be encoded using +a called computation that only has the instruction as its root and any +parameters for inputs. The in-flight input/output buffer handling and aliasing +can then be shared for any asynchronous operation. The async-start instruction’s +output shape will then be a tuple of the input operands, output values, and any +intermediate state that is needed for the `async-update` or `async-done` +instructions. + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start), + calls=%async_op +``` + +In the representation above, only `async-start` has a called computation since +it is trivial to find what the `async-done` does by following its operand to +find the corresponding `async-start` to find the called computation. + +Today both `async-start` and `async-done` have a called computation attribute, +but long term we plan to keep it only for `async-start`, since it is trivial +to find what the `async-done` does by following its operand to find the +corresponding `async-start` to find the called computation. + +> [!NOTE] +> Tracked as b/302594825 internally. + +Also note +that the first element in the output tuple of `async-start` aliases with the +operand, so the buffer stays alive until at least the async-done instruction. +Similarly, the second element aliases with the output of `async-done`, and the +third element is the context state that is used to keep track of the +asynchronous operation. This representation also supports multiple tensors in +the asynchronous operation input and/or output and the aliasing works the same +way: + +``` +%async_op { + %param0 = f32[64] parameter(0) + %param1 = f32[64] parameter(1) + ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1), + op_specific_attr=”foo” +} + +%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[]) + async-start(f32[64] %operand0, f32[64] %operand1), + calls=%async_op +%async-done = (f32[32], f32[32]) async-done(%async-start) +``` + +In addition, the op can further be decomposed into zero or more `async-update` +steps that perform intermediate computations. The input/output aliasing works +the same way with the `async-update` instruction and each `async-start` and +`async-update` instructions must have one user that is either another +`async-update` or an `async-done`: + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-update0 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-start) +%async-update1 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-update0) +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-update1) + +``` + +## Syntax sugar + +Since having a separate computation to define the operation that will be +performed asynchronously is a bit cumbersome, we also propose a syntax sugar to +automatically print and parse asynchronous operations as if they are first-class +opcodes. The idea is to treat the “-start”, “-update”, and “-done” suffixes +specially by automatically creating the computation and instruction (without the +suffix) when parsing. For example, the code snippet above can be pretty-printed +to the following and the two can be parsed to the same representation: + +``` +%op-start = (f32[64], f32[32], s32[]) op-start(f32[64] %operand), + op_specific_attr=”foo” +%op-update0 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-start), + op_specific_attr=”foo” +%op-update1 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-update0), + op_specific_attr=”foo” +%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1), + op_specific_attr=”foo” + +``` + +In order not to create ambiguities, the verifier will not allow an operation to +be wrapped with async-start if we explicitly defined an opcode for that +operation with the “-start” and/or “-done” suffixes. This is also an escape +hatch in case we have any instructions that require HLO-level treatment that +doesn’t fit in the model described above (e.g. the aliasing input/output +buffers). So, initially, `copy-start`/`copy-done`, +`collective-permute-start`/`collective-permute-done` etc. will continue to use +their respective first-class opcodes instead of the new +`async-start`/`async-done` opcodes until we clean up the code to remove these +“-start”/”-done” opcodes. From 50b71dc9abf38c6b5b4e41ddcab26fbd41b25973 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 1 Dec 2023 11:58:56 -0800 Subject: [PATCH 295/381] [XLA:SPMD] Use CollectiveOpGroupMode::kFlattenedID in GetDefaultCollectiveOpsCreator if partition_subgroups.size() <= 1. Related CollectiveOpGroupMode * kCrossReplicaAndPartition: channel_id is set, use_global_device_ids = false * kFlattenedID: channel_id is set, use_global_device_ids = true kCrossReplicaAndPartition can be rewritten as kFlattenedID with flattened IDs in replica_groups. Before this cl, if partition_subgroups.size() <= 1, we use kCrossReplicaAndPartition. We unify the mode as kFlattenedID in this cl. PiperOrigin-RevId: 587076950 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 068e94eac10fa7..a633dec5677b4a 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -4723,34 +4723,28 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, const std::vector>& partition_subgroups, int64_t channel_id) { - if (partition_subgroups.size() <= 1) { - std::vector groups(num_replicas); - // TODO(yuanzx): Unify subgroup definition with AllToAll. - for (int64_t i = 0; i < num_replicas; ++i) { - groups[i].add_replica_ids(i); - } - HloComputation* reduction_clone = - reduction->parent()->AddComputationAndUnifyNamesAndIds( - reduction->Clone(), false); - HloInstruction* all_reduce = - b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction_clone, groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); - reduction_clone->SetCollectiveCallInstruction(all_reduce); - return all_reduce; - } - std::vector device_groups; - device_groups.reserve(partition_subgroups.size() * num_replicas); - for (int64_t i = 0; i < num_replicas; ++i) { - for (const auto& pgroup : partition_subgroups) { + if (partition_subgroups.size() <= 1) { + device_groups.reserve(num_replicas); + for (int64_t rid = 0; rid < num_replicas; ++rid) { device_groups.emplace_back(); - for (int64_t pid : pgroup) { - device_groups.back().add_replica_ids(i * num_partitions + pid); + for (int64_t pid = 0; pid < num_partitions; ++pid) { + device_groups.back().add_replica_ids(rid * num_partitions + pid); + } + } + } else { + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64_t rid = 0; rid < num_replicas; ++rid) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64_t pid : pgroup) { + device_groups.back().add_replica_ids(rid * num_partitions + + pid); + } } } } + HloComputation* reduction_clone = reduction->parent()->AddComputationAndUnifyNamesAndIds( reduction->Clone(), false); From 85a5052bb80dab0deb648ce60445b16695d5ebc0 Mon Sep 17 00:00:00 2001 From: Jackson Stokes Date: Fri, 1 Dec 2023 11:59:57 -0800 Subject: [PATCH 296/381] [XLA:GPU] Fix crash in triton tiling of pad ops. Pad ops are expected to be tiled in conjunction with a split-k transform. This adds a check to reject tiling of triton softmax fusions with pad ops. PiperOrigin-RevId: 587077239 --- .../gpu/triton_fusion_analysis_test.cc | 32 +++++++++++++++++++ .../service/gpu/triton_tiling_propagation.cc | 4 +++ 2 files changed, 36 insertions(+) diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index 17ac3b47cafcd3..918b5ac93c5251 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -706,6 +706,38 @@ ENTRY main { EXPECT_FALSE(analysis.ok()); } +TEST_F(TritonSoftmaxAnalysisTest, PadWithinTritonSoftmaxIsNotSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_1 = f32[4,127]{1,0} parameter(0) + constant_0 = f32[] constant(0) + reduce = f32[4]{0} reduce(param_1, constant_0), dimensions={1}, to_apply=add + broadcast = f32[4,127]{1,0} broadcast(reduce), dimensions={0} + ROOT pad = f32[8,127]{1,0} pad(broadcast, constant_0), padding=0_4x0_0 +} + +ENTRY main { + param_0 = f32[4,127]{1,0} parameter(0) + ROOT fusion = f32[8,127]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index b0f27751fa651d..240888d09736d2 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -842,6 +842,10 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); } else if (hlo.opcode() == HloOpcode::kPad) { + if (std::holds_alternative(properties)) { + return "Pad ops are only supported when they are generated as part of " + "the split-k transform of dot fusions."; + } if (direction != TransformDirection::kOutputToInput) { return "Unsupported pad direction."; } From 7ba21c6a89283b016899ef7e20e842fdbe94edc9 Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Fri, 1 Dec 2023 12:00:00 -0800 Subject: [PATCH 297/381] [HloValueSemanticsAnalysis] Add a number of missing handlers. PiperOrigin-RevId: 587077258 --- .../service/hlo_value_semantics_analysis.cc | 94 ++++++++++++++++++- .../service/hlo_value_semantics_analysis.h | 4 + 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index af50110b6d6d84..c94b63ef56c02c 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -457,6 +457,51 @@ Status EinsumDepthAnalysis::HandleOutfeed(HloInstruction* outfeed) { return OkStatus(); } +Status EinsumDepthAnalysis::HandleCollectivePermuteStart( + HloInstruction* collective_permute_start) { + auto depth_iter = einsum_depth_map_.find(collective_permute_start); + CHECK(depth_iter != einsum_depth_map_.end()); + const ShapeTree& depth_tree = depth_iter->second; + for (int operand_index = 0; + operand_index < collective_permute_start->operand_count(); + ++operand_index) { + HloInstruction* operand = + collective_permute_start->mutable_operand(operand_index); + if (operand_index >= 2) { + TF_RETURN_IF_ERROR(SetInstructionDepth(operand, GetMaxDepth(depth_tree))); + continue; + } + auto operand_depth_iter = GetOrCreateDepthTree(operand); + ShapeTree& operand_depth = operand_depth_iter->second; + SetDepthFromTupleDepth(operand_depth, depth_tree, 1); + } + return OkStatus(); +} + +Status EinsumDepthAnalysis::HandleCollectivePermuteDone( + HloInstruction* collective_permute_done) { + auto depth_iter = einsum_depth_map_.find(collective_permute_done); + CHECK(depth_iter != einsum_depth_map_.end()); + const ShapeTree& depth_tree = depth_iter->second; + auto operand_depth_iter = + GetOrCreateDepthTree(collective_permute_done->mutable_operand(0)); + ShapeTree& operand_depth = operand_depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + operand_depth.ForEachMutableElement([&operand_depth, &depth_tree, max_depth]( + const ShapeIndex& index, int* depth) { + if (!operand_depth.IsLeaf(index)) { + return; + } + if (index.front() == 0 || index.front() == 1) { + ShapeIndex output_index = index; + output_index.pop_front(); + *depth = depth_tree.element(output_index); + } + *depth = max_depth; + }); + return OkStatus(); +} + std::string HloValueSemanticLabelToString(HloValueSemanticLabel label) { switch (label) { case HloValueSemanticLabel::kStatic: @@ -1165,7 +1210,9 @@ Status HloValueSemanticsPropagation::HandleWhile(HloInstruction* xla_while) { Status HloValueSemanticsPropagation::HandleCustomCall( HloInstruction* custom_call) { - if (custom_call->custom_call_target() == "Sharding") { + if (custom_call->custom_call_target() == "Sharding" || + custom_call->custom_call_target() == "SPMDFullToShardShape" || + custom_call->custom_call_target() == "SPMDShardToFullShape") { const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(custom_call->operand(0)); analysis_->DeepCopyHloValueSemantics(custom_call, operand_semantics); @@ -1345,11 +1392,52 @@ Status HloValueSemanticsPropagation::HandleAfterAll(HloInstruction* after_all) { Status HloValueSemanticsPropagation::HandleAsyncStart( HloInstruction* async_start) { - return Unimplemented("AsyncStart is not supported yet."); + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {async_start, {}}); + ShapeTree semantics_shape_tree(async_start->shape(), + semantics); + for (int operand_index = 0; operand_index < async_start->operand_count(); + ++operand_index) { + HloInstruction* operand = async_start->mutable_operand(operand_index); + const ShapeTree& operand_semantics_tree = + analysis_->GetInstructionSemantics(operand); + analysis_->DeepCopyHloValueSemantics( + semantics_shape_tree, operand_semantics_tree, {}, {0, operand_index}); + } + std::vector operand_indices(async_start->operand_count()); + std::iota(operand_indices.begin(), operand_indices.end(), 0); + TF_ASSIGN_OR_RETURN( + HloValueSemantics output_semantics, + ComputeSemanticsFromOperands(async_start, operand_indices)); + semantics_shape_tree.ForEachMutableElement( + [&output_semantics, &semantics_shape_tree, this, async_start]( + const ShapeIndex& index, const HloValueSemantics** semantics_ptr) { + if (index.empty() || index.front() == 0) { + return; + } + if (!semantics_shape_tree.IsLeaf(index)) { + *semantics_ptr = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {async_start, {}}); + return; + } + if (index.front() == 1) { + *semantics_ptr = AddSemantics(output_semantics); + return; + } + if (index.front() == 2) { + *semantics_ptr = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {async_start, {}}); + } + }); + analysis_->SetHloValueSemantics(async_start, semantics_shape_tree); + return OkStatus(); } Status HloValueSemanticsPropagation::HandleAsyncDone( HloInstruction* async_done) { - return Unimplemented("AsyncDone is not supported yet."); + const ShapeTree& operand_semantics_tree = + analysis_->GetInstructionSemantics(async_done->operand(0)); + analysis_->DeepCopyHloValueSemantics(async_done, operand_semantics_tree, {1}); + return OkStatus(); } Status HloValueSemanticsPropagation::HandleInfeed(HloInstruction* infeed) { diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index 895f0583390601..6b8084aaf9af9f 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -85,6 +85,10 @@ class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleAfterAll(HloInstruction* after_all) override; Status HandleOutfeed(HloInstruction* outfeed) override; + Status HandleCollectivePermuteStart( + HloInstruction* collective_permute_start) override; + Status HandleCollectivePermuteDone( + HloInstruction* collective_permute_done) override; const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } private: From 50c5cfbb08fc677c3937e678fefc94935ca02a07 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Dec 2023 12:00:12 -0800 Subject: [PATCH 298/381] [TSL] Trim dependencies of the CUDA stubs. * Avoid depending on tsl::Env, since that links in many other parts of TSL. We don't need this indirection here anyway, since the users are other parts of TSL. * Use absl::Status directly, instead of depending on tsl/platform/errors.h and tsl/platform/status.h. Those dependencies link in protobuf, which has a noticeable binary size impact. However we don't need the TSL error helpers here: just use ABSL errors. These lack a handful of conveniences, but for these particular messages the extra context information isn't helpful and it's not worth the binary size impact. Reduces the jaxlib CUDA wheel size by around 4MB, since multiple libraries in JAX link in these stubs. PiperOrigin-RevId: 587077318 --- .../xla/third_party/tsl/tsl/cuda/BUILD.bazel | 34 +++-- .../third_party/tsl/tsl/cuda/cublasLt_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cublas_stub.cc | 6 +- .../xla/third_party/tsl/tsl/cuda/cuda_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cudart_stub.cc | 7 +- .../third_party/tsl/tsl/cuda/cudnn_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cufft_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cupti_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cusolver_stub.cc | 6 +- .../third_party/tsl/tsl/cuda/cusparse_stub.cc | 6 +- .../xla/third_party/tsl/tsl/cuda/nccl_stub.cc | 6 +- .../tsl/tsl/platform/default/BUILD | 10 +- .../tsl/platform/default/dlopen_checker.cc | 31 ++-- .../platform/default/dlopen_checker_stub.cc | 6 +- .../tsl/tsl/platform/default/dso_loader.cc | 132 ++++++++++-------- .../tsl/tsl/platform/default/dso_loader.h | 88 ++++++------ .../tsl/tsl/platform/default/load_library.cc | 26 ++-- .../tsl/tsl/platform/load_library.h | 13 +- .../tsl/tsl/platform/windows/BUILD | 5 +- .../tsl/tsl/platform/windows/load_library.cc | 23 +-- 20 files changed, 222 insertions(+), 207 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel index 57597e207686ff..6ccfd7a019a3ce 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel +++ b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel @@ -2,10 +2,6 @@ # Stubs for dynamically loading CUDA. load("//tsl/cuda:stub.bzl", "cuda_stub") -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -44,7 +40,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -67,7 +64,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -90,7 +88,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -122,7 +121,8 @@ cc_library( "//tsl:is_cuda_enabled_and_oss": [ ":cuda", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:load_library", + "//tsl/platform:logging", "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", ], @@ -151,7 +151,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@local_config_cuda//cuda:cudnn_header", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -187,7 +188,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -213,7 +215,8 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cupti_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -237,7 +240,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -261,7 +265,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -287,6 +292,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_nccl//:nccl_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc index df4e73bebc126c..d078aa2f2c55ee 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. @@ -33,8 +34,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc index 814d64d75d8d61..fe3cec911ca186 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc @@ -24,7 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLAS API by forwarding to cuBLAS loaded from the DSO. // Note that it does not implement the v1 interface. @@ -43,8 +44,7 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc index a199d4cc700442..298d493db97d15 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUDA driver API by forwarding to CUDA loaded from the DSO. @@ -36,8 +37,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc index a3797b5c751cd8..5ec2fabd84a712 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" namespace { void *GetDsoHandle() { @@ -39,8 +40,8 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; - auto env = tsl::Env::Default(); - env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError(); + tsl::internal::GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol) + .IgnoreError(); return symbol; } diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc index f3cab179eb0b71..1c85b1ea684a28 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "third_party/gpus/cudnn/cudnn.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuDNN API by forwarding to cuDNN loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc index 8f5c1b0d687337..275560027af19b 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cufftXt.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuFFT API by forwarding to cuFFT loaded from the DSO. @@ -37,8 +38,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc index 9e632010d83a7a..aab8217aa3ebe5 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUPTI API by forwarding to CUPTI loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc index d11601b3bd4217..418ce47311d718 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolverSp.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusolver API by forwarding to cusolver loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc index 16141e51e2613b..8b545cd0c1c1d8 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusparse API by forwarding to cusparse loaded from the DSO. @@ -37,8 +38,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc index 0ebae2f3c2b2eb..462ab127ee446b 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc @@ -18,7 +18,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/nccl/nccl.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the nccl API by forwarding to nccl loaded from a DSO. @@ -40,8 +41,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index 428af13748f2cb..aac69570b88c4f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -82,12 +82,11 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:env", - "//tsl/platform:errors", + "//tsl/platform:load_library", "//tsl/platform:logging", "//tsl/platform:path", - "//tsl/platform:status", - "//tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -247,8 +246,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc index 2d67789d8a0017..eb8fff80bfb6ac 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tsl/platform/default/dso_loader.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace internal { namespace DsoLoader { -Status TryDlopenCUDALibraries() { +absl::Status TryDlopenCUDALibraries() { namespace CachedLoader = ::tsl::internal::CachedDsoLoader; auto cudart_status = CachedLoader::GetCudaRuntimeDsoHandle(); auto cublas_status = CachedLoader::GetCublasDsoHandle(); @@ -36,14 +35,14 @@ Status TryDlopenCUDALibraries() { !cufft_status.status().ok() || !cusolver_status.status().ok() || !cusparse_status.status().ok() || !cudnn_status.status().ok() || !cublaslt_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all CUDA libraries.")); + return absl::Status(absl::StatusCode::kInternal, + absl::StrCat("Cannot dlopen all CUDA libraries.")); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status TryDlopenROCmLibraries() { +absl::Status TryDlopenROCmLibraries() { auto rocblas_status = GetRocblasDsoHandle(); auto miopen_status = GetMiopenDsoHandle(); auto rocfft_status = GetHipfftDsoHandle(); @@ -57,32 +56,30 @@ Status TryDlopenROCmLibraries() { || !hipblaslt_status.status().ok() #endif ) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all ROCm libraries.")); + return absl::InternalError("Cannot dlopen all ROCm libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { #if GOOGLE_CUDA return TryDlopenCUDALibraries(); #elif TENSORFLOW_USE_ROCM return TryDlopenROCmLibraries(); #else LOG(INFO) << "Not built with GPU enabled. Skip GPU library dlopen check."; - return tsl::OkStatus(); + return absl::OkStatus(); #endif } -Status TryDlopenTensorRTLibraries() { +absl::Status TryDlopenTensorRTLibraries() { auto nvinfer_status = GetNvInferDsoHandle(); auto nvinferplugin_status = GetNvInferPluginDsoHandle(); if (!nvinfer_status.status().ok() || !nvinferplugin_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all TensorRT libraries.")); + return absl::InternalError("Cannot dlopen all TensorRT libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc index 1d4b213427b5a0..67f734302835d8 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tsl/platform/default/dso_loader.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace tsl { namespace internal { namespace DsoLoader { // Skip check when GPU libraries are statically linked. -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { LOG(INFO) << "GPU libraries are statically linked, skip dlopen check."; - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace DsoLoader } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc index eeff5d9e7ed94d..a835a81489367a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc @@ -16,17 +16,18 @@ limitations under the License. #include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "third_party/gpus/cuda/cuda_config.h" #include "third_party/nccl/nccl_config.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/load_library.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "third_party/tensorrt/tensorrt_config.h" #if TENSORFLOW_USE_ROCM @@ -37,22 +38,23 @@ namespace tsl { namespace internal { namespace { -string GetCudaVersion() { return TF_CUDA_VERSION; } -string GetCudaRtVersion() { return TF_CUDART_VERSION; } -string GetCuptiVersion() { return TF_CUPTI_VERSION; } -string GetCudnnVersion() { return TF_CUDNN_VERSION; } -string GetCublasVersion() { return TF_CUBLAS_VERSION; } -string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } -string GetCufftVersion() { return TF_CUFFT_VERSION; } -string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } -string GetNcclVersion() { return TF_NCCL_VERSION; } -string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } - -StatusOr GetDsoHandle(const string& name, const string& version) { - auto filename = Env::Default()->FormatLibraryFileName(name, version); +std::string GetCudaVersion() { return TF_CUDA_VERSION; } +std::string GetCudaRtVersion() { return TF_CUDART_VERSION; } +std::string GetCuptiVersion() { return TF_CUPTI_VERSION; } +std::string GetCudnnVersion() { return TF_CUDNN_VERSION; } +std::string GetCublasVersion() { return TF_CUBLAS_VERSION; } +std::string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } +std::string GetCufftVersion() { return TF_CUFFT_VERSION; } +std::string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } +std::string GetNcclVersion() { return TF_NCCL_VERSION; } +std::string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } + +absl::StatusOr GetDsoHandle(const std::string& name, + const std::string& version) { + auto filename = tsl::internal::FormatLibraryFileName(name, version); void* dso_handle; - Status status = - Env::Default()->LoadDynamicLibrary(filename.c_str(), &dso_handle); + absl::Status status = + tsl::internal::LoadDynamicLibrary(filename.c_str(), &dso_handle); if (status.ok()) { VLOG(1) << "Successfully opened dynamic library " << filename; return dso_handle; @@ -66,12 +68,12 @@ StatusOr GetDsoHandle(const string& name, const string& version) { } #endif VLOG(1) << message; - return Status(absl::StatusCode::kFailedPrecondition, message); + return absl::Status(absl::StatusCode::kFailedPrecondition, message); } } // namespace namespace DsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvcuda", ""); #elif defined(__APPLE__) @@ -85,31 +87,31 @@ StatusOr GetCudaDriverDsoHandle() { return GetDsoHandle("cuda", "1"); } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { return GetDsoHandle("cudart", GetCudaRtVersion()); } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { return GetDsoHandle("cublas", GetCublasVersion()); } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { return GetDsoHandle("cublasLt", GetCublasVersion()); } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { return GetDsoHandle("cufft", GetCufftVersion()); } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { return GetDsoHandle("cusolver", GetCusolverVersion()); } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { return GetDsoHandle("cusparse", GetCusparseVersion()); } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { // Load specific version of CUPTI this is built. auto status_or_handle = GetDsoHandle("cupti", GetCuptiVersion()); if (status_or_handle.ok()) return status_or_handle; @@ -117,15 +119,15 @@ StatusOr GetCuptiDsoHandle() { return GetDsoHandle("cupti", ""); } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { return GetDsoHandle("cudnn", GetCudnnVersion()); } -StatusOr GetNcclDsoHandle() { +absl::StatusOr GetNcclDsoHandle() { return GetDsoHandle("nccl", GetNcclVersion()); } -StatusOr GetNvInferDsoHandle() { +absl::StatusOr GetNvInferDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvinfer", ""); #else @@ -133,7 +135,7 @@ StatusOr GetNvInferDsoHandle() { #endif } -StatusOr GetNvInferPluginDsoHandle() { +absl::StatusOr GetNvInferPluginDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvinfer_plugin", ""); #else @@ -141,134 +143,142 @@ StatusOr GetNvInferPluginDsoHandle() { #endif } -StatusOr GetRocblasDsoHandle() { return GetDsoHandle("rocblas", ""); } +absl::StatusOr GetRocblasDsoHandle() { + return GetDsoHandle("rocblas", ""); +} -StatusOr GetMiopenDsoHandle() { return GetDsoHandle("MIOpen", ""); } +absl::StatusOr GetMiopenDsoHandle() { + return GetDsoHandle("MIOpen", ""); +} -StatusOr GetHipfftDsoHandle() { return GetDsoHandle("hipfft", ""); } +absl::StatusOr GetHipfftDsoHandle() { + return GetDsoHandle("hipfft", ""); +} -StatusOr GetRocrandDsoHandle() { return GetDsoHandle("rocrand", ""); } +absl::StatusOr GetRocrandDsoHandle() { + return GetDsoHandle("rocrand", ""); +} -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { return GetDsoHandle("rocsolver", ""); } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { return GetDsoHandle("hipsolver", ""); } #endif -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { return GetDsoHandle("roctracer64", ""); } -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { return GetDsoHandle("hipsparse", ""); } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { return GetDsoHandle("hipblaslt", ""); } -StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } +absl::StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } } // namespace DsoLoader namespace CachedDsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { static auto result = new auto(DsoLoader::GetCudaDriverDsoHandle()); return *result; } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { static auto result = new auto(DsoLoader::GetCudaRuntimeDsoHandle()); return *result; } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { static auto result = new auto(DsoLoader::GetCublasDsoHandle()); return *result; } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { static auto result = new auto(DsoLoader::GetCublasLtDsoHandle()); return *result; } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { static auto result = new auto(DsoLoader::GetCufftDsoHandle()); return *result; } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { static auto result = new auto(DsoLoader::GetCusolverDsoHandle()); return *result; } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { static auto result = new auto(DsoLoader::GetCusparseDsoHandle()); return *result; } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { static auto result = new auto(DsoLoader::GetCuptiDsoHandle()); return *result; } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { static auto result = new auto(DsoLoader::GetCudnnDsoHandle()); return *result; } -StatusOr GetRocblasDsoHandle() { +absl::StatusOr GetRocblasDsoHandle() { static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); return *result; } -StatusOr GetMiopenDsoHandle() { +absl::StatusOr GetMiopenDsoHandle() { static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); return *result; } -StatusOr GetHipfftDsoHandle() { +absl::StatusOr GetHipfftDsoHandle() { static auto result = new auto(DsoLoader::GetHipfftDsoHandle()); return *result; } -StatusOr GetRocrandDsoHandle() { +absl::StatusOr GetRocrandDsoHandle() { static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); return *result; } -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { static auto result = new auto(DsoLoader::GetRoctracerDsoHandle()); return *result; } -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { static auto result = new auto(DsoLoader::GetRocsolverDsoHandle()); return *result; } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { static auto result = new auto(DsoLoader::GetHipsolverDsoHandle()); return *result; } #endif -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { static auto result = new auto(DsoLoader::GetHipsparseDsoHandle()); return *result; } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { static auto result = new auto(DsoLoader::GetHipblasltDsoHandle()); return *result; } -StatusOr GetHipDsoHandle() { +absl::StatusOr GetHipDsoHandle() { static auto result = new auto(DsoLoader::GetHipDsoHandle()); return *result; } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h index ee5b2b28af3486..6f72484d504f53 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ #define TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" namespace tsl { namespace internal { @@ -28,65 +28,65 @@ namespace internal { namespace DsoLoader { // The following methods either load the DSO of interest and return a dlopen // handle or error status. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); -StatusOr GetNcclDsoHandle(); -StatusOr GetNvInferDsoHandle(); -StatusOr GetNvInferPluginDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetNcclDsoHandle(); +absl::StatusOr GetNvInferDsoHandle(); +absl::StatusOr GetNvInferPluginDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipDsoHandle(); // The following method tries to dlopen all necessary GPU libraries for the GPU // platform TF is built with (CUDA or ROCm) only when these libraries should be // dynamically loaded. Error status is returned when any of the libraries cannot // be dlopened. -Status MaybeTryDlopenGPULibraries(); +absl::Status MaybeTryDlopenGPULibraries(); // The following method tries to dlopen all necessary TensorRT libraries when // these libraries should be dynamically loaded. Error status is returned when // any of the libraries cannot be dlopened. -Status TryDlopenTensorRTLibraries(); +absl::Status TryDlopenTensorRTLibraries(); } // namespace DsoLoader // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs // more than once. namespace CachedDsoLoader { // Cached versions of the corresponding DsoLoader methods above. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipblasltDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipblasltDsoHandle(); +absl::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc b/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc index f49adf2f7f257d..70961c8dc990ef 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc @@ -17,26 +17,26 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL); if (!*handle) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { // Check that the handle is not NULL to avoid dlsym's RTLD_DEFAULT behavior. if (!handle) { *symbol = nullptr; @@ -46,14 +46,14 @@ Status GetSymbolFromLibrary(void* handle, const char* symbol_name, if (!*symbol) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; #if defined(__APPLE__) if (version.size() == 0) { filename = "lib" + name + ".dylib"; diff --git a/third_party/xla/third_party/tsl/tsl/platform/load_library.h b/third_party/xla/third_party/tsl/tsl/platform/load_library.h index e46f85da0a7f9a..5a42f2a3439fd0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/load_library.h +++ b/third_party/xla/third_party/tsl/tsl/platform/load_library.h @@ -16,16 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ #define TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle); -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol); -string FormatLibraryFileName(const string& name, const string& version); +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle); +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol); +std::string FormatLibraryFileName(const std::string& name, + const std::string& version); } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD index 7ff0f110fe6722..2bde9eb95b3b73 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD @@ -1,10 +1,9 @@ -load("//tsl:tsl.default.bzl", "filegroup") - # Tensorflow windows-specific implementations of tensorflow/core/platform libraries. load( "//tsl:tsl.bzl", "tsl_copts", ) +load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -144,7 +143,7 @@ cc_library( deps = [ ":wide_char", "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc index 0c47532dc687a7..66d2d62cf6e130 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc @@ -28,7 +28,7 @@ limitations under the License. #include #include -#include "tsl/platform/errors.h" +#include "absl/status/status.h" #include "tsl/platform/windows/wide_char.h" #pragma comment(lib, "Shlwapi.lib") @@ -37,8 +37,8 @@ namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { - string file_name = library_filename; +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { + std::string file_name = library_filename; std::replace(file_name.begin(), file_name.end(), '/', '\\'); std::wstring ws_file_name(tsl::Utf8ToWideChar(file_name)); @@ -46,26 +46,27 @@ Status LoadDynamicLibrary(const char* library_filename, void** handle) { HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH); if (!hModule) { - return tsl::errors::NotFound(file_name + " not found"); + return absl::NotFoundError(file_name + " not found"); } *handle = hModule; - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { FARPROC found_symbol; found_symbol = GetProcAddress((HMODULE)handle, symbol_name); if (found_symbol == NULL) { - return tsl::errors::NotFound(std::string(symbol_name) + " not found"); + return absl::NotFoundError(std::string(symbol_name) + " not found"); } *symbol = (void**)found_symbol; - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; if (version.size() == 0) { filename = name + ".dll"; } else { From 7d6546aa26340e1aa7a64805064d68c1cd5eaf4a Mon Sep 17 00:00:00 2001 From: Grant Jensen Date: Fri, 1 Dec 2023 12:28:18 -0800 Subject: [PATCH 299/381] [tflite] Fix tflite selective_build_script. Was previously producing: `macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]` PiperOrigin-RevId: 587085882 --- .../xla/third_party/tsl/tsl/platform/default/subprocess.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc b/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc index d750328ebf38fd..c786295c08e0e9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc @@ -30,7 +30,11 @@ limitations under the License. #include "tsl/platform/logging.h" // Android versions older than 28 do not have posix_spawn(). -#define USE_POSIX_SPAWN !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#define USE_POSIX_SPAWN 1 +#else // defined(__ANDROID_API__) && __ANDROID_API__ < 28 +#define USE_POSIX_SPAWN 0 +#endif // !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 // 1) FYI from m3b@ about fork(): // A danger of calling fork() (as opposed to clone() or vfork()) is that if From c1be00f2e75053924af0ac5cf91ed530dd0ffcb5 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Fri, 1 Dec 2023 12:40:56 -0800 Subject: [PATCH 300/381] Remove tensorrt dependency when installing tensorflow[and-cuda] tensorrt's odd packaging is causing install headaches (TF 2.15.0 currently cannot be installed because it depends on tensorrt_ packages). Also, even if it is installed correctly, TF doesn't seem to be able to use it (it says: could not find tensorrt). For the time being, our solution is to remove this dependency for the simple [and-cuda] install method. Note: tensorrt is packaged differently than other Nvidia packages. The "tensorrt" package, when installed, fetches two other packages from Nvidia's PyPI repository that are not present on the main PyPI repository. PiperOrigin-RevId: 587089327 --- tensorflow/tools/pip_package/setup.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 3615a69b7f6fde..57faa7fb4ae7f2 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -179,9 +179,6 @@ def standard_or_nightly(standard, nightly): 'nvidia-cusparse-cu12 == 12.1.2.141', 'nvidia-nccl-cu12 == 2.18.3', 'nvidia-nvjitlink-cu12 == 12.2.140', - 'tensorrt == 8.6.1.post1', - 'tensorrt-bindings == 8.6.1', - 'tensorrt-libs == 8.6.1', ] DOCLINES = __doc__.split('\n') From 1a771bcb749beb4cecb42c6b08478de5c0bfa581 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 14:06:19 -0800 Subject: [PATCH 301/381] [xla:gpu] Add bf16 CUTLASS gemm kernel Changes for pattern matching will be in the follow up changes PiperOrigin-RevId: 587114037 --- third_party/xla/xla/service/gpu/kernels/BUILD | 17 +++++++++++++- .../kernels/cutlass_gemm_custom_kernel.cu.cc | 18 +++++++++------ .../gpu/kernels/cutlass_gemm_kernel.cu.h | 9 +++++++- .../gpu/kernels/cutlass_gemm_kernels.cu.h | 18 ++++++++++----- ...tlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc | 22 +++++++++++++++++++ .../cutlass_gemm_kernels_f32xf32_to_f32.cu.cc | 2 +- 6 files changed, 70 insertions(+), 16 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 7b4742e54f549b..4e96b185ee0301 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -196,7 +196,10 @@ cc_library( # cuda_library( # name = "cutlass_gemm_kernels", # visibility = ["//visibility:private"], -# deps = [":cutlass_gemm_kernels_f32xf32_to_f32"], +# deps = [ +# ":cutlass_gemm_kernels_bf16xbf16_to_bf16", +# ":cutlass_gemm_kernels_f32xf32_to_f32", +# ], # ) # # cuda_library( @@ -209,6 +212,18 @@ cc_library( # ], # ) # +# # CUTLASS requires all loops to be unrolled, and in some kernels defined below we force Clang/LLVM +# # to unroll them with extra compiler options because by default LLVM is not as aggressive with loop +# # unrolling as NVCC. +# +# cuda_library( +# name = "cutlass_gemm_kernels_bf16xbf16_to_bf16", +# srcs = ["cutlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc"], +# copts = ["-mllvm -unroll-threshold=100000"], +# visibility = ["//visibility:private"], +# deps = [":cutlass_gemm_kernels_header"], +# ) +# # cuda_library( # name = "cutlass_gemm_kernels_f32xf32_to_f32", # srcs = ["cutlass_gemm_kernels_f32xf32_to_f32.cu.cc"], diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc index 19ea111789330a..de8868fc497012 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc @@ -39,7 +39,7 @@ static StatusOr LoadCutlassGemmUniversal( auto packing = ArgsPacking(problem_size, indices, slices); se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); - spec.AddInProcessSymbol(GetCutlassGemmKernel(), "cutlass_gemm"); + spec.AddInProcessSymbol(GetKernelSymbol(), "cutlass_gemm"); return CustomKernel("cutlass_gemm", std::move(spec), BlockDim(problem_size), ThreadDim(), @@ -50,12 +50,16 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices) { - if (dtype != PrimitiveType::F32) - return absl::InvalidArgumentError( - "Currently cutlass gemm kernel supports only F32 data type"); - - return LoadCutlassGemmUniversal( - m, n, k, indices, slices); + switch (dtype) { + case PrimitiveType::F32: + return LoadCutlassGemmUniversal( + m, n, k, indices, slices); + case PrimitiveType::BF16: + return LoadCutlassGemmUniversal( + m, n, k, indices, slices); + default: + return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); + } } } // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index 1721c731e05670..0979e656c0387e 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -141,6 +141,7 @@ template KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, const ArgsIndices &indices, const DynamicSliceIndices &slices) { + using Accumulator = typename Gemm::ElementAccumulator; using Arguments = typename Gemm::Arguments; using Kernel = typename Gemm::GemmKernel; using Params = typename Kernel::Params; @@ -171,7 +172,13 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, auto ptr_c = ArgPtr(mem_args, indices); auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - float alpha = 1.0, beta = 0.0; + + // TODO(ezhulenev): We hardcode parameters for `LinearCombination` epilogue, + // however `Gemm` template can be compiled with arbitrary epilogues. We have + // to support custom epilogues in a way that does not leak cutlass types + // via the public API function signature. + Accumulator alpha{1.0}; + Accumulator beta{0.0}; // CUTLASS operation arguments. Arguments arguments(mode, problem_size, diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h index c2471f1ceb39d3..fbd5715b564a75 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h @@ -22,10 +22,15 @@ limitations under the License. namespace xla::gpu::kernel::gemm_universal { struct CutlassGemmKernels { - using F32xF32toF32 = - cutlass::gemm::device::GemmUniversal; + using F32xF32toF32 = cutlass::gemm::device::GemmUniversal< + float, cutlass::layout::RowMajor, // A + float, cutlass::layout::RowMajor, // B + float, cutlass::layout::RowMajor>; // C + + using BF16xBF16toBF16 = cutlass::gemm::device::GemmUniversal< + cutlass::bfloat16_t, cutlass::layout::RowMajor, // A + cutlass::bfloat16_t, cutlass::layout::RowMajor, // B + cutlass::bfloat16_t, cutlass::layout::RowMajor>; // C }; // This entry point is based on `cutlass::Kernel2` template with an extra @@ -55,12 +60,13 @@ __global__ void Kernel(typename Gemm::Params params, } template -void* GetCutlassGemmKernel() { +void* GetKernelSymbol() { return reinterpret_cast(Kernel); } // Extern templates for all supported CUTLASS Gemm kernels. -extern template void* GetCutlassGemmKernel(); +extern template void* GetKernelSymbol(); +extern template void* GetKernelSymbol(); } // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc new file mode 100644 index 00000000000000..0373d222489540 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_bf16xbf16_to_bf16.cu.cc @@ -0,0 +1,22 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +template void* GetKernelSymbol(); + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc index b4da3d222a6370..85ad96bbfbbfa7 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels_f32xf32_to_f32.cu.cc @@ -17,6 +17,6 @@ limitations under the License. namespace xla::gpu::kernel::gemm_universal { -template void* GetCutlassGemmKernel(); +template void* GetKernelSymbol(); } // namespace xla::gpu::kernel::gemm_universal From 5252fa4a441c74e872b013a0ba1e448d3ed1abb7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 14:08:00 -0800 Subject: [PATCH 302/381] #tf-data Modify the in-memory symbolic checkpointing for Shuffle Op to only write the tensors changed in buffer instead of writing the whole buffer to memory checkpoint after every GetNextInternal call. PiperOrigin-RevId: 587114690 --- tensorflow/core/data/BUILD | 4 +- tensorflow/core/data/serialization_utils.cc | 38 +++++++--- tensorflow/core/data/serialization_utils.h | 9 +++ .../core/data/serialization_utils_test.cc | 72 +++++++++++++++++++ tensorflow/core/kernels/data/BUILD | 4 +- .../core/kernels/data/shuffle_dataset_op.cc | 38 +++++++++- 6 files changed, 153 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 2c21b45cb6f2b9..171a6fa4d88e49 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_not_mobile", @@ -8,6 +7,7 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_protos_all", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -393,7 +393,9 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:stringpiece", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/core/data/serialization_utils.cc b/tensorflow/core/data/serialization_utils.cc index e07ec49b9137de..01b5a1289e1257 100644 --- a/tensorflow/core/data/serialization_utils.cc +++ b/tensorflow/core/data/serialization_utils.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/serialization_utils.h" +#include #include #include #include #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/data/compression_utils.h" @@ -30,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { namespace data { @@ -118,20 +121,39 @@ Status ReadElementsFromCheckpoint(IteratorContext* ctx, return OkStatus(); } +Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + int64_t index) { + const std::vector& element = elements[index]; + std::string element_prefix = absl::StrCat(key_prefix, "::", index); + TF_RETURN_IF_ERROR( + writer->WriteScalar(element_prefix, kNumComponents, element.size())); + for (int j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j])); + } + return OkStatus(); +} + Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements) { TF_RETURN_IF_ERROR( writer->WriteScalar(key_prefix, kNumElements, elements.size())); for (int i = 0; i < elements.size(); ++i) { - const std::vector& element = elements[i]; - std::string element_prefix = absl::StrCat(key_prefix, "::", i); - TF_RETURN_IF_ERROR( - writer->WriteScalar(element_prefix, kNumComponents, element.size())); - for (int j = 0; j < elements[i].size(); ++j) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j])); - } + TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); + } + return OkStatus(); +} + +Status UpdateCheckpointElements( + IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + const absl::flat_hash_set& checkpoint_indices) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(key_prefix, kNumElements, elements.size())); + for (int64_t i : checkpoint_indices) { + TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); } return OkStatus(); } diff --git a/tensorflow/core/data/serialization_utils.h b/tensorflow/core/data/serialization_utils.h index d5e83c32eb488f..b55dfdfb7eca8c 100644 --- a/tensorflow/core/data/serialization_utils.h +++ b/tensorflow/core/data/serialization_utils.h @@ -47,6 +47,15 @@ Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements); +// Updates the dataset elements in the checkpoint for given `checkpoint_indices` +// using the given key prefix, assuming that vector of elements have +// checkpointed these before. The elements can be read back by passing the same +// key prefix to ReadElementsFromCheckpoint. +Status UpdateCheckpointElements( + IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + const absl::flat_hash_set& checkpoint_indices); + // Helper class for reading data from a vector of VariantTensorData objects. class VariantTensorDataReader : public IteratorStateReader { public: diff --git a/tensorflow/core/data/serialization_utils_test.cc b/tensorflow/core/data/serialization_utils_test.cc index ddd424c519841c..5de7acfdc30f53 100644 --- a/tensorflow/core/data/serialization_utils_test.cc +++ b/tensorflow/core/data/serialization_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/data/serialization_utils.h" +#include #include #include #include @@ -203,6 +204,16 @@ class ParameterizedIteratorStateVariantTest } }; +class ParemeterizedCheckpointIndicesTest + : public DatasetOpsTestBase, + public ::testing::WithParamInterface> { + protected: + absl::flat_hash_set GetCheckpointIndices() const { + absl::flat_hash_set checkpoint_indices = GetParam(); + return checkpoint_indices; + } +}; + std::vector> TestCases() { return { CreateTensors(TensorShape{1}, {{1}}), // int64 @@ -216,6 +227,18 @@ std::vector> TestCases() { }; } +std::vector> CheckpointIndicesTestCases() { + return { + {/*checkpoint_indices*/}, + {/*checkpoint_indices*/ 0}, + {/*checkpoint_indices*/ 0, 1}, + {/*checkpoint_indices*/ 0, 1, 2}, + {/*checkpoint_indices*/ 1, 3, 4}, + {/*checkpoint_indices*/ 1, 2, 3, 4}, + {/*checkpoint_indices*/ 0, 1, 2, 3, 4}, + }; +} + TEST_P(ParameterizedIteratorStateVariantTest, EncodeAndDecode) { VariantTensorData data = GetVariantTensorData(); TF_ASSERT_OK_AND_ASSIGN(VariantTensorData result, EncodeAndDecode(data)); @@ -236,9 +259,58 @@ TEST_P(ParameterizedIteratorStateVariantTest, DecodeUncompressed) { } } +TEST_P(ParemeterizedCheckpointIndicesTest, + CheckpointElementsRoundTripUsingIndices) { + std::vector> elements; + elements.push_back(CreateTensors(TensorShape({3}), {{1, 2, 3}})); + elements.push_back(CreateTensors(TensorShape({2}), {{4, 5}})); + elements.push_back( + CreateTensors(TensorShape({5}), {{6, 7, 8, 9, 10}})); + elements.push_back( + CreateTensors(TensorShape({4}), {{11, 12, 13, 14}})); + elements.push_back(CreateTensors(TensorShape({2}), {{15, 16}})); + VariantTensorDataWriter writer; + tstring test_prefix = full_name("test_prefix"); + // Generate checkpoint for entire buffer + absl::flat_hash_set checkpoint_indices_write = {0, 1, 2, 3, 4}; + TF_ASSERT_OK(WriteElementsToCheckpoint(&writer, test_prefix, elements)); + // Update the elements at checkpoint indices + for (auto index : GetCheckpointIndices()) { + elements.at(index) = CreateTensors(TensorShape({1}), {{1}}); + } + TF_ASSERT_OK(UpdateCheckpointElements(&writer, test_prefix, elements, + GetCheckpointIndices())); + std::vector data; + writer.GetData(&data); + + VariantTensorDataReader reader(data); + std::vector> read_elements; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ctx, + TestContext::Create()); + TF_ASSERT_OK(ReadElementsFromCheckpoint(ctx->iter_ctx(), &reader, test_prefix, + &read_elements)); + + ASSERT_EQ(elements.size(), read_elements.size()); + // Check if checkpoint state of entire buffer is as expected + for (int index = 0; index < elements.size(); ++index) { + std::vector& original = elements[index]; + std::vector& read = read_elements[index]; + + ASSERT_EQ(original.size(), read.size()); + for (int j = 0; j < original.size(); ++j) { + EXPECT_EQ(original[j].NumElements(), read[j].NumElements()); + EXPECT_EQ(original[j].flat()(0), read[j].flat()(0)); + } + } +} + INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedIteratorStateVariantTest, ::testing::ValuesIn(TestCases())); +INSTANTIATE_TEST_SUITE_P(Instantiation, ParemeterizedCheckpointIndicesTest, + ::testing::ValuesIn(CheckpointIndicesTestCases())); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 06f1bf2bb3c531..74eb029a9aedc1 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1,11 +1,11 @@ # Description: # OpKernels for tf.data -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") # Definitions are loaded separately so that copybara can pattern match (and modify) each definition. load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_kernel_library") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -1172,7 +1172,9 @@ tf_kernel_library( "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:serialization_utils", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 5143182f0b1a90..cb2c28dbf1ea58 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/serialization_utils.h" @@ -203,6 +205,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { mutex_lock l(mu_); seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); + // Initialize checkpoint_indices_ to the entire buffer. + if (ctx->symbolic_checkpoint()) { + for (int64_t i = 0; i < buffer_->size(); ++i) { + checkpoint_indices_.insert(i); + } + } return OkStatus(); } @@ -229,6 +237,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { this->RecordBufferDequeue(ctx, *out_tensors); std::swap(buffer_->at(index), buffer_->at(slices_.front()->start % buffer_->size())); + checkpoint_indices_.insert(index); + checkpoint_indices_.insert(slices_.front()->start % buffer_->size()); slices_.front()->start++; num_elements_--; return OkStatus(); @@ -273,8 +283,20 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kEpoch, epoch_)); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kNumElements, num_elements_)); - TF_RETURN_IF_ERROR(WriteElementsToCheckpoint( - writer, absl::StrCat(prefix(), kColon, "buffer"), *buffer_)); + const std::string key_prefix = absl::StrCat(prefix(), kColon, "buffer"); + if (ctx->symbolic_checkpoint()) { + // When symbolic checkpointing is turned on, `writer` + // already contains checkpoint of the shuffle buffer created by the + // previous invocation of this instance and the indices that need to be + // updated are stored in `checkpoint_indices`. + TF_RETURN_IF_ERROR(UpdateCheckpointElements( + writer, key_prefix, *buffer_, checkpoint_indices_)); + checkpoint_indices_.clear(); + } else { + TF_RETURN_IF_ERROR( + WriteElementsToCheckpoint(writer, key_prefix, *buffer_)); + } + TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kSlicesSize, slices_.size())); for (size_t i = 0; i < slices_.size(); ++i) { @@ -339,6 +361,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(ReadElementsFromCheckpoint( ctx, reader, absl::StrCat(prefix(), kColon, "buffer"), buffer_.get())); + if (ctx->symbolic_checkpoint()) { + DCHECK(checkpoint_indices_.empty()); + for (size_t i = 0; i < buffer_->size(); ++i) { + checkpoint_indices_.insert(i); + } + } for (const auto& element : *buffer_) { RecordBufferEnqueue(ctx, element); } @@ -502,9 +530,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { this->RecordBufferEnqueue(ctx, element); if (num_elements_ == buffer_->size()) { DCHECK(IsShuffleAll()); + checkpoint_indices_.insert(buffer_->size()); buffer_->push_back(element); } else { size_t index = slices_.back()->end % buffer_->size(); + checkpoint_indices_.insert(index); buffer_->at(index) = std::move(element); } num_elements_++; @@ -530,6 +560,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_); // Not owned. std::unique_ptr>> buffer_ TF_GUARDED_BY(mu_); + // Holds the indices of `buffer_` that have changed since the previous + // `SaveInternal()` and need to be updated in the MemoryCheckpoint + // (if symbolic checkpointing is used) in the next `SaveInternal()`. + absl::flat_hash_set checkpoint_indices_ TF_GUARDED_BY(mu_); std::unique_ptr input_impl_ TF_GUARDED_BY(mu_) = nullptr; int64_t epoch_ TF_GUARDED_BY(mu_) = 0; int64_t num_elements_ TF_GUARDED_BY(mu_) = 0; From a93819e78c9031fde665eb47033bef86ea35c02b Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Fri, 1 Dec 2023 14:39:40 -0800 Subject: [PATCH 303/381] [HloValueSemanticsAnalysis] Add handlers for async host send recv. PiperOrigin-RevId: 587125079 --- third_party/xla/xla/service/BUILD | 1 + .../service/hlo_value_semantics_analysis.cc | 262 +++++++++++++++++- .../service/hlo_value_semantics_analysis.h | 76 ++++- .../hlo_value_semantics_analysis_test.cc | 4 +- 4 files changed, 336 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 5917512e839f1a..18c5db752de9de 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4400,6 +4400,7 @@ cc_library( ":hlo_value", "//xla:shape_tree", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status", "//xla:statusor", "//xla:util", diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index c94b63ef56c02c..cc407ea2bea75b 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -51,6 +51,34 @@ limitations under the License. namespace xla { +extern const char kXlaHostTransferRendezvousNameAttr[]; + +SendRecvGroupMap CreateSendRecvGroupMap(const HloModule& hlo_module) { + SendRecvGroupMap send_recv_group_map; + for (HloComputation* computation : hlo_module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kSend && + instruction->opcode() != HloOpcode::kRecv) { + continue; + } + std::string rendezvous = instruction->frontend_attributes().map().at( + kXlaHostTransferRendezvousNameAttr); + auto send_recv_iter = send_recv_group_map.find(rendezvous); + if (send_recv_iter == send_recv_group_map.end()) { + auto insert_success = send_recv_group_map.insert( + {rendezvous, SendRecvGroup{nullptr, nullptr}}); + send_recv_iter = insert_success.first; + } + if (instruction->opcode() == HloOpcode::kSend) { + send_recv_iter->second.send = instruction; + } else { + send_recv_iter->second.recv = instruction; + } + } + } + return send_recv_group_map; +} + bool HloPreOrderDFS::IsReady(const HloInstruction* instruction) const { for (HloInstruction* user : instruction->users()) { if (!visited_.contains(user)) { @@ -73,6 +101,24 @@ std::vector GetAllInstructionsWithZeroUsers( return results; } +StatusOr GetMatchingSendOrRecvFromMap( + HloInstruction* send_or_recv, const SendRecvGroupMap& send_recv_group_map) { + if (send_or_recv->opcode() != HloOpcode::kSend && + send_or_recv->opcode() != HloOpcode::kRecv) { + return InvalidArgument("Expecting only send or recv"); + } + std::string rendezvous = send_or_recv->frontend_attributes().map().at( + kXlaHostTransferRendezvousNameAttr); + auto send_recv_iter = send_recv_group_map.find(rendezvous); + if (send_recv_iter == send_recv_group_map.end()) { + return InternalError("Missing send or recv from send recv group."); + } + if (send_or_recv->opcode() == HloOpcode::kSend) { + return send_recv_iter->second.recv; + } + return send_recv_iter->second.send; +} + } // namespace Status HloPreOrderDFS::Run(const HloComputation& computation, @@ -124,8 +170,10 @@ Status EinsumDepthAnalysis::RunInternal( } StatusOr> EinsumDepthAnalysis::Run( - const HloComputation& computation) { - EinsumDepthAnalysis* analysis_ptr = new EinsumDepthAnalysis(); + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map) { + EinsumDepthAnalysis* analysis_ptr = + new EinsumDepthAnalysis(send_recv_group_map); std::unique_ptr analysis(analysis_ptr); TF_RETURN_IF_ERROR(analysis->RunInternal(computation, std::nullopt)); return analysis; @@ -502,6 +550,69 @@ Status EinsumDepthAnalysis::HandleCollectivePermuteDone( return OkStatus(); } +Status EinsumDepthAnalysis::HandleSend(HloInstruction* send) { + auto depth_iter = GetOrCreateDepthTree(send); + const ShapeTree& depth_tree = depth_iter->second; + HloInstruction* send_buffer = send->mutable_operand(0); + auto send_buffer_depth_iter = GetOrCreateDepthTree(send_buffer); + ShapeTree& send_buffer_depth = send_buffer_depth_iter->second; + SetDepthFromTupleDepth(send_buffer_depth, depth_tree, 0); + int max_depth = GetMaxDepth(depth_tree); + HloInstruction* token = send->mutable_operand(1); + return SetInstructionDepth(token, max_depth); +} + +Status EinsumDepthAnalysis::HandleRecv(HloInstruction* recv) { + auto depth_iter = GetOrCreateDepthTree(recv); + const ShapeTree& depth_tree = depth_iter->second; + TF_ASSIGN_OR_RETURN(HloInstruction * send, + GetMatchingSendOrRecvFromMap(recv, send_recv_group_map_)); + auto send_depth_iter = GetOrCreateDepthTree(send); + ShapeTree& send_depth = send_depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + send_depth.ForEachMutableElement([&depth_tree, &send_depth, max_depth]( + const ShapeIndex& index, int* depth) { + if (!send_depth.IsLeaf(index)) { + return; + } + if (index.front() == 0) { + *depth = MergeDepth(*depth, depth_tree.element(index)); + return; + } + *depth = MergeDepth(*depth, max_depth); + }); + return OkStatus(); +} + +Status EinsumDepthAnalysis::HandleSendDone(HloInstruction* send_done) { + HloInstruction* send = send_done->mutable_operand(0); + auto depth_iter = GetOrCreateDepthTree(send_done); + const ShapeTree& depth_tree = depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + return SetInstructionDepth(send, max_depth); +} + +Status EinsumDepthAnalysis::HandleRecvDone(HloInstruction* recv_done) { + auto depth_iter = GetOrCreateDepthTree(recv_done); + const ShapeTree& depth_tree = depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + HloInstruction* recv = recv_done->mutable_operand(0); + auto recv_depth_iter = GetOrCreateDepthTree(recv); + ShapeTree& recv_depth = recv_depth_iter->second; + recv_depth.ForEachMutableElement([&depth_tree, &recv_depth, max_depth]( + const ShapeIndex& index, int* depth) { + if (!recv_depth.IsLeaf(index)) { + return; + } + if (index.front() == 0) { + *depth = MergeDepth(*depth, depth_tree.element(index)); + return; + } + *depth = MergeDepth(*depth, max_depth); + }); + return OkStatus(); +} + std::string HloValueSemanticLabelToString(HloValueSemanticLabel label) { switch (label) { case HloValueSemanticLabel::kStatic: @@ -549,6 +660,7 @@ StatusOr> HloValueSemanticsAnalysis::Run(const HloModule& module) { std::unique_ptr value_semantics_analysis = absl::WrapUnique(new HloValueSemanticsAnalysis(module)); + value_semantics_analysis->InitializeSendRecvGroups(); TF_RETURN_IF_ERROR(value_semantics_analysis->InitializeEinsumDepth()); value_semantics_analysis->AnnotateWeights(); TF_RETURN_IF_ERROR( @@ -559,11 +671,26 @@ HloValueSemanticsAnalysis::Run(const HloModule& module) { Status HloValueSemanticsAnalysis::InitializeEinsumDepth() { TF_ASSIGN_OR_RETURN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module_.entry_computation())); + EinsumDepthAnalysis::Run(*module_.entry_computation(), + send_recv_group_map_)); einsum_depth_map_ = einsum_depth_analysis->GetEinsumDepthMap(); return OkStatus(); } +void HloValueSemanticsAnalysis::InitializeSendRecvGroups() { + send_recv_group_map_ = CreateSendRecvGroupMap(module_); +} + +bool HloValueSemanticsAnalysis::HasSemanticsFor( + const HloInstruction* instruction) const { + return value_semantics_.contains(instruction); +} + +StatusOr HloValueSemanticsAnalysis::GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const { + return GetMatchingSendOrRecvFromMap(send_or_recv, send_recv_group_map_); +} + HloValueSemantics::Id HloValueSemanticsAnalysis::NextId() { return next_id_++; } const HloValueSemantics* HloValueSemanticsAnalysis::NewHloValueSemantics( @@ -1065,8 +1192,14 @@ HloValueSemanticsPropagation::ComputeSemanticsFromOperands( return semantics_vec.back(); } +#define RETURN_IF_ALREADY_PROPAGATED(instruction) \ + if (analysis_->HasSemanticsFor(instruction)) { \ + return OkStatus(); \ + } + Status HloValueSemanticsPropagation::DefaultAction( HloInstruction* instruction) { + RETURN_IF_ALREADY_PROPAGATED(instruction); std::vector operand_indices(instruction->operand_count()); std::iota(operand_indices.begin(), operand_indices.end(), 0); TF_ASSIGN_OR_RETURN( @@ -1085,6 +1218,7 @@ Status HloValueSemanticsPropagation::HandleParameter( } Status HloValueSemanticsPropagation::HandleConstant(HloInstruction* constant) { + RETURN_IF_ALREADY_PROPAGATED(constant); const HloValueSemantics* constant_semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {constant, {}}); ShapeTree semantics_shape_tree(constant->shape(), @@ -1094,6 +1228,7 @@ Status HloValueSemanticsPropagation::HandleConstant(HloInstruction* constant) { } Status HloValueSemanticsPropagation::HandleIota(HloInstruction* iota) { + RETURN_IF_ALREADY_PROPAGATED(iota); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {iota, {}}); ShapeTree semantics_shape_tree(iota->shape(), @@ -1104,6 +1239,7 @@ Status HloValueSemanticsPropagation::HandleIota(HloInstruction* iota) { Status HloValueSemanticsPropagation::HandlePartitionId( HloInstruction* partition_id) { + RETURN_IF_ALREADY_PROPAGATED(partition_id); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {partition_id, {}}); ShapeTree semantics_shape_tree( @@ -1113,6 +1249,7 @@ Status HloValueSemanticsPropagation::HandlePartitionId( } Status HloValueSemanticsPropagation::HandleReplicaId( HloInstruction* replica_id) { + RETURN_IF_ALREADY_PROPAGATED(replica_id); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {replica_id, {}}); ShapeTree semantics_shape_tree(replica_id->shape(), @@ -1132,6 +1269,7 @@ Status HloValueSemanticsPropagation::HandleRngBitGenerator( } Status HloValueSemanticsPropagation::HandleClamp(HloInstruction* clamp) { + RETURN_IF_ALREADY_PROPAGATED(clamp); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(clamp->operand(1)); analysis_->DeepCopyHloValueSemantics(clamp, operand_semantics); @@ -1139,6 +1277,7 @@ Status HloValueSemanticsPropagation::HandleClamp(HloInstruction* clamp) { } Status HloValueSemanticsPropagation::HandleTuple(HloInstruction* tuple) { + RETURN_IF_ALREADY_PROPAGATED(tuple); ShapeTree semantics_shape_tree(tuple->shape(), nullptr); for (int operand_index = 0; operand_index < tuple->operand_count(); @@ -1164,6 +1303,7 @@ Status HloValueSemanticsPropagation::HandleTuple(HloInstruction* tuple) { Status HloValueSemanticsPropagation::HandleGetTupleElement( HloInstruction* get_tuple_element) { + RETURN_IF_ALREADY_PROPAGATED(get_tuple_element); const HloInstruction* tuple = get_tuple_element->operand(0); int64_t tuple_index = get_tuple_element->tuple_index(); const ShapeTree& tuple_semantics = @@ -1177,6 +1317,7 @@ Status HloValueSemanticsPropagation::HandleGetTupleElement( } Status HloValueSemanticsPropagation::HandleCall(HloInstruction* call) { + RETURN_IF_ALREADY_PROPAGATED(call); HloComputation* computation = call->called_computations()[0]; TF_RETURN_IF_ERROR( analysis_->RunOnComputation(*computation, call->operands())); @@ -1187,6 +1328,7 @@ Status HloValueSemanticsPropagation::HandleCall(HloInstruction* call) { } Status HloValueSemanticsPropagation::HandleFusion(HloInstruction* fusion) { + RETURN_IF_ALREADY_PROPAGATED(fusion); HloComputation* computation = fusion->called_computations()[0]; TF_RETURN_IF_ERROR( analysis_->RunOnComputation(*computation, fusion->operands())); @@ -1197,6 +1339,7 @@ Status HloValueSemanticsPropagation::HandleFusion(HloInstruction* fusion) { } Status HloValueSemanticsPropagation::HandleWhile(HloInstruction* xla_while) { + RETURN_IF_ALREADY_PROPAGATED(xla_while); TF_RETURN_IF_ERROR(analysis_->RunOnComputation(*xla_while->while_condition(), xla_while->operands())); HloComputation* computation = xla_while->while_body(); @@ -1210,6 +1353,7 @@ Status HloValueSemanticsPropagation::HandleWhile(HloInstruction* xla_while) { Status HloValueSemanticsPropagation::HandleCustomCall( HloInstruction* custom_call) { + RETURN_IF_ALREADY_PROPAGATED(custom_call); if (custom_call->custom_call_target() == "Sharding" || custom_call->custom_call_target() == "SPMDFullToShardShape" || custom_call->custom_call_target() == "SPMDShardToFullShape") { @@ -1224,6 +1368,7 @@ Status HloValueSemanticsPropagation::HandleCustomCall( Status HloValueSemanticsPropagation::HandleConditional( HloInstruction* conditional) { + RETURN_IF_ALREADY_PROPAGATED(conditional); for (int i = 0; i < conditional->called_computations().size(); ++i) { TF_RETURN_IF_ERROR( analysis_->RunOnComputation(*conditional->called_computations()[i], @@ -1237,6 +1382,7 @@ Status HloValueSemanticsPropagation::HandleConditional( } Status HloValueSemanticsPropagation::HandleSelect(HloInstruction* select) { + RETURN_IF_ALREADY_PROPAGATED(select); TF_ASSIGN_OR_RETURN(HloValueSemantics semantics, ComputeSemanticsFromOperands(select, {1, 2})); const HloValueSemantics* semantics_ptr = AddSemantics(semantics); @@ -1248,6 +1394,7 @@ Status HloValueSemanticsPropagation::HandleSelect(HloInstruction* select) { Status HloValueSemanticsPropagation::HandleConcatenate( HloInstruction* concatenate) { + RETURN_IF_ALREADY_PROPAGATED(concatenate); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(concatenate->operand(0)); analysis_->DeepCopyHloValueSemantics(concatenate, operand_semantics); @@ -1256,6 +1403,7 @@ Status HloValueSemanticsPropagation::HandleConcatenate( Status HloValueSemanticsPropagation::HandleDynamicSlice( HloInstruction* dynamic_slice) { + RETURN_IF_ALREADY_PROPAGATED(dynamic_slice); const HloInstruction* dynamic_slice_operand = dynamic_slice->operand(0); const HloValueSemantics* operand_semantics = analysis_->GetSemantics(dynamic_slice_operand); @@ -1277,6 +1425,7 @@ Status HloValueSemanticsPropagation::HandleDynamicSlice( Status HloValueSemanticsPropagation::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + RETURN_IF_ALREADY_PROPAGATED(dynamic_update_slice); TF_ASSIGN_OR_RETURN( HloValueSemantics semantics, ComputeSemanticsFromOperands(dynamic_update_slice, {0, 1})); @@ -1289,6 +1438,7 @@ Status HloValueSemanticsPropagation::HandleDynamicUpdateSlice( Status HloValueSemanticsPropagation::HandleCopyStart( HloInstruction* copy_start) { + RETURN_IF_ALREADY_PROPAGATED(copy_start); ShapeTree semantics_shape_tree(copy_start->shape()); const ShapeTree& operand_semantics_shape_tree = analysis_->GetInstructionSemantics(copy_start->operand(0)); @@ -1317,6 +1467,7 @@ Status HloValueSemanticsPropagation::HandleCopyStart( } Status HloValueSemanticsPropagation::HandleCopyDone(HloInstruction* copy_done) { + RETURN_IF_ALREADY_PROPAGATED(copy_done); const ShapeTree& operand_semantics_shape_tree = analysis_->GetInstructionSemantics(copy_done->operand(0)); analysis_->DeepCopyHloValueSemantics(copy_done, operand_semantics_shape_tree, @@ -1325,6 +1476,7 @@ Status HloValueSemanticsPropagation::HandleCopyDone(HloInstruction* copy_done) { } Status HloValueSemanticsPropagation::HandleCollectivePermuteStart( HloInstruction* collective_permute_start) { + RETURN_IF_ALREADY_PROPAGATED(collective_permute_start); ShapeTree semantics_shape_tree( collective_permute_start->shape()); const ShapeTree& operand_semantics_shape_tree = @@ -1358,6 +1510,7 @@ Status HloValueSemanticsPropagation::HandleCollectivePermuteStart( } Status HloValueSemanticsPropagation::HandleCollectivePermuteDone( HloInstruction* collective_permute_done) { + RETURN_IF_ALREADY_PROPAGATED(collective_permute_done); const ShapeTree& operand_semantics_shape_tree = analysis_->GetInstructionSemantics(collective_permute_done->operand(0)); analysis_->DeepCopyHloValueSemantics(collective_permute_done, @@ -1365,6 +1518,7 @@ Status HloValueSemanticsPropagation::HandleCollectivePermuteDone( return OkStatus(); } Status HloValueSemanticsPropagation::HandleGather(HloInstruction* gather) { + RETURN_IF_ALREADY_PROPAGATED(gather); const ShapeTree& operand_semantics_shape_tree = analysis_->GetInstructionSemantics(gather->operand(0)); analysis_->DeepCopyHloValueSemantics(gather, operand_semantics_shape_tree); @@ -1372,6 +1526,7 @@ Status HloValueSemanticsPropagation::HandleGather(HloInstruction* gather) { } Status HloValueSemanticsPropagation::HandleScatter(HloInstruction* scatter) { + RETURN_IF_ALREADY_PROPAGATED(scatter); TF_ASSIGN_OR_RETURN(HloValueSemantics semantics, ComputeSemanticsFromOperands(scatter, {0, 2})); const HloValueSemantics* semantics_ptr = AddSemantics(semantics); @@ -1382,6 +1537,7 @@ Status HloValueSemanticsPropagation::HandleScatter(HloInstruction* scatter) { } Status HloValueSemanticsPropagation::HandleAfterAll(HloInstruction* after_all) { + RETURN_IF_ALREADY_PROPAGATED(after_all); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kTupleOrToken, {after_all, {}}); ShapeTree semantics_shape_tree(after_all->shape(), @@ -1392,6 +1548,7 @@ Status HloValueSemanticsPropagation::HandleAfterAll(HloInstruction* after_all) { Status HloValueSemanticsPropagation::HandleAsyncStart( HloInstruction* async_start) { + RETURN_IF_ALREADY_PROPAGATED(async_start); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kTupleOrToken, {async_start, {}}); ShapeTree semantics_shape_tree(async_start->shape(), @@ -1434,6 +1591,7 @@ Status HloValueSemanticsPropagation::HandleAsyncStart( } Status HloValueSemanticsPropagation::HandleAsyncDone( HloInstruction* async_done) { + RETURN_IF_ALREADY_PROPAGATED(async_done); const ShapeTree& operand_semantics_tree = analysis_->GetInstructionSemantics(async_done->operand(0)); analysis_->DeepCopyHloValueSemantics(async_done, operand_semantics_tree, {1}); @@ -1441,6 +1599,7 @@ Status HloValueSemanticsPropagation::HandleAsyncDone( } Status HloValueSemanticsPropagation::HandleInfeed(HloInstruction* infeed) { + RETURN_IF_ALREADY_PROPAGATED(infeed); ShapeTree semantics_shape_tree(infeed->shape(), nullptr); semantics_shape_tree.ForEachMutableElement( @@ -1459,6 +1618,7 @@ Status HloValueSemanticsPropagation::HandleInfeed(HloInstruction* infeed) { } Status HloValueSemanticsPropagation::HandleDomain(HloInstruction* domain) { + RETURN_IF_ALREADY_PROPAGATED(domain); HloInstruction* domain_operand = domain->mutable_operand(0); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(domain_operand); @@ -1468,6 +1628,7 @@ Status HloValueSemanticsPropagation::HandleDomain(HloInstruction* domain) { Status HloValueSemanticsPropagation::HandleOptimizationBarrier( HloInstruction* opt_barrier) { + RETURN_IF_ALREADY_PROPAGATED(opt_barrier); HloInstruction* opt_barrier_operand = opt_barrier->mutable_operand(0); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(opt_barrier_operand); @@ -1475,4 +1636,99 @@ Status HloValueSemanticsPropagation::HandleOptimizationBarrier( return OkStatus(); } +Status HloValueSemanticsPropagation::HandleSend(HloInstruction* send) { + RETURN_IF_ALREADY_PROPAGATED(send); + ShapeTree semantics_tree(send->shape(), nullptr); + HloInstruction* source_buffer = send->mutable_operand(0); + const ShapeTree& source_buffer_semantics = + analysis_->GetInstructionSemantics(source_buffer); + analysis_->DeepCopyHloValueSemantics(semantics_tree, source_buffer_semantics, + {}, {0}); + + semantics_tree.ForEachMutableElement( + [this, send, &semantics_tree](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty()) { + if (index.front() == 1 && semantics_tree.IsLeaf(index)) { + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {send, index}); + return; + } + if (index.front() == 0) { + return; + } + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {send, index}); + }); + analysis_->SetHloValueSemantics(send, semantics_tree); + return OkStatus(); +} + +Status HloValueSemanticsPropagation::HandleRecv(HloInstruction* recv) { + // Since recv is not a prerequisite of send, we might have not propagated + // semantics to the corresponding send when we reach this recv. So we visit + // the send first before visiting this recv. + // We use RETURN_IF_ALREADY_PROPAGATED to avoid processing an HLO more than + // once. + RETURN_IF_ALREADY_PROPAGATED(recv); + TF_ASSIGN_OR_RETURN(HloInstruction * send, + analysis_->GetMatchingSendOrRecv(recv)); + TF_RETURN_IF_ERROR(send->Accept(this)); + ShapeTree semantics_tree(recv->shape(), nullptr); + const ShapeTree& send_buffer_semantics = + analysis_->GetInstructionSemantics(send); + analysis_->DeepCopyHloValueSemantics(semantics_tree, send_buffer_semantics, + {0}, {0}); + semantics_tree.ForEachMutableElement( + [this, recv, &semantics_tree](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty()) { + if (index.front() == 1 && semantics_tree.IsLeaf(index)) { + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {recv, index}); + return; + } + if (index.front() == 0) { + return; + } + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {recv, index}); + }); + analysis_->SetHloValueSemantics(recv, semantics_tree); + return OkStatus(); +} + +Status HloValueSemanticsPropagation::HandleSendDone(HloInstruction* send_done) { + RETURN_IF_ALREADY_PROPAGATED(send_done); + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {send_done, {}}); + ShapeTree send_done_semantics_tree( + send_done->shape(), semantics); + analysis_->SetHloValueSemantics(send_done, send_done_semantics_tree); + return OkStatus(); +} +Status HloValueSemanticsPropagation::HandleRecvDone(HloInstruction* recv_done) { + RETURN_IF_ALREADY_PROPAGATED(recv_done); + ShapeTree semantics_tree(recv_done->shape(), + nullptr); + HloInstruction* recv = recv_done->mutable_operand(0); + const ShapeTree& recv_semantics = + analysis_->GetInstructionSemantics(recv); + analysis_->DeepCopyHloValueSemantics(semantics_tree, recv_semantics, {0}, + {0}); + semantics_tree.ForEachMutableElement( + [this, recv_done](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty() && index.front() == 0) { + return; + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {recv_done, index}); + }); + analysis_->SetHloValueSemantics(recv_done, semantics_tree); + return OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index 6b8084aaf9af9f..fa4d14ad829898 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -38,6 +38,15 @@ limitations under the License. namespace xla { +struct SendRecvGroup { + HloInstruction* send; + HloInstruction* recv; +}; + +using SendRecvGroupMap = absl::flat_hash_map; + +SendRecvGroupMap GetSendRecvGroupMap(const HloModule& hlo_module); + class HloPreOrderDFS { public: HloPreOrderDFS() = default; @@ -71,7 +80,8 @@ using EinsumDepthMap = class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { public: static StatusOr> Run( - const HloComputation& computation); + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map); ~EinsumDepthAnalysis() override = default; Status DefaultAction(HloInstruction* instruction) override; Status HandleTuple(HloInstruction* tuple) override; @@ -89,10 +99,15 @@ class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { HloInstruction* collective_permute_start) override; Status HandleCollectivePermuteDone( HloInstruction* collective_permute_done) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecvDone(HloInstruction* recv_done) override; const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } private: - EinsumDepthAnalysis() = default; + explicit EinsumDepthAnalysis(const SendRecvGroupMap& send_recv_group_map) + : send_recv_group_map_(send_recv_group_map) {} Status RunInternal(const HloComputation& computation, const std::optional>& root_depth); EinsumDepthMap::iterator GetOrCreateDepthTree(HloInstruction* instruction); @@ -104,6 +119,49 @@ class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { const ShapeTree& root_depth, absl::Span operands); EinsumDepthMap einsum_depth_map_; + const SendRecvGroupMap send_recv_group_map_; +}; + +using EinsumHeightMap = + absl::flat_hash_map>; + +class EinsumHeightAnalysis : public DfsHloVisitorWithDefault { + public: + static StatusOr> Run( + const HloComputation& computation); + ~EinsumHeightAnalysis() override = default; + Status DefaultAction(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCall(HloInstruction* call) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleOutfeed(HloInstruction* outfeed) override; + Status HandleCollectivePermuteStart( + HloInstruction* collective_permute_start) override; + Status HandleCollectivePermuteDone( + HloInstruction* collective_permute_done) override; + const EinsumHeightMap& GetEinsumHeightMap() const { + return einsum_height_map_; + } + + private: + EinsumHeightAnalysis() = default; + Status RunInternal(const HloComputation& computation, + absl::Span operands); + EinsumHeightMap::iterator GetOrCreateHeightTree(HloInstruction* instruction); + Status SetInstructionHeight(HloInstruction* instruction, int height); + Status SetInstructionHeight(HloInstruction* instruction, + const ShapeTree& height); + Status HandleHeightIncrementInstruction(HloInstruction* instruction); + Status HandleCalledComputation(const HloComputation& computation, + absl::Span operands); + EinsumHeightMap einsum_height_map_; + const SendRecvGroupMap send_recv_group_map_; }; // The comment below explains where the labels could originate from. Once @@ -159,6 +217,7 @@ class HloValueSemanticsAnalysis { static StatusOr> Run( const HloModule& module); virtual ~HloValueSemanticsAnalysis() = default; + bool HasSemanticsFor(const HloInstruction* instruction) const; const HloValueSemantics* GetSemantics(const HloInstruction* instruction, const ShapeIndex& index = {}) const; @@ -167,11 +226,19 @@ class HloValueSemanticsAnalysis { } const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } + const SendRecvGroupMap& GetSendRecvGroupMap() const { + return send_recv_group_map_; + } + + StatusOr GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const; protected: friend class HloValueSemanticsPropagation; explicit HloValueSemanticsAnalysis(const HloModule& module); Status InitializeEinsumDepth(); + // We match send and recv HLOs to propagate semantics from send to recv. + void InitializeSendRecvGroups(); void AnnotateWeights(); // Infer semantics for all instructions in the computation. Computation @@ -206,6 +273,7 @@ class HloValueSemanticsAnalysis { value_semantics_map_; HloValueSemantics::Id next_id_; EinsumDepthMap einsum_depth_map_; + SendRecvGroupMap send_recv_group_map_; }; class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { @@ -247,6 +315,10 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { Status HandleDomain(HloInstruction* domain) override; Status HandleOptimizationBarrier(HloInstruction* opt_barrier) override; Status HandleRngBitGenerator(HloInstruction* rng_bit_generator) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecvDone(HloInstruction* recv_done) override; protected: HloValueSemantics CopySemantics(const HloValueSemantics& semantics) const; diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc index fd1704f9192e82..3128115727da46 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc @@ -567,7 +567,7 @@ TEST_F(EinsumDepthAnalysisTest, MnistTrainingLoop) { /*num_partitions=*/1)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module->entry_computation())); + EinsumDepthAnalysis::Run(*module->entry_computation(), {})); const EinsumDepthMap& einsum_depth_map = einsum_depth_analysis->GetEinsumDepthMap(); HloComputation* computation = module->GetComputationWithName("body.49"); @@ -612,7 +612,7 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) { ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module->entry_computation())); + EinsumDepthAnalysis::Run(*module->entry_computation(), {})); const EinsumDepthMap& einsum_depth_map = einsum_depth_analysis->GetEinsumDepthMap(); HloComputation* computation = module->GetComputationWithName("entry"); From 3987ee36b9b8e2c2f3172f5cf9efd847e877c533 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 1 Dec 2023 15:06:55 -0800 Subject: [PATCH 304/381] Switch build configs in Linux Arm64 continuous builds from native to cross-compile This switches over the regular Linux Arm64 continuous builds from building natively to cross-compiling on remote Linux x86 VMs. PiperOrigin-RevId: 587134081 --- .../envs/continuous_linux_arm64_cpu_py39_cross_compile | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile b/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile new file mode 100644 index 00000000000000..23870d6c181bd3 --- /dev/null +++ b/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile @@ -0,0 +1,6 @@ +# This envrionment is experimental and should not yet be used for production jobs +TFCI_BAZEL_COMMON_ARGS="--config rbe_cross_compile_linux_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.9 From 79c1c93ce9a3337a26a8c31e70e235474770bef2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 15:17:05 -0800 Subject: [PATCH 305/381] Replaced std::vector with absl::InlinedVector. PiperOrigin-RevId: 587136655 --- third_party/xla/xla/layout_util.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/layout_util.cc b/third_party/xla/xla/layout_util.cc index c4c2c5ece4eed9..0d9f5e24893a11 100644 --- a/third_party/xla/xla/layout_util.cc +++ b/third_party/xla/xla/layout_util.cc @@ -252,7 +252,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(shape.rank(), false); + absl::InlinedVector dimensions_in_layout(shape.rank(), + false); for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dim = layout.minor_to_major(i); if (dim < 0 || dim >= shape.rank()) { From 69fa8c46bf89e429b9b978ec56f106f2631dbb63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20K=C3=B6ppe?= Date: Fri, 1 Dec 2023 15:24:18 -0800 Subject: [PATCH 306/381] Fix missing header inclusion. Found by -Wundefined-func-template. PiperOrigin-RevId: 587138447 --- tensorflow/core/framework/variant.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index c4e23a8d07ba5e..152e0538f81bfe 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/strcat.h" From aef6b65b451260adc9aab708220cad3bb406788f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 15:43:05 -0800 Subject: [PATCH 307/381] No op change. PiperOrigin-RevId: 587143760 --- tensorflow/core/data/BUILD | 23 ++++++++++- tensorflow/core/data/utils.cc | 7 ++++ tensorflow/core/data/utils.h | 5 +++ tensorflow/core/data/utils_test.cc | 66 ++++++++++++++++++++++++++++++ tensorflow/core/kernels/data/BUILD | 2 + 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/data/utils_test.cc diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 171a6fa4d88e49..b1c01427a0462f 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -39,6 +39,10 @@ exports_files([ "serialization_utils.h", "split_utils.cc", "split_utils.h", + "file_logger_client_no_op.h", + "file_logger_client_no_op.cc", + "file_logger_client_interface.h", + "file_logger_client_interface.cc", "stats_utils.cc", "stats_utils.h", "tfdataz_metrics.h", @@ -632,6 +636,8 @@ cc_library( hdrs = ["utils.h"], # copybara:uncomment copts = ["-Wthread-safety-analysis"], deps = [ + ":file_logger_client_interface", + ":file_logger_client_no_op", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", @@ -639,10 +645,25 @@ cc_library( ], ) +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + # copybara:uncomment extra_copts = ["-Wthread-safety-analysis"], + deps = [ + ":file_logger_client_interface", + ":file_logger_client_no_op", + ":utils", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "file_logger_client_interface", hdrs = ["file_logger_client_interface.h"], - # copybara:uncomment visibility = ["//learning/processing/tf_data_logger/client:__subpackages__"], + visibility = [ + "//learning/processing/tf_data_logger/client:__subpackages__", + "//tensorflow:internal", + ], ) cc_library( diff --git a/tensorflow/core/data/utils.cc b/tensorflow/core/data/utils.cc index 7d346dcbecd319..73f8a75587e97e 100644 --- a/tensorflow/core/data/utils.cc +++ b/tensorflow/core/data/utils.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/utils.h" +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "tensorflow/core/data/file_logger_client_interface.h" +#include "tensorflow/core/data/file_logger_client_no_op.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/protobuf/data_service.pb.h" @@ -44,5 +47,9 @@ absl::StatusOr DisableCompressionAtRuntime( return false; } +std::unique_ptr CreateFileLoggerClient() { + return std::make_unique(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils.h b/tensorflow/core/data/utils.h index d80431c9680ab8..00fe795f9c7f3e 100644 --- a/tensorflow/core/data/utils.h +++ b/tensorflow/core/data/utils.h @@ -15,11 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DATA_UTILS_H_ #define TENSORFLOW_CORE_DATA_UTILS_H_ +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "tensorflow/core/data/file_logger_client_interface.h" #include "tensorflow/core/protobuf/data_service.pb.h" namespace tensorflow { @@ -48,6 +50,9 @@ std::string LocalityOptimizedPath(const std::string& path); absl::StatusOr DisableCompressionAtRuntime( const std::string& data_transfer_protocol, DeploymentMode deployment_mode); +// Creates a instance of a class derived from FileLoggerClientInterface. +std::unique_ptr CreateFileLoggerClient(); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils_test.cc b/tensorflow/core/data/utils_test.cc new file mode 100644 index 00000000000000..1f908acb278b59 --- /dev/null +++ b/tensorflow/core/data/utils_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/utils.h" + +#include + +#include +#include "tensorflow/core/data/file_logger_client_interface.h" +#include "tensorflow/core/data/file_logger_client_no_op.h" + +namespace tensorflow::data { +namespace { + +TEST(Util, CreateFileLoggerClient) { + std::unique_ptr client = CreateFileLoggerClient(); + EXPECT_NE(dynamic_cast(client.get()), nullptr); +} + +TEST(Util, DefaultDataTransferProtocol) { + EXPECT_EQ(DefaultDataTransferProtocol(), "grpc"); +} + +TEST(TranslateFileName, NoOp) { + constexpr char file[] = "/home/tfdata/file1"; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(TranslateFileName, EmptyPath) { + constexpr char file[] = ""; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(TranslateFileName, TfDataPath) { + constexpr char file[] = "tfdata/file1"; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(LocalityOptimizedPath, NoOp) { + constexpr char file[] = "/home/tfdata/file1"; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +TEST(LocalityOptimizedPath, EmptyPath) { + constexpr char file[] = ""; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +TEST(LocalityOptimizedPath, TfDataPath) { + constexpr char file[] = "tfdata/file1"; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +} // namespace +} // namespace tensorflow::data diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 74eb029a9aedc1..e6aa8509396fcf 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1514,6 +1514,8 @@ filegroup( "//tensorflow/core/data:captured_function.h", "//tensorflow/core/data:compression_utils.h", "//tensorflow/core/data:dataset_utils.h", + "//tensorflow/core/data:file_logger_client_interface.h", + "//tensorflow/core/data:file_logger_client_no_op.h", "//tensorflow/core/data:finalization_utils.h", "//tensorflow/core/data:metric_utils.h", "//tensorflow/core/data:name_utils.h", From f01a2719204bac1a52567e130b2dc5177d5e7abf Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Fri, 1 Dec 2023 15:43:23 -0800 Subject: [PATCH 308/381] [HloValueSemanticsAnalysis] Fix activation gradient classification for MoE. The idea is that activation-activation einsums are (almost) never activation gradient. PiperOrigin-RevId: 587143854 --- .../service/hlo_value_semantics_analysis.cc | 50 ++++++++++--------- .../service/hlo_value_semantics_analysis.h | 5 +- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index cc407ea2bea75b..7e1bd0e87e5af4 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -920,7 +920,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther( instruction->opcode() == HloOpcode::kConvolution; if (is_dot_or_convolution && other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemantics(other_semantics); } @@ -939,8 +940,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromRandomAndOther( } StatusOr -HloValueSemanticsPropagation::CreateGradientSemantics( - HloInstruction* gradient_candidate) const { +HloValueSemanticsPropagation::MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const { const EinsumDepthMap& einsum_depth_map = analysis_->GetEinsumDepthMap(); auto depth_iter = einsum_depth_map.find(gradient_candidate); CHECK(depth_iter != einsum_depth_map.end()); @@ -956,8 +958,7 @@ HloValueSemanticsPropagation::CreateGradientSemantics( return HloValueSemantics(HloValueSemanticLabel::kWeightGradient, {gradient_candidate, {}}); } - return HloValueSemantics(HloValueSemanticLabel::kActivationGradient, - {gradient_candidate, {}}); + return HloValueSemantics(fallback_label, {gradient_candidate, {}}); } StatusOr @@ -972,6 +973,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( instruction->opcode() == HloOpcode::kConvolution; if (other_semantics.label() == HloValueSemanticLabel::kWeight) { if (!is_dot_or_convolution) { + if (weight_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } return HloValueSemantics(HloValueSemanticLabel::kActivation, @@ -988,7 +992,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // operand. if (OriginDependsOn(other_semantics, weight_semantics.origin(), /*recursive=*/true)) { - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } @@ -997,7 +1002,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // which produce an Activation. The ActivationGradient to this Activation // could be used in an einsum with one of the Weights to compute // the WeightGradient for the other Weight. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient); return CopySemantics(other_semantics); @@ -1015,14 +1021,16 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || instruction->opcode() == HloOpcode::kConvolution; if (!is_dot_or_convolution) { + if (activation_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivation) { // Like said above, since loss is classified as Activation, an einsum - // between an Activation X and an Activation Y could be WeightGradient or - // even ActivationGradient when either X or Y is the loss. This case is - // different from other Activation einsums because there must a dependency - // between X and Y. + // between an Activation X and an Activation Y could be WeightGradient if + // either X or Y is the loss. This case is different from other Activation + // einsums because there must a dependency between X and Y. bool other_depends_on_activation = OriginDependsOn( other_semantics, activation_semantics.origin(), /*recursive=*/true); bool activation_depends_on_other = @@ -1032,14 +1040,19 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( // If there is no dependency between the two Activations, the output must // be an Activation. if (other_depends_on_activation || activation_depends_on_other) { - return CreateGradientSemantics(instruction); + // We check if the einsum is actually weight gradient. If it is not, fall + // back to activation, since we expect the loss to be computed from an + // activation-weight einsum. + return MaybeCreateGradientSemantics(instruction, + HloValueSemanticLabel::kActivation); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { // An Activation-ActivationGradient einsum could be computing // WeightGradient or ActivationGradient. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient) << "instruction: " << instruction->ToString() @@ -1407,16 +1420,7 @@ Status HloValueSemanticsPropagation::HandleDynamicSlice( const HloInstruction* dynamic_slice_operand = dynamic_slice->operand(0); const HloValueSemantics* operand_semantics = analysis_->GetSemantics(dynamic_slice_operand); - const HloValueSemantics* semantics = nullptr; - if (operand_semantics->label() == HloValueSemanticLabel::kStatic || - operand_semantics->label() == HloValueSemanticLabel::kRandom || - operand_semantics->label() == HloValueSemanticLabel::kWeight) { - semantics = analysis_->NewHloValueSemantics(operand_semantics->label(), - {dynamic_slice, {}}); - } else { - HloValueSemantics semantics_value = CopySemantics(*operand_semantics); - semantics = AddSemantics(semantics_value); - } + const HloValueSemantics* semantics = AddSemantics(*operand_semantics); ShapeTree semantics_shape_tree( dynamic_slice->shape(), semantics); analysis_->SetHloValueSemantics(dynamic_slice, semantics_shape_tree); diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index fa4d14ad829898..634b13f21ed65c 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -346,8 +346,9 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { bool OriginDependsOn(const HloValueSemantics& semantics, const HloPosition& origin_dependence, bool recursive = false) const; - StatusOr CreateGradientSemantics( - HloInstruction* gradient_candidate) const; + StatusOr MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const; StatusOr ComputeSemanticsFromStaticAndOther( const HloValueSemantics& static_semantics, const HloValueSemantics& other_semantics, From 2326e80b5bcded31cdab020507c78aff34beefe8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 16:17:57 -0800 Subject: [PATCH 309/381] [XLA] Share the same underlying DFS stack across instructions. PiperOrigin-RevId: 587153954 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 41 ++++++++++++------- third_party/xla/xla/hlo/ir/hlo_computation.h | 6 ++- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 30af92643a7b1a..07ea166007098e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -417,23 +417,27 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, void HloComputation::ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, - std::vector& post_order) const { + std::vector& post_order, + std::vector* dfs_stack_scratch) const { ForEachInstructionPostOrderImpl( [&post_order](HloInstruction* hlo) { post_order.push_back(hlo); }, root, - channel_dependencies, visited); + channel_dependencies, visited, dfs_stack_scratch); } void HloComputation::ForEachInstructionPostOrderImpl( absl::FunctionRef func, HloInstruction* root, const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited) const { - std::vector dfs_stack = {root}; - while (!dfs_stack.empty()) { - HloInstruction& current = *dfs_stack.back(); + absl::flat_hash_map& visited, + std::vector* dfs_stack_scratch) const { + auto* dfs_stack = dfs_stack_scratch; + dfs_stack->clear(); + dfs_stack->push_back(root); + while (!dfs_stack->empty()) { + HloInstruction& current = *dfs_stack->back(); auto [it, was_inserted] = visited.insert({¤t, kVisiting}); if (!was_inserted) { // We've already seen this instruction. - dfs_stack.pop_back(); + dfs_stack->pop_back(); if (it->second != kVisited) { DCHECK_EQ(current.parent(), this) << "Instruction " << current.name() @@ -451,7 +455,8 @@ void HloComputation::ForEachInstructionPostOrderImpl( if (¤t != root) { auto it = channel_dependencies.find(¤t); if (it != channel_dependencies.end()) { - dfs_stack.insert(dfs_stack.end(), it->second.begin(), it->second.end()); + dfs_stack->insert(dfs_stack->end(), it->second.begin(), + it->second.end()); } } @@ -459,11 +464,12 @@ void HloComputation::ForEachInstructionPostOrderImpl( // processed first. This will produce a more natural ordering and a nicer // result for things like HLO stringification. const HloInstruction::InstructionVector& operands = current.operands(); - dfs_stack.insert(dfs_stack.end(), operands.rbegin(), operands.rend()); + dfs_stack->insert(dfs_stack->end(), operands.rbegin(), operands.rend()); const std::vector& predecessors = current.control_predecessors(); - dfs_stack.insert(dfs_stack.end(), predecessors.begin(), predecessors.end()); + dfs_stack->insert(dfs_stack->end(), predecessors.begin(), + predecessors.end()); } } @@ -509,8 +515,9 @@ std::vector HloComputation::MakeInstructionPostOrderFrom( HloInstruction& postorder_root) const { std::vector post_order; absl::flat_hash_map visited; + std::vector dfs_stack_scratch; ComputeInstructionPostOrder(&postorder_root, ComputeChannelDependencies(), - visited, post_order); + visited, post_order, &dfs_stack_scratch); return post_order; } @@ -524,10 +531,12 @@ std::vector HloComputation::MakeInstructionPostOrder( post_order.reserve(instruction_count()); absl::flat_hash_map visited; visited.reserve(instruction_count()); + std::vector dfs_stack_scratch; + dfs_stack_scratch.reserve(instruction_count()); for (auto& instruction : instructions_) { if (instruction->users().empty()) { ComputeInstructionPostOrder(instruction.get(), channel_dependencies, - visited, post_order); + visited, post_order, &dfs_stack_scratch); } } CHECK_EQ(instructions_.size(), post_order.size()) @@ -602,11 +611,14 @@ void HloComputation::ForEachInstructionPostOrder( absl::FunctionRef func) const { absl::flat_hash_map visited; visited.reserve(instruction_count()); + std::vector dfs_stack_scratch; + dfs_stack_scratch.reserve(instruction_count()); auto channel_dependencies = ComputeChannelDependencies(); for (auto& instruction : instructions_) { if (instruction->users().empty()) { ForEachInstructionPostOrderImpl(func, instruction.get(), - channel_dependencies, visited); + channel_dependencies, visited, + &dfs_stack_scratch); } } } @@ -1405,12 +1417,13 @@ std::unique_ptr HloComputation::CloneInContext( // ourselves. std::vector postorder; absl::flat_hash_map visited; + std::vector dfs_stack; for (const auto& instr : instructions_) { - std::vector dfs_stack; const HloInstruction* new_instr = replace(instr.get()); if (!new_instr) { continue; } + dfs_stack.clear(); dfs_stack.push_back(new_instr); while (!dfs_stack.empty()) { diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 92d2e00013bb1a..2764e26aa67c3b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -820,12 +820,14 @@ class HloComputation { void ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, - std::vector& post_order) const; + std::vector& post_order, + std::vector* dfs_stack_scratch) const; void ForEachInstructionPostOrderImpl( absl::FunctionRef func, HloInstruction* root, const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited) const; + absl::flat_hash_map& visited, + std::vector* dfs_stack_scratch) const; Status RemoveUnusedParametersImpl(bool allow_non_fusion); From 7052d20c1b7bc422d42e4d57740b06580bf31c65 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 1 Dec 2023 16:40:22 -0800 Subject: [PATCH 310/381] Refactor spmd_partitioner.cc. Merge PatternMatchMergeSharding and PatternMatchUnmergeSharding into one unified function PatternMatchMergeOrSplitSharding. PiperOrigin-RevId: 587160141 --- third_party/xla/xla/service/spmd/BUILD | 11 +- .../xla/xla/service/spmd/spmd_partitioner.cc | 330 +++++++----------- 2 files changed, 134 insertions(+), 207 deletions(-) diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index f80c4104cdf29f..19330164818afd 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -1,7 +1,7 @@ # Description: SPMD partitioning pass. -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") package( default_visibility = ["//visibility:public"], @@ -34,12 +34,16 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//xla:array", "//xla:comparison_util", "//xla:literal", "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:types", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", @@ -50,12 +54,14 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", + "//xla/service:computation_layout", "//xla/service:custom_call_sharding_helper", "//xla/service:dot_as_convolution_util", "//xla/service:flatten_call_graph", "//xla/service:hlo_cse", "//xla/service:hlo_dce", "//xla/service:hlo_lexer", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:hlo_pass_pipeline", "//xla/service:pattern_matcher", @@ -69,8 +75,11 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index a633dec5677b4a..61a194790ab58f 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -16,8 +16,10 @@ limitations under the License. #include "xla/service/spmd/spmd_partitioner.h" #include +#include #include #include +#include #include #include #include @@ -27,13 +29,16 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/array.h" #include "xla/comparison_util.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -42,20 +47,32 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/protobuf_util.h" +#include "xla/service/call_graph.h" +#include "xla/service/computation_layout.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/shape_inference.h" #include "xla/service/spmd/custom_call_handler.h" #include "xla/service/spmd/spmd_partitioner_util.h" #include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1766,19 +1783,21 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( namespace { -// Matching a pattern like [..,X,..,Y] -> [..,X*Y,..,1] or [..,X,..,Y] -> -// [..,1,..,X*Y]. +// Matching the following patterns, where X, Y, cannot be 1, Z can be 1. +// 1. [..,X,..,Y,..] -> [..,X*Y,..,1,..] +// 2. [..,Y,..,X,..] -> [..,1,..,X*Y,..] +// 3. [..,X*Y,..,Z,..] -> [..,X,..,Y*Z,..] +// 4. [..,Z,..,X*Y,..] -> [..,Y*Z,..,X,..] // Output tuple: -// - HloSharding: The original sharding with an extra dimension added of size 1. -// - HloSharding: The sharding with the dimension we want to merge moved in -// place of the dimension of size 1 we added. -// - int: Dimension in the input that is going to be merged with another -// dimension (becoming bigger). -// - int: Dimension in the input that is going to be merged into another -// dimension (becoming 1). -std::optional> -PatternMatchMergeSharding(const Shape& shape, const HloSharding& source, - const HloSharding& target) { +// - HloSharding: The original sharding with an extra dimension added of size 1 +// or Y. +// - HloSharding: The sharding with the new dimension added moved in the place +// where we expect the target dimension to be. +// - int64_t: The index of X. +std::optional> +PatternMatchMergeOrSplitSharding(const Shape& shape, const Shape& base_shape, + const HloSharding& source, + const HloSharding& target) { if (!source.IsTiled() || !target.IsTiled()) { return std::nullopt; } @@ -1791,170 +1810,103 @@ PatternMatchMergeSharding(const Shape& shape, const HloSharding& source, target.tile_assignment().dimensions()[target.TiledDataRank()])) { return std::nullopt; } - for (int i = 0; i < target.TiledDataRank(); ++i) { - if (source.tile_assignment().dim(i) < target.tile_assignment().dim(i) && - (target.tile_assignment().dim(i) % source.tile_assignment().dim(i)) == - 0) { - auto get_reshaped_sharding = - [&](int64_t target_idx) -> std::optional { - if (target.tile_assignment().dim(target_idx) != 1) { - return std::nullopt; - } - if (target.tile_assignment().dim(i) != - source.tile_assignment().dim(i) * - source.tile_assignment().dim(target_idx)) { - return std::nullopt; - } - if (shape.dimensions(i) % source.tile_assignment().dim(target_idx) != - 0) { - return std::nullopt; - } - return hlo_sharding_util::SplitShardingDimension( - source, i, source.tile_assignment().dim(i)); - }; - for (int j = i - 1; j >= 0; --j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Merge From Left"; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); - } - } - for (int j = i + 1; j < target.TiledDataRank(); ++j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Merge From Right"; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j + 1]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); - } - } - } - } - return std::nullopt; -} -// Matching a pattern like [..,X*Y,..,1] -> [..,X,..,Y] or [..,1,..,X*Y] -> -// [..,X,..,Y]. -// Output tuple: -// - HloSharding: The original sharding with an extra dimension added of size Y. -// - HloSharding: The sharding with the new dimension added moved in the place -// where we expect the target dimension to be. -// - int: Dimension in the input that is going to be unmerged (getting split). -// - int: Dimension in the input that is going to be the destination of the -// unmerged dimension. -std::optional> -PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, - const HloSharding& source, - const HloSharding& target) { - if (!source.IsTiled() || !target.IsTiled()) { - return std::nullopt; - } - if (source.TiledDataRank() != target.TiledDataRank()) { - return std::nullopt; + std::vector diff_index; + for (int64_t i = 0; i < target.TiledDataRank(); ++i) { + if (source.tile_assignment().dim(i) != target.tile_assignment().dim(i)) { + diff_index.push_back(i); + } } - if ((source.HasPartialReplication() ^ target.HasPartialReplication()) || - (source.HasPartialReplication() && - source.tile_assignment().dimensions()[source.TiledDataRank()] != - target.tile_assignment().dimensions()[target.TiledDataRank()])) { + if (diff_index.size() < 2) { return std::nullopt; } - for (int i = 0; i < target.TiledDataRank(); ++i) { - if (source.tile_assignment().dim(i) > target.tile_assignment().dim(i) && - target.tile_assignment().dim(i) != 1 && - base_shape.dimensions(i) % source.tile_assignment().dim(i) == 0 && - source.tile_assignment().dim(i) % target.tile_assignment().dim(i) == - 0) { - auto get_reshaped_sharding = - [&](int64_t target_dim) -> std::optional { - if (source.tile_assignment().dim(target_dim) == - target.tile_assignment().dim(target_dim) || - source.tile_assignment().dim(i) != - target.tile_assignment().dim(i) * - target.tile_assignment().dim(target_dim)) { - VLOG(10) << "Skipped for target dim different from dimension_size " - << target_dim - << " src size: " << source.tile_assignment().dim(i) - << " target size: " - << target.tile_assignment().dim(target_dim); - return std::nullopt; - } - return hlo_sharding_util::SplitShardingDimension( - source, i, target.tile_assignment().dim(i)); - }; - for (int j = i - 1; j >= 0; --j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Right i = " << i << ",j = " << j; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); + + // Iterate every pair of elements in diff_index. + for (int64_t diff_index_i = 0; diff_index_i < diff_index.size(); + ++diff_index_i) { + for (int64_t diff_index_j = diff_index_i + 1; + diff_index_j < diff_index.size(); ++diff_index_j) { + int64_t i = diff_index[diff_index_i]; + int64_t j = diff_index[diff_index_j]; + const std::vector is_one = {source.tile_assignment().dim(i) == 1, + source.tile_assignment().dim(j) == 1, + target.tile_assignment().dim(i) == 1, + target.tile_assignment().dim(j) == 1}; + int64_t new_dim_size; + switch (std::count(is_one.begin(), is_one.end(), true)) { + case 1: { + if (source.tile_assignment().dim(i) * + source.tile_assignment().dim(j) != + target.tile_assignment().dim(i) * + target.tile_assignment().dim(j)) { + continue; + } + if (source.tile_assignment().dim(i) == 1 || + target.tile_assignment().dim(i) == 1) { + std::swap(i, j); + // After the swap, we always have the following. + // i is the dimension without size 1 in either source or target + // j is the dimension with size 1 in either source or target + } + if (target.tile_assignment().dim(j) == 1) { + // dim of size 1 is in the target + if (shape.dimensions(i) % source.tile_assignment().dim(j) != 0) { + continue; + } + new_dim_size = source.tile_assignment().dim(i); + } else { + // dim of size 1 is in the source + if (base_shape.dimensions(i) % source.tile_assignment().dim(i) != + 0) { + continue; + } + new_dim_size = target.tile_assignment().dim(i); + } + break; } - } - for (int j = i + 1; j < target.TiledDataRank(); ++j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Left i = " << i << ",j = " << j; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j + 1]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); + case 0: { + if (source.tile_assignment().dim(i) < + target.tile_assignment().dim(i)) { + std::swap(i, j); + // After the swap, we always have the following. + // source.tile_assignment().dim(i) > target.tile_assignment().dim(i) + // source.tile_assignment().dim(j) < target.tile_assignment().dim(j) + } + if (source.tile_assignment().dim(i) != + target.tile_assignment().dim(i) * + target.tile_assignment().dim(j)) { + continue; + } + if (base_shape.dimensions(i) % source.tile_assignment().dim(i) != 0) { + continue; + } + new_dim_size = target.tile_assignment().dim(i); + break; } + default: + continue; } + + auto reshaped_sharding = + hlo_sharding_util::SplitShardingDimension(source, i, new_dim_size); + std::vector dimensions( + reshaped_sharding.tile_assignment().dimensions().begin(), + reshaped_sharding.tile_assignment().dimensions().end()); + std::swap(dimensions[i + 1], dimensions[j + (j > i ? 1 : 0)]); + auto target_tile_assignment = + target.tile_assignment().Reshape(dimensions); + auto new_sharding = + source.HasPartialReplication() + ? HloSharding::PartialTile(target_tile_assignment, + source.metadata()) + : HloSharding::Tile(target_tile_assignment, source.metadata()); + VLOG(10) << "Reshaped sharding before: " << reshaped_sharding.ToString(); + VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); + return std::make_tuple(std::move(reshaped_sharding), + std::move(new_sharding), i); } } + return std::nullopt; } @@ -2046,10 +1998,9 @@ std::optional PartitionedHlo::TryComplexReshardHandling( const bool is_source_partially_replicated = sharding().ReplicateOnLastTileDim(); const bool is_target_partially_replicated = target.ReplicateOnLastTileDim(); - if (auto reshape = - PatternMatchMergeSharding(this->hlo()->shape(), sharding(), target)) { - auto& [before_sharding, new_reshaped_sharding, source_dim, target_dim] = - *reshape; + if (auto reshape = PatternMatchMergeOrSplitSharding( + this->hlo()->shape(), this->base_shape(), sharding(), target)) { + auto& [before_sharding, new_reshaped_sharding, source_dim] = *reshape; VLOG(10) << "Matched \"pattern_match_reshape()\": " << std::get<0>(*reshape).ToString(); VLOG(10) << "Original shape: " << hlo()->shape().ToString(); @@ -2077,39 +2028,6 @@ std::optional PartitionedHlo::TryComplexReshardHandling( } return reshaped; } - if (auto reshape = PatternMatchUnmergeSharding( - this->hlo()->shape(), this->base_shape(), sharding(), target)) { - auto& [before_sharding, new_reshaped_sharding, source_dim, target_dim] = - *reshape; - VLOG(10) << "Matched \"unmerge_sharding()\": " - << new_reshaped_sharding.ToString(); - VLOG(10) << "Original shape: " << hlo()->shape().ToString(); - VLOG(10) << "Base shape: " << base_shape().ToString(); - PartitionedHlo reshaped = SplitReshapeHelper( - *this, source_dim, this->hlo()->shape().dimensions(source_dim), - before_sharding); - VLOG(10) << "Reshaped shape: " << reshaped.hlo()->shape().ToString(); - VLOG(10) << "Reshaped base_shape: " << reshaped.base_shape().ToString(); - VLOG(10) << "Before sharding: " << before_sharding.ToString(); - VLOG(10) << "Reshaped: " << reshaped.hlo()->ToString(); - auto reshard = reshaped.ReshardNoCache(new_reshaped_sharding, - /*pad_value=*/std::nullopt, - /*allow_full_replication=*/false); - if (reshard.sharding() != new_reshaped_sharding) { - return std::nullopt; - } - auto reshaped_sharding = hlo_sharding_util::MergeShardingDimension( - reshard.sharding(), source_dim); - reshaped = MergeReshapeHelper(reshard, source_dim, reshaped_sharding); - if (reshaped.sharding() != target) { - reshaped = reshaped.ReshardNoCache(target, /*pad_value=*/std::nullopt, - /*allow_full_replication=*/false); - if (reshaped.sharding() != target) { - return std::nullopt; - } - } - return reshaped; - } if (auto intermediate_target = PatternMatchPartiallyReplicateDim(sharding(), target)) { VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": " From 75d45d31e2893195dc7a15b0319fa8922c7cca03 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2023 17:09:42 -0800 Subject: [PATCH 311/381] Lower shape.cstr_broadcastable op in ShapeLegalizeToHLO Lower the op to shape_assertion custom_calls. And replace the result shape.witness with a const_witness of true. This allows the subsequent shape.assuming regions to be removed. Also add an option to the pass to control whether to lower cstr_xxx ops. This makes the impl easier as we can reuse util functions in the file. And it makes sure existing use cases are not modified. PiperOrigin-RevId: 587167546 --- .../mlir_hlo/mhlo/transforms/mhlo_passes.td | 4 + .../xla/xla/mlir_hlo/mhlo/transforms/passes.h | 3 +- .../shape_legalize_to_hlo.cc | 95 ++++++++++++++++--- .../mhlo/shape_cstr_legalize_to_hlo.mlir | 60 ++++++++++++ .../Dialect/mhlo/shape_legalize_to_hlo.mlir | 8 ++ 5 files changed, 158 insertions(+), 12 deletions(-) create mode 100644 third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index e29f17382437d3..5bf2e91fbd4afc 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -392,4 +392,8 @@ def ShapeLegalizeToHloPass : Pass<"shape-legalize-to-hlo", "func::FuncOp"> { compilation pipelines that use HLO operations to model dynamism. }]; let dependentDialects = ["mhlo::MhloDialect"]; + let options = [ + Option<"legalize_constraints_", "legalize-constraints", "bool", + /*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call"> + ]; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h index 946c56ae18ab1b..dd06b8ddd49bcb 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/passes.h @@ -196,7 +196,8 @@ std::unique_ptr> createHloLegalizeToStablehloPass(); std::unique_ptr> createStablehloLegalizeToHloPass(); // Legalizes from the Shape dialect to the MHLO dialect. -std::unique_ptr> createShapeLegalizeToHloPass(); +std::unique_ptr> createShapeLegalizeToHloPass( + bool legalizeConstraints = false); // Test passes. std::unique_ptr createTestInferShapedTypeMethodsPass(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc index 48a1fc1f722810..29c47aa7921417 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -268,8 +269,7 @@ struct ConvertShapeBroadcastOpPattern if (!shape1 || !shape2) return failure(); auto tensorType1 = shape1.getType().dyn_cast(); auto tensorType2 = shape2.getType().dyn_cast(); - if (!tensorType1 || !tensorType2 || tensorType1.getRank() != 1 || - tensorType2.getRank() != 1 || + if (!tensorType1 || !tensorType2 || tensorType1.getDimSize(0) != tensorType2.getDimSize(0)) return failure(); @@ -356,6 +356,68 @@ struct ConvertTensorFromElementsPattern } }; +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // Only support broadcasting for two 1D tensors with same size. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = shape1.getType().dyn_cast(); + auto tensorType2 = shape2.getType().dyn_cast(); + if (!tensorType1 || !tensorType2 || + tensorType1.getDimSize(0) != tensorType2.getDimSize(0)) + return failure(); + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({tensorType1.getDimSize(0)}, + rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create( + op.getLoc(), shape1, allOne, ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create( + op.getLoc(), shape2, allOne, ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create( + op.getLoc(), shape1, shape2, ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < tensorType1.getDimSize(0); ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getI64TensorAttr(i), + rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + auto customCall = rewriter.create( + op.getLoc(), TypeRange{}, ValueRange{allBroadcastableScalar}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + rewriter.getStringAttr("Shape assertion failed")); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + template struct CastOperandsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -387,6 +449,12 @@ struct CastOperandsPattern : public OpRewritePattern { // needed to support bounded dynamism in MHLO export. struct ShapeLegalizeToHloPass : public impl::ShapeLegalizeToHloPassBase { + explicit ShapeLegalizeToHloPass(bool legalizeConstraints) + : impl::ShapeLegalizeToHloPassBase< + ShapeLegalizeToHloPass>::ShapeLegalizeToHloPassBase() { + this->legalize_constraints_ = legalizeConstraints; + } + void runOnOperation() override { // In order to make dynamic MHLO programs compatible with HLO, // we need to get rid of all non-MHLO ops as well as the two shape-related @@ -413,12 +481,10 @@ struct ShapeLegalizeToHloPass // is able to remove unnecessary cruft. At the moment, this pass is a // work in progress, so not all of these ops are supported. // - // The only problem (and a big problem at that) are the ops involved in - // shape constraints: cstr* ops as well as shape.assuming*. Since HLO does - // not support shape constraints, it is currently unclear what to do with - // them, unless they can be removed by --symbolic-shape-optimization. - // At the moment, this pass is a work in progress, so it does not provide - // an answer to this problem yet. + // When legalize_constraints_ is set true, cstr* ops are also legalized. + // A shape_assertion custom_call is used to check the constraint. And the + // shape.assuming region will consume a shape.const_witness that evaluate to + // true, so that it can be removed later in a canonicalizer pass. ConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalDialect(); @@ -429,6 +495,10 @@ struct ShapeLegalizeToHloPass }); target.addLegalOp(); target.addLegalOp(); + if (this->legalize_constraints_) { + target.addLegalOp(); + } // The patterns do what one might expect, converting between MLIR-style // and HLO-style shape computations. @@ -446,6 +516,9 @@ struct ShapeLegalizeToHloPass patterns.add>(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + if (this->legalize_constraints_) { + patterns.add(&getContext()); + } if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); @@ -454,9 +527,9 @@ struct ShapeLegalizeToHloPass } // namespace -std::unique_ptr> -createShapeLegalizeToHloPass() { - return std::make_unique(); +std::unique_ptr> createShapeLegalizeToHloPass( + bool legalizeConstraints) { + return std::make_unique(legalizeConstraints); } } // namespace mhlo diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 00000000000000..18d1667991edec --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-hlo-opt --shape-legalize-to-hlo=legalize-constraints=true --split-input-file --verify-diagnostics %s | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_different_dims(%arg0: tensor<2xindex>, %arg1: tensor<3xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir index f81af2fa052715..6d3ee19148b667 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir @@ -186,3 +186,11 @@ func.func @shape_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tens %0 = shape.broadcast %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> func.return %0 : tensor<4xindex> } + +// ----- + +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) -> !shape.witness { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + func.return %0 : !shape.witness +} From 09a463fbac98383ce8d7a9902edda76b8ba884c1 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 1 Dec 2023 17:15:48 -0800 Subject: [PATCH 312/381] Deduplicate stablehlo/experimental passes PiperOrigin-RevId: 587169004 --- third_party/stablehlo/temporary.patch | 2120 ++++++----------- .../xla/third_party/stablehlo/temporary.patch | 2120 ++++++----------- 2 files changed, 1506 insertions(+), 2734 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8abb8b476d8c90..b9fde3ceefd71a 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,3 +1,14 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -889,6 +889,7 @@ + hdrs = [ + "stablehlo/transforms/MapStablehloToVhlo.h", + "stablehlo/transforms/Passes.h", ++ "stablehlo/transforms/StablehloRefineShapes.h", + ], + strip_include_prefix = ".", + deps = [ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -159,7 +170,7 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel -@@ -0,0 +1,113 @@ +@@ -0,0 +1,114 @@ +# Copyright 2023 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -243,6 +254,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/ + "//:chlo_ops", + "//:stablehlo_ops", + "//:stablehlo_ops_inc_gen", ++ "//:stablehlo_passes", + "//:stablehlo_type_inference", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", @@ -1826,7 +1838,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp b/stabl diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt --- stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +++ stablehlo/stablehlo/experimental/transforms/CMakeLists.txt -@@ -0,0 +1,38 @@ +@@ -0,0 +1,39 @@ +# Copyright 2023 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -1862,8 +1874,9 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable + MLIRTransformUtils + ExperimentalStablehloOps + StablehloBase -+ StablehloTypeInference + StablehloOps ++ StablehloPasses ++ StablehloTypeInference +) diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h --- stablehlo/stablehlo/experimental/transforms/Passes.h @@ -1944,7 +1957,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp -@@ -0,0 +1,441 @@ +@@ -0,0 +1,167 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2023 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); @@ -1960,14 +1973,12 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +limitations under the License. +==============================================================================*/ + -+#include "llvm/ADT/DenseSet.h" ++#include ++ +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/Value.h" ++#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1975,6 +1986,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -1985,169 +1997,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + +namespace { + -+struct CanonicalizeCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ auto indicesAttr = -+ op->getAttr("indices_of_shape_operands").cast(); -+ DenseSet indices(indicesAttr.value_begin(), -+ indicesAttr.value_end()); -+ -+ // Discard the indices_of_shape_operands attribute. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that its value is consistent with the result types. -+ // In the future, when we upgrade indices_of_shape_operands from an -+ // experiment to a full-fledged StableHLO feature, this logic will be moved -+ // to a proper verifier. -+ SmallVector newAttrs; -+ for (auto attr : op->getAttrs()) { -+ if (attr.getName() == "indices_of_shape_operands") continue; -+ if (attr.getName() == "operand_layouts") { -+ // Drop the operand_layouts that correspond to indices_of_shape_operands -+ ArrayAttr operandLayouts = op.getOperandLayoutsAttr(); -+ SmallVector newOperandLayouts; -+ for (unsigned i = 0; i < operandLayouts.size(); ++i) { -+ if (indices.contains(i)) continue; -+ newOperandLayouts.push_back(operandLayouts[i]); -+ } -+ attr = NamedAttribute(attr.getName(), -+ rewriter.getArrayAttr(newOperandLayouts)); -+ } -+ newAttrs.push_back(attr); -+ } -+ -+ // Discard the operands that correspond to indices_of_shape_operands. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that: 1) these operands are static, 2) the values of these -+ // operands are consistent with the result types. -+ SmallVector newOperands; -+ auto resultIndex = 0; -+ for (auto& operand : op->getOpOperands()) { -+ if (indices.contains(operand.getOperandNumber())) { -+ auto resultType = -+ op->getResult(resultIndex).getType().dyn_cast(); -+ if (!resultType || !resultType.hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, -+ "expected static result types"); -+ ++resultIndex; -+ continue; -+ } -+ newOperands.push_back(operand.get()); -+ } -+ rewriter.replaceOpWithNewOp(op, op.getResultTypes(), -+ newOperands, newAttrs); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_dimensions operand as well as the -+ // known_expanding_dimensions and known_nonexpanding_dimensions attributes. -+ // We rely on the verifier to make sure that their values are consistent -+ // with the result type. -+ if (!op.getOperand().getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static operand type"); -+ if (!succeeded(hlo::matchInts(op.getOutputDimensions()))) -+ return rewriter.notifyMatchFailure(op, -+ "expected static output_dimensions"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getBroadcastDimensions()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicConvOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ // ConvolutionOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector padding; -+ if (!succeeded(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected static padding"); -+ auto paddingAttr = DenseIntElementsAttr::get( -+ RankedTensorType::get({static_cast(padding.size()) / 2, 2}, -+ rewriter.getI64Type()), -+ padding); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getLhs(), op.getRhs(), op.getWindowStridesAttr(), -+ paddingAttr, op.getLhsDilationAttr(), op.getRhsDilationAttr(), -+ op.getWindowReversalAttr(), op.getDimensionNumbers(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfigAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicGatherOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicGatherOp op, -+ PatternRewriter& rewriter) const override { -+ // GatherOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector sliceSizes; -+ if (!succeeded(hlo::matchInts(op.getSliceSizes(), sliceSizes))) -+ return rewriter.notifyMatchFailure(op, "expected static slice_sizes"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getStartIndices(), -+ op.getDimensionNumbersAttr(), rewriter.getI64TensorAttr(sliceSizes), -+ op.getIndicesAreSortedAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicIotaOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_shape operand. We rely on the verifier -+ // to make sure that its value is consistent with result type. -+ SmallVector outputShape; -+ if (!succeeded(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), -+ op.getIotaDimension()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // PadOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, "expected static low"); -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, "expected static high"); -+ if (!succeeded(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, "expected static interior"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getPaddingValue(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding)); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2196,22 +2045,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern ignores and discards the output_shape operand. We rely on -+ // the verifier to make sure that its value is consistent with result type. -+ if (!succeeded(hlo::matchInts(op.getOutputShape()))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2262,91 +2095,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeRealDynamicSliceOpToDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // DynamicSliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ -+ // This rewrite only works for unit strides because DynamicSliceOp -+ // doesn't support strides (i.e. it implicitly has unit strides). -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ if (!llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // Check that slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (!matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && -+ !matchPattern(op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) -+ return rewriter.notifyMatchFailure( -+ op, "expected limit indices equal to start indices plus constant"); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector startIndices; -+ for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { -+ auto startIndexElementType = -+ op.getStartIndices().getType().getElementType(); -+ auto startIndex1DType = RankedTensorType::get({1}, startIndexElementType); -+ auto startIndex1D = rewriter.create( -+ op.getLoc(), startIndex1DType, op.getStartIndices(), -+ rewriter.getI64TensorAttr(i), rewriter.getI64TensorAttr(i + 1), -+ rewriter.getI64TensorAttr(1)); -+ auto startIndex0DType = RankedTensorType::get({}, startIndexElementType); -+ auto startIndex0D = rewriter.create( -+ op.getLoc(), startIndex0DType, startIndex1D); -+ startIndices.push_back(startIndex0D); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), startIndices, -+ rewriter.getI64TensorAttr(sliceSizes)); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeRealDynamicSliceOpToSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // SliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector startIndices, limitIndices, strides; -+ if (!succeeded(hlo::matchInts(op.getStartIndices(), startIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static start"); -+ if (!succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static limit"); -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides)); -+ return success(); -+ } -+}; -+ +struct StablehloCanonicalizeDynamismPass + : public impl::StablehloCanonicalizeDynamismPassBase< + StablehloCanonicalizeDynamismPass> { @@ -2362,19 +2110,10 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + patterns.add(&getContext()); -+ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ patterns.add( -+ &getContext()); -+ patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); @@ -2389,7 +2128,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp -@@ -0,0 +1,1308 @@ +@@ -0,0 +1,162 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -2404,41 +2143,22 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +limitations under the License. +==============================================================================*/ + ++#include "stablehlo/transforms/StablehloRefineShapes.h" ++ +#include -+#include -+#include -+#include + -+#include "llvm/ADT/APInt.h" -+#include "llvm/ADT/APSInt.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/STLFunctionalExtras.h" -+#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" -+#include "llvm/ADT/StringRef.h" -+#include "llvm/Support/ErrorHandling.h" -+#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" -+#include "mlir/IR/MLIRContext.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/OpDefinition.h" -+#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/Base.h" -+#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/TypeInference.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -2449,382 +2169,239 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + +namespace { + -+// DenseElementsAttr can be constructed from ArrayRef but not from -+// ArrayRef. This helper bridges the gap. -+DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { -+ SmallVector supportedValues(values); -+ return DenseIntElementsAttr::get(type, supportedValues); -+} -+ -+APSInt getAPSInt(Type type, uint64_t value) { -+ unsigned numBits; -+ bool isUnsigned; -+ if (auto integerType = type.dyn_cast()) { -+ numBits = integerType.getWidth(); -+ // Signless types are treated as signed, per StableHLO convention. -+ isUnsigned = integerType.isUnsignedInteger(); -+ } else { -+ llvm::report_fatal_error("expected integer type"); -+ } -+ return APSInt({/*numBits=*/numBits, value}, -+ /*isUnsigned=*/isUnsigned); -+} -+ -+// The patterns below implement partial evaluation of shape computations which -+// is a critical part of implementing type refinement for ops like -+// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape -+// depends on the value of their shape operands. -+ -+template -+LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, -+ FuncType fn) { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || -+ !resultType.getElementType().template isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ -+ SmallVector result; -+ if constexpr (OpType::template hasTrait()) { -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ for (const auto& operandEl : operand) { -+ result.push_back(fn(operandEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<2>::Impl>()) { -+ SmallVector lhs, rhs; -+ if (failed(hlo::matchInts(op.getLhs(), lhs)) || -+ failed(hlo::matchInts(op.getRhs(), rhs))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { -+ result.push_back(fn(lhsEl, rhsEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<3>::Impl>()) { -+ SmallVector x, y, z; -+ if (failed(hlo::matchInts(op->getOperand(0), x)) || -+ failed(hlo::matchInts(op->getOperand(1), y)) || -+ failed(hlo::matchInts(op->getOperand(2), z))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { -+ result.push_back(fn(xEl, yEl, zEl)); -+ } -+ } else { -+ llvm::report_fatal_error("unsupported number of operands"); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+} -+ -+struct EvalAddOpPattern : public OpRewritePattern { ++struct RefineDynamicReduceWindowOpPattern ++ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AddOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); -+ } -+}; ++ auto maybeOp = getDynamicReduceWindowOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicReduceWindowOpAdaptor op = *maybeOp; + -+struct EvalAndOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AndOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ SmallVector windowDimensions, windowStrides, baseDilations, ++ windowDilations, padding; ++ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dimensions"); ++ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_strides"); ++ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant base_dilations"); ++ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dilations"); ++ if (failed(hlo::matchInts(op.getPadding(), padding))) ++ return rewriter.notifyMatchFailure(op, "expected constant padding"); + -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); -+ }); ++ SmallVector inferredReturnTypes; ++ if (failed(hlo::inferReduceWindowOp( ++ /*location=*/{}, op.getInputs(), op.getInitValues(), ++ rewriter.getI64TensorAttr(windowDimensions), ++ rewriter.getI64TensorAttr(windowStrides), ++ rewriter.getI64TensorAttr(baseDilations), ++ rewriter.getI64TensorAttr(windowDilations), ++ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) ++ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); ++ return refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; + -+struct EvalBroadcastInDimOpPattern : public OpRewritePattern { ++struct RefineDynamicRngBitGeneratorOpPattern ++ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BroadcastInDimOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank() || operandType.getRank() != 0) -+ return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); ++ auto maybeOp = getDynamicRngBitGeneratorOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; + -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ auto scalar = operand[0]; ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ auto initialStateType = op.getInitialState().getType().cast(); ++ SmallVector outputShape; ++ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) ++ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); + -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), scalar)); -+ return success(); ++ // We only need to refine the shape of `output` (the second result). ++ // The shape of `output_state` (the first result) is determined by the shape ++ // of `initial_state`, so we ignore it and provide an empty refinement. ++ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); + } +}; + -+struct EvalClampOpPattern : public OpRewritePattern { ++struct RefineDynamicTopKOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ClampOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt min, APSInt operand, APSInt max) { -+ if (operand < min) return min; -+ if (max < operand) return max; -+ return operand; -+ }); -+ } -+}; ++ auto maybeOp = getDynamicTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicTopKOpAdaptor op = *maybeOp; + -+struct EvalCompareOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CompareOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ bool result; -+ switch (op.getComparisonDirection()) { -+ case ComparisonDirection::EQ: -+ result = lhs == rhs; -+ break; -+ case ComparisonDirection::NE: -+ result = lhs != rhs; -+ break; -+ case ComparisonDirection::GE: -+ result = lhs >= rhs; -+ break; -+ case ComparisonDirection::GT: -+ result = lhs > rhs; -+ break; -+ case ComparisonDirection::LE: -+ result = lhs <= rhs; -+ break; -+ case ComparisonDirection::LT: -+ result = lhs < rhs; -+ break; -+ } -+ return getAPSInt(resultType.getElementType(), result); -+ }); ++ auto operandType = op.getOperand().getType().cast(); ++ SmallVector outputShape(operandType.getShape()); ++ SmallVector k; ++ if (failed(hlo::matchInts(op.getK(), k))) ++ return rewriter.notifyMatchFailure(op, "expected constant k"); ++ ++ outputShape[operandType.getRank() - 1] = k[0]; ++ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); + } +}; + -+struct EvalConcatenateOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConcatenateOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || op.getDimension() != 0) -+ return rewriter.notifyMatchFailure(op, "expected dimension = 0"); ++struct StablehloRefineShapesPass ++ : public impl::StablehloRefineShapesPassBase { ++ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + -+ SmallVector result; -+ for (Value operand : op->getOperands()) { -+ if (failed(hlo::matchInts(operand, result))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ } ++ void runOnOperation() override { ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; ++ // The algorithm behind this pass consists of a single traversal of the ++ // function. This is sufficient because we only support one function per ++ // program at the moment. ++ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // There have been recent refactors to applyPatternsAndFoldGreedily ++ // upstream, and that might be the reason. ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; + -+struct EvalConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth(); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ return operand.extOrTrunc(resultBitWidth); -+ }); ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ if (failed( ++ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { ++ return signalPassFailure(); ++ } + } +}; + -+struct EvalDivOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DivOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); -+ } -+}; ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +@@ -340,6 +340,19 @@ + %1 = stablehlo.constant dense<2> : tensor + %2 = stablehlo.multiply %0, %1 : tensor + func.return %2 : tensor ++} + -+struct EvalGetDimensionSizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(GetDimensionSizeOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand"); -+ if (operandType.isDynamicDim(op.getDimension())) -+ return rewriter.notifyMatchFailure(op, "expected static dimension"); -+ -+ auto result = operandType.getDimSize(op.getDimension()); -+ rewriter.replaceOpWithNewOp( -+ op, DenseIntElementsAttr::get(op.getType(), result)); -+ return success(); -+ } -+}; ++// ----- + -+struct EvalMaxOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MaxOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs >= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMinOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MinOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs <= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMulOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MulOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); -+ } -+}; -+ -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ -+struct EvalRemOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RemOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); -+ } -+}; -+ -+struct EvalReshapeOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ DenseIntElementsAttr attr; -+ if (!matchPattern(op.getOperand(), m_Constant(&attr))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ rewriter.replaceOpWithNewOp(op, attr.reshape(op.getType())); -+ return success(); -+ } -+}; -+ -+struct EvalSelectOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SelectOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector pred, onTrue, onFalse; -+ if (failed(hlo::matchInts(op.getPred(), pred)) || -+ failed(hlo::matchInts(op.getOnTrue(), onTrue)) || -+ failed(hlo::matchInts(op.getOnFalse(), onFalse))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ -+ SmallVector result; -+ for (auto [predEl, onTrueEl, onFalseEl] : -+ llvm::zip(pred, onTrue, onFalse)) { -+ result.push_back(predEl != 0 ? onTrueEl : onFalseEl); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), result)); -+ return success(); -+ } -+}; -+ -+struct EvalSignOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SignOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ int64_t result; -+ if (operand.isNegative()) -+ result = -1; -+ else if (operand.isZero()) -+ result = 0; -+ else -+ result = 1; -+ return getAPSInt(resultType.getElementType(), result); -+ }); -+ } -+}; -+ -+struct EvalSliceOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SliceOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || resultType.getRank() != 1) -+ return rewriter.notifyMatchFailure(op, "expected 1-dimensional type"); -+ -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ -+ int64_t start = op.getStartIndices().getValues()[0]; -+ int64_t limit = op.getLimitIndices().getValues()[0]; -+ int64_t stride = op.getStrides().getValues()[0]; -+ SmallVector result; -+ for (auto i = start; i < limit; i += stride) { -+ result.push_back(operand[i]); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; -+ -+struct EvalSubtractOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SubtractOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); -+ } -+}; -+ -+// The patterns below implement shape refinement of individual ops. -+// In a nutshell, they use the upstream type inference infrastructure and a -+// StableHLO-specific extension to refine return types based on potentially -+// refined operands. ++// CHECK-LABEL: func @eval_or ++func.func @eval_or() -> tensor { ++ // CHECK-NOT: stablehlo.or ++ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor ++ // CHECK: return [[RESULT]] ++ %0 = stablehlo.constant dense : tensor ++ %1 = stablehlo.constant dense : tensor ++ %2 = stablehlo.or %0, %1 : tensor ++ func.return %2 : tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +--- stablehlo/stablehlo/transforms/Passes.h ++++ stablehlo/stablehlo/transforms/Passes.h +@@ -18,9 +18,12 @@ + + #include + ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/Quant/QuantOps.h" + #include "mlir/Dialect/Shape/IR/Shape.h" ++#include "mlir/IR/BuiltinOps.h" + #include "mlir/Pass/Pass.h" ++#include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/DialectConversion.h" + + namespace mlir { +@@ -33,6 +36,14 @@ + #define GEN_PASS_DECL_VHLOTOVERSIONPASS + #define GEN_PASS_REGISTRATION + #include "stablehlo/transforms/Passes.h.inc" ++ ++// Populates --stablehlo-canonicalize-dynamism patterns. ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); ++ ++// Populates --stablehlo-refine-shapes patterns. ++void populateStablehloRefineShapesPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); + + // Populates StableHLO ops to VHLO ops rewriting patterns. + void populateStablehloToVhloPatterns(RewritePatternSet *patterns, +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -314,16 +314,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add( +- &getContext()); +- patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); +@@ -332,5 +323,19 @@ + }; + + } // namespace ++ ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -11,6 +11,8 @@ + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + -+// Refines the values using the given types. -+// Tricky implementation details: -+// 1) Need to support partial shape refinements, e.g. if just a single -+// dimension size out of an entire tensor type got refined. This is done -+// via inferMostSpecificType. -+// 2) Need to signal propagation of the refined shapes across the -+// StableHLO program. Different callers of this function have different -+// propagation needs, so this function doesn't signal anything on its own -+// and leaves that to the callers. ++#include "stablehlo/transforms/StablehloRefineShapes.h" + + #include + #include +@@ -53,6 +55,193 @@ + #define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS + #include "stablehlo/transforms/Passes.h.inc" + +LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, + ValueRange values, TypeRange types) { + if (values.size() != types.size()) @@ -2911,10 +2488,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return success(); +} + -+// Refines the return types of the given operation using the given types. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. +LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, + ArrayRef types) { + if (failed(refineValues(rewriter, op, op->getResults(), types))) @@ -2929,19 +2502,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return success(); +} + -+// Refines the return types of the given operation using the given types. -+// Tricky implementation details: -+// 1) `types` can include non-shaped types. If there are tuple types, -+// then they are first flattened into non-tuple types using in-order -+// traversal, and only then we apply the refinements. If there are other -+// types, then the corresponding refinements must be completely empty. -+// 2) Encodings are not supported. In principle, TypeExtensions should be -+// supportable, but this needs careful thinking through. Given that no one -+// asked for support for bounded dynamism in this pass yet, this is left -+// for future work. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. +LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, + ArrayRef refinements) { + SmallVector flattenedTypes; @@ -3028,6 +2588,526 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return failure(); + return refineReturnTypes(rewriter, op, refinedTypes); +} ++ + namespace { + + // DenseElementsAttr can be constructed from ArrayRef but not from +@@ -304,6 +493,20 @@ + } + }; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ + struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, +@@ -407,245 +610,6 @@ + // In a nutshell, they use the upstream type inference infrastructure and a + // StableHLO-specific extension to refine return types based on potentially + // refined operands. +- +-// Refines the values using the given types. +-// Tricky implementation details: +-// 1) Need to support partial shape refinements, e.g. if just a single +-// dimension size out of an entire tensor type got refined. This is done +-// via inferMostSpecificType. +-// 2) Need to signal propagation of the refined shapes across the +-// StableHLO program. Different callers of this function have different +-// propagation needs, so this function doesn't signal anything on its own +-// and leaves that to the callers. +-LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, +- ValueRange values, TypeRange types) { +- if (values.size() != types.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineValues failed for " << types << ": expected " +- << values.size() << " types, got " << types.size(); +- }); +- +- // Check whether `types` contain any new information with respect to existing +- // return types. Even if just a single dimension size out of an entire tensor +- // type got updated, using `inferMostSpecificType` ensures that we don't +- // miss that. +- bool needsRefinement = false; +- SmallVector refinedTypes; +- for (auto it : llvm::zip(values.getTypes(), types)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto currentType = std::get<0>(it); +- auto refinement = std::get<1>(it); +- auto refinedType = hlo::inferMostSpecificType( +- /*location=*/{}, {currentType, refinement}); +- if (failed(refinedType)) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "inferMostSpecificType failed for " << currentType << " and " +- << refinement; +- }); +- refinedTypes.push_back(*refinedType); +- needsRefinement |= (currentType != *refinedType); +- } +- if (!needsRefinement) +- return rewriter.notifyMatchFailure(op, "doesn't need refinement"); +- +- for (auto it : llvm::zip(values, refinedTypes)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto value = std::get<0>(it); +- auto refinedType = std::get<1>(it); +- if (value.getType() == refinedType) continue; +- +- // Check whether the users of this value are ready for the type of the +- // value to be refined. +- for (Operation* user : value.getUsers()) { +- // CHLO and StableHLO ops are designed to support type refinements of +- // their operands and results. Any operand type in these ops can change +- // within what's supported by `inferMostSpecificType` without breaking +- // verification of the op. +- if (isa(user->getDialect())) +- continue; +- +- // Simply changing operand type of `func.return` won't work because +- // that won't update the FunctionType of the enclosing `func.func`. +- // Nonetheless, we still want to support these ops because they are widely +- // used in StableHLO programs (although the plan of record is to replace +- // `func.return` ops in StableHLO programs with `stablehlo.return`: +- // https://github.com/openxla/stablehlo/issues/425). +- if (isa(user)) continue; +- +- // Unlike in TensorFlow's type inference pass, here we work only with +- // allowlisted ops to focus our support on well-defined semantics of +- // StableHLO programs. +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "unsupported refinement: tried to refine " << value.getType() +- << " to " << refinedType << " for user " << user; +- }); +- } +- +- // Happy path: simply call setType here because most of our users are +- // fine with that. +- auto unrefinedType = value.getType(); +- value.setType(refinedType); +- +- // Special case: for `func.return`, guard the refinement with a cast +- // and leave propagation of the refined return type to a dedicated pattern. +- auto isFuncReturn = [](OpOperand& use) -> bool { +- return isa(use.getOwner()); +- }; +- if (llvm::none_of(value.getUses(), isFuncReturn)) continue; +- rewriter.setInsertionPointAfter(op); +- auto castToUnrefinedType = rewriter.create( +- op->getLoc(), unrefinedType, value); +- value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); +- } +- +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef types) { +- if (failed(refineValues(rewriter, op, op->getResults(), types))) +- return failure(); +- +- // This `replaceOpWithIf` call doesn't actually change the IR, but +- // it does ask the rewriter to visit all the users of this op. There is no +- // upstream API to achieve this directly, but if it's introduced in the +- // future, we could use it here. +- rewriter.replaceOpWithIf(op, op->getResults(), +- [](OpOperand& use) { return false; }); +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// Tricky implementation details: +-// 1) `types` can include non-shaped types. If there are tuple types, +-// then they are first flattened into non-tuple types using in-order +-// traversal, and only then we apply the refinements. If there are other +-// types, then the corresponding refinements must be completely empty. +-// 2) Encodings are not supported. In principle, TypeExtensions should be +-// supportable, but this needs careful thinking through. Given that no one +-// asked for support for bounded dynamism in this pass yet, this is left +-// for future work. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef refinements) { +- SmallVector flattenedTypes; +- hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); +- auto flattenedSize = flattenedTypes.size(); +- if (flattenedSize != refinements.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineReturnTypes failed: expected " << flattenedSize +- << " refinements, got " << refinements.size(); +- }); +- +- SmallVector flattenedRefinedTypes; +- for (auto it : llvm::zip(flattenedTypes, refinements)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- ShapedType currentType = std::get<0>(it).dyn_cast(); +- ShapedTypeComponents refinement = std::get<1>(it); +- auto failWithReason = [&](StringRef reason) { +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineTypes failed: refining " << currentType +- << "with refinement: {"; +- if (refinement.hasRank()) { +- diag << "shape = [" << refinement.getDims() << "]"; +- if (refinement.getAttribute()) +- diag << "attribute = " << refinement.getAttribute(); +- } else { +- diag << "hasRank = false"; +- } +- diag << ", elementType = " << refinement.getElementType(); +- diag << "} failed: " << reason; +- }); +- }; +- +- // If the current type is not a shaped type, then the refinement must +- // be completely empty. +- if (!currentType) { +- if (refinement.hasRank() || refinement.getElementType() || +- refinement.getAttribute()) +- return failWithReason("unsupported refinement"); +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If the refinement has an element type, then it must be the same as +- // the current element type. +- Type currentElementType = currentType.getElementType(); +- if (refinement.getElementType() && +- currentElementType != refinement.getElementType()) +- return failWithReason("expected compatible element types"); +- +- // If neither the current type nor the refinement are ranked, then there's +- // nothing to refine, and we return the current type. +- bool hasRank = currentType.hasRank() || refinement.hasRank(); +- if (!hasRank) { +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If either the current type or the refinement have encodings, then +- // we fail. Encodings are left for future work. +- Attribute currentEncoding = nullptr; +- if (auto currentRankedType = currentType.dyn_cast()) { +- currentEncoding = currentRankedType.getEncoding(); +- } +- Attribute refinedEncoding = refinement.getAttribute(); +- if (currentEncoding || refinedEncoding) +- return failWithReason("expected compatible encodings"); +- +- // If both the current type and the refinement have shapes, use the shape +- // from the refinement. Otherwise, pick whatever is available. +- // Make sure that the resulting type is compatible with the current type +- // to avoid creating invalid code. +- auto refinedShape = +- refinement.hasRank() ? refinement.getDims() : currentType.getShape(); +- auto refinedType = RankedTensorType::get(refinedShape, currentElementType); +- if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) +- return failWithReason("expected compatible shapes"); +- flattenedRefinedTypes.push_back(refinedType); +- } +- +- SmallVector refinedTypes; +- if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), +- flattenedRefinedTypes, refinedTypes))) +- return failure(); +- return refineReturnTypes(rewriter, op, refinedTypes); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- ArrayRef shape) { +- return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- Value shapeValue) { +- // At the moment, we only support refining return types using fully static +- // shape values which serves the current use cases well. +- // Support for partially static shape values is left for future work. +- SmallVector shape; +- if (failed(hlo::matchInts(shapeValue, shape))) +- return rewriter.notifyMatchFailure(op, "expected constant output shape"); +- return refineReturnShape(rewriter, op, shape); +-} + + struct RefineAllGatherOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; +@@ -1105,39 +1069,8 @@ + using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + + void runOnOperation() override { +- // Only one function per module is supported at the moment to avoid the need +- // to think about iterative type inference algorithms. +- // Current use cases are served well by inlining multiple functions into +- // a single function, so we leave native support for multiple functions to +- // future work. +- // To enable modules that contain CustomCallOp::called_computations, +- // we allow multiple functions, in which case we only refine the main +- // function called "main", assuming that the called computations will have +- // static shapes. Lifting this assumption and expanding refinement to +- // multiple functions is left for future work. +- ModuleOp module = getOperation(); +- auto funcs = llvm::to_vector(module.getOps()); +- if (funcs.empty()) return; +- func::FuncOp func; +- if (funcs.size() == 1) { +- func = funcs[0]; +- } else { +- func = module.lookupSymbol("main"); +- } +- if (!func) { +- module.emitOpError() +- << "must have no more than one function or a `main`" +- << " function to clearly identify which function will be refined"; +- return signalPassFailure(); +- } +- +- // Similarly, only one block per function is supported at the moment. +- // At the StableHLO level, functions are expected to only have one block, +- // so supporting more is out of scope for this pass. +- if (!func.getRegion().hasOneBlock()) { +- func.emitOpError() << "must have exactly one block"; +- return signalPassFailure(); +- } ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per +@@ -1153,43 +1086,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + if (failed( + applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + return signalPassFailure(); +@@ -1198,5 +1095,86 @@ + }; + + } // namespace ++ ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { ++ // Only one function per module is supported at the moment to avoid the need ++ // to think about iterative type inference algorithms. ++ // Current use cases are served well by inlining multiple functions into ++ // a single function, so we leave native support for multiple functions to ++ // future work. ++ // To enable modules that contain CustomCallOp::called_computations, ++ // we allow multiple functions, in which case we only refine the main ++ // function called "main", assuming that the called computations will have ++ // static shapes. Lifting this assumption and expanding refinement to ++ // multiple functions is left for future work. ++ auto funcs = llvm::to_vector(module.getOps()); ++ if (funcs.empty()) return nullptr; ++ ++ func::FuncOp result; ++ if (funcs.size() == 1) { ++ result = funcs[0]; ++ } else { ++ result = module.lookupSymbol("main"); ++ } ++ if (!result) { ++ module.emitOpError() ++ << "must have no more than one function or a `main`" ++ << " function to clearly identify which function will be refined"; ++ return nullptr; ++ } ++ ++ // Similarly, only one block per function is supported at the moment. ++ // At the StableHLO level, functions are expected to only have one block, ++ // so supporting more is out of scope for this pass. ++ if (!result.getRegion().hasOneBlock()) { ++ result.emitOpError() << "must have exactly one block"; ++ return nullptr; ++ } ++ ++ return result; ++} ++ ++void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -0,0 +1,102 @@ ++/* Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/Operation.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/Types.h" ++#include "mlir/IR/Value.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "stablehlo/dialect/Base.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++// Gets a FuncOp that --stablehlo-refine-shapes will run on. ++// Returns a nullptr and emits appropriate errors if such a function cannot ++// be obtained from the module. ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module); ++ ++// Refines the values using the given types. ++// Tricky implementation details: ++// 1) Need to support partial shape refinements, e.g. if just a single ++// dimension size out of an entire tensor type got refined. This is done ++// via inferMostSpecificType. ++// 2) Need to signal propagation of the refined shapes across the ++// StableHLO program. Different callers of this function have different ++// propagation needs, so this function doesn't signal anything on its own ++// and leaves that to the callers. ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types); ++ ++// Refines the return types of the given operation using the given types. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types); ++ ++// Refines the return types of the given operation using the given types. ++// Tricky implementation details: ++// 1) `types` can include non-shaped types. If there are tuple types, ++// then they are first flattened into non-tuple types using in-order ++// traversal, and only then we apply the refinements. If there are other ++// types, then the corresponding refinements must be completely empty. ++// 2) Encodings are not supported. In principle, TypeExtensions should be ++// supportable, but this needs careful thinking through. Given that no one ++// asked for support for bounded dynamism in this pass yet, this is left ++// for future work. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements); + +// Refines the return type of the given operation using the given shape. +// This function also signals PatternRewriter that it needs to visit all the @@ -3055,702 +3135,8 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return refineReturnShape(rewriter, op, shape); +} + -+struct RefineAllGatherOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AllGatherOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getAllGatherDim())) -+ refinement[op.getAllGatherDim()] *= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineBitcastConvertOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BitcastConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // If bit widths of the operand and the result are different, then -+ // operand and result shapes have different ranks. -+ // This complicates the logic quite a bit and is not needed to pass the -+ // current tests, so we leave this for future work. -+ auto resultType = op.getType(); -+ auto getBitWidthFn = [](ShapedType type) { -+ auto elementType = type.getElementType(); -+ if (auto complexType = elementType.dyn_cast()) -+ return complexType.getElementType().getIntOrFloatBitWidth(); -+ return elementType.getIntOrFloatBitWidth(); -+ }; -+ -+ if (getBitWidthFn(operandType) != getBitWidthFn(resultType)) -+ return rewriter.notifyMatchFailure(op, "unsupported bit width"); -+ -+ return refineReturnShape(rewriter, op, operandType.getShape()); -+ } -+}; -+ -+struct RefineConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvertOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; ++} // namespace stablehlo ++} // namespace mlir + -+struct RefineConvolutionOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvolutionOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), op.getPadding(), op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ return refineReturnTypes(rewriter, op, refinements); -+ } -+}; -+ -+struct RefineDotGeneralOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DotGeneralOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferDotGeneralOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getDotDimensionNumbersAttr().getLhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getLhsContractingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsContractingDimensions(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferDotGeneralOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputDimensions()); -+ } -+}; -+ -+struct RefineDynamicConvOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector padding; -+ if (failed(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant d_padding"); -+ if (op.getPadding().has_value()) -+ return rewriter.notifyMatchFailure(op, "expected empty padding"); -+ auto paddingType = RankedTensorType::get( -+ op.getDPadding().getType().getShape(), rewriter.getIntegerType(64)); -+ auto paddingAttr = DenseIntElementsAttr::get(paddingType, padding); -+ -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), paddingAttr, op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicIotaOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (failed(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_low"); -+ if (failed(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_high"); -+ if (failed(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant interior_padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferPadOp( -+ /*location=*/{}, op.getOperand().getType(), -+ op.getPaddingValue().getType(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferPadOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineDynamicReduceWindowOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicReduceWindowOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicReduceWindowOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector windowDimensions, windowStrides, baseDilations, -+ windowDilations, padding; -+ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dimensions"); -+ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_strides"); -+ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant base_dilations"); -+ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dilations"); -+ if (failed(hlo::matchInts(op.getPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferReduceWindowOp( -+ /*location=*/{}, op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), -+ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicRngBitGeneratorOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicRngBitGeneratorOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ auto initialStateType = op.getInitialState().getType().cast(); -+ SmallVector outputShape; -+ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); -+ -+ // We only need to refine the shape of `output` (the second result). -+ // The shape of `output_state` (the first result) is determined by the shape -+ // of `initial_state`, so we ignore it and provide an empty refinement. -+ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); -+ } -+}; -+ -+struct RefineDynamicTopKOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicTopKOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicTopKOpAdaptor op = *maybeOp; -+ -+ auto operandType = op.getOperand().getType().cast(); -+ SmallVector outputShape(operandType.getShape()); -+ SmallVector k; -+ if (failed(hlo::matchInts(op.getK(), k))) -+ return rewriter.notifyMatchFailure(op, "expected constant k"); -+ -+ outputShape[operandType.getRank() - 1] = k[0]; -+ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); -+ } -+}; -+ -+struct RefineInferTypeOpInterfacePattern -+ : public OpInterfaceRewritePattern { -+ explicit RefineInferTypeOpInterfacePattern(MLIRContext* context) -+ : OpInterfaceRewritePattern(context, /*benefit=*/0) {} -+ LogicalResult matchAndRewrite(InferTypeOpInterface op, -+ PatternRewriter& rewriter) const override { -+ // Unlike in TensorFlow's type inference pass, here we work only with -+ // allowlisted ops to focus our support on well-defined semantics of -+ // StableHLO programs. -+ if (!isa(op->getDialect())) -+ return rewriter.notifyMatchFailure(op, "unsupported dialect"); -+ -+ // For the ops that implement InferTypeOpInterface, we reinfer their return -+ // types and see what happens. -+ // Operands of these ops might have been refined elsewhere (e.g. someone -+ // might have updated argument types of a function) or earlier during this -+ // pass, and this might enable refinement opportunities downstream. -+ SmallVector inferredReturnTypes; -+ if (failed(op.inferReturnTypes(getContext(), /*location=*/{}, -+ op->getOperands(), op->getAttrDictionary(), -+ op->getPropertiesStorage(), op->getRegions(), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReturnTypes failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineRealDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // Alternative #1: All attributes are fully static (SliceOp style). -+ SmallVector startIndices, limitIndices, strides; -+ if (succeeded(hlo::matchInts(op.getStartIndices(), startIndices)) && -+ succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices)) && -+ succeeded(hlo::matchInts(op.getStrides(), strides))) { -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferSliceOp(/*location=*/{}, op.getOperand().getType(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ // Alternative #2: Slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) || -+ matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) { -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides)) || -+ !llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ auto startIndicesElementType = -+ op.getStartIndices().getType().getElementType(); -+ SmallVector startIndicesTypes( -+ sliceSizesAttr.size(), -+ RankedTensorType::get({}, startIndicesElementType)); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferDynamicSliceOp( -+ op.getLoc(), op.getOperand().getType(), startIndicesTypes, -+ rewriter.getI64TensorAttr(sliceSizes), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferDynamicSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ return rewriter.notifyMatchFailure( -+ op, -+ "expected either fully static attributes (SliceOp style) " -+ "or static sliceSizes (DynamicSliceOp style)"); -+ } -+}; -+ -+struct RefineReduceScatterOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReduceScatterOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getScatterDimension())) -+ refinement[op.getScatterDimension()] /= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineRngOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RngOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getShape()); -+ } -+}; -+ -+struct RefineUniformQuantizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(UniformQuantizeOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferUniformQuantizeOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineWhileOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(WhileOp op, -+ PatternRewriter& rewriter) const override { -+ // Push the potentially refined operand types into the nested regions. -+ // This can lead to refinements of the return types of the body (but not -+ // of the cond since it always returns tensor), but the key insight here -+ // is that the enclosing while op doesn't care about these refinements -+ // (because its return types are equal to its operand types). -+ // If we end up with incompatibilities between while's return types and -+ // body's return types, the verifier will tell us about that. This means -+ // that the original program wasn't well-formed. TODO(burmako): Implement -+ // better error reporting for this case. -+ // This serves the current use cases well, so the implementation of more -+ // sophisticated refinement algorithm is left for future work. -+ rewriter.startRootUpdate(op); -+ auto condStatus = refineValues(rewriter, op, op.getCond().getArguments(), -+ op.getOperandTypes()); -+ auto bodyStatus = refineValues(rewriter, op, op.getBody().getArguments(), -+ op.getOperandTypes()); -+ if (succeeded(condStatus) || succeeded(bodyStatus)) { -+ rewriter.finalizeRootUpdate(op); -+ return success(); -+ } else { -+ rewriter.cancelRootUpdate(op); -+ return failure(); -+ } -+ } -+}; -+ -+struct UpdateFunctionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(func::ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ // Check whether any of the values returned by `func.return` are casts -+ // which convert more specific type to less specific type. -+ // Such ops are produced by the algorithm behind this pass to avoid -+ // bringing the enclosing `func.func` op into an inconsistent state when -+ // refining individual ops. This pattern cleans this up. -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ llvm::SmallSet castsToReplace; -+ for (auto [i, operand] : llvm::enumerate(op.getOperands())) { -+ auto cast = -+ dyn_cast_or_null(operand.getDefiningOp()); -+ if (!cast || cast.getInputs().size() != 1 || -+ cast.getOutputs().size() != 1) -+ continue; -+ -+ // Only proceed if the type that we're casting from is more specific -+ // than the type that we're casting to. -+ auto sourceType = cast.getInputs()[0].getType(); -+ auto destType = cast.getOutputs()[0].getType(); -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {sourceType, destType}); -+ if (failed(mostSpecificType) || destType == *mostSpecificType) continue; -+ -+ // If the source type of the cast is more specific than the target type, -+ // then we conclude that the cast is redundant (i.e. needs to be removed) -+ // and that the return type of the function needs an update. -+ needsUpdate = true; -+ updatedResultTypes[i] = sourceType; -+ -+ // Insert into set and continue iterating. -+ // ReturnOp may point to same value more than once. -+ castsToReplace.insert(cast); -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ // Replace CastOps with more specific operands than results. -+ for (auto cast : castsToReplace) -+ rewriter.replaceOp(cast, cast->getOperands()); -+ -+ // If the type of the enclosing `func.func` needs an update, we simply -+ // call setType. We can afford this simplicity because our algorithm -+ // currently supports only one function per module. -+ auto func = cast(op->getParentOp()); -+ func.setType( -+ rewriter.getFunctionType(func.getArgumentTypes(), updatedResultTypes)); -+ return success(); -+ } -+}; -+ -+struct UpdateRegionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ if (!isa(op->getParentOp())) -+ return rewriter.notifyMatchFailure(op, "unsupported region"); -+ -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ for (auto [regionType, refinedType] : llvm::zip( -+ op->getParentOp()->getResultTypes(), op->getOperandTypes())) { -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {regionType, refinedType}); -+ if (failed(mostSpecificType) || regionType == *mostSpecificType) continue; -+ needsUpdate = true; -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); -+ return success(); -+ } -+}; -+ -+struct StablehloRefineShapesPass -+ : public impl::StablehloRefineShapesPassBase { -+ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; -+ -+ void runOnOperation() override { -+ // Only one function per module is supported at the moment to avoid the need -+ // to think about iterative type inference algorithms. -+ // Current use cases are served well by inlining multiple functions into -+ // a single function, so we leave native support for multiple functions to -+ // future work. -+ // To enable modules that contain CustomCallOp::called_computations, -+ // we allow multiple functions, in which case we only refine the main -+ // function called "main", assuming that the called computations will have -+ // static shapes. Lifting this assumption and expanding refinement to -+ // multiple functions is left for future work. -+ ModuleOp module = getOperation(); -+ auto funcs = llvm::to_vector(module.getOps()); -+ if (funcs.empty()) return; -+ func::FuncOp func; -+ if (funcs.size() == 1) { -+ func = funcs[0]; -+ } else { -+ func = module.lookupSymbol("main"); -+ } -+ if (!func) { -+ module.emitOpError() -+ << "must have no more than one function or a `main`" -+ << " function to clearly identify which function will be refined"; -+ return signalPassFailure(); -+ } -+ -+ // Similarly, only one block per function is supported at the moment. -+ // At the StableHLO level, functions are expected to only have one block, -+ // so supporting more is out of scope for this pass. -+ if (!func.getRegion().hasOneBlock()) { -+ func.emitOpError() << "must have exactly one block"; -+ return signalPassFailure(); -+ } -+ -+ // The algorithm behind this pass consists of a single traversal of the -+ // function. This is sufficient because we only support one function per -+ // program at the moment. -+ // TODO(#1048): Find out why .maxIterations = 1 no longer works. -+ // There have been recent refactors to applyPatternsAndFoldGreedily -+ // upstream, and that might be the reason. -+ GreedyRewriteConfig config; -+ config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; -+ config.maxIterations = 2; -+ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; -+ config.strictMode = GreedyRewriteStrictness::AnyOp; -+ -+ RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ if (failed( -+ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { -+ return signalPassFailure(); -+ } -+ } -+}; -+ -+} // namespace -+} // namespace experimental -+} // namespace stablehlo -+} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -340,6 +340,19 @@ - %1 = stablehlo.constant dense<2> : tensor - %2 = stablehlo.multiply %0, %1 : tensor - func.return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_or -+func.func @eval_or() -> tensor { -+ // CHECK-NOT: stablehlo.or -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense : tensor -+ %1 = stablehlo.constant dense : tensor -+ %2 = stablehlo.or %0, %1 : tensor -+ func.return %2 : tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -304,6 +304,20 @@ - } - }; - -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ - struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, -@@ -1165,6 +1179,7 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++#endif // STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8abb8b476d8c90..b9fde3ceefd71a 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,3 +1,14 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -889,6 +889,7 @@ + hdrs = [ + "stablehlo/transforms/MapStablehloToVhlo.h", + "stablehlo/transforms/Passes.h", ++ "stablehlo/transforms/StablehloRefineShapes.h", + ], + strip_include_prefix = ".", + deps = [ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -159,7 +170,7 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel -@@ -0,0 +1,113 @@ +@@ -0,0 +1,114 @@ +# Copyright 2023 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -243,6 +254,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/ + "//:chlo_ops", + "//:stablehlo_ops", + "//:stablehlo_ops_inc_gen", ++ "//:stablehlo_passes", + "//:stablehlo_type_inference", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", @@ -1826,7 +1838,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp b/stabl diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt --- stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +++ stablehlo/stablehlo/experimental/transforms/CMakeLists.txt -@@ -0,0 +1,38 @@ +@@ -0,0 +1,39 @@ +# Copyright 2023 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -1862,8 +1874,9 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable + MLIRTransformUtils + ExperimentalStablehloOps + StablehloBase -+ StablehloTypeInference + StablehloOps ++ StablehloPasses ++ StablehloTypeInference +) diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h --- stablehlo/stablehlo/experimental/transforms/Passes.h @@ -1944,7 +1957,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/s diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp -@@ -0,0 +1,441 @@ +@@ -0,0 +1,167 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2023 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); @@ -1960,14 +1973,12 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +limitations under the License. +==============================================================================*/ + -+#include "llvm/ADT/DenseSet.h" ++#include ++ +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/Value.h" ++#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1975,6 +1986,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -1985,169 +1997,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + +namespace { + -+struct CanonicalizeCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ auto indicesAttr = -+ op->getAttr("indices_of_shape_operands").cast(); -+ DenseSet indices(indicesAttr.value_begin(), -+ indicesAttr.value_end()); -+ -+ // Discard the indices_of_shape_operands attribute. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that its value is consistent with the result types. -+ // In the future, when we upgrade indices_of_shape_operands from an -+ // experiment to a full-fledged StableHLO feature, this logic will be moved -+ // to a proper verifier. -+ SmallVector newAttrs; -+ for (auto attr : op->getAttrs()) { -+ if (attr.getName() == "indices_of_shape_operands") continue; -+ if (attr.getName() == "operand_layouts") { -+ // Drop the operand_layouts that correspond to indices_of_shape_operands -+ ArrayAttr operandLayouts = op.getOperandLayoutsAttr(); -+ SmallVector newOperandLayouts; -+ for (unsigned i = 0; i < operandLayouts.size(); ++i) { -+ if (indices.contains(i)) continue; -+ newOperandLayouts.push_back(operandLayouts[i]); -+ } -+ attr = NamedAttribute(attr.getName(), -+ rewriter.getArrayAttr(newOperandLayouts)); -+ } -+ newAttrs.push_back(attr); -+ } -+ -+ // Discard the operands that correspond to indices_of_shape_operands. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that: 1) these operands are static, 2) the values of these -+ // operands are consistent with the result types. -+ SmallVector newOperands; -+ auto resultIndex = 0; -+ for (auto& operand : op->getOpOperands()) { -+ if (indices.contains(operand.getOperandNumber())) { -+ auto resultType = -+ op->getResult(resultIndex).getType().dyn_cast(); -+ if (!resultType || !resultType.hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, -+ "expected static result types"); -+ ++resultIndex; -+ continue; -+ } -+ newOperands.push_back(operand.get()); -+ } -+ rewriter.replaceOpWithNewOp(op, op.getResultTypes(), -+ newOperands, newAttrs); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_dimensions operand as well as the -+ // known_expanding_dimensions and known_nonexpanding_dimensions attributes. -+ // We rely on the verifier to make sure that their values are consistent -+ // with the result type. -+ if (!op.getOperand().getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static operand type"); -+ if (!succeeded(hlo::matchInts(op.getOutputDimensions()))) -+ return rewriter.notifyMatchFailure(op, -+ "expected static output_dimensions"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getBroadcastDimensions()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicConvOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ // ConvolutionOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector padding; -+ if (!succeeded(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected static padding"); -+ auto paddingAttr = DenseIntElementsAttr::get( -+ RankedTensorType::get({static_cast(padding.size()) / 2, 2}, -+ rewriter.getI64Type()), -+ padding); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getLhs(), op.getRhs(), op.getWindowStridesAttr(), -+ paddingAttr, op.getLhsDilationAttr(), op.getRhsDilationAttr(), -+ op.getWindowReversalAttr(), op.getDimensionNumbers(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfigAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicGatherOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicGatherOp op, -+ PatternRewriter& rewriter) const override { -+ // GatherOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector sliceSizes; -+ if (!succeeded(hlo::matchInts(op.getSliceSizes(), sliceSizes))) -+ return rewriter.notifyMatchFailure(op, "expected static slice_sizes"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getStartIndices(), -+ op.getDimensionNumbersAttr(), rewriter.getI64TensorAttr(sliceSizes), -+ op.getIndicesAreSortedAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicIotaOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_shape operand. We rely on the verifier -+ // to make sure that its value is consistent with result type. -+ SmallVector outputShape; -+ if (!succeeded(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), -+ op.getIotaDimension()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // PadOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, "expected static low"); -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, "expected static high"); -+ if (!succeeded(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, "expected static interior"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getPaddingValue(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding)); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2196,22 +2045,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern ignores and discards the output_shape operand. We rely on -+ // the verifier to make sure that its value is consistent with result type. -+ if (!succeeded(hlo::matchInts(op.getOutputShape()))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2262,91 +2095,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeRealDynamicSliceOpToDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // DynamicSliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ -+ // This rewrite only works for unit strides because DynamicSliceOp -+ // doesn't support strides (i.e. it implicitly has unit strides). -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ if (!llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // Check that slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (!matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && -+ !matchPattern(op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) -+ return rewriter.notifyMatchFailure( -+ op, "expected limit indices equal to start indices plus constant"); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector startIndices; -+ for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { -+ auto startIndexElementType = -+ op.getStartIndices().getType().getElementType(); -+ auto startIndex1DType = RankedTensorType::get({1}, startIndexElementType); -+ auto startIndex1D = rewriter.create( -+ op.getLoc(), startIndex1DType, op.getStartIndices(), -+ rewriter.getI64TensorAttr(i), rewriter.getI64TensorAttr(i + 1), -+ rewriter.getI64TensorAttr(1)); -+ auto startIndex0DType = RankedTensorType::get({}, startIndexElementType); -+ auto startIndex0D = rewriter.create( -+ op.getLoc(), startIndex0DType, startIndex1D); -+ startIndices.push_back(startIndex0D); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), startIndices, -+ rewriter.getI64TensorAttr(sliceSizes)); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeRealDynamicSliceOpToSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // SliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector startIndices, limitIndices, strides; -+ if (!succeeded(hlo::matchInts(op.getStartIndices(), startIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static start"); -+ if (!succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static limit"); -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides)); -+ return success(); -+ } -+}; -+ +struct StablehloCanonicalizeDynamismPass + : public impl::StablehloCanonicalizeDynamismPassBase< + StablehloCanonicalizeDynamismPass> { @@ -2362,19 +2110,10 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + patterns.add(&getContext()); -+ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ patterns.add( -+ &getContext()); -+ patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); @@ -2389,7 +2128,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp -@@ -0,0 +1,1308 @@ +@@ -0,0 +1,162 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -2404,41 +2143,22 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +limitations under the License. +==============================================================================*/ + ++#include "stablehlo/transforms/StablehloRefineShapes.h" ++ +#include -+#include -+#include -+#include + -+#include "llvm/ADT/APInt.h" -+#include "llvm/ADT/APSInt.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/STLFunctionalExtras.h" -+#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" -+#include "llvm/ADT/StringRef.h" -+#include "llvm/Support/ErrorHandling.h" -+#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" -+#include "mlir/IR/MLIRContext.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/OpDefinition.h" -+#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/Base.h" -+#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/TypeInference.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -2449,382 +2169,239 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + +namespace { + -+// DenseElementsAttr can be constructed from ArrayRef but not from -+// ArrayRef. This helper bridges the gap. -+DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { -+ SmallVector supportedValues(values); -+ return DenseIntElementsAttr::get(type, supportedValues); -+} -+ -+APSInt getAPSInt(Type type, uint64_t value) { -+ unsigned numBits; -+ bool isUnsigned; -+ if (auto integerType = type.dyn_cast()) { -+ numBits = integerType.getWidth(); -+ // Signless types are treated as signed, per StableHLO convention. -+ isUnsigned = integerType.isUnsignedInteger(); -+ } else { -+ llvm::report_fatal_error("expected integer type"); -+ } -+ return APSInt({/*numBits=*/numBits, value}, -+ /*isUnsigned=*/isUnsigned); -+} -+ -+// The patterns below implement partial evaluation of shape computations which -+// is a critical part of implementing type refinement for ops like -+// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape -+// depends on the value of their shape operands. -+ -+template -+LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, -+ FuncType fn) { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || -+ !resultType.getElementType().template isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ -+ SmallVector result; -+ if constexpr (OpType::template hasTrait()) { -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ for (const auto& operandEl : operand) { -+ result.push_back(fn(operandEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<2>::Impl>()) { -+ SmallVector lhs, rhs; -+ if (failed(hlo::matchInts(op.getLhs(), lhs)) || -+ failed(hlo::matchInts(op.getRhs(), rhs))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { -+ result.push_back(fn(lhsEl, rhsEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<3>::Impl>()) { -+ SmallVector x, y, z; -+ if (failed(hlo::matchInts(op->getOperand(0), x)) || -+ failed(hlo::matchInts(op->getOperand(1), y)) || -+ failed(hlo::matchInts(op->getOperand(2), z))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { -+ result.push_back(fn(xEl, yEl, zEl)); -+ } -+ } else { -+ llvm::report_fatal_error("unsupported number of operands"); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+} -+ -+struct EvalAddOpPattern : public OpRewritePattern { ++struct RefineDynamicReduceWindowOpPattern ++ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AddOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); -+ } -+}; ++ auto maybeOp = getDynamicReduceWindowOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicReduceWindowOpAdaptor op = *maybeOp; + -+struct EvalAndOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AndOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ SmallVector windowDimensions, windowStrides, baseDilations, ++ windowDilations, padding; ++ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dimensions"); ++ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_strides"); ++ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant base_dilations"); ++ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dilations"); ++ if (failed(hlo::matchInts(op.getPadding(), padding))) ++ return rewriter.notifyMatchFailure(op, "expected constant padding"); + -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); -+ }); ++ SmallVector inferredReturnTypes; ++ if (failed(hlo::inferReduceWindowOp( ++ /*location=*/{}, op.getInputs(), op.getInitValues(), ++ rewriter.getI64TensorAttr(windowDimensions), ++ rewriter.getI64TensorAttr(windowStrides), ++ rewriter.getI64TensorAttr(baseDilations), ++ rewriter.getI64TensorAttr(windowDilations), ++ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) ++ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); ++ return refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; + -+struct EvalBroadcastInDimOpPattern : public OpRewritePattern { ++struct RefineDynamicRngBitGeneratorOpPattern ++ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BroadcastInDimOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank() || operandType.getRank() != 0) -+ return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); ++ auto maybeOp = getDynamicRngBitGeneratorOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; + -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ auto scalar = operand[0]; ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ auto initialStateType = op.getInitialState().getType().cast(); ++ SmallVector outputShape; ++ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) ++ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); + -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), scalar)); -+ return success(); ++ // We only need to refine the shape of `output` (the second result). ++ // The shape of `output_state` (the first result) is determined by the shape ++ // of `initial_state`, so we ignore it and provide an empty refinement. ++ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); + } +}; + -+struct EvalClampOpPattern : public OpRewritePattern { ++struct RefineDynamicTopKOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ClampOp op, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt min, APSInt operand, APSInt max) { -+ if (operand < min) return min; -+ if (max < operand) return max; -+ return operand; -+ }); -+ } -+}; ++ auto maybeOp = getDynamicTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicTopKOpAdaptor op = *maybeOp; + -+struct EvalCompareOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CompareOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ bool result; -+ switch (op.getComparisonDirection()) { -+ case ComparisonDirection::EQ: -+ result = lhs == rhs; -+ break; -+ case ComparisonDirection::NE: -+ result = lhs != rhs; -+ break; -+ case ComparisonDirection::GE: -+ result = lhs >= rhs; -+ break; -+ case ComparisonDirection::GT: -+ result = lhs > rhs; -+ break; -+ case ComparisonDirection::LE: -+ result = lhs <= rhs; -+ break; -+ case ComparisonDirection::LT: -+ result = lhs < rhs; -+ break; -+ } -+ return getAPSInt(resultType.getElementType(), result); -+ }); ++ auto operandType = op.getOperand().getType().cast(); ++ SmallVector outputShape(operandType.getShape()); ++ SmallVector k; ++ if (failed(hlo::matchInts(op.getK(), k))) ++ return rewriter.notifyMatchFailure(op, "expected constant k"); ++ ++ outputShape[operandType.getRank() - 1] = k[0]; ++ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); + } +}; + -+struct EvalConcatenateOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConcatenateOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || op.getDimension() != 0) -+ return rewriter.notifyMatchFailure(op, "expected dimension = 0"); ++struct StablehloRefineShapesPass ++ : public impl::StablehloRefineShapesPassBase { ++ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + -+ SmallVector result; -+ for (Value operand : op->getOperands()) { -+ if (failed(hlo::matchInts(operand, result))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ } ++ void runOnOperation() override { ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; ++ // The algorithm behind this pass consists of a single traversal of the ++ // function. This is sufficient because we only support one function per ++ // program at the moment. ++ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // There have been recent refactors to applyPatternsAndFoldGreedily ++ // upstream, and that might be the reason. ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; + -+struct EvalConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth(); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ return operand.extOrTrunc(resultBitWidth); -+ }); ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ if (failed( ++ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { ++ return signalPassFailure(); ++ } + } +}; + -+struct EvalDivOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DivOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); -+ } -+}; ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +@@ -340,6 +340,19 @@ + %1 = stablehlo.constant dense<2> : tensor + %2 = stablehlo.multiply %0, %1 : tensor + func.return %2 : tensor ++} + -+struct EvalGetDimensionSizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(GetDimensionSizeOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand"); -+ if (operandType.isDynamicDim(op.getDimension())) -+ return rewriter.notifyMatchFailure(op, "expected static dimension"); -+ -+ auto result = operandType.getDimSize(op.getDimension()); -+ rewriter.replaceOpWithNewOp( -+ op, DenseIntElementsAttr::get(op.getType(), result)); -+ return success(); -+ } -+}; ++// ----- + -+struct EvalMaxOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MaxOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs >= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMinOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MinOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs <= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMulOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MulOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); -+ } -+}; -+ -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ -+struct EvalRemOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RemOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); -+ } -+}; -+ -+struct EvalReshapeOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ DenseIntElementsAttr attr; -+ if (!matchPattern(op.getOperand(), m_Constant(&attr))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ rewriter.replaceOpWithNewOp(op, attr.reshape(op.getType())); -+ return success(); -+ } -+}; -+ -+struct EvalSelectOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SelectOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector pred, onTrue, onFalse; -+ if (failed(hlo::matchInts(op.getPred(), pred)) || -+ failed(hlo::matchInts(op.getOnTrue(), onTrue)) || -+ failed(hlo::matchInts(op.getOnFalse(), onFalse))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ -+ SmallVector result; -+ for (auto [predEl, onTrueEl, onFalseEl] : -+ llvm::zip(pred, onTrue, onFalse)) { -+ result.push_back(predEl != 0 ? onTrueEl : onFalseEl); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), result)); -+ return success(); -+ } -+}; -+ -+struct EvalSignOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SignOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ int64_t result; -+ if (operand.isNegative()) -+ result = -1; -+ else if (operand.isZero()) -+ result = 0; -+ else -+ result = 1; -+ return getAPSInt(resultType.getElementType(), result); -+ }); -+ } -+}; -+ -+struct EvalSliceOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SliceOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || resultType.getRank() != 1) -+ return rewriter.notifyMatchFailure(op, "expected 1-dimensional type"); -+ -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ -+ int64_t start = op.getStartIndices().getValues()[0]; -+ int64_t limit = op.getLimitIndices().getValues()[0]; -+ int64_t stride = op.getStrides().getValues()[0]; -+ SmallVector result; -+ for (auto i = start; i < limit; i += stride) { -+ result.push_back(operand[i]); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; -+ -+struct EvalSubtractOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SubtractOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); -+ } -+}; -+ -+// The patterns below implement shape refinement of individual ops. -+// In a nutshell, they use the upstream type inference infrastructure and a -+// StableHLO-specific extension to refine return types based on potentially -+// refined operands. ++// CHECK-LABEL: func @eval_or ++func.func @eval_or() -> tensor { ++ // CHECK-NOT: stablehlo.or ++ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor ++ // CHECK: return [[RESULT]] ++ %0 = stablehlo.constant dense : tensor ++ %1 = stablehlo.constant dense : tensor ++ %2 = stablehlo.or %0, %1 : tensor ++ func.return %2 : tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +--- stablehlo/stablehlo/transforms/Passes.h ++++ stablehlo/stablehlo/transforms/Passes.h +@@ -18,9 +18,12 @@ + + #include + ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/Quant/QuantOps.h" + #include "mlir/Dialect/Shape/IR/Shape.h" ++#include "mlir/IR/BuiltinOps.h" + #include "mlir/Pass/Pass.h" ++#include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/DialectConversion.h" + + namespace mlir { +@@ -33,6 +36,14 @@ + #define GEN_PASS_DECL_VHLOTOVERSIONPASS + #define GEN_PASS_REGISTRATION + #include "stablehlo/transforms/Passes.h.inc" ++ ++// Populates --stablehlo-canonicalize-dynamism patterns. ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); ++ ++// Populates --stablehlo-refine-shapes patterns. ++void populateStablehloRefineShapesPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); + + // Populates StableHLO ops to VHLO ops rewriting patterns. + void populateStablehloToVhloPatterns(RewritePatternSet *patterns, +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -314,16 +314,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add( +- &getContext()); +- patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); +@@ -332,5 +323,19 @@ + }; + + } // namespace ++ ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -11,6 +11,8 @@ + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + -+// Refines the values using the given types. -+// Tricky implementation details: -+// 1) Need to support partial shape refinements, e.g. if just a single -+// dimension size out of an entire tensor type got refined. This is done -+// via inferMostSpecificType. -+// 2) Need to signal propagation of the refined shapes across the -+// StableHLO program. Different callers of this function have different -+// propagation needs, so this function doesn't signal anything on its own -+// and leaves that to the callers. ++#include "stablehlo/transforms/StablehloRefineShapes.h" + + #include + #include +@@ -53,6 +55,193 @@ + #define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS + #include "stablehlo/transforms/Passes.h.inc" + +LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, + ValueRange values, TypeRange types) { + if (values.size() != types.size()) @@ -2911,10 +2488,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return success(); +} + -+// Refines the return types of the given operation using the given types. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. +LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, + ArrayRef types) { + if (failed(refineValues(rewriter, op, op->getResults(), types))) @@ -2929,19 +2502,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return success(); +} + -+// Refines the return types of the given operation using the given types. -+// Tricky implementation details: -+// 1) `types` can include non-shaped types. If there are tuple types, -+// then they are first flattened into non-tuple types using in-order -+// traversal, and only then we apply the refinements. If there are other -+// types, then the corresponding refinements must be completely empty. -+// 2) Encodings are not supported. In principle, TypeExtensions should be -+// supportable, but this needs careful thinking through. Given that no one -+// asked for support for bounded dynamism in this pass yet, this is left -+// for future work. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. +LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, + ArrayRef refinements) { + SmallVector flattenedTypes; @@ -3028,6 +2588,526 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return failure(); + return refineReturnTypes(rewriter, op, refinedTypes); +} ++ + namespace { + + // DenseElementsAttr can be constructed from ArrayRef but not from +@@ -304,6 +493,20 @@ + } + }; + ++struct EvalOrOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(OrOp op, ++ PatternRewriter& rewriter) const override { ++ auto resultType = op.getType(); ++ if (!resultType.getElementType().isInteger(1)) ++ return rewriter.notifyMatchFailure(op, "expected boolean element type"); ++ ++ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { ++ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); ++ }); ++ } ++}; ++ + struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, +@@ -407,245 +610,6 @@ + // In a nutshell, they use the upstream type inference infrastructure and a + // StableHLO-specific extension to refine return types based on potentially + // refined operands. +- +-// Refines the values using the given types. +-// Tricky implementation details: +-// 1) Need to support partial shape refinements, e.g. if just a single +-// dimension size out of an entire tensor type got refined. This is done +-// via inferMostSpecificType. +-// 2) Need to signal propagation of the refined shapes across the +-// StableHLO program. Different callers of this function have different +-// propagation needs, so this function doesn't signal anything on its own +-// and leaves that to the callers. +-LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, +- ValueRange values, TypeRange types) { +- if (values.size() != types.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineValues failed for " << types << ": expected " +- << values.size() << " types, got " << types.size(); +- }); +- +- // Check whether `types` contain any new information with respect to existing +- // return types. Even if just a single dimension size out of an entire tensor +- // type got updated, using `inferMostSpecificType` ensures that we don't +- // miss that. +- bool needsRefinement = false; +- SmallVector refinedTypes; +- for (auto it : llvm::zip(values.getTypes(), types)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto currentType = std::get<0>(it); +- auto refinement = std::get<1>(it); +- auto refinedType = hlo::inferMostSpecificType( +- /*location=*/{}, {currentType, refinement}); +- if (failed(refinedType)) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "inferMostSpecificType failed for " << currentType << " and " +- << refinement; +- }); +- refinedTypes.push_back(*refinedType); +- needsRefinement |= (currentType != *refinedType); +- } +- if (!needsRefinement) +- return rewriter.notifyMatchFailure(op, "doesn't need refinement"); +- +- for (auto it : llvm::zip(values, refinedTypes)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto value = std::get<0>(it); +- auto refinedType = std::get<1>(it); +- if (value.getType() == refinedType) continue; +- +- // Check whether the users of this value are ready for the type of the +- // value to be refined. +- for (Operation* user : value.getUsers()) { +- // CHLO and StableHLO ops are designed to support type refinements of +- // their operands and results. Any operand type in these ops can change +- // within what's supported by `inferMostSpecificType` without breaking +- // verification of the op. +- if (isa(user->getDialect())) +- continue; +- +- // Simply changing operand type of `func.return` won't work because +- // that won't update the FunctionType of the enclosing `func.func`. +- // Nonetheless, we still want to support these ops because they are widely +- // used in StableHLO programs (although the plan of record is to replace +- // `func.return` ops in StableHLO programs with `stablehlo.return`: +- // https://github.com/openxla/stablehlo/issues/425). +- if (isa(user)) continue; +- +- // Unlike in TensorFlow's type inference pass, here we work only with +- // allowlisted ops to focus our support on well-defined semantics of +- // StableHLO programs. +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "unsupported refinement: tried to refine " << value.getType() +- << " to " << refinedType << " for user " << user; +- }); +- } +- +- // Happy path: simply call setType here because most of our users are +- // fine with that. +- auto unrefinedType = value.getType(); +- value.setType(refinedType); +- +- // Special case: for `func.return`, guard the refinement with a cast +- // and leave propagation of the refined return type to a dedicated pattern. +- auto isFuncReturn = [](OpOperand& use) -> bool { +- return isa(use.getOwner()); +- }; +- if (llvm::none_of(value.getUses(), isFuncReturn)) continue; +- rewriter.setInsertionPointAfter(op); +- auto castToUnrefinedType = rewriter.create( +- op->getLoc(), unrefinedType, value); +- value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); +- } +- +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef types) { +- if (failed(refineValues(rewriter, op, op->getResults(), types))) +- return failure(); +- +- // This `replaceOpWithIf` call doesn't actually change the IR, but +- // it does ask the rewriter to visit all the users of this op. There is no +- // upstream API to achieve this directly, but if it's introduced in the +- // future, we could use it here. +- rewriter.replaceOpWithIf(op, op->getResults(), +- [](OpOperand& use) { return false; }); +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// Tricky implementation details: +-// 1) `types` can include non-shaped types. If there are tuple types, +-// then they are first flattened into non-tuple types using in-order +-// traversal, and only then we apply the refinements. If there are other +-// types, then the corresponding refinements must be completely empty. +-// 2) Encodings are not supported. In principle, TypeExtensions should be +-// supportable, but this needs careful thinking through. Given that no one +-// asked for support for bounded dynamism in this pass yet, this is left +-// for future work. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef refinements) { +- SmallVector flattenedTypes; +- hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); +- auto flattenedSize = flattenedTypes.size(); +- if (flattenedSize != refinements.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineReturnTypes failed: expected " << flattenedSize +- << " refinements, got " << refinements.size(); +- }); +- +- SmallVector flattenedRefinedTypes; +- for (auto it : llvm::zip(flattenedTypes, refinements)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- ShapedType currentType = std::get<0>(it).dyn_cast(); +- ShapedTypeComponents refinement = std::get<1>(it); +- auto failWithReason = [&](StringRef reason) { +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineTypes failed: refining " << currentType +- << "with refinement: {"; +- if (refinement.hasRank()) { +- diag << "shape = [" << refinement.getDims() << "]"; +- if (refinement.getAttribute()) +- diag << "attribute = " << refinement.getAttribute(); +- } else { +- diag << "hasRank = false"; +- } +- diag << ", elementType = " << refinement.getElementType(); +- diag << "} failed: " << reason; +- }); +- }; +- +- // If the current type is not a shaped type, then the refinement must +- // be completely empty. +- if (!currentType) { +- if (refinement.hasRank() || refinement.getElementType() || +- refinement.getAttribute()) +- return failWithReason("unsupported refinement"); +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If the refinement has an element type, then it must be the same as +- // the current element type. +- Type currentElementType = currentType.getElementType(); +- if (refinement.getElementType() && +- currentElementType != refinement.getElementType()) +- return failWithReason("expected compatible element types"); +- +- // If neither the current type nor the refinement are ranked, then there's +- // nothing to refine, and we return the current type. +- bool hasRank = currentType.hasRank() || refinement.hasRank(); +- if (!hasRank) { +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If either the current type or the refinement have encodings, then +- // we fail. Encodings are left for future work. +- Attribute currentEncoding = nullptr; +- if (auto currentRankedType = currentType.dyn_cast()) { +- currentEncoding = currentRankedType.getEncoding(); +- } +- Attribute refinedEncoding = refinement.getAttribute(); +- if (currentEncoding || refinedEncoding) +- return failWithReason("expected compatible encodings"); +- +- // If both the current type and the refinement have shapes, use the shape +- // from the refinement. Otherwise, pick whatever is available. +- // Make sure that the resulting type is compatible with the current type +- // to avoid creating invalid code. +- auto refinedShape = +- refinement.hasRank() ? refinement.getDims() : currentType.getShape(); +- auto refinedType = RankedTensorType::get(refinedShape, currentElementType); +- if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) +- return failWithReason("expected compatible shapes"); +- flattenedRefinedTypes.push_back(refinedType); +- } +- +- SmallVector refinedTypes; +- if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), +- flattenedRefinedTypes, refinedTypes))) +- return failure(); +- return refineReturnTypes(rewriter, op, refinedTypes); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- ArrayRef shape) { +- return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- Value shapeValue) { +- // At the moment, we only support refining return types using fully static +- // shape values which serves the current use cases well. +- // Support for partially static shape values is left for future work. +- SmallVector shape; +- if (failed(hlo::matchInts(shapeValue, shape))) +- return rewriter.notifyMatchFailure(op, "expected constant output shape"); +- return refineReturnShape(rewriter, op, shape); +-} + + struct RefineAllGatherOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; +@@ -1105,39 +1069,8 @@ + using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + + void runOnOperation() override { +- // Only one function per module is supported at the moment to avoid the need +- // to think about iterative type inference algorithms. +- // Current use cases are served well by inlining multiple functions into +- // a single function, so we leave native support for multiple functions to +- // future work. +- // To enable modules that contain CustomCallOp::called_computations, +- // we allow multiple functions, in which case we only refine the main +- // function called "main", assuming that the called computations will have +- // static shapes. Lifting this assumption and expanding refinement to +- // multiple functions is left for future work. +- ModuleOp module = getOperation(); +- auto funcs = llvm::to_vector(module.getOps()); +- if (funcs.empty()) return; +- func::FuncOp func; +- if (funcs.size() == 1) { +- func = funcs[0]; +- } else { +- func = module.lookupSymbol("main"); +- } +- if (!func) { +- module.emitOpError() +- << "must have no more than one function or a `main`" +- << " function to clearly identify which function will be refined"; +- return signalPassFailure(); +- } +- +- // Similarly, only one block per function is supported at the moment. +- // At the StableHLO level, functions are expected to only have one block, +- // so supporting more is out of scope for this pass. +- if (!func.getRegion().hasOneBlock()) { +- func.emitOpError() << "must have exactly one block"; +- return signalPassFailure(); +- } ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per +@@ -1153,43 +1086,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + if (failed( + applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + return signalPassFailure(); +@@ -1198,5 +1095,86 @@ + }; + + } // namespace ++ ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { ++ // Only one function per module is supported at the moment to avoid the need ++ // to think about iterative type inference algorithms. ++ // Current use cases are served well by inlining multiple functions into ++ // a single function, so we leave native support for multiple functions to ++ // future work. ++ // To enable modules that contain CustomCallOp::called_computations, ++ // we allow multiple functions, in which case we only refine the main ++ // function called "main", assuming that the called computations will have ++ // static shapes. Lifting this assumption and expanding refinement to ++ // multiple functions is left for future work. ++ auto funcs = llvm::to_vector(module.getOps()); ++ if (funcs.empty()) return nullptr; ++ ++ func::FuncOp result; ++ if (funcs.size() == 1) { ++ result = funcs[0]; ++ } else { ++ result = module.lookupSymbol("main"); ++ } ++ if (!result) { ++ module.emitOpError() ++ << "must have no more than one function or a `main`" ++ << " function to clearly identify which function will be refined"; ++ return nullptr; ++ } ++ ++ // Similarly, only one block per function is supported at the moment. ++ // At the StableHLO level, functions are expected to only have one block, ++ // so supporting more is out of scope for this pass. ++ if (!result.getRegion().hasOneBlock()) { ++ result.emitOpError() << "must have exactly one block"; ++ return nullptr; ++ } ++ ++ return result; ++} ++ ++void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -0,0 +1,102 @@ ++/* Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/Operation.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/Types.h" ++#include "mlir/IR/Value.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "stablehlo/dialect/Base.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++// Gets a FuncOp that --stablehlo-refine-shapes will run on. ++// Returns a nullptr and emits appropriate errors if such a function cannot ++// be obtained from the module. ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module); ++ ++// Refines the values using the given types. ++// Tricky implementation details: ++// 1) Need to support partial shape refinements, e.g. if just a single ++// dimension size out of an entire tensor type got refined. This is done ++// via inferMostSpecificType. ++// 2) Need to signal propagation of the refined shapes across the ++// StableHLO program. Different callers of this function have different ++// propagation needs, so this function doesn't signal anything on its own ++// and leaves that to the callers. ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types); ++ ++// Refines the return types of the given operation using the given types. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types); ++ ++// Refines the return types of the given operation using the given types. ++// Tricky implementation details: ++// 1) `types` can include non-shaped types. If there are tuple types, ++// then they are first flattened into non-tuple types using in-order ++// traversal, and only then we apply the refinements. If there are other ++// types, then the corresponding refinements must be completely empty. ++// 2) Encodings are not supported. In principle, TypeExtensions should be ++// supportable, but this needs careful thinking through. Given that no one ++// asked for support for bounded dynamism in this pass yet, this is left ++// for future work. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements); + +// Refines the return type of the given operation using the given shape. +// This function also signals PatternRewriter that it needs to visit all the @@ -3055,702 +3135,8 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + return refineReturnShape(rewriter, op, shape); +} + -+struct RefineAllGatherOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AllGatherOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getAllGatherDim())) -+ refinement[op.getAllGatherDim()] *= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineBitcastConvertOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BitcastConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // If bit widths of the operand and the result are different, then -+ // operand and result shapes have different ranks. -+ // This complicates the logic quite a bit and is not needed to pass the -+ // current tests, so we leave this for future work. -+ auto resultType = op.getType(); -+ auto getBitWidthFn = [](ShapedType type) { -+ auto elementType = type.getElementType(); -+ if (auto complexType = elementType.dyn_cast()) -+ return complexType.getElementType().getIntOrFloatBitWidth(); -+ return elementType.getIntOrFloatBitWidth(); -+ }; -+ -+ if (getBitWidthFn(operandType) != getBitWidthFn(resultType)) -+ return rewriter.notifyMatchFailure(op, "unsupported bit width"); -+ -+ return refineReturnShape(rewriter, op, operandType.getShape()); -+ } -+}; -+ -+struct RefineConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvertOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; ++} // namespace stablehlo ++} // namespace mlir + -+struct RefineConvolutionOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvolutionOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), op.getPadding(), op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ return refineReturnTypes(rewriter, op, refinements); -+ } -+}; -+ -+struct RefineDotGeneralOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DotGeneralOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferDotGeneralOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getDotDimensionNumbersAttr().getLhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getLhsContractingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsContractingDimensions(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferDotGeneralOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputDimensions()); -+ } -+}; -+ -+struct RefineDynamicConvOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector padding; -+ if (failed(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant d_padding"); -+ if (op.getPadding().has_value()) -+ return rewriter.notifyMatchFailure(op, "expected empty padding"); -+ auto paddingType = RankedTensorType::get( -+ op.getDPadding().getType().getShape(), rewriter.getIntegerType(64)); -+ auto paddingAttr = DenseIntElementsAttr::get(paddingType, padding); -+ -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), paddingAttr, op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicIotaOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (failed(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_low"); -+ if (failed(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_high"); -+ if (failed(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant interior_padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferPadOp( -+ /*location=*/{}, op.getOperand().getType(), -+ op.getPaddingValue().getType(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferPadOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineDynamicReduceWindowOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicReduceWindowOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicReduceWindowOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector windowDimensions, windowStrides, baseDilations, -+ windowDilations, padding; -+ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dimensions"); -+ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_strides"); -+ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant base_dilations"); -+ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dilations"); -+ if (failed(hlo::matchInts(op.getPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferReduceWindowOp( -+ /*location=*/{}, op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), -+ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicRngBitGeneratorOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicRngBitGeneratorOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ auto initialStateType = op.getInitialState().getType().cast(); -+ SmallVector outputShape; -+ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); -+ -+ // We only need to refine the shape of `output` (the second result). -+ // The shape of `output_state` (the first result) is determined by the shape -+ // of `initial_state`, so we ignore it and provide an empty refinement. -+ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); -+ } -+}; -+ -+struct RefineDynamicTopKOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicTopKOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicTopKOpAdaptor op = *maybeOp; -+ -+ auto operandType = op.getOperand().getType().cast(); -+ SmallVector outputShape(operandType.getShape()); -+ SmallVector k; -+ if (failed(hlo::matchInts(op.getK(), k))) -+ return rewriter.notifyMatchFailure(op, "expected constant k"); -+ -+ outputShape[operandType.getRank() - 1] = k[0]; -+ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); -+ } -+}; -+ -+struct RefineInferTypeOpInterfacePattern -+ : public OpInterfaceRewritePattern { -+ explicit RefineInferTypeOpInterfacePattern(MLIRContext* context) -+ : OpInterfaceRewritePattern(context, /*benefit=*/0) {} -+ LogicalResult matchAndRewrite(InferTypeOpInterface op, -+ PatternRewriter& rewriter) const override { -+ // Unlike in TensorFlow's type inference pass, here we work only with -+ // allowlisted ops to focus our support on well-defined semantics of -+ // StableHLO programs. -+ if (!isa(op->getDialect())) -+ return rewriter.notifyMatchFailure(op, "unsupported dialect"); -+ -+ // For the ops that implement InferTypeOpInterface, we reinfer their return -+ // types and see what happens. -+ // Operands of these ops might have been refined elsewhere (e.g. someone -+ // might have updated argument types of a function) or earlier during this -+ // pass, and this might enable refinement opportunities downstream. -+ SmallVector inferredReturnTypes; -+ if (failed(op.inferReturnTypes(getContext(), /*location=*/{}, -+ op->getOperands(), op->getAttrDictionary(), -+ op->getPropertiesStorage(), op->getRegions(), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReturnTypes failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineRealDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // Alternative #1: All attributes are fully static (SliceOp style). -+ SmallVector startIndices, limitIndices, strides; -+ if (succeeded(hlo::matchInts(op.getStartIndices(), startIndices)) && -+ succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices)) && -+ succeeded(hlo::matchInts(op.getStrides(), strides))) { -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferSliceOp(/*location=*/{}, op.getOperand().getType(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ // Alternative #2: Slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) || -+ matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) { -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides)) || -+ !llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ auto startIndicesElementType = -+ op.getStartIndices().getType().getElementType(); -+ SmallVector startIndicesTypes( -+ sliceSizesAttr.size(), -+ RankedTensorType::get({}, startIndicesElementType)); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferDynamicSliceOp( -+ op.getLoc(), op.getOperand().getType(), startIndicesTypes, -+ rewriter.getI64TensorAttr(sliceSizes), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferDynamicSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ return rewriter.notifyMatchFailure( -+ op, -+ "expected either fully static attributes (SliceOp style) " -+ "or static sliceSizes (DynamicSliceOp style)"); -+ } -+}; -+ -+struct RefineReduceScatterOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReduceScatterOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getScatterDimension())) -+ refinement[op.getScatterDimension()] /= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineRngOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RngOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getShape()); -+ } -+}; -+ -+struct RefineUniformQuantizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(UniformQuantizeOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferUniformQuantizeOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineWhileOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(WhileOp op, -+ PatternRewriter& rewriter) const override { -+ // Push the potentially refined operand types into the nested regions. -+ // This can lead to refinements of the return types of the body (but not -+ // of the cond since it always returns tensor), but the key insight here -+ // is that the enclosing while op doesn't care about these refinements -+ // (because its return types are equal to its operand types). -+ // If we end up with incompatibilities between while's return types and -+ // body's return types, the verifier will tell us about that. This means -+ // that the original program wasn't well-formed. TODO(burmako): Implement -+ // better error reporting for this case. -+ // This serves the current use cases well, so the implementation of more -+ // sophisticated refinement algorithm is left for future work. -+ rewriter.startRootUpdate(op); -+ auto condStatus = refineValues(rewriter, op, op.getCond().getArguments(), -+ op.getOperandTypes()); -+ auto bodyStatus = refineValues(rewriter, op, op.getBody().getArguments(), -+ op.getOperandTypes()); -+ if (succeeded(condStatus) || succeeded(bodyStatus)) { -+ rewriter.finalizeRootUpdate(op); -+ return success(); -+ } else { -+ rewriter.cancelRootUpdate(op); -+ return failure(); -+ } -+ } -+}; -+ -+struct UpdateFunctionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(func::ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ // Check whether any of the values returned by `func.return` are casts -+ // which convert more specific type to less specific type. -+ // Such ops are produced by the algorithm behind this pass to avoid -+ // bringing the enclosing `func.func` op into an inconsistent state when -+ // refining individual ops. This pattern cleans this up. -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ llvm::SmallSet castsToReplace; -+ for (auto [i, operand] : llvm::enumerate(op.getOperands())) { -+ auto cast = -+ dyn_cast_or_null(operand.getDefiningOp()); -+ if (!cast || cast.getInputs().size() != 1 || -+ cast.getOutputs().size() != 1) -+ continue; -+ -+ // Only proceed if the type that we're casting from is more specific -+ // than the type that we're casting to. -+ auto sourceType = cast.getInputs()[0].getType(); -+ auto destType = cast.getOutputs()[0].getType(); -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {sourceType, destType}); -+ if (failed(mostSpecificType) || destType == *mostSpecificType) continue; -+ -+ // If the source type of the cast is more specific than the target type, -+ // then we conclude that the cast is redundant (i.e. needs to be removed) -+ // and that the return type of the function needs an update. -+ needsUpdate = true; -+ updatedResultTypes[i] = sourceType; -+ -+ // Insert into set and continue iterating. -+ // ReturnOp may point to same value more than once. -+ castsToReplace.insert(cast); -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ // Replace CastOps with more specific operands than results. -+ for (auto cast : castsToReplace) -+ rewriter.replaceOp(cast, cast->getOperands()); -+ -+ // If the type of the enclosing `func.func` needs an update, we simply -+ // call setType. We can afford this simplicity because our algorithm -+ // currently supports only one function per module. -+ auto func = cast(op->getParentOp()); -+ func.setType( -+ rewriter.getFunctionType(func.getArgumentTypes(), updatedResultTypes)); -+ return success(); -+ } -+}; -+ -+struct UpdateRegionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ if (!isa(op->getParentOp())) -+ return rewriter.notifyMatchFailure(op, "unsupported region"); -+ -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ for (auto [regionType, refinedType] : llvm::zip( -+ op->getParentOp()->getResultTypes(), op->getOperandTypes())) { -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {regionType, refinedType}); -+ if (failed(mostSpecificType) || regionType == *mostSpecificType) continue; -+ needsUpdate = true; -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); -+ return success(); -+ } -+}; -+ -+struct StablehloRefineShapesPass -+ : public impl::StablehloRefineShapesPassBase { -+ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; -+ -+ void runOnOperation() override { -+ // Only one function per module is supported at the moment to avoid the need -+ // to think about iterative type inference algorithms. -+ // Current use cases are served well by inlining multiple functions into -+ // a single function, so we leave native support for multiple functions to -+ // future work. -+ // To enable modules that contain CustomCallOp::called_computations, -+ // we allow multiple functions, in which case we only refine the main -+ // function called "main", assuming that the called computations will have -+ // static shapes. Lifting this assumption and expanding refinement to -+ // multiple functions is left for future work. -+ ModuleOp module = getOperation(); -+ auto funcs = llvm::to_vector(module.getOps()); -+ if (funcs.empty()) return; -+ func::FuncOp func; -+ if (funcs.size() == 1) { -+ func = funcs[0]; -+ } else { -+ func = module.lookupSymbol("main"); -+ } -+ if (!func) { -+ module.emitOpError() -+ << "must have no more than one function or a `main`" -+ << " function to clearly identify which function will be refined"; -+ return signalPassFailure(); -+ } -+ -+ // Similarly, only one block per function is supported at the moment. -+ // At the StableHLO level, functions are expected to only have one block, -+ // so supporting more is out of scope for this pass. -+ if (!func.getRegion().hasOneBlock()) { -+ func.emitOpError() << "must have exactly one block"; -+ return signalPassFailure(); -+ } -+ -+ // The algorithm behind this pass consists of a single traversal of the -+ // function. This is sufficient because we only support one function per -+ // program at the moment. -+ // TODO(#1048): Find out why .maxIterations = 1 no longer works. -+ // There have been recent refactors to applyPatternsAndFoldGreedily -+ // upstream, and that might be the reason. -+ GreedyRewriteConfig config; -+ config.useTopDownTraversal = true; -+ config.enableRegionSimplification = true; -+ config.maxIterations = 2; -+ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; -+ config.strictMode = GreedyRewriteStrictness::AnyOp; -+ -+ RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ if (failed( -+ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { -+ return signalPassFailure(); -+ } -+ } -+}; -+ -+} // namespace -+} // namespace experimental -+} // namespace stablehlo -+} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -340,6 +340,19 @@ - %1 = stablehlo.constant dense<2> : tensor - %2 = stablehlo.multiply %0, %1 : tensor - func.return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_or -+func.func @eval_or() -> tensor { -+ // CHECK-NOT: stablehlo.or -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense : tensor -+ %1 = stablehlo.constant dense : tensor -+ %2 = stablehlo.or %0, %1 : tensor -+ func.return %2 : tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -304,6 +304,20 @@ - } - }; - -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ - struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, -@@ -1165,6 +1179,7 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++#endif // STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H From 598681e59231551d15c74a0a8b5875df7fba4bb3 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Fri, 1 Dec 2023 17:43:47 -0800 Subject: [PATCH 313/381] [XLA] Remove hard-coded constants from chi-square test This makes it easier to adjust. PiperOrigin-RevId: 587176721 --- third_party/xla/xla/tests/BUILD | 1 + third_party/xla/xla/tests/prng_test.cc | 84 +++++++++++++++++++------- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 1b32e39e73b18a..dab194689cb5fc 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1890,6 +1890,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/tests/prng_test.cc b/third_party/xla/xla/tests/prng_test.cc index cf9e1dcbae618f..accfa44034ae72 100644 --- a/third_party/xla/xla/tests/prng_test.cc +++ b/third_party/xla/xla/tests/prng_test.cc @@ -13,10 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include #include #include +#include +#include +#include #include "absl/types/span.h" +#include "unsupported/Eigen/SpecialFunctions" // from @eigen_archive #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/literal.h" @@ -43,8 +51,8 @@ class PrngTest : public ClientLibraryTestBase { // of the given range size. `expected_count` is the number of times each // possible value is expected to be generated. Thus, the sample size is // `range_size * expected_count`. - double UniformChiSquared(int32_t range_size, int32_t expected_count, - int64_t seed = 42); + void UniformChiSquared(int32_t range_size, int32_t expected_count, + int64_t seed = 42); }; template @@ -141,10 +149,30 @@ template T Square(T x) { return x * x; } + +// Calculates the p-value (probability) of a given chi-square value and degrees +// of freedom. +double ChiSquarePValue(double chi_square, int dof) { + // We are doing a right-tailed test so the p-value is calculated as 1 - CDF. + // + // The CDF can be computed using the regularized lower incomplete gamma + // function like so: + // gammainc(dof/2, chi_square/2). + // + // Seeing as we are interested in 1-CDF, we can compute this using the + // regularized upper incomplete gamma function like so: + // gammaincc(dof/2, chi_square/2). + // + // NIST/SEMATECH e-Handbook of Statistical Methods, 1.3.6.6.6. Chi-Square + // Distribution: Cumulative Distribution Function + // https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm#cdf + return Eigen::numext::igammac(0.5 * dof, 0.5 * chi_square); +} + } // namespace -double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, - int64_t seed) { +void PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, + int64_t seed) { int32_t sample_size = range_size * expected_count; XlaBuilder builder(TestName()); @@ -157,34 +185,48 @@ double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, std::vector counts(range_size, 0); actual.EachCell( [&counts](absl::Span, int32_t value) { ++counts[value]; }); + LOG(INFO) << "sample_size = " << sample_size; + LOG(INFO) << "range_size = " << range_size; + LOG(INFO) << "expected_count = " << expected_count; + for (int32_t i = 0; i < range_size; ++i) { + LOG(INFO) << "counts[" << i << "] = " << counts[i]; + } int64_t sum = 0; for (int32_t i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); } - return static_cast(sum) / expected_count; + double chi_square = static_cast(sum) / expected_count; + int64_t dof = range_size - 1; + double p_value = ChiSquarePValue(chi_square, dof); + const double kLevelOfSignificance = 1e-5; + // We have two hypotheses: + // - null hypothesis: the distribution we sampled from cannot be distinguished + // from a uniform random distribution. + // - alternate hypothesis: the distribution we sampled from can be + // distinguished from a uniform random distribution. + // + // The lower our calculated p-value, the less likely we would get this result + // if the null hypothesis were true. If our p-value is greater than or equal + // to `kLevelOfSignificance`, we cannot reject the null hypothesis. + // + // Another way of saying this is that if our p-value is greater than or equal + // to `kLevelOfSignificance` then we can consider our data randomly + // distributed with a confidence of 1-kLevelOfSignificance; otherwise, if our + // p-value is less than `kLevelOfSignificance` then our data is non-random + // with a confidence of 1-kLevelOfSignificance. + EXPECT_GE(p_value, kLevelOfSignificance); } // We only test distribution of uniform discrete PRNG as other types are based // on it. // These range sizes are arbitrary but include prime numbers, powers of 2, and // other composite numbers. -// The level of significance in all these cases is 1/20. // TODO(b/35723038): Use parametrized tests where possible. -XLA_TEST_F(PrngTest, Uniformity7) { - EXPECT_LT(UniformChiSquared(7, 256), 12.5916); -} -XLA_TEST_F(PrngTest, Uniformity61) { - EXPECT_LT(UniformChiSquared(61, 256), 79.0819); -} -XLA_TEST_F(PrngTest, Uniformity64) { - EXPECT_LT(UniformChiSquared(64, 256), 82.5287); -} -XLA_TEST_F(PrngTest, Uniformity108) { - EXPECT_LT(UniformChiSquared(108, 256), 132.144); -} -XLA_TEST_F(PrngTest, Uniformity256) { - EXPECT_LT(UniformChiSquared(256, 512), 293.248); -} +XLA_TEST_F(PrngTest, Uniformity7) { UniformChiSquared(7, 256); } +XLA_TEST_F(PrngTest, Uniformity61) { UniformChiSquared(61, 256); } +XLA_TEST_F(PrngTest, Uniformity64) { UniformChiSquared(64, 256); } +XLA_TEST_F(PrngTest, Uniformity108) { UniformChiSquared(108, 256); } +XLA_TEST_F(PrngTest, Uniformity256) { UniformChiSquared(256, 256); } // TODO(b/134770669): May remove this test if we decide not to support map // computations with kRng instructions. From c28fa6e366e58c9c368b9330d6f500c4223661db Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 1 Dec 2023 18:40:56 -0800 Subject: [PATCH 314/381] Add keys() to WeakrefLRUCache. This should probably only be used for debugging. PiperOrigin-RevId: 587189745 --- third_party/xla/xla/pjrt/lru_cache.h | 3 +++ .../xla/xla/python/weakref_lru_cache.cc | 19 +++++++++++++++++++ .../xla/xla/python/weakref_lru_cache_test.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/third_party/xla/xla/pjrt/lru_cache.h b/third_party/xla/xla/pjrt/lru_cache.h index ce3d1933f8c5fa..0386817fe2768d 100644 --- a/third_party/xla/xla/pjrt/lru_cache.h +++ b/third_party/xla/xla/pjrt/lru_cache.h @@ -92,6 +92,9 @@ class LRUCache { int Size() const { return entries_.size(); } int Capacity() const { return lru_list_->Capacity(); } + auto begin() const { return entries_.begin(); } + auto end() const { return entries_.end(); } + private: LRUList* lru_list_; diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 20bde65e9c928d..4de438bfea47c6 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -28,6 +28,8 @@ limitations under the License. #include "pybind11/cast.h" // from @pybind11 #include "pybind11/gil.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11/stl.h" // from @pybind11 #include "xla/pjrt/lru_cache.h" namespace jax { @@ -227,6 +229,22 @@ class WeakrefLRUCache : public std::enable_shared_from_this { return fn_(weakref_key, *args, **kwargs); } } + std::vector GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_key : entries_) { + for (const auto& rest : *wr_key.second) { + pybind11::tuple result(4); + result[0] = wr_key.first.weakref; + result[1] = rest.first.context; + result[2] = rest.first.args; + result[3] = rest.first.kwargs; + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; + } CacheInfo GetCacheInfo() const { CacheInfo result; result.hits = total_queries_ - misses_; @@ -265,6 +283,7 @@ void BuildWeakrefLRUCacheAPI(pybind11::module& m) { py::class_>( m, "WeakrefLRUCache") .def("__call__", &WeakrefLRUCache::Call) + .def("cache_keys", &WeakrefLRUCache::GetKeys) .def("cache_info", &WeakrefLRUCache::GetCacheInfo) .def("cache_clear", &WeakrefLRUCache::Clear); py::class_(weakref_lru_cache, diff --git a/third_party/xla/xla/python/weakref_lru_cache_test.py b/third_party/xla/xla/python/weakref_lru_cache_test.py index 213d8a9c23aedd..ae8808bb62d49a 100644 --- a/third_party/xla/xla/python/weakref_lru_cache_test.py +++ b/third_party/xla/xla/python/weakref_lru_cache_test.py @@ -94,6 +94,23 @@ def CacheFn(obj, kwkey1, kwkey2): self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + if __name__ == "__main__": absltest.main() From 3880d85680a8c9bab3eb221cf7ed5a97ed76979e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 18:51:42 -0800 Subject: [PATCH 315/381] [xla:gpu] Add support for custom fusions/kernels to XLA runtime PiperOrigin-RevId: 587191918 --- .../gpu/transforms/gpu_to_gpu_runtime.cc | 13 +- .../gpu/transforms/lmhlo_to_gpu_launch.cc | 75 +++++++- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 41 +++-- .../xla/xla/service/gpu/ir_emitter_unnested.h | 3 +- .../xla/xla/service/gpu/kernel_thunk.cc | 22 ++- .../xla/xla/service/gpu/kernel_thunk.h | 29 ++- .../xla/service/gpu/kernels/custom_kernel.cc | 3 + .../xla/service/gpu/kernels/custom_kernel.h | 3 + .../kernels/cutlass_gemm_custom_kernel.cu.cc | 15 +- .../gpu/kernels/cutlass_gemm_custom_kernel.h | 3 +- .../cutlass_gemm_custom_kernel_test.cc | 2 +- .../gpu/kernels/cutlass_gemm_fusion.cc | 15 +- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 10 +- third_party/xla/xla/service/gpu/runtime/BUILD | 13 ++ .../xla/service/gpu/runtime/kernel_launch.cc | 168 +++++++++++++++++- third_party/xla/xla/service/gpu/thunk.cc | 1 + third_party/xla/xla/service/gpu/thunk.h | 1 + .../mhlo_to_lhlo_with_xla.cc | 10 ++ 19 files changed, 383 insertions(+), 45 deletions(-) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc index a016b4802100ed..d7a03fe538c043 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc @@ -25,7 +25,10 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/mlir/backends/gpu/transforms/uid_generator.h" #include "xla/mlir/runtime/utils/custom_calls.h" @@ -179,9 +182,12 @@ class LaunchFuncOpLowering : public OpRewritePattern { // Add kernel arguments. llvm::copy(op.getKernelOperands(), std::back_inserter(args)); + auto computation = op->getAttr("__custom_fusion_computation"); + // Get or create a custom call function declaration. func::FuncOp callee = custom_calls_.GetOrCreate( - b, "xla.gpu.func.launch", TypeRange(ValueRange(args)), TypeRange()); + b, computation ? "xla.gpu.func.custom_launch" : "xla.gpu.func.launch", + TypeRange(ValueRange(args)), TypeRange()); // Create a function launch call operation. auto call = b.create(callee.getName(), TypeRange(), args); @@ -198,6 +204,11 @@ class LaunchFuncOpLowering : public OpRewritePattern { call->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(0)); } + // Copy custom fusion computation. + if (computation) { + call->setAttr("__custom_fusion_computation", computation); + } + // Erase the original gpu launch operation. rewriter.eraseOp(op); diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc index d0df74a362ad08..4f4b7dd2af9e6b 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -203,9 +204,9 @@ static absl::StatusOr> Match( // Check if we know how to lower a Thunk to Gpu operation(s). auto is_supported = [](const std::unique_ptr& thunk) -> bool { - Thunk::Kind kinds[] = {Thunk::kKernel, Thunk::kCopy, - Thunk::kMemset32BitValue, Thunk::kMemzero, - Thunk::kSequential}; + Thunk::Kind kinds[] = {Thunk::kKernel, Thunk::kCustomKernel, + Thunk::kCopy, Thunk::kMemset32BitValue, + Thunk::kMemzero, Thunk::kSequential}; return llvm::any_of( kinds, [&](Thunk::Kind kind) { return thunk->kind() == kind; }); }; @@ -281,6 +282,57 @@ static void LowerKernelThunkToGpuOp( kernel_args); } +static void LowerCustomKernelThunkToGpuOp( + Operation* op, OpBuilder& b, GPUModuleOp gpu_module, + const CustomKernelThunk& thunk, const SmallVector& kernel_args, + const SmallVector& kernel_args_written) { + mlir::Location loc = op->getLoc(); + b.setInsertionPointToStart(gpu_module.getBody()); + + auto func_type = + b.getType(TypeRange(ValueRange(kernel_args)), TypeRange()); + + gpu::GPUFuncOp kernel_func = + b.create(loc, thunk.custom_kernel_name(), func_type); + kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); + + for (int i = 0; i < kernel_args.size(); ++i) { + if (kernel_args_written[i]) { + kernel_func.setArgAttr(i, "lmhlo.written", b.getUnitAttr()); + } + } + + b.setInsertionPointToEnd(&kernel_func.getBody().back()); + b.create(loc); + + auto make_const_idx = [&](int64_t value) { + auto attr = b.getIndexAttr(value); + return b.create(loc, attr).getResult(); + }; + + auto make_kernel_dim3 = [&](const auto& dim3) { + return KernelDim3{make_const_idx(dim3.x), make_const_idx(dim3.y), + make_const_idx(dim3.z)}; + }; + + b.setInsertionPoint(op); + auto launch_dims = thunk.launch_dimensions(); + auto grid_size = make_kernel_dim3(launch_dims.block_counts()); + auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block()); + auto shmem_size = b.create( + loc, b.getI32IntegerAttr(thunk.shmem_bytes())); + + auto launch_func = b.create( + loc, kernel_func, grid_size, block_size, shmem_size, kernel_args); + + if (auto computation = op->getAttr("__custom_fusion_computation")) { + launch_func->setAttr("__custom_fusion_computation", computation); + } else { + launch_func->setAttr("__custom_fusion_computation", + b.getStringAttr("")); + } +} + static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, GPUModuleOp gpu_module, Thunk* thunk) { auto loc = op->getLoc(); @@ -341,6 +393,23 @@ static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, return; } + if (thunk->kind() == Thunk::kCustomKernel) { + const auto* kernel_thunk = static_cast(thunk); + + SmallVector kernel_args; + for (auto kernel_arg : kernel_thunk->values()) + kernel_args.push_back(kernel_arg); + + SmallVector kernel_args_written; + for (auto written : kernel_thunk->written()) { + kernel_args_written.push_back(written); + } + + LowerCustomKernelThunkToGpuOp(op, b, gpu_module, *kernel_thunk, kernel_args, + kernel_args_written); + return; + } + CHECK(false) << "Thunk kind not handled: " << thunk->kind(); } diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2a2004472ba197..7c3c92b97114ce 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4446,6 +4446,7 @@ test_suite( # copybara:uncomment "//third_party/py/jax/tests:pmap_test_gpu", # copybara:uncomment "//tensorflow/compiler/tests:fft_test_gpu", "//xla/python:xla_client_test_gpu", + # copybara:uncomment "//xla/service/gpu/kernels:cutlass_gemm_fusion_test_gpu", "//xla/service/gpu/tests:add_preds.hlo.test", "//xla/service/gpu/tests:concat.hlo.test", "//xla/service/gpu/tests:constant.hlo.test", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 76bb3b0e17e2db..aca82442de5e9c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -329,13 +329,23 @@ StatusOr> BuildKernelThunkForFusion( StatusOr> BuildCustomKernelThunkForFusion( IrEmitterContext& ir_emitter_context, const HloFusionInstruction* fusion, - CustomKernel custom_kernel) { - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - KernelArguments::Create(ir_emitter_context.buffer_assignment(), fusion)); + mlir::lmhlo::FusionOp fusion_op, CustomKernel custom_kernel) { + TF_ASSIGN_OR_RETURN(auto kernel_arguments, + ir_emitter_context.emit_ir_from_hlo() + ? KernelArguments::Create( + ir_emitter_context.buffer_assignment(), fusion) + : KernelArguments::Create( + ir_emitter_context.allocations(), fusion_op)); + + std::variant instr; + if (ir_emitter_context.emit_ir_from_hlo()) { + instr = fusion; + } else { + instr = fusion_op; + } return std::make_unique( - fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); + instr, std::move(custom_kernel), std::move(kernel_arguments.args())); } // Derives the number of warps to use for processing a Triton Softmax fusion. @@ -2203,7 +2213,8 @@ Status IrEmitterUnnested::EmitFusion( instr->backend_config()); TF_ASSIGN_OR_RETURN( emission_result, - EmitCustomFusion(instr, backend_config.custom_fusion_config())); + EmitCustomFusion(instr, nullptr, + backend_config.custom_fusion_config())); break; } default: @@ -2278,8 +2289,13 @@ Status IrEmitterUnnested::EmitFusion( EmitScatter(fusion, fusion_op, fusion_analysis)); break; } - case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: - LOG(FATAL) << "kCustomFusion is not supported by JitRt runtime"; + case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { + TF_ASSIGN_OR_RETURN( + emission_result, + EmitCustomFusion(fusion, fusion_op, + backend_config.custom_fusion_config())); + break; + } } for (auto& thunk : emission_result.thunks) { @@ -3305,7 +3321,8 @@ StatusOr IrEmitterUnnested::EmitScatter( } StatusOr IrEmitterUnnested::EmitCustomFusion( - const HloFusionInstruction* fusion, const CustomFusionConfig& config) { + const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, + const CustomFusionConfig& config) { VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); auto* registry = CustomFusionRegistry::Default(); @@ -3336,9 +3353,9 @@ StatusOr IrEmitterUnnested::EmitCustomFusion( return absl::InternalError("Expected exactly one custom kernel"); } - TF_ASSIGN_OR_RETURN( - auto thunk, BuildCustomKernelThunkForFusion(*ir_emitter_context_, fusion, - std::move(kernels[0]))); + TF_ASSIGN_OR_RETURN(auto thunk, BuildCustomKernelThunkForFusion( + *ir_emitter_context_, fusion, fusion_op, + std::move(kernels[0]))); FusionEmissionResult result; result.thunks.push_back(std::move(thunk)); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index edab1055255e70..5bb76cf301677d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -374,7 +374,8 @@ class IrEmitterUnnested : public IrEmitter { // Emits kernel thunk for a custom fusion implemented with hand written custom // device kernels. StatusOr EmitCustomFusion( - const HloFusionInstruction* fusion, const CustomFusionConfig& config); + const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, + const CustomFusionConfig& config); // Builds a kernel thunk for a non-fusion operation, without reuse. // diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.cc b/third_party/xla/xla/service/gpu/kernel_thunk.cc index 58dfab7b11320a..a8947e19f6bc0a 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/kernel_thunk.cc @@ -173,9 +173,15 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { //===----------------------------------------------------------------------===// CustomKernelThunk::CustomKernelThunk( - const HloInstruction* instr, CustomKernel custom_kernel, + std::variant instr, + CustomKernel custom_kernel, absl::Span kernel_arguments) - : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), + : Thunk(Kind::kCustomKernel, + std::holds_alternative(instr) + ? Thunk::ThunkInfo::WithProfileAnnotation( + std::get(instr)) + : Thunk::ThunkInfo::WithProfileAnnotation( + std::get(instr))), custom_kernel_(std::move(custom_kernel)) { args_.reserve(kernel_arguments.size()); written_.reserve(kernel_arguments.size()); @@ -185,6 +191,18 @@ CustomKernelThunk::CustomKernelThunk( written_.push_back(kernel_argument.written()); } } + + if (std::holds_alternative(instr)) { + // Skip populating MLIR values_ if emitting from HLO. + return; + } + + values_.reserve(kernel_arguments.size()); + for (const auto& kernel_argument : kernel_arguments) { + if (!kernel_argument.first_with_same_slice().has_value()) { + values_.push_back(RemoveTransformingOperations(kernel_argument.value())); + } + } } std::string CustomKernelThunk::ToStringExtra(int indent) const { diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.h b/third_party/xla/xla/service/gpu/kernel_thunk.h index 7f8ff1331324e4..2c16d3cc1659f0 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/kernel_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -136,7 +137,8 @@ class KernelThunk : public Thunk { // compiled by XLA and loaded from an executable source. class CustomKernelThunk : public Thunk { public: - CustomKernelThunk(const HloInstruction* instr, CustomKernel custom_kernel, + CustomKernelThunk(std::variant inst, + CustomKernel custom_kernel, absl::Span kernel_arguments); std::string ToStringExtra(int indent) const override; @@ -145,6 +147,28 @@ class CustomKernelThunk : public Thunk { ExecutableSource src) override; Status ExecuteOnStream(const ExecuteParams& params) override; + // TODO(ezhulenev): All of the APIs below needed only for LMHLO lowering and + // should be removed after we migrate to Thunks runtime. + + std::string_view custom_kernel_name() const { return custom_kernel_.name(); } + + const std::vector& written() const { return written_; } + absl::Span values() const { return values_; } + + LaunchDimensions launch_dimensions() const { + LaunchDimensions::Dim3D threads; + threads.x = custom_kernel_.thread_dims().x; + threads.y = custom_kernel_.thread_dims().y; + threads.z = custom_kernel_.thread_dims().z; + LaunchDimensions::Dim3D blocks; + blocks.x = custom_kernel_.block_dims().x; + blocks.y = custom_kernel_.block_dims().y; + blocks.z = custom_kernel_.block_dims().z; + return LaunchDimensions(blocks, threads); + } + + int64_t shmem_bytes() const { return custom_kernel_.shared_memory_bytes(); } + private: // Buffer slices passed to the kernel as arguments. std::vector args_; @@ -152,6 +176,9 @@ class CustomKernelThunk : public Thunk { // args_[i] is written iff (written_[i] == true). std::vector written_; + // mlir::Value(s) corresponding to the buffer slice arguments. + std::vector values_; + CustomKernel custom_kernel_; // Loaded kernels for each `StreamExecutor`. diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc index b9451eb6ff0154..f7e9e75e73ae02 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/strings/str_format.h" @@ -36,6 +37,8 @@ CustomKernel::CustomKernel(std::string name, shared_memory_bytes_(shared_memory_bytes) {} +std::string_view CustomKernel::name() const { return name_; } + const se::MultiKernelLoaderSpec& CustomKernel::kernel_spec() const { return kernel_spec_; } diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h index 32f65b76a7a49c..d6ceca389e3f63 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -47,6 +48,8 @@ class CustomKernel { se::BlockDim block_dims, se::ThreadDim thread_dims, size_t shared_memory_bytes); + std::string_view name() const; + const se::MultiKernelLoaderSpec& kernel_spec() const; se::BlockDim block_dims() const; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc index de8868fc497012..96f9d7a17f0b0c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc @@ -30,8 +30,8 @@ namespace xla::gpu::kernel::gemm_universal { template static StatusOr LoadCutlassGemmUniversal( - int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, - const DynamicSliceIndices& slices) { + std::string name, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices) { using Kernel = typename Gemm::GemmKernel; cutlass::gemm::GemmCoord problem_size = {m, n, k}; @@ -39,24 +39,25 @@ static StatusOr LoadCutlassGemmUniversal( auto packing = ArgsPacking(problem_size, indices, slices); se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); - spec.AddInProcessSymbol(GetKernelSymbol(), "cutlass_gemm"); + spec.AddInProcessSymbol(GetKernelSymbol(), name); - return CustomKernel("cutlass_gemm", std::move(spec), + return CustomKernel(std::move(name), std::move(spec), BlockDim(problem_size), ThreadDim(), sizeof(typename Kernel::SharedStorage)); } -StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, +StatusOr GetCutlassGemmKernel(std::string name, + PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices) { switch (dtype) { case PrimitiveType::F32: return LoadCutlassGemmUniversal( - m, n, k, indices, slices); + std::move(name), m, n, k, indices, slices); case PrimitiveType::BF16: return LoadCutlassGemmUniversal( - m, n, k, indices, slices); + std::move(name), m, n, k, indices, slices); default: return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index 345d5973c5720b..da971f93e18a8a 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -25,7 +25,8 @@ limitations under the License. namespace xla::gpu::kernel::gemm_universal { -StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, +StatusOr GetCutlassGemmKernel(std::string name, + PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index f577c2b431c6fc..02f5a80ce74a9c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -43,7 +43,7 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. auto custom_kernel = - GetCutlassGemmKernel(PrimitiveType::F32, 4, 4, 4, + GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}); TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 9bc6c6d62ee867..bb7787466e4cd4 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { @@ -225,9 +226,10 @@ class CutlassGemmFusion : public CustomFusion { size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN(auto kernel, - kernel::gemm_universal::GetCutlassGemmKernel( - dtype, m, n, k, indices, /*slices=*/{})); + TF_ASSIGN_OR_RETURN( + auto kernel, + kernel::gemm_universal::GetCutlassGemmKernel( + "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{})); return std::vector{std::move(kernel)}; } }; @@ -300,9 +302,10 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomFusion { size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN(auto kernel, - kernel::gemm_universal::GetCutlassGemmKernel( - dtype, m, n, k, args_indices, slices)); + TF_ASSIGN_OR_RETURN( + auto kernel, kernel::gemm_universal::GetCutlassGemmKernel( + "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, + k, args_indices, slices)); return std::vector{std::move(kernel)}; } }; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index bf26adb5f43303..64aace4fa50ace 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/literal_util.h" #include "xla/service/gpu/custom_fusion_rewriter.h" @@ -31,14 +30,7 @@ limitations under the License. namespace xla::gpu { -class CutlassFusionTest : public HloTestBase { - // Custom fusions are not supported by XLA runtime. - DebugOptions GetDebugOptionsForTest() override { - auto debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_gpu_enable_xla_runtime_executable(false); - return debug_options; - } -}; +class CutlassFusionTest : public HloTestBase {}; //===----------------------------------------------------------------------===// // Pattern matching tests diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index ac1ebe27e21435..edd51c1e1600c8 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -641,18 +641,31 @@ cc_library( deps = [ ":concurrent_region", ":support", + "//xla:statusor", "//xla:types", + "//xla/hlo/ir:hlo", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/runtime:state", "//xla/service:executable", + "//xla/service:hlo_proto_cc", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_fusion", + "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_graph", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc index 3b324049ed175f..2702254d5890d4 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc @@ -15,21 +15,40 @@ limitations under the License. #include "xla/service/gpu/runtime/kernel_launch.h" +#include +#include #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/runtime/custom_call.h" +#include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/runtime/state.h" +#include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/runtime/concurrent_region.h" #include "xla/service/gpu/runtime/support.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo.pb.h" #include "xla/service/service_executable_run_options.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/stream_executor/gpu/gpu_graph.h" @@ -133,6 +152,131 @@ static absl::Status LaunchImpl( execution_stream); } +//===----------------------------------------------------------------------===// +// Define the custom kernel (fusion) launch custom call. +//===----------------------------------------------------------------------===// + +static StatusOr> CreateCustomKernel( + se::StreamExecutor* executor, std::string_view name, + std::string_view custom_fusion_computation) { + auto* registry = CustomFusionRegistry::Default(); + auto* custom_fusion = registry->Lookup(name); + + // If custom fusion is not found it means that some of the build targets might + // not be statically linked into the binary. + if (custom_fusion == nullptr) { + return absl::InternalError(absl::StrCat( + "Custom fusion ", name, " not found in a default registry.")); + } + + // Parse attached custom fusion computation. + HloComputationProto computation_proto; + if (!computation_proto.ParseFromArray(custom_fusion_computation.data(), + custom_fusion_computation.size())) { + return absl::InternalError("Failed to parse custom fusion computation"); + } + + // Build HloComputation from a proto for passing to custom fusion. + absl::flat_hash_map computation_map; + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); + + // Load custom kernels that can implement a fusion computation. + TF_ASSIGN_OR_RETURN(std::vector kernels, + custom_fusion->LoadKernels(computation.get())); + + // This should never happen, it means that compilation pipeline created a + // fusion operation that is not supported by a given custom fusion. + if (kernels.empty()) { + return absl::InternalError( + absl::StrCat("Custom fusion ", name, + " returned empty custom kernels for a fused computation")); + } + + auto kernel = std::make_unique(executor); + TF_RETURN_IF_ERROR( + executor->GetKernel(kernels[0].kernel_spec(), kernel.get())); + + return kernel; +} + +static absl::Status CustomLaunchImpl( + const ServiceExecutableRunOptions* run_options, const std::string* ptx, + const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, + ConcurrentRegionStatus* region_status, + State> device_kernel, + int32_t shared_memory_bytes, int32_t grid_size_x, int32_t grid_size_y, + int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y, + int32_t block_size_z, CustomCall::RemainingArgs args, std::string_view name, + int64_t stream_id, std::string_view custom_fusion_computation) { + se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + + LaunchDimensions launch_dimensions( + {grid_size_x, grid_size_y, grid_size_z}, + {block_size_x, block_size_y, block_size_z}); + + // If kernel does not exist load it from a custom fusion computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr * kernel, device_kernel.GetOrCreate([&] { + return ToAbsl( + CreateCustomKernel(executor, name, custom_fusion_computation)); + })); + assert((*kernel)->name() == name && "unexpected loaded kernel"); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (VLOG_IS_ON(3)) { + TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream)); + if (is_capturing) { + if (region_status->IsInConcurrentRegion()) { + LOG(INFO) << "Launching " << (*kernel)->name() + << "in a concurrent region during GPU graph capture"; + } else { + LOG(INFO) << "Launching " << (*kernel)->name() + << "during GPU graph capture"; + } + } else { + LOG(INFO) << "Launching " << (*kernel)->name(); + } + } +#else + VLOG(3) << "Launching " << (*kernel)->name(); +#endif + + absl::InlinedVector buffer_args(args.size()); + + // Add MemRef arguments as buffer arguments. + for (unsigned i = 0; i < args.size(); ++i) { + // We get arguments corresponding to XLA allocations required by the + // compiled device kernel, and not the actual memrefs that device kernel + // writes/reads, so we don't have to pass the size along with the pointer. + if (auto strided = args.get(i); succeeded(strided)) { + buffer_args[i] = se::DeviceMemoryBase(strided->data); + continue; + } + + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported argument #%d type", i)); + } + + // If we are capturing a concurrent region in a GPU graph, then use the + // stream provided by ConcurrentRegionStatus to execute the kernel. + se::Stream* execution_stream = stream; + if (stream_id != 0) { + DCHECK(region_status->IsInConcurrentRegion()); + TF_ASSIGN_OR_RETURN(execution_stream, region_status->GetStream(stream_id)); + } else if (region_status->IsInConcurrentRegion()) { + execution_stream = region_status->GetNextStream(); + } + + se::KernelArgsDeviceMemoryArray kernel_args(buffer_args, shared_memory_bytes); + return executor->Launch( + stream, se::ThreadDim(block_size_x, block_size_y, block_size_z), + se::BlockDim(grid_size_x, grid_size_y, grid_size_z), **kernel, + kernel_args); +} + //===----------------------------------------------------------------------===// XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -155,9 +299,31 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("kernel") .Attr("stream")); +XLA_RUNTIME_DEFINE_CUSTOM_CALL( + CustomLaunch, FunctionWrapper(), checks, + CustomCall::Bind("xla.gpu.func.custom_launch") + .UserData() + .UserData() + .UserData*>() + .UserData() + .UserData() + .State>("uid") + .Arg() // shared_memory_bytes + .Arg() // grid_size_x + .Arg() // grid_size_y + .Arg() // grid_size_z + .Arg() // block_size_x + .Arg() // block_size_y + .Arg() // block_size_x + .RemainingArgs() // args + .Attr("kernel") + .Attr("stream") + .Attr("__custom_fusion_computation")); + void RegisterKernelLaunchCustomCalls( runtime::DirectCustomCallRegistry& registry) { registry.Register("xla.gpu.func.launch", Launch); + registry.Register("xla.gpu.func.custom_launch", CustomLaunch); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/thunk.cc b/third_party/xla/xla/service/gpu/thunk.cc index 6c544a85685138..5f2fa4885a5efd 100644 --- a/third_party/xla/xla/service/gpu/thunk.cc +++ b/third_party/xla/xla/service/gpu/thunk.cc @@ -50,6 +50,7 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kCubSort); CASE(kCublasLtMatmul); CASE(kCustomCall); + CASE(kCustomKernel); CASE(kNcclAllGather); CASE(kNcclAllGatherStart); CASE(kNcclAllGatherDone); diff --git a/third_party/xla/xla/service/gpu/thunk.h b/third_party/xla/xla/service/gpu/thunk.h index a560899dcebf77..047c93d0edac78 100644 --- a/third_party/xla/xla/service/gpu/thunk.h +++ b/third_party/xla/xla/service/gpu/thunk.h @@ -72,6 +72,7 @@ class Thunk { kCubSort, kCublasLtMatmul, kCustomCall, + kCustomKernel, kFft, kFor, kGemm, diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 53289bcf357427..0cbcd44d1e0d21 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -436,6 +436,16 @@ tsl::StatusOr LhloDialectEmitter::EmitFusionOp( HloInstruction::BackendConfigToRawString(backend_config)); fusion.setBackendConfigAttr(builder_.getStringAttr(backend_config_str)); + // For custom fusion backend config we also attach serialized version of the + // attached HLO computation. + if (backend_config.kind() == "__custom_fusion") { + std::string computation_str; + fusion_instr->fused_instructions_computation()->ToProto().SerializeToString( + &computation_str); + fusion->setAttr("__custom_fusion_computation", + builder_.getStringAttr(computation_str)); + } + // Fold GTE/Tuple pairs. // // Since the fused region refers to values in its parent region, we can't From 0922d15c80e32f435a1a8df1c4602d5e3865db15 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 19:30:02 -0800 Subject: [PATCH 316/381] [xla:gpu] Add kernel launch context to arguments packing to be able to query kernel occupancy data Query CUTLASS kernel occupancy for custom GEMMs PiperOrigin-RevId: 587199835 --- .../gpu/kernels/cutlass_gemm_kernel.cu.h | 13 +++++--- .../xla/xla/stream_executor/cuda/BUILD | 3 ++ .../cuda/cuda_command_buffer_test.cc | 3 +- .../xla/stream_executor/cuda/cuda_executor.cc | 8 ++++- .../xla/stream_executor/cuda/cuda_kernel.cc | 25 +++++++++++++-- third_party/xla/xla/stream_executor/gpu/BUILD | 2 ++ .../xla/xla/stream_executor/gpu/gpu_kernel.h | 32 +++++++++++++------ third_party/xla/xla/stream_executor/kernel.cc | 20 ++++++++++++ third_party/xla/xla/stream_executor/kernel.h | 28 +++++++++++++++- .../xla/xla/stream_executor/kernel_spec.h | 3 +- .../stream_executor_internal.h | 7 ++++ 11 files changed, 124 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index 0979e656c0387e..78e0ed36c53b77 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -152,7 +152,8 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, using PackedArgs = StatusOr>; - return [=](const se::KernelArgs &args) -> PackedArgs { + return [=](const se::KernelLaunchContext &ctx, + const se::KernelArgs &args) -> PackedArgs { auto *mem_args = Cast(&args); cutlass::Status can_implement = Kernel::can_implement(problem_size); @@ -189,11 +190,15 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, lda, ldb, ldc, ldc // strides ); - // TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate - // correct kernel occupancy using GpuRuntime. + // Query kernel API for SM occupancy for the launch dimensions. + TF_ASSIGN_OR_RETURN(int32_t sm_occupancy, + ctx.kernel()->GetMaxOccupiedBlocksPerCore( + ctx.threads(), args.number_of_shared_bytes())); + + // TODO(ezhulenev): Get number of SMs from DeviceDescription. // Convert CUTLASS operation arguments to a device kernel parameters. - Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10); + Params params(arguments, /*device_sms=*/128, sm_occupancy); // Optionally set up dynamic slice parameters to allow kernel adjust buffer // pointers passed via `params`. diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index d292bc57ace60b..a75d742f8c4b95 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -441,10 +441,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cuda_driver", + "@com_google_absl//absl/log", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:gpu_kernel_header", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/platform", + "@local_tsl//tsl/platform:statusor", ]), ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index d3064b4bdde443..2c40db007be26b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -114,7 +114,8 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { // Register a kernel with a custom arguments packing function that packs // device memory arguments into a struct with pointers. - MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgs& args) { + MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelLaunchContext&, + const KernelArgs& args) { auto bufs = Cast(&args)->device_memory_args(); auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; return PackKernelArgs(add, internal::Ptrs3{ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 500ac1789e495d..419590c30a08be 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -239,6 +240,10 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, cuda_kernel->gpu_function_ptr())); } + // Update CUDA kernel properties after it was loaded in the CUDA context. + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_gpu_context(context_); + // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); @@ -464,7 +469,8 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, "Kernel is missing a custom arguments packing function for device " "memory arguments array"); - TF_ASSIGN_OR_RETURN(auto packed, pack(*device_mem)); + KernelLaunchContext ctx(&kernel, block_dims, thread_dims); + TF_ASSIGN_OR_RETURN(auto packed, pack(ctx, *device_mem)); return launch(*packed); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index 2840c0f8165e8f..464f2118128a4b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -13,7 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_kernel.h" +#include +#include + +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -30,9 +39,21 @@ CUfunc_cache GpuKernel::GetGpuCacheConfig() const { return CU_FUNC_CACHE_PREFER_EQUAL; default: LOG(FATAL) << "Unknown KernelCacheConfig" - << static_cast(preferred_cache_config_); + << static_cast(preferred_cache_config_); } } +tsl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + int32_t threads_per_block = threads.x * threads.y * threads.z; + VLOG(0) << "Get kernel block occupancy: " << name_ + << "; threads_per_block: " << threads_per_block + << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + + return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, + threads_per_block, + dynamic_shared_memory_bytes); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index ba8f01430baeb9..302c69579fdf10 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -223,10 +223,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gpu_driver_header", + ":gpu_types_header", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/platform", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index 09443a23259b58..7f8ea596902133 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -22,11 +22,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ +#include +#include +#include +#include + #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -35,10 +41,7 @@ namespace gpu { // KernelInterface. class GpuKernel : public internal::KernelInterface { public: - GpuKernel() - : gpu_function_(nullptr), - arity_(0), - preferred_cache_config_(KernelCacheConfig::kNoPreference) {} + GpuKernel() = default; // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. @@ -49,6 +52,9 @@ class GpuKernel : public internal::KernelInterface { void set_arity(unsigned arity) { arity_ = arity; } unsigned Arity() const override { return arity_; } + void set_name(std::string name) { name_ = std::move(name); } + void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; } + // Returns the GpuFunctionHandle value for passing to the CUDA API. GpuFunctionHandle AsGpuFunctionHandle() const { DCHECK(gpu_function_ != nullptr); @@ -79,12 +85,18 @@ class GpuKernel : public internal::KernelInterface { // CUfunc_cache. GpuFuncCachePreference GetGpuCacheConfig() const; + tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + private: - GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle. - unsigned arity_; // Number of formal parameters the kernel takes. + GpuContext* gpu_context_ = nullptr; // context where kernel is loaded + std::string name_; // kernel name + + GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes - // Preferred (but not required) cache configuration for this kernel. - KernelCacheConfig preferred_cache_config_; + // Preferred (but not required) cache configuration for this kernel + KernelCacheConfig preferred_cache_config_ = KernelCacheConfig::kNoPreference; }; // Given a platform-independent kernel datatype, returns the (const) internal diff --git a/third_party/xla/xla/stream_executor/kernel.cc b/third_party/xla/xla/stream_executor/kernel.cc index 3b06a32fc31c5c..4bf8f3987df9b6 100644 --- a/third_party/xla/xla/stream_executor/kernel.cc +++ b/third_party/xla/xla/stream_executor/kernel.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/kernel.h" +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/demangle.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -45,6 +47,18 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { shared_memory_bytes_ = shared_memory_bytes; } +//===----------------------------------------------------------------------===// +// KernelLaunchContext +//===----------------------------------------------------------------------===// + +KernelLaunchContext::KernelLaunchContext(const Kernel *kernel, BlockDim blocks, + ThreadDim threads) + : kernel_(kernel), blocks_(blocks), threads_(threads) {} + +//===----------------------------------------------------------------------===// +// Kernel +//===----------------------------------------------------------------------===// + Kernel::Kernel(Kernel &&from) : parent_(from.parent_), implementation_(std::move(from.implementation_)), @@ -74,6 +88,12 @@ KernelCacheConfig Kernel::GetPreferredCacheConfig() const { return implementation_->GetPreferredCacheConfig(); } +tsl::StatusOr Kernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + return implementation_->GetMaxOccupiedBlocksPerCore( + threads, dynamic_shared_memory_bytes); +} + void Kernel::set_name(absl::string_view name) { name_ = std::string(name); diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 9077f80955ed8e..1a3ad3b4775cb3 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -88,10 +88,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/launch_dim.h" #include "tsl/platform/statusor.h" namespace stream_executor { +class Kernel; class StreamExecutor; namespace internal { @@ -209,6 +211,25 @@ class KernelArgsPackedArrayBase : public KernelArgs { Kind kind() const final { return Kind::kPackedArray; } }; +//===----------------------------------------------------------------------===// +// KernelLaunchContext +//===----------------------------------------------------------------------===// + +// Properties of a kernel launch that might impact kernel arguments packing. +class KernelLaunchContext { + public: + KernelLaunchContext(const Kernel *kernel, BlockDim blocks, ThreadDim threads); + + const Kernel *kernel() const { return kernel_; } + BlockDim blocks() const { return blocks_; } + ThreadDim threads() const { return threads_; } + + private: + const Kernel *kernel_; + BlockDim blocks_; + ThreadDim threads_; +}; + //===----------------------------------------------------------------------===// // Kernel //===----------------------------------------------------------------------===// @@ -226,7 +247,7 @@ class Kernel { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgs &args)>; + const KernelLaunchContext &ctx, const KernelArgs &args)>; Kernel(Kernel &&from); @@ -268,6 +289,11 @@ class Kernel { // Gets the preferred cache configuration for a kernel. KernelCacheConfig GetPreferredCacheConfig() const; + // Returns the maximum number of blocks (per multiprocessor) occupied by the + // kernel given the number of threads per block and shared memory size. + tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const; + // Sets custom kernels arguments packing function for a kernel. void set_kernel_args_packing(KernelArgsPacking kernel_args_packing) { kernel_args_packing_ = std::move(kernel_args_packing); diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index 6144944306bef2..e49eae4ec322b4 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -62,6 +62,7 @@ limitations under the License. namespace stream_executor { class KernelArgs; // defined in kernel.h +class KernelLaunchContext; // defined in kernel.h class KernelArgsPackedArrayBase; // defined in kernel.h // Describes how to load a kernel on a target platform. @@ -262,7 +263,7 @@ class MultiKernelLoaderSpec { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgs &args)>; + const KernelLaunchContext &ctx, const KernelArgs &args)>; explicit MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing = nullptr); diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 8c17dbdc155603..9931e5b0984afd 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -99,6 +99,13 @@ class KernelInterface { // Gets the preferred cache configuration. virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; + // Returns the maximum number of blocks (per multiprocessor) occupied by the + // kernel given the number of threads per block and shared memory size. + virtual tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + return absl::UnimplementedError("Not Implemented"); + } + private: KernelInterface(const KernelInterface&) = delete; void operator=(const KernelInterface&) = delete; From 2c537a36e129cd23ab3d04335a58806675036b13 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 1 Dec 2023 22:33:10 -0800 Subject: [PATCH 317/381] [stream_executor] Replace usage of GetSubBuffer with GetSlice PiperOrigin-RevId: 587236905 --- .../xla/service/gpu/gpu_transfer_manager.cc | 3 +- .../service/gpu/runtime/topk_kernel_test.cc | 6 ++-- .../xla/xla/service/transfer_manager.cc | 3 +- .../stream_executor/gpu/redzone_allocator.cc | 30 +++++++++---------- 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index 5bd1c557e952a3..a695ac712dec59 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -123,8 +123,7 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, } auto buffer_8 = se::DeviceMemory(buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + auto metadata_buffer = buffer_8.GetSlice(offset, metadata_size); copies.push_back(std::make_pair(metadata_buffer, &device_sub_shape)); return OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc index 32d776574cbae9..ace5363adcd5ef 100644 --- a/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/topk_kernel_test.cc @@ -112,8 +112,7 @@ TEST_P(TopkTest, TopKFloat) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), - executor->GetSubBuffer(&output_values, k * i, k), + stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), k * sizeof(T)); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); @@ -153,8 +152,7 @@ TEST_P(TopkTest, TopKPackedNegative) { std::vector got(k); ASSERT_TRUE(stream.BlockHostUntilDone().ok()); for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), - executor->GetSubBuffer(&output_values, k * i, k), + stream.ThenMemcpy(got.data(), output_values.GetSlice(k * i, k), k * sizeof(T)); std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); std::sort(slice.begin(), slice.end(), std::greater()); diff --git a/third_party/xla/xla/service/transfer_manager.cc b/third_party/xla/xla/service/transfer_manager.cc index 2f9cd018e375f9..395e4086fea203 100644 --- a/third_party/xla/xla/service/transfer_manager.cc +++ b/third_party/xla/xla/service/transfer_manager.cc @@ -163,8 +163,7 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, return InvalidArgument("Dynamic shape metadata size should not be 0"); } auto buffer_8 = se::DeviceMemory(buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + auto metadata_buffer = buffer_8.GetSlice(offset, metadata_size); TF_ASSIGN_OR_RETURN( auto metadata, TransferArrayFromDevice( diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index ceee5503e29b73..c5632ccf403102 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -87,23 +87,22 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( static_assert(sizeof(uint8_t) == 1, "Unexpected size"); DeviceMemory allocated_buffer_memory(*allocated_buffer); - DeviceMemory lhs_redzone = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, 0, redzone_size_); + DeviceMemory lhs_redzone = + allocated_buffer_memory.GetSlice(0, redzone_size_); - DeviceMemory data_chunk = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_, byte_size); + DeviceMemory data_chunk = + allocated_buffer_memory.GetSlice(redzone_size_, byte_size); // Split up the RHS redzone into two pieces: // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by // - redzone_size_ bytes. // We do this because Stream::ThenMemset32 requires the buffer address and // size to be aligned to 4 bytes. - DeviceMemory rhs_redzone_slop = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop); + DeviceMemory rhs_redzone_slop = + allocated_buffer_memory.GetSlice(redzone_size_ + byte_size, rhs_slop); - DeviceMemory rhs_redzone_nonslop = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop, - redzone_size_); + DeviceMemory rhs_redzone_nonslop = allocated_buffer_memory.GetSlice( + redzone_size_ + byte_size + rhs_slop, redzone_size_); uint8_t pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, redzone_pattern_}; @@ -260,7 +259,6 @@ static tsl::StatusOr CheckRedzonesForBuffer( const DeviceMemory& out_param, const ComparisonKernelT& comparison_kernel, int64_t user_allocation_size, uint64_t redzone_size, uint8_t redzone_pattern) { - StreamExecutor* executor = stream->parent(); int64_t rhs_slop = RoundUpToNearest(user_allocation_size, kRhsRedzoneAlign) - user_allocation_size; @@ -268,14 +266,14 @@ static tsl::StatusOr CheckRedzonesForBuffer( DeviceMemory buffer_uint8(memory); DeviceMemory lhs_redzone = - executor->GetSubBuffer(&buffer_uint8, 0, - /*element_count=*/redzone_size); + buffer_uint8.GetSlice(0, + /*element_count=*/redzone_size); DeviceMemory user_allocation = - executor->GetSubBuffer(&buffer_uint8, redzone_size, - /*element_count=*/user_allocation_size); + buffer_uint8.GetSlice(redzone_size, + /*element_count=*/user_allocation_size); DeviceMemory rhs_redzone = - executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size, - /*element_count=*/redzone_size + rhs_slop); + buffer_uint8.GetSlice(redzone_size + user_allocation_size, + /*element_count=*/redzone_size + rhs_slop); TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param, comparison_kernel)); From 47c1593649557d28acfa133a52f2fe5035fe3e9d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 2 Dec 2023 01:02:12 -0800 Subject: [PATCH 318/381] compat: Update forward compatibility horizon to 2023-12-02 PiperOrigin-RevId: 587260332 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 1ee91ed50267a4..7e45ae9aa8593e 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 1) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 2) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 5d94df816394d23c7183af3cb38eea3f16e4bd31 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 2 Dec 2023 01:02:19 -0800 Subject: [PATCH 319/381] Update GraphDef version to 1698. PiperOrigin-RevId: 587260364 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index bc01dd0998da44..7e322009d2eb9c 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1697 // Updated: 2023/12/1 +#define TF_GRAPH_DEF_VERSION 1698 // Updated: 2023/12/2 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 1fbe954cecc4ff8187dda75c4406689f2d4d909e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 2 Dec 2023 04:30:49 -0800 Subject: [PATCH 320/381] [CPUExecutable] Add helper to get the name of the main entry point function generated cpu code. PiperOrigin-RevId: 587288300 --- third_party/xla/xla/service/cpu/cpu_executable.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index dcf43b56e94231..55dcaa4f161185 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -177,6 +177,8 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } + const std::string& module_name() const { return module_name_; } + static int64_t ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. From 7c91223105dbd7706eb56fda14bdc1d7247ffe48 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 2 Dec 2023 21:16:56 -0800 Subject: [PATCH 321/381] [XLA] Use operator[] instead of at() for InlinedVector and vector. PiperOrigin-RevId: 587403877 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 4 +-- third_party/xla/xla/layout.h | 22 +++++++-------- third_party/xla/xla/permutation_util.cc | 2 +- .../cpu/tracked_tfrt_cpu_device_buffer.cc | 2 +- .../service/dynamic_dimension_inference.cc | 4 +-- .../xla/service/hlo_replication_analysis.cc | 4 +-- .../xla/service/logical_buffer_analysis.cc | 2 +- .../xla/service/tuple_points_to_analysis.cc | 6 ++-- third_party/xla/xla/shape.cc | 2 +- third_party/xla/xla/shape.h | 28 +++++++++---------- 10 files changed, 37 insertions(+), 39 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 5370489b0df6e3..a35ee7485ee0f6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2398,12 +2398,12 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { } const HloInstruction* HloInstruction::operand(int64_t i) const { - return operands_.at(i); + return operands_[i]; } HloInstruction* HloInstruction::mutable_operand(int64_t i) { CHECK(operands_[i] != nullptr); - return operands_.at(i); + return operands_[i]; } int64_t HloInstruction::operand_index(const HloInstruction* target) const { diff --git a/third_party/xla/xla/layout.h b/third_party/xla/xla/layout.h index c64806c0b2644d..2a201a420dc1f6 100644 --- a/third_party/xla/xla/layout.h +++ b/third_party/xla/xla/layout.h @@ -58,7 +58,7 @@ class Tile { std::string ToString() const; // Returns the bound of the tile in the given dimension index. - int64_t dimension(int i) const { return dimensions_.at(i); } + int64_t dimension(int i) const { return dimensions_[i]; } // Returns the dimensions of the tile. absl::Span dimensions() const { return dimensions_; } @@ -203,10 +203,10 @@ class Layout { // Methods for accessing the DimLevelType array. int dim_level_types_size() const { return dim_level_types_.size(); } DimLevelType dim_level_type(int index) const { - return dim_level_types_.at(index); + return dim_level_types_[index]; } Layout& set_dim_level_type(int index, DimLevelType dim_level_type) { - dim_level_types_.at(index) = dim_level_type; + dim_level_types_[index] = dim_level_type; return *this; } Layout& add_dim_level_type(DimLevelType dim_level_type) { @@ -224,9 +224,9 @@ class Layout { // Methods for accessing the dim_unique array. int dim_unique_size() const { return dim_unique_.size(); } - bool dim_unique(int index) const { return dim_unique_.at(index); } + bool dim_unique(int index) const { return dim_unique_[index]; } Layout& set_dim_unique(int index, bool unique) { - dim_unique_.at(index) = unique; + dim_unique_[index] = unique; return *this; } Layout& add_dim_unique(bool unique) { @@ -244,9 +244,9 @@ class Layout { // Methods for accessing the dim_ordered array. int dim_ordered_size() const { return dim_ordered_.size(); } - bool dim_ordered(int index) const { return dim_ordered_.at(index); } + bool dim_ordered(int index) const { return dim_ordered_[index]; } Layout& set_dim_ordered(int index, bool ordered) { - dim_ordered_.at(index) = ordered; + dim_ordered_[index] = ordered; return *this; } Layout& add_dim_ordered(bool ordered) { @@ -264,9 +264,9 @@ class Layout { // Methods for accessing the minor-to-major array. int minor_to_major_size() const { return minor_to_major_.size(); } - int64_t minor_to_major(int index) const { return minor_to_major_.at(index); } + int64_t minor_to_major(int index) const { return minor_to_major_[index]; } Layout& set_minor_to_major(int index, int64_t value) { - minor_to_major_.at(index) = value; + minor_to_major_[index] = value; return *this; } Layout& add_minor_to_major(int64_t value) { @@ -286,8 +286,8 @@ class Layout { // Methods for accessing the tile field. int64_t tiles_size() const { return tiles_.size(); } - const Tile& tiles(int index) const { return tiles_.at(index); } - Tile* mutable_tiles(int index) { return &tiles_.at(index); } + const Tile& tiles(int index) const { return tiles_[index]; } + Tile* mutable_tiles(int index) { return &tiles_[index]; } Tile* add_tiles() { tiles_.push_back(Tile()); return &tiles_.back(); diff --git a/third_party/xla/xla/permutation_util.cc b/third_party/xla/xla/permutation_util.cc index e28c4bf89fbdd4..8e857fc8975c35 100644 --- a/third_party/xla/xla/permutation_util.cc +++ b/third_party/xla/xla/permutation_util.cc @@ -37,7 +37,7 @@ std::vector InversePermutation( DCHECK(IsPermutation(input_permutation)); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { - output_permutation.at(input_permutation.at(i)) = i; + output_permutation[input_permutation[i]] = i; } return output_permutation; } diff --git a/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc index 5d327e57cb4018..2b5487e6b6850d 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc @@ -133,7 +133,7 @@ void TrackedTfrtCpuDeviceBuffer::AddUsageEvents( if (usage_events_.size() >= 1024) { int i = 0; while (i < usage_events_.size()) { - auto& event = usage_events_.at(i); + auto& event = usage_events_[i]; if (event.IsAvailable()) { using std::swap; swap(event, usage_events_.back()); diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc index 5e76198b462f9f..6a1e36488e34fb 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc @@ -1077,7 +1077,7 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( const Shape& subshape = ShapeUtil::GetSubshape(hlo->shape(), index); auto* element = dynamic_sizes.mutable_element(index); element->resize(subshape.rank(), nullptr); - element->at(dimension) = dynamic_size; + (*element)[dimension] = dynamic_size; return OkStatus(); })); dynamic_sizes.ForEachElement([&](const ShapeIndex& index, const auto& sizes) { @@ -1655,7 +1655,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( auto* leaf_dynamic_sizes = dynamic_sizes.mutable_element(reduce_window_result_index); leaf_dynamic_sizes->resize(subshape.rank(), nullptr); - leaf_dynamic_sizes->at(dimension) = dynamic_size; + (*leaf_dynamic_sizes)[dimension] = dynamic_size; }); return OkStatus(); diff --git a/third_party/xla/xla/service/hlo_replication_analysis.cc b/third_party/xla/xla/service/hlo_replication_analysis.cc index 70b73ed0798922..d7505cecc0e303 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis.cc +++ b/third_party/xla/xla/service/hlo_replication_analysis.cc @@ -421,13 +421,13 @@ Status HloReplicationAnalysis::ComputeHloReplication() { if (replication) { // If parameter replication status has been set explicitly, use that // instead. - if (!cross_partition_spmd_ && replication->at(leaf_index)) { + if (!cross_partition_spmd_ && (*replication)[leaf_index]) { // Setting parameter replication status for replicas in // non cross-partition spmd mode. *shape_tree.mutable_element(index) = HloReplication::ReplicatedOnAllDevices(); } - if (cross_partition_spmd_ && !replication->at(leaf_index)) { + if (cross_partition_spmd_ && !(*replication)[leaf_index]) { // Setting paramemter replication status for partitions in // cross-partition spmd mode. *shape_tree.mutable_element(index) = diff --git a/third_party/xla/xla/service/logical_buffer_analysis.cc b/third_party/xla/xla/service/logical_buffer_analysis.cc index 15afe61bea624f..db661a4485f0f8 100644 --- a/third_party/xla/xla/service/logical_buffer_analysis.cc +++ b/third_party/xla/xla/service/logical_buffer_analysis.cc @@ -78,7 +78,7 @@ Status LogicalBufferAnalysis::Analyze() { } LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const { - return *logical_buffers_.at(id); + return *logical_buffers_[id]; } LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction, diff --git a/third_party/xla/xla/service/tuple_points_to_analysis.cc b/third_party/xla/xla/service/tuple_points_to_analysis.cc index 612fb488b26054..1d6e31bfd5584d 100644 --- a/third_party/xla/xla/service/tuple_points_to_analysis.cc +++ b/third_party/xla/xla/service/tuple_points_to_analysis.cc @@ -326,7 +326,7 @@ Status TuplePointsToAnalysis::HandleAsyncStart(HloInstruction* async_start) { [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) { if (target_index.size() >= 2 && target_index.front() == 0) { const PointsToSet& operand_points_to_set = - GetPointsToSet(async_start->operand(target_index.at(1))); + GetPointsToSet(async_start->operand(target_index[1])); ShapeIndex source_index(target_index.begin() + 2, target_index.end()); *buffers = operand_points_to_set.element(source_index); for (HloInstruction* tuple : @@ -645,7 +645,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( const TuplePointsToAnalysis::BufferAliasVector& TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const { - return logical_buffer_aliases_.at(buffer.id()); + return logical_buffer_aliases_[buffer.id()]; } const TuplePointsToAnalysis::BufferDefinitionVector& @@ -719,7 +719,7 @@ std::string TuplePointsToAnalysis::ToString() const { absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); - for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { + for (const BufferAlias& alias : logical_buffer_aliases_[b->id()]) { absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 0ad4897f320e9f..f074915d445f3b 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -159,7 +159,7 @@ void Shape::DeleteDimension(int64_t dim_to_delete) { } const Shape& Shape::tuple_shapes(int index) const { - return tuple_shapes_.at(index); + return tuple_shapes_[index]; } Shape* Shape::add_tuple_shapes() { diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index b2828b429e586c..c641c00f27451a 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -104,18 +104,18 @@ class Shape { // Returns true if the given dimension is unbounded dynamic. bool is_unbounded_dynamic_dimension(int dimension) const { - return dimensions_.at(dimension) == kUnboundedSize; + return dimensions_[dimension] == kUnboundedSize; } // Sets a given dimension as unbounded dynamic. void set_unbounded_dynamic_dimension(int dimension) { dynamic_dimensions_[dimension] = true; - dimensions_.at(dimension) = kUnboundedSize; + dimensions_[dimension] = kUnboundedSize; } // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { - return dynamic_dimensions_.at(dimension); + return dynamic_dimensions_[dimension]; } // Sets whether or not the given dimension is dynamically-sized. @@ -149,18 +149,16 @@ class Shape { // Methods for accessing the dimensions array. int dimensions_size() const { return dimensions_.size(); } - int64_t dimensions(int index) const { return dimensions_.at(index); } + int64_t dimensions(int index) const { return dimensions_[index]; } int64_t dimensions_minor(int index) const { CHECK(has_layout()); - return dimensions_.at(layout_->minor_to_major(index)); - } - void set_dimensions(int index, int64_t value) { - dimensions_.at(index) = value; + return dimensions_[layout_->minor_to_major(index)]; } + void set_dimensions(int index, int64_t value) { dimensions_[index] = value; } void set_dimensions_minor(int index, int64_t value) { CHECK(has_layout()); - dimensions_.at(layout_->minor_to_major(index)) = value; + dimensions_[layout_->minor_to_major(index)] = value; } void add_dimensions(int64_t value) { dimensions_.push_back(value); @@ -179,7 +177,7 @@ class Shape { // tuple shapes. int tuple_shapes_size() const { return tuple_shapes_.size(); } const Shape& tuple_shapes(int index) const; - Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } + Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_[index]; } Shape* add_tuple_shapes(); void clear_tuple_shapes() { tuple_shapes_.clear(); } const std::vector& tuple_shapes() const { return tuple_shapes_; } @@ -372,8 +370,8 @@ class ProgramShape { // Methods for accessing and manipulating the Shape of the parameters. int parameters_size() const { return parameters_.size(); } - const Shape& parameters(int index) const { return parameters_.at(index); } - Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + const Shape& parameters(int index) const { return parameters_[index]; } + Shape* mutable_parameters(int index) { return ¶meters_[index]; } Shape* add_parameters() { parameters_.emplace_back(); return ¶meters_.back(); @@ -389,13 +387,13 @@ class ProgramShape { // Methods for accessing and manipulating the names of the parameters. int parameter_names_size() const { return parameter_names_.size(); } const std::string& parameter_names(int index) const { - return parameter_names_.at(index); + return parameter_names_[index]; } void set_parameter_names(int index, const std::string& value) { - parameter_names_.at(index) = value; + parameter_names_[index] = value; } std::string* mutable_parameter_names(int index) { - return ¶meter_names_.at(index); + return ¶meter_names_[index]; } void add_parameter_names(const std::string& value) { parameter_names_.push_back(value); From 9646aa0fed780a7521e4a388bc3b4f5148551932 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 3 Dec 2023 01:01:52 -0800 Subject: [PATCH 322/381] compat: Update forward compatibility horizon to 2023-12-03 PiperOrigin-RevId: 587433042 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 7e45ae9aa8593e..a9cfdd725e1526 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 2) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 3) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From c93b77f1740066e07ee56db6ad7f41ee42637686 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 3 Dec 2023 01:01:52 -0800 Subject: [PATCH 323/381] Update GraphDef version to 1699. PiperOrigin-RevId: 587433043 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 7e322009d2eb9c..a2beb3655c7ded 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1698 // Updated: 2023/12/2 +#define TF_GRAPH_DEF_VERSION 1699 // Updated: 2023/12/3 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From be3cb548ccdf2f4b63a0e785689920a18487025b Mon Sep 17 00:00:00 2001 From: Inho Seo Date: Sun, 3 Dec 2023 17:20:01 -0800 Subject: [PATCH 324/381] Remove `tpu` tag after quantizing a TPU model. As the output model of TF quantizer is basically for CPU, the tpu tag should be removed. PiperOrigin-RevId: 587555909 --- .../stablehlo/python/pywrap_quantization.cc | 4 ++++ .../tensorflow/python/pywrap_quantize_model.cc | 13 ++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index e8765c2620e141..6622d920f8aef8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -139,6 +139,10 @@ PYBIND11_MODULE(pywrap_quantization, m) { return post_calibrated_exported_model.status(); } + // Remove the `tpu` tag from the quantized saved model as it is for CPU. + // Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + tags.erase("tpu"); py_function_library.SaveExportedModel( dst_saved_model_path, *post_calibrated_exported_model, *calibrated_saved_model_path, tags, signature_def_map); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index db97a95f0aecd3..26bf067cf34eea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -79,12 +79,15 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { std::unordered_set tags; tags.insert(quantization_options.tags().begin(), quantization_options.tags().end()); - const absl::StatusOr exported_model = QuantizeQatModel(src_saved_model_path, signature_keys, tags, quantization_options, function_aliases); if (!exported_model.ok()) return exported_model.status(); + // Remove the `tpu` tag from the quantized saved model as it is for CPU. + // Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + tags.erase("tpu"); py_function_library.SaveExportedModel( dst_saved_model_path, *exported_model, src_saved_model_path, tags, signature_def_map); @@ -133,6 +136,10 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { QuantizePtqDynamicRange(src_saved_model_path, signature_keys, tags, quantization_options, function_aliases); + // Remove the `tpu` tag from the quantized saved model as it is for CPU. + // Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + tags.erase("tpu"); py_function_library.SaveExportedModel( dst_saved_model_path, *exported_model, src_saved_model_path, tags, signature_def_map); @@ -283,6 +290,10 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { if (!post_calibrated_exported_model.ok()) return post_calibrated_exported_model.status(); + // Remove the `tpu` tag from the quantized saved model as it is for CPU. + // Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + tags.erase("tpu"); py_function_library.SaveExportedModel( dst_saved_model_path, *post_calibrated_exported_model, *calibrated_saved_model_path, tags, signature_def_map); From b043b4e7ceb8e43f88e4af0627241fd33291630f Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Sun, 3 Dec 2023 18:12:12 -0800 Subject: [PATCH 325/381] Adds tests for TransposeOp in `xla_builder_test` and `shape_inference_test`. The tests test shape inference for unbounded dynamism. PiperOrigin-RevId: 587562742 --- .../xla/xla/client/xla_builder_test.cc | 16 ++++++++++++ .../xla/xla/service/shape_inference_test.cc | 26 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 8979a6f2b8e461..f46bf076b6f9a8 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1780,5 +1780,21 @@ TEST_F(XlaBuilderTest, UnboundedSubUnsupportedImplicitBroadcast) { HasSubstr("Unbounded dynamic shapes not supported")); } +TEST_F(XlaBuilderTest, UnboundedTranspose) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}"); + StatusOr expected = ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Transpose(Parameter(&b, 0, operand.value(), "operand"), + /*permutation=*/{4, 0, 3, 2, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanStringWithLayout(result) + << " expected: " << ShapeUtil::HumanStringWithLayout(expected.value()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 2068bb0ec3661b..8c439dda806cff 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -3779,6 +3779,32 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { } } +TEST_F(ShapeInferenceTest, UnboundedTranspose) { + auto operand = ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}"); + auto expected = ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}"); + ASSERT_IS_OK(operand.status()); + auto inferred_status = ShapeInference::InferTransposeShape( + operand.value(), /*dimensions=*/{4, 0, 3, 2, 1}); + ASSERT_IS_OK(expected.status()); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(ShapeInferenceTest, UnboundedTransposeRank1) { + auto operand = ParseShape("f32[?]"); + auto expected = ParseShape("f32[?]"); + ASSERT_IS_OK(operand.status()); + auto inferred_status = + ShapeInference::InferTransposeShape(operand.value(), /*dimensions=*/{0}); + ASSERT_IS_OK(expected.status()); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { auto lhs = ParseShape(GetParam()[0]); auto rhs = ParseShape(GetParam()[1]); From 567dad70d8d5db78436f921f75e279fa625a8a1b Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Sun, 3 Dec 2023 18:51:33 -0800 Subject: [PATCH 326/381] Add unbounded dynamism for SliceOp. This adds the necessary changes in shape inference following StableHLO rules for unbounded dynamism. Also added tests in `shape_inference_test.cc` and `xla_builder_test.cc`. PiperOrigin-RevId: 587568591 --- third_party/xla/xla/client/xla_builder_test.cc | 17 +++++++++++++++++ third_party/xla/xla/service/shape_inference.cc | 8 +++++--- .../xla/xla/service/shape_inference_test.cc | 13 +++++++++++++ third_party/xla/xla/shape.h | 6 ++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index f46bf076b6f9a8..3ec45849a7afc6 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -1748,6 +1748,23 @@ TEST_F(XlaBuilderTest, UnboundedPowUnsupportedImplicitBroadcast) { HasSubstr("Unbounded dynamic shapes not supported")); } +TEST_F(XlaBuilderTest, UnboundedSlice) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, <=3, ?]"); + StatusOr expected = ParseShape("f32[1, <=2, 3]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Slice(Parameter(&b, 0, operand.value(), "operand"), + /*start_indices=*/{0, 1, 2}, + /*limit_indices=*/{1, 3, 5}, + /*strides=*/{1, 1, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto result = module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + TEST_F(XlaBuilderTest, UnboundedSub) { XlaBuilder b(TestName()); StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index bf324e880c3963..ca0b2fe6247841 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -2652,11 +2652,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (start_index < 0) { return InvalidArgument("Negative start index to slice: %d.", start_index); } - if (limit_index > arg.dimensions(dimension)) { + int64_t dimension_size = arg.dimensions(dimension); + if (!arg.is_unbounded_dynamic_dimension(dimension) && + limit_index > dimension_size) { return error( StrFormat("limit index (%d) must be less than or equal to dimension " "size (%d)", - limit_index, arg.dimensions(dimension))); + limit_index, dimension_size)); } VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); @@ -2678,7 +2680,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (sizes[i] == 1) { continue; } - is_dynamic[i] = arg.is_dynamic_dimension(i); + is_dynamic[i] = arg.is_bounded_dynamic_dimension(i); } return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 8c439dda806cff..5e123c4546b513 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -3898,6 +3898,19 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { } } +TEST_F(ShapeInferenceTest, UnboundedSlice) { + StatusOr operand = ParseShape("f32[1, <=3, ?]"); + StatusOr expected = ParseShape("f32[1, <=2, 3]"); + ASSERT_IS_OK(operand.status()); + StatusOr inferred_status = ShapeInference::InferSliceShape( + operand.value(), /*starts=*/{0, 1, 2}, /*limits=*/{1, 3, 5}, + /*strides=*/{1, 1, 1}); + ASSERT_IS_OK(expected.status()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) + << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { auto lhs = ParseShape(GetParam()[0]); auto rhs = ParseShape(GetParam()[1]); diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index c641c00f27451a..ad72a30bca2886 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -113,6 +113,12 @@ class Shape { dimensions_[dimension] = kUnboundedSize; } + // Returns true if the given dimension is bounded dynamic. + bool is_bounded_dynamic_dimension(int dimension) const { + return is_dynamic_dimension(dimension) && + !is_unbounded_dynamic_dimension(dimension); + } + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_[dimension]; From 6261f161aba7aade1844013ceb06cda8abc3dd2e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sun, 3 Dec 2023 19:54:02 -0800 Subject: [PATCH 327/381] [xla:gpu] Add se::DeviceDescription to custom fusion and fusion matcher Pass device description to custom fusion patterns and kernels to be able to do device-dependent fusions (i.e. not all CUTLASS layouts have the same support on all devices) PiperOrigin-RevId: 587577073 --- third_party/xla/xla/service/gpu/BUILD | 3 +++ .../xla/service/gpu/custom_fusion_rewriter.cc | 6 ++++-- .../xla/service/gpu/custom_fusion_rewriter.h | 5 ++++- .../gpu/custom_fusion_rewriter_test.cc | 8 ++++++-- .../xla/xla/service/gpu/gpu_compiler.cc | 3 ++- .../xla/service/gpu/ir_emitter_unnested.cc | 3 ++- third_party/xla/xla/service/gpu/kernels/BUILD | 4 ++++ .../xla/service/gpu/kernels/custom_fusion.h | 5 +++-- .../gpu/kernels/custom_fusion_pattern.cc | 5 +++-- .../gpu/kernels/custom_fusion_pattern.h | 7 +++++-- .../kernels/cutlass_gemm_custom_kernel.cu.cc | 19 ++++++++++--------- .../gpu/kernels/cutlass_gemm_custom_kernel.h | 10 +++++----- .../cutlass_gemm_custom_kernel_test.cc | 6 +++--- .../gpu/kernels/cutlass_gemm_fusion.cc | 15 ++++++++++----- .../service/gpu/kernels/cutlass_gemm_fusion.h | 10 +++++++--- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 10 +++++++--- .../gpu/kernels/cutlass_gemm_kernel.cu.h | 7 +++---- .../xla/service/gpu/runtime/kernel_launch.cc | 3 ++- 18 files changed, 83 insertions(+), 46 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 7c3c92b97114ce..4f15bf6f8f9c19 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2854,6 +2854,7 @@ cc_library( "//xla/service:hlo_pass", "//xla/service/gpu/kernels:custom_fusion_library", "//xla/service/gpu/kernels:custom_fusion_pattern", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2873,7 +2874,9 @@ xla_cc_test( deps = [ ":custom_fusion_rewriter", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/kernels:custom_fusion_pattern", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc index 037446811d52aa..df279f901beb71 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -32,14 +32,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla::gpu { CustomFusionRewriter::CustomFusionRewriter( + const se::DeviceDescription* device, const CustomFusionPatternRegistry* patterns) - : patterns_(patterns) {} + : device_(device), patterns_(patterns) {} // Returns instructions that have to become custom fusion parameters. Returns an // error if matched pattern can't be outlined as a fusion. @@ -144,7 +146,7 @@ StatusOr CustomFusionRewriter::Run( // Collect all potential custom fusion matches in the module. for (HloComputation* computation : module->computations()) { for (HloInstruction* instr : computation->instructions()) { - auto matched = patterns_->Match(instr); + auto matched = patterns_->Match(*device_, instr); matches.insert(matches.end(), matched.begin(), matched.end()); } } diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h index 1e5e643a4a4c13..dd0c641873efb1 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -61,7 +62,8 @@ namespace xla::gpu { // class CustomFusionRewriter : public HloModulePass { public: - explicit CustomFusionRewriter(const CustomFusionPatternRegistry* patterns = + explicit CustomFusionRewriter(const se::DeviceDescription* device, + const CustomFusionPatternRegistry* patterns = CustomFusionPatternRegistry::Default()); absl::string_view name() const override { return "custom-fusion-rewriter"; } @@ -71,6 +73,7 @@ class CustomFusionRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: + const se::DeviceDescription* device_; const CustomFusionPatternRegistry* patterns_; }; diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc index 55268442fc48a7..84252a0fa2e3ff 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc @@ -21,7 +21,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -33,7 +35,8 @@ namespace xla::gpu { class SimpleGemmPattern : public CustomFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override { + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override { if (auto* dot = DynCast(instr)) { CustomFusionConfig config; config.set_name("simple_gemm"); @@ -80,7 +83,8 @@ TEST_F(CustomFusionRewriterTest, SimpleGemm) { CustomFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index e6699a7c872540..b0955f0a2c238d 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1213,7 +1213,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // heuristic, so we can mix and match various Gemm implementations based // on projected (measured) performance. if (debug_options.xla_gpu_enable_custom_fusions()) { - pipeline.AddPass(); + pipeline.AddPass( + &gpu_target_config.device_description); } // Rewrite GEMMs into custom calls. diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index aca82442de5e9c..273137a053f18d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -3338,7 +3338,8 @@ StatusOr IrEmitterUnnested::EmitCustomFusion( // Load custom kernels that can implement a fusion computation. TF_ASSIGN_OR_RETURN( std::vector kernels, - custom_fusion->LoadKernels(fusion->fused_instructions_computation())); + custom_fusion->LoadKernels(ir_emitter_context_->gpu_device_info(), + fusion->fused_instructions_computation())); // This should never happen, it means that compilation pipeline created a // fusion operation that is not supported by a given custom fusion. diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 4e96b185ee0301..efef54862054d7 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -24,6 +24,7 @@ cc_library( "//xla:status", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -43,6 +44,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", ], ) @@ -91,6 +93,7 @@ cc_library( # "//xla:xla_data_proto_cc", # "//xla/hlo/ir:hlo", # "//xla/service:pattern_matcher", +# "//xla/stream_executor:device_description", # "@local_tsl//tsl/platform:errors", # "@local_tsl//tsl/platform:logging", # "@local_tsl//tsl/platform:statusor", @@ -113,6 +116,7 @@ cc_library( # "//xla:error_spec", # "//xla:literal_util", # "//xla/service/gpu:custom_fusion_rewriter", +# "//xla/service/gpu:gpu_device_info_for_tests", # "//xla/tests:hlo_test_base", # "@local_tsl//tsl/platform:test", # "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h index 9311e43e630402..51425d767f5f5d 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h @@ -26,10 +26,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/service/gpu/kernels/custom_fusion.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/status.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/logging.h" namespace xla::gpu { @@ -98,8 +98,9 @@ class CustomFusion { public: virtual ~CustomFusion() = default; - // Loads kernels implementing `hlo_computation`. + // Loads kernels implementing `hlo_computation` optimized for a given device. virtual StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const = 0; }; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc index db4d5d21409189..811b29ffb198e7 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -29,10 +30,10 @@ CustomFusionPatternRegistry* CustomFusionPatternRegistry::Default() { } std::vector CustomFusionPatternRegistry::Match( - HloInstruction* instr) const { + const se::DeviceDescription& device, HloInstruction* instr) const { std::vector matches; for (auto& pattern : patterns_) { - if (auto matched = pattern->TryMatch(instr); matched.has_value()) + if (auto matched = pattern->TryMatch(device, instr); matched.has_value()) matches.push_back(std::move(*matched)); } return matches; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h index 02123388e26004..cc90abccfae81c 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -48,7 +49,8 @@ class CustomFusionPattern { // TODO(ezhulenev): Today the last instruction defines custom fusion root // (results), however we need to add support for custom fusion that can return // intermediate result, and custom fusions that require an extra workspace. - virtual std::optional TryMatch(HloInstruction *instr) const = 0; + virtual std::optional TryMatch(const se::DeviceDescription &device, + HloInstruction *instr) const = 0; }; //===----------------------------------------------------------------------===// @@ -61,7 +63,8 @@ class CustomFusionPatternRegistry { // global static registry. static CustomFusionPatternRegistry *Default(); - std::vector Match(HloInstruction *instr) const; + std::vector Match( + const se::DeviceDescription &device, HloInstruction *instr) const; void Add(std::unique_ptr pattern); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc index 96f9d7a17f0b0c..d6867d8cf8f53c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cu.cc @@ -31,12 +31,14 @@ namespace xla::gpu::kernel::gemm_universal { template static StatusOr LoadCutlassGemmUniversal( std::string name, int32_t m, int32_t n, int32_t k, - const ArgsIndices& indices, const DynamicSliceIndices& slices) { + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device) { using Kernel = typename Gemm::GemmKernel; cutlass::gemm::GemmCoord problem_size = {m, n, k}; - auto packing = ArgsPacking(problem_size, indices, slices); + auto packing = + ArgsPacking(problem_size, indices, slices, device.core_count()); se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); spec.AddInProcessSymbol(GetKernelSymbol(), name); @@ -46,18 +48,17 @@ static StatusOr LoadCutlassGemmUniversal( sizeof(typename Kernel::SharedStorage)); } -StatusOr GetCutlassGemmKernel(std::string name, - PrimitiveType dtype, int32_t m, - int32_t n, int32_t k, - const ArgsIndices& indices, - const DynamicSliceIndices& slices) { +StatusOr GetCutlassGemmKernel( + std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device) { switch (dtype) { case PrimitiveType::F32: return LoadCutlassGemmUniversal( - std::move(name), m, n, k, indices, slices); + std::move(name), m, n, k, indices, slices, device); case PrimitiveType::BF16: return LoadCutlassGemmUniversal( - std::move(name), m, n, k, indices, slices); + std::move(name), m, n, k, indices, slices, device); default: return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index da971f93e18a8a..1a50d50b915fd5 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -21,15 +21,15 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/cutlass_gemm.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla::gpu::kernel::gemm_universal { -StatusOr GetCutlassGemmKernel(std::string name, - PrimitiveType dtype, int32_t m, - int32_t n, int32_t k, - const ArgsIndices& indices, - const DynamicSliceIndices& slices); +StatusOr GetCutlassGemmKernel( + std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device); } // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 02f5a80ce74a9c..efb4e066b95501 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -42,9 +42,9 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { se::Kernel gemm(executor); // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. - auto custom_kernel = - GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::F32, 4, 4, 4, - /*indices=*/{0, 1, 2}, /*slices=*/{}); + auto custom_kernel = GetCutlassGemmKernel( + "cutlass_gemm", PrimitiveType::F32, 4, 4, 4, + /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); int64_t length = 4 * 4; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index bb7787466e4cd4..6fb374742ada09 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/status.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -148,7 +149,7 @@ static StatusOr MatchGemmWithDynamicUpdateSlice( //===----------------------------------------------------------------------===// std::optional CutlassGemmPattern::TryMatch( - HloInstruction* instr) const { + const se::DeviceDescription& device, HloInstruction* instr) const { auto* dot = DynCast(instr); if (!dot) return std::nullopt; @@ -162,7 +163,7 @@ std::optional CutlassGemmPattern::TryMatch( std::optional CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( - HloInstruction* instr) const { + const se::DeviceDescription& device, HloInstruction* instr) const { auto* update_slice = DynCast(instr); if (!update_slice) return std::nullopt; @@ -176,7 +177,8 @@ CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( } std::optional -CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { +CutlassGemmWithUpcastPattern::TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const { auto* dot = DynCast(instr); if (!dot) return std::nullopt; @@ -200,6 +202,7 @@ CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { class CutlassGemmFusion : public CustomFusion { public: StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const final { auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) { @@ -229,7 +232,7 @@ class CutlassGemmFusion : public CustomFusion { TF_ASSIGN_OR_RETURN( auto kernel, kernel::gemm_universal::GetCutlassGemmKernel( - "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{})); + "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device)); return std::vector{std::move(kernel)}; } }; @@ -237,6 +240,7 @@ class CutlassGemmFusion : public CustomFusion { class CutlassGemmWithUpcastFusion : public CustomFusion { public: StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const final { auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) { @@ -264,6 +268,7 @@ class CutlassGemmWithUpcastFusion : public CustomFusion { class CutlassGemmWithDynamicUpdateSliceFusion : public CustomFusion { public: StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const final { auto* dus = DynCast( computation->root_instruction()); @@ -305,7 +310,7 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomFusion { TF_ASSIGN_OR_RETURN( auto kernel, kernel::gemm_universal::GetCutlassGemmKernel( "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, - k, args_indices, slices)); + k, args_indices, slices, device)); return std::vector{std::move(kernel)}; } }; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h index 42c92520a2789a..4477efecece0dc 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h @@ -20,26 +20,30 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { // Pattern matches simple row-major gemms to CUTLASS kernels. class CutlassGemmPattern : public CustomFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override; + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; }; // Pattern matches simple row-major gemms with dynamic-update-slice. class CutlassGemmWithDynamicUpdateSlicePattern : public CustomFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override; + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; }; // Pattern matches mixed dtype gemms when one of the operands is upcasted to an // accumulator (output) dtype, i.e. BF16 <= BF16 x S8. class CutlassGemmWithUpcastPattern : public CustomFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override; + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 64aace4fa50ace..489917eac04d1c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/literal_util.h" #include "xla/service/gpu/custom_fusion_rewriter.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -69,7 +70,8 @@ TEST_F(CutlassFusionTest, RowMajorGemm) { CustomFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -108,7 +110,8 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { CustomFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -157,7 +160,8 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { CustomFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index 78e0ed36c53b77..dd5cdeecf90d28 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -140,7 +140,8 @@ int32_t *SlicePtr(const se::KernelArgsDeviceMemoryArray *args, int64_t index) { template KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, const ArgsIndices &indices, - const DynamicSliceIndices &slices) { + const DynamicSliceIndices &slices, + int32_t device_sms) { using Accumulator = typename Gemm::ElementAccumulator; using Arguments = typename Gemm::Arguments; using Kernel = typename Gemm::GemmKernel; @@ -195,10 +196,8 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, ctx.kernel()->GetMaxOccupiedBlocksPerCore( ctx.threads(), args.number_of_shared_bytes())); - // TODO(ezhulenev): Get number of SMs from DeviceDescription. - // Convert CUTLASS operation arguments to a device kernel parameters. - Params params(arguments, /*device_sms=*/128, sm_occupancy); + Params params(arguments, device_sms, sm_occupancy); // Optionally set up dynamic slice parameters to allow kernel adjust buffer // pointers passed via `params`. diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc index 2702254d5890d4..8ffca574914c21 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc @@ -184,7 +184,8 @@ static StatusOr> CreateCustomKernel( // Load custom kernels that can implement a fusion computation. TF_ASSIGN_OR_RETURN(std::vector kernels, - custom_fusion->LoadKernels(computation.get())); + custom_fusion->LoadKernels( + executor->GetDeviceDescription(), computation.get())); // This should never happen, it means that compilation pipeline created a // fusion operation that is not supported by a given custom fusion. From b10022255becf6cc807406fc3a4a1b45630942d3 Mon Sep 17 00:00:00 2001 From: shawnwang18 <35983922+shawnwang18@users.noreply.github.com> Date: Sun, 3 Dec 2023 23:27:11 -0800 Subject: [PATCH 328/381] PR #7358: [XLA:GPU] Add command buffer memory Free command Imported from GitHub PR https://github.com/openxla/xla/pull/7358 This PR adds the memory Free command to command buffer. The Free command is used to free previously allocated memory regions through command buffer Allocate command. The Free command is constructed with BufferAllocation object. During runtime, there are two cases that the Free command needs to handle. case 1. The allocation that is going to be freed is allocated in the same command buffer, so the real allocation address is tracked by command buffer runtime, The record parameter (address of memory to be freed) of Free command is specified as LazyAllocationMarker, command buffer runtime will fetch real address from internal tracked allocation address map indexed by current allocation index. case2. Free the allocation that is allocated in other command buffer thunk, for this case, the real allocation address is not tracked by current command buffer, and the record parameter of Free commands needs to be specified as real address. Copybara import of the project: -- ba9793b0376cc7ebf82335a183a2c66460f0a4b0 by Shawn Wang : Add command buffer free command Merging this change closes #7358 PiperOrigin-RevId: 587611331 --- .../xla/xla/service/gpu/buffer_allocations.cc | 46 +++++++-- .../xla/xla/service/gpu/buffer_allocations.h | 41 +++++--- .../xla/xla/service/gpu/runtime3/BUILD | 1 + .../runtime3/command_buffer_allocations.cc | 11 +-- .../gpu/runtime3/command_buffer_allocations.h | 6 +- .../gpu/runtime3/command_buffer_cmd.cc | 72 ++++++++------ .../service/gpu/runtime3/command_buffer_cmd.h | 27 +++++- .../gpu/runtime3/command_buffer_thunk.cc | 3 +- .../gpu/runtime3/command_buffer_thunk.h | 6 ++ .../gpu/runtime3/command_buffer_thunk_test.cc | 94 ++++++++++++++++++- .../xla/xla/stream_executor/command_buffer.cc | 4 + .../xla/xla/stream_executor/command_buffer.h | 3 + .../xla/stream_executor/cuda/cuda_driver.cc | 9 ++ .../stream_executor/gpu/gpu_command_buffer.cc | 25 +++++ .../stream_executor/gpu/gpu_command_buffer.h | 2 + .../xla/xla/stream_executor/gpu/gpu_driver.h | 7 ++ .../stream_executor_internal.h | 4 + 17 files changed, 291 insertions(+), 70 deletions(-) diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.cc b/third_party/xla/xla/service/gpu/buffer_allocations.cc index 193586bc9a4807..fc55f0951436df 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/buffer_allocations.cc @@ -54,7 +54,23 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( BufferAllocation::Index buffer_index) const { CHECK_GE(buffer_index, 0); CHECK_LT(buffer_index, buffers_.size()); - return buffers_[buffer_index]; + se::DeviceMemoryBase base = buffers_[buffer_index]; + if (reinterpret_cast(base.opaque()) == kExternalAllocationMarker) { + if (!external_allocations_) { + LOG(ERROR) << "Does not have external allocations for buffer " + << buffer_index; + return se::DeviceMemoryBase(); + } + auto external_address = + external_allocations_->GetDeviceAddress(buffer_index); + if (external_address.ok()) { + return external_address.value(); + } + LOG(ERROR) << "External address for allocation" << buffer_index + << " is not allocated yet"; + return se::DeviceMemoryBase(); + } + return base; } se::DeviceMemoryBase& BufferAllocations::GetMutableDeviceAddress( @@ -74,14 +90,26 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( buffer_slice.size()); } -StatusOr BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice, - const ExternalAllocations& external_allocations) const { - // Check if base memory address is an external allocation. - se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); - return reinterpret_cast(base.opaque()) == kExternalAllocationMarker - ? external_allocations.GetDeviceAddress(buffer_slice) - : GetDeviceAddress(buffer_slice); +Status BufferAllocations::AddExternalAllocation( + BufferAllocation::Index index, se::DeviceMemoryBase memory) const { + if (external_allocations_ == nullptr) { + return InternalError( + "Calling external allocations, but no allocation tracker is provided" + "for allocation %d", + index); + } + return external_allocations_->AddAllocation(index, memory); +} + +Status BufferAllocations::EraseExternalAllocation( + BufferAllocation::Index index) const { + if (external_allocations_ == nullptr) { + return InternalError( + "Calling external allocations, but no allocation tracker is provided" + "for allocation %d", + index); + } + return external_allocations_->EraseAllocation(index); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.h b/third_party/xla/xla/service/gpu/buffer_allocations.h index 37d2eb8c2d2f54..940e111bf82b9f 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.h +++ b/third_party/xla/xla/service/gpu/buffer_allocations.h @@ -54,18 +54,29 @@ class BufferAllocations { public: virtual ~ExternalAllocations() = default; - // Return a device address for a given buffer slice. Returns error if + // Return a device address for a given buffer allocation. Returns error if // corresponding allocation is not yet allocated. virtual StatusOr GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const = 0; + BufferAllocation::Index index) const = 0; + + // Adds an external allocation for a given buffer index. Returns error if + // allocation already exists. + virtual Status AddAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) = 0; + + // Erases an external allocation for a given buffer index. Returns error if + // allocation does not exists. + virtual Status EraseAllocation(BufferAllocation::Index index) = 0; }; BufferAllocations(absl::Span buffers, int device_ordinal, - se::DeviceMemoryAllocator* memory_allocator) + se::DeviceMemoryAllocator* memory_allocator, + ExternalAllocations* external_allocations = nullptr) : buffers_(buffers.begin(), buffers.end()), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + external_allocations_(external_allocations) {} BufferAllocations(BufferAllocations&& other) = default; BufferAllocations& operator=(BufferAllocations&& other) = default; @@ -92,12 +103,12 @@ class BufferAllocations { se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; - // Finds an allocation for a given buffer slice, and if it happens to be an - // external allocation resolves it using user-provided external allocations. - // Returns error if external allocations do not have an address for a slice. - StatusOr GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice, - const ExternalAllocations& external_allocations) const; + // Add new allocation allocated by external allocator. + Status AddExternalAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) const; + + // Remove allocation freed by external allocator. + Status EraseExternalAllocation(BufferAllocation::Index index) const; // Tears down all buffers allocated by this object that are not in // `live_addresses`. @@ -121,12 +132,16 @@ class BufferAllocations { // indexed by Index. Each element can point to a temporary buffer, an // input buffer, or nullptr if no buffer is needed for that Index. - // a nullptr buffer with non-zero size buffer is assumed to be lazily - // allocated buffer, and will be allocated through command buffer Allocate - // command during runtime. + // a special address (se::kExternalAllocationMarker) with non-zero size buffer + // is assumed to be lazily allocated buffer, and will be allocated through + // command buffer Allocate command during runtime. std::vector buffers_; int device_ordinal_; se::DeviceMemoryAllocator* memory_allocator_; + + // For buffer address that marked as ExternalAllocations, tracks its real + // address here. + ExternalAllocations* external_allocations_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index 76a3e04dbc68f6..014ce6606e24bb 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -157,6 +157,7 @@ xla_test( srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), backends = ["gpu"], deps = [ + ":command_buffer_allocations", ":command_buffer_cmd", ":command_buffer_thunk", "//xla:shape_util", diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc index 0aa46ffec10f29..20a5c10e2de58b 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.cc @@ -28,16 +28,13 @@ limitations under the License. namespace xla::gpu { StatusOr CommandBufferAllocations::GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const { - auto base = allocs_.find(buffer_slice.index()); + BufferAllocation::Index index) const { + auto base = allocs_.find(index); if (base == allocs_.end()) { return absl::InternalError(absl::StrCat("Command buffer allocation #", - buffer_slice.index(), - " was not allocated")); + index, " was not allocated")); } - - char* ptr = static_cast(const_cast(base->second.opaque())); - return se::DeviceMemoryBase(ptr + buffer_slice.offset(), buffer_slice.size()); + return allocs_.at(index); } Status CommandBufferAllocations::AddAllocation(BufferAllocation::Index index, diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h index 3435dc0d69434c..d0db712a0a4a40 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_allocations.h @@ -31,16 +31,16 @@ namespace xla::gpu { class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { public: StatusOr GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const override; + BufferAllocation::Index index) const override; // Adds an external allocation for a given buffer index. Returns error if // allocation already exists. Status AddAllocation(BufferAllocation::Index index, - se::DeviceMemoryBase memory); + se::DeviceMemoryBase memory) override; // Erases an external allocation for a given buffer index. Returns error if // allocation does not exists. - Status EraseAllocation(BufferAllocation::Index index); + Status EraseAllocation(BufferAllocation::Index index) override; private: absl::flat_hash_map allocs_; diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index c9c6b505971486..4c3253b357456f 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/status.h" #include "xla/stream_executor/command_buffer.h" @@ -153,9 +154,7 @@ Status LaunchCmd::Record(const RecordParams& params, absl::InlinedVector buffers; for (const BufferAllocation::Slice& arg : args_) { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buf, - params.buffer_allocations->GetDeviceAddress( - arg, *params.command_buffer_allocations)); + se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); buffers.push_back(buf); } @@ -187,16 +186,14 @@ MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst, Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, se::CommandBuffer* command_buffer) { - VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << ", src=" << src_ - << ", num_bytes=" << num_bytes_; - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase dst, - params.buffer_allocations->GetDeviceAddress( - dst_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase src, - params.buffer_allocations->GetDeviceAddress( - src_, *params.command_buffer_allocations)); + se::DeviceMemoryBase dst = params.buffer_allocations->GetDeviceAddress(dst_); + se::DeviceMemoryBase src = params.buffer_allocations->GetDeviceAddress(src_); + VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << "(" + << reinterpret_cast(dst.opaque()) << ")" + << ", src=" << src_ << "(" << reinterpret_cast(src.opaque()) + << ")" + << ", num_bytes=" << num_bytes_; return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); } @@ -369,25 +366,45 @@ CommandBufferCmd::Slices WhileCmd::slices() { // AllocateCmd //===----------------------------------------------------------------------===// -AllocateCmd::AllocateCmd(BufferAllocation* allocation) +AllocateCmd::AllocateCmd(BufferAllocation allocation) : allocation_(allocation) {} Status AllocateCmd::Record(const RecordParams& params, se::CommandBuffer* command_buffer) { // Memory allocation address is returned on graph creation, and there is no // update operation - VLOG(5) << "AllocationCmd: index=" << allocation_->index(); + VLOG(2) << "AllocationCmd: index=" << allocation_.index(); TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - command_buffer->Allocate(allocation_->size())); + command_buffer->Allocate(allocation_.size())); + return params.buffer_allocations->AddExternalAllocation(allocation_.index(), + buffer); +} - TF_RETURN_IF_ERROR(params.command_buffer_allocations->AddAllocation( - allocation_->index(), buffer)); +CommandBufferCmd::Slices AllocateCmd::slices() { return {}; } - return OkStatus(); +//===----------------------------------------------------------------------===// +// FreeCmd +//===----------------------------------------------------------------------===// + +FreeCmd::FreeCmd(BufferAllocation allocation) : allocation_(allocation) {} + +Status FreeCmd::Record(const RecordParams& params, + se::CommandBuffer* command_buffer) { + VLOG(2) << "FreeCmd: index=" << allocation_.index(); + + se::DeviceMemoryBase address = + params.buffer_allocations->GetDeviceAddress(allocation_.index()); + + // Free is in the same command buffer + TF_RETURN_IF_ERROR(command_buffer->Free(address)); + + // Remove the buffer from external allocations. + return params.buffer_allocations->EraseExternalAllocation( + allocation_.index()); } -CommandBufferCmd::Slices AllocateCmd::slices() { return {}; } +CommandBufferCmd::Slices FreeCmd::slices() { return {}; } //===----------------------------------------------------------------------===// // GemmCmd @@ -419,15 +436,14 @@ Status GemmCmd::Record(const RecordParams& params, se::DeviceMemoryBase workspace(nullptr, 0); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs, - params.buffer_allocations->GetDeviceAddress( - lhs_buffer_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs, - params.buffer_allocations->GetDeviceAddress( - rhs_buffer_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase out, - params.buffer_allocations->GetDeviceAddress( - output_buffer_, *params.command_buffer_allocations)); + se::DeviceMemoryBase lhs = + params.buffer_allocations->GetDeviceAddress(lhs_buffer_); + + se::DeviceMemoryBase rhs = + params.buffer_allocations->GetDeviceAddress(rhs_buffer_); + + se::DeviceMemoryBase out = + params.buffer_allocations->GetDeviceAddress(output_buffer_); TF_ASSIGN_OR_RETURN( auto nested_buffer, diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 31be131338ecbe..48dcff6d27f6c5 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -62,7 +62,6 @@ class CommandBufferCmd { struct RecordParams { se::StreamExecutor* executor; const BufferAllocations* buffer_allocations; - CommandBufferAllocations* command_buffer_allocations; }; // Prepares a command for recording on a given executor. We split it into a @@ -310,16 +309,36 @@ class WhileCmd : public CommandBufferCmd { class AllocateCmd : public CommandBufferCmd { public: - explicit AllocateCmd(BufferAllocation* allocation); + AllocateCmd(BufferAllocation allocation); - // After calling this function, the allocated memory address is updated to + // After calling this function, the allocated memory is tracked in + // CommandBuffer object. Status Record(const RecordParams& params, se::CommandBuffer* command_buffer) override; Slices slices() override; private: - BufferAllocation* allocation_; + BufferAllocation allocation_; +}; + +//===----------------------------------------------------------------------===// +// FreeCmd +//===----------------------------------------------------------------------===// + +class FreeCmd : public CommandBufferCmd { + public: + FreeCmd(BufferAllocation allocation); + + // After calling this function, the allocated memory address for dst + // BufferAllocation is freed, no update is required. + Status Record(const RecordParams& params, + se::CommandBuffer* command_buffer) override; + + Slices slices() override; + + private: + BufferAllocation allocation_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc index 3638935f2c0c9e..e3b4ee21a51375 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/status.h" @@ -79,7 +80,7 @@ Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { absl::MutexLock lock(&cmd_buffer->mutex); CommandBufferCmd::RecordParams record_params = { - executor, params.buffer_allocations, &cmd_buffer->allocations}; + executor, const_cast(params.buffer_allocations)}; if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h index e5e2a3ed6abda1..156c1d4bbb8597 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk.h @@ -40,6 +40,12 @@ class CommandBufferThunk : public Thunk { Status Initialize(se::StreamExecutor*, ExecutableSource) override; Status ExecuteOnStream(const ExecuteParams& params) override; + // Return the allocation address that was lazilly allocated inside command + // buffer. This API is required when the buffers are allocated inside command + // buffer but will be consumed by non-command buffer operations. + StatusOr GetCommandBufferAllocationAddress( + const ExecuteParams& params, int64_t index); + private: // Command buffer instantiated on a `se::StreamExecutor` instance, and // auxiliary state required for efficient command buffer updates. diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 94ec906b84441d..247c0fdf8d4817 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime3/command_buffer_allocations.h" #include "xla/service/gpu/runtime3/command_buffer_cmd.h" #include "xla/service/gpu/thunk.h" #include "xla/service/service_executable_run_options.h" @@ -110,9 +112,11 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // 1. Allocates memory region "a" and "c" outside command buffer. // 2. Allocates memory region "b" inside command buffer. // 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer. -// 4. MemCopyDEviceToDevice from "b" to "c" inside command buffer. -// 5. Verify that region "c" has the same content as "a". -TEST(CommandBufferThunkTest, AllocateCmd) { + +// 4. MemCopyDeviceToDevice from "b" to "c" inside command buffer. +// 5. Free memory region "b" inside command buffer. +// 6. Verify that region "c" has the same content as "a". +TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { se::StreamExecutor* executor = CudaExecutor(); se::Stream stream(executor); @@ -132,9 +136,10 @@ TEST(CommandBufferThunkTest, AllocateCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(&alloc_b); + commands.Emplace(alloc_b); commands.Emplace(slice_b, slice_a, byte_length); commands.Emplace(slice_c, slice_b, byte_length); + commands.Emplace(alloc_b); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); @@ -142,11 +147,17 @@ TEST(CommandBufferThunkTest, AllocateCmd) { // Prepare arguments: a=42, b=0 se::DeviceMemory a = executor->AllocateArray(length, 0); stream.ThenMemset32(&a, 42, byte_length); + se::DeviceMemory b(se::DeviceMemoryBase( reinterpret_cast(BufferAllocations::kExternalAllocationMarker), byte_length)); se::DeviceMemory c = executor->AllocateArray(length, 0); - BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator()); + + std::unique_ptr external_allocation = + std::make_unique(); + + BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + external_allocation.get()); ServiceExecutableRunOptions run_options; Thunk::ExecuteParams params(run_options, allocations, &stream, {}); @@ -163,6 +174,79 @@ TEST(CommandBufferThunkTest, AllocateCmd) { ASSERT_EQ(dst, std::vector(4, 42)); } +// This test does the following operations: +// 1. Allocates memory region "a" and "c" outside command buffer. +// 2. Allocates memory region "b" inside command buffer thunk 1. +// 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer 1. +// 4. MemCopyDeviceToDevice from "b" to "c" inside command buffer 2. +// 5. Free memory region "b" inside command buffer 2. +// 6. Verify that region "c" has the same content as "a". +TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { + se::StreamExecutor* executor = CudaExecutor(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + // Prepare arguments: + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + + // =================Thunk 1================================= + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands1; + commands1.Emplace(alloc_b); + commands1.Emplace(slice_b, slice_a, byte_length); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk1(std::move(commands1), Thunk::ThunkInfo(nullptr)); + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + stream.ThenMemset32(&a, 42, byte_length); + se::DeviceMemory b(se::DeviceMemoryBase( + reinterpret_cast(BufferAllocations::kExternalAllocationMarker), + byte_length)); + se::DeviceMemory c = executor->AllocateArray(length, 0); + + std::unique_ptr external_allocation = + std::make_unique(); + + BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + external_allocation.get()); + + ServiceExecutableRunOptions run_options; + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk1.ExecuteOnStream(params)); + + // =================Thunk 2================================= + CommandBufferCmdSequence commands2; + commands2.Emplace(slice_c, slice_b, byte_length); + commands2.Emplace(alloc_b); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk2(std::move(commands2), Thunk::ThunkInfo(nullptr)); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk2.ExecuteOnStream(params)); + + // Copy `c` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), allocations.GetMutableDeviceAddress(2), + byte_length); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + TEST(CommandBufferThunkTest, LaunchCmd) { se::StreamExecutor* executor = CudaExecutor(); diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 551a667843d818..8980cb5b7d1faa 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -174,6 +174,10 @@ tsl::Status CommandBuffer::While(StreamExecutor* executor, std::move(body_builder)); } +tsl::Status CommandBuffer::Free(DeviceMemoryBase dst) { + return implementation_->Free(dst); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2867664854b513..dbb34a2ddc41cd 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -181,6 +181,9 @@ class CommandBuffer { // Adds a device memory allocation command to the command buffer. tsl::StatusOr Allocate(size_t bytes); + // This API free buffer that is allocated by Allocate command + tsl::Status Free(DeviceMemoryBase dst); + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. tsl::Status Finalize(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index e60e17d641ce6f..fdf5dd03ef48d1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -979,6 +979,15 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { return std::pair{params.dptr, params.bytesize}; } +/*static*/ tsl::Status GpuDriver::GraphAddMemFreeNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, + CUdeviceptr gpu_dst) { + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddMemFreeNode(node, graph, deps.data(), deps.size(), gpu_dst), + "Failed to add memory free node to a CUDA graph"); + return ::tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 643ae4879df369..d1cc3d11438511 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -305,6 +305,7 @@ tsl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { TF_RETURN_IF_ERROR(CheckNotFinalized()); + // Adds a new memory allocation node to the graph under construction. if (state_ == State::kCreate) { Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); @@ -338,6 +339,30 @@ tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { return UnsupportedStateError(state_); } +tsl::Status GpuCommandBuffer::Free(DeviceMemoryBase dst) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + // Adds a new memfree node to the graph under construction. + if (state_ == State::kCreate) { + Dependencies deps = GetDependencies(); + GpuGraphNodeHandle* node = &nodes_.emplace_back(); + GpuDevicePtr gpu_dptr = AsDevicePtr(dst); + TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemFreeNode( + node, graph_, absl::MakeSpan(deps), gpu_dptr)); + return tsl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // memfree node implemented through CUDA graph only free buffers that is + // allocated through memory alloc node, so buffer address will not change, + // no update is required. + update_state_.node_idx++; + return tsl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 5223e9d24e7cb5..8938e04c047785 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -65,6 +65,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { tsl::StatusOr Allocate(size_t bytes) override; + tsl::Status Free(DeviceMemoryBase dst) override; + tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, CommandBuffer::Builder then_builder) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 2dff75b93e1cef..d89edf4797fe3f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -531,6 +531,13 @@ class GpuDriver { static tsl::StatusOr> GraphGetMemAllocNodeParams(GpuGraphNodeHandle node); + // Create a memfree node and adds it to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1geb7cdce5d9be2d28d9428e74eb00fa53 + static tsl::Status GraphAddMemFreeNode(GpuGraphNodeHandle* node, + GpuGraphHandle graph, + absl::Span deps, + GpuDevicePtr gpu_dst); + // Creates a memcpy node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g674da6ab54a677f13e0e0e8206ff5073 static tsl::Status GraphAddMemcpyD2DNode(GpuContext* context, diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 9931e5b0984afd..c7bd1736a8314b 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -204,6 +204,10 @@ class CommandBufferInterface { CommandBuffer::Builder cond_builder, CommandBuffer::Builder body_builder) = 0; + // Adds a device memory free command to the command buffer, buffer is + // allocated in other command buffer, free through real address. + virtual tsl::Status Free(DeviceMemoryBase dst) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From ac947e98596f2c1f204bc774fa7261bb4509a6ec Mon Sep 17 00:00:00 2001 From: Xin Zhou Date: Sun, 3 Dec 2023 23:33:46 -0800 Subject: [PATCH 329/381] [XLA:TPU] Fix missing element_size_in_bits coping in Layout Assignment pass. PiperOrigin-RevId: 587612159 --- third_party/xla/xla/service/layout_assignment.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 3e5e64a51823e5..a79185682d7591 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -2229,6 +2229,8 @@ Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { result_shape.layout().tiles().begin(), result_shape.layout().tiles().end()); } + subshape->mutable_layout()->set_element_size_in_bits( + result_shape.layout().element_size_in_bits()); } }; xla::ShapeUtil::ForEachMutableSubshape( From ae5b2c4d0e1eb26846e50100ef3967beb54204dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 01:01:52 -0800 Subject: [PATCH 330/381] compat: Update forward compatibility horizon to 2023-12-04 PiperOrigin-RevId: 587627246 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index a9cfdd725e1526..4df5a7d140a7f6 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 3) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 4) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 1e44b26e3ea31df8dc26da0ea1237281eff0dcfb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 01:01:52 -0800 Subject: [PATCH 331/381] Update GraphDef version to 1700. PiperOrigin-RevId: 587627247 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index a2beb3655c7ded..107f4804f95df9 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1699 // Updated: 2023/12/3 +#define TF_GRAPH_DEF_VERSION 1700 // Updated: 2023/12/4 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 1ed9db08b51f5830cce01444c20177e65d002fd5 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 4 Dec 2023 01:49:21 -0800 Subject: [PATCH 332/381] [XLA] Implement `MHLO` <--> `StableHLO` roundtrip for `TopK` via a custom call. PiperOrigin-RevId: 587636918 --- .../hlo_legalize_to_stablehlo.cc | 10 ++++++++-- .../mhlo/hlo-legalize-to-stablehlo.mlir | 19 ++++++++++--------- .../mhlo/stablehlo-legalize-to-hlo.mlir | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 4774aac93d4b56..835ea7d2900e59 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -52,7 +52,7 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { // Please let us know if we missed something, and we'll recategorize them. if (isa(hloOp.getOperation())) { return true; } @@ -140,6 +140,12 @@ std::optional getPublicFeaturesNotInStablehlo(HloOpTy hloOp) { mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) return 1; } + // StableHLO doesn't support TopK yet. + // Proposal: https://github.com/openxla/stablehlo/pull/1593 + if constexpr (std::is_same::value) { + // Version 1: Initial version for TopK. + return 1; + } return std::nullopt; } @@ -566,7 +572,7 @@ void populateHloToStablehloPatterns(RewritePatternSet* patterns, #include "stablehlo/dialect/StablehloOps.cpp.inc" >(patterns, converter, context, allowExperimentalFeatures); - populateHloToStablehloCustomCallPatterns( + populateHloToStablehloCustomCallPatterns( patterns, converter, context, allowExperimentalFeatures); } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 0d06994182e37c..fa860a960abe29 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1543,7 +1543,16 @@ func.func @op_tanh(%arg0: tensor) -> tensor { func.return %0 : tensor } -// TopKOp aka mhlo.topk is unsupported at the moment (see negative test below). +// CHECK-LABEL: "op_topk" +func.func @op_topk(%arg0: tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) { + // CHECK: "stablehlo.custom_call"(%arg0) { + // CHECK-SAME: call_target_name = "mhlo.topk" + // CHECK-SAME{LITERAL}: mhlo.attributes = {k = 8 : i64, largest = true} + // CHECK-SAME{LITERAL}: mhlo.version = 1 : i64 + // CHECK-SAME: } : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<5x10xf32> -> (tensor<5x8xf32>, tensor<5x8xi32>) + func.return %0#0, %0#1 : tensor<5x8xf32>, tensor<5x8xi32> +} // CHECK-LABEL: "op_torch_index_select" func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { @@ -2057,14 +2066,6 @@ func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> ten // ----- -func.func @op_topk(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{failed to legalize operation 'mhlo.topk' that was explicitly marked illegal}} - %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - func.func @op_xla_rng_get_and_update_state() -> tensor<2xui64> { // expected-error@+1 {{failed to legalize operation 'mhlo.xla.rng_get_and_update_state' that was explicitly marked illegal}} %0 = "mhlo.xla.rng_get_and_update_state"() { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index e9ce4e3ce68b2e..ce40b9d6e19d13 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1942,3 +1942,17 @@ func.func @op_custom_call_botched_mhlo_backend_config_version(%arg0: tensor } : (tensor) -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: "op_topk_mhlo_v1" +func.func @op_topk_mhlo_v1(%arg0: tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) { + // CHECK: "mhlo.topk"(%arg0) {k = 8 : i64, largest = true} : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + %0:2 = "stablehlo.custom_call"(%arg0) { + backend_config = "", + call_target_name = "mhlo.topk", + mhlo.attributes = {k = 8 : i64, largest = true}, + mhlo.version = 1 : i64 + } : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + func.return %0#0, %0#1 : tensor<5x8xf32>, tensor<5x8xi32> +} From 2fc145e8416520224497ea96bfacd56bf9602f1e Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 4 Dec 2023 02:00:38 -0800 Subject: [PATCH 333/381] Allow memcpy fusions with multiple copies. We horizontally fuse copies, but then revert back to the loop emitter later. That doesn't seem necessary though. PiperOrigin-RevId: 587639218 --- .../xla/xla/service/gpu/fusions/copy.cc | 24 +++++---- .../xla/xla/service/gpu/fusions/copy.h | 13 +++-- .../xla/xla/service/gpu/fusions/fusions.cc | 51 ++++++++++++++----- .../xla/xla/tests/multioutput_fusion_test.cc | 30 ----------- 4 files changed, 60 insertions(+), 58 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/copy.cc b/third_party/xla/xla/service/gpu/fusions/copy.cc index a04885f9f0a255..4f9a99405768e9 100644 --- a/third_party/xla/xla/service/gpu/fusions/copy.cc +++ b/third_party/xla/xla/service/gpu/fusions/copy.cc @@ -25,17 +25,21 @@ StatusOr MemcpyFusion::Emit( IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, KernelReuseCache& kernel_cache, llvm::IRBuilder<>*) const { - auto src_buffer = *GetAllocationSlice(src_, ir_emitter_context.allocations()); - auto dst_buffer = *GetAllocationSlice(dst_, ir_emitter_context.allocations()); FusionEmissionResult result; - if (src_buffer != dst_buffer) { - result.thunks.emplace_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(fusion_op), - /*source_buffer=*/src_buffer, - /*destination_buffer=*/dst_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(src_)), - /*source_value=*/src_, - /*destination_value=*/dst_)); + for (auto [src, dst] : llvm::zip(srcs_, dsts_)) { + auto src_buffer = + *GetAllocationSlice(src, ir_emitter_context.allocations()); + auto dst_buffer = + *GetAllocationSlice(dst, ir_emitter_context.allocations()); + if (src_buffer != dst_buffer) { + result.thunks.emplace_back(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(fusion_op), + /*source_buffer=*/src_buffer, + /*destination_buffer=*/dst_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(src)), + /*source_value=*/src, + /*destination_value=*/dst)); + } } return result; } diff --git a/third_party/xla/xla/service/gpu/fusions/copy.h b/third_party/xla/xla/service/gpu/fusions/copy.h index 173517d73384c7..a0dcc166ca6451 100644 --- a/third_party/xla/xla/service/gpu/fusions/copy.h +++ b/third_party/xla/xla/service/gpu/fusions/copy.h @@ -15,17 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_COPY_H_ #define XLA_SERVICE_GPU_FUSIONS_COPY_H_ +#include + #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" namespace xla { namespace gpu { -// Special case of a fusion consisting only of a kCopy instruction that can be -// implemented using a memcpy. +// Special case of a fusion consisting only of `kCopy` instructions that can be +// implemented using `memcpy`s. class MemcpyFusion : public FusionInterface { public: - MemcpyFusion(mlir::Value src, mlir::Value dst) : src_(src), dst_(dst) {} + MemcpyFusion(std::vector srcs, std::vector dsts) + : srcs_(std::move(srcs)), dsts_(std::move(dsts)) {} StatusOr Emit(IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, @@ -35,8 +38,8 @@ class MemcpyFusion : public FusionInterface { llvm::IRBuilder<>*) const final; private: - mlir::Value src_; - mlir::Value dst_; + std::vector srcs_; + std::vector dsts_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 4de451b27ea142..8d370cecbdbc2e 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/fusions.h" +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. namespace xla { namespace gpu { namespace { +namespace { bool IsParameterOrGteOfParameter(const HloInstruction* instr) { if (instr->opcode() == HloOpcode::kParameter) { @@ -66,6 +68,40 @@ bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) { } // namespace +std::optional> GetCopyFusion( + HloFusionAnalysis& analysis, + absl::Span allocations, + mlir::lmhlo::FusionOp fusion_op) { + if (!fusion_op) { + return std::nullopt; + } + + auto params = GetHloOperands(fusion_op); + auto outputs = GetHloOutputs(fusion_op); + std::vector srcs; + srcs.reserve(outputs.size()); + + for (auto* root : analysis.fusion_roots()) { + if (root->opcode() != HloOpcode::kCopy || + root->operand(0)->opcode() != HloOpcode::kParameter || + !LayoutUtil::Equal(root->operand(0)->shape().layout(), + root->shape().layout())) { + return std::nullopt; + } + + mlir::Value src = params[root->operand(0)->parameter_number()]; + if (!GetAllocationSlice(src, allocations).ok()) return std::nullopt; + + srcs.emplace_back(src); + } + + return std::make_unique( + std::move(srcs), + std::vector(outputs.begin(), outputs.end())); +} + +} // namespace + std::optional> GetFusionEmitter( HloFusionAnalysis& analysis, absl::Span allocations, @@ -84,19 +120,8 @@ std::optional> GetFusionEmitter( return std::make_unique(analysis); } } - if (is_single && analysis.fusion_roots().size() == 1 && - analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { - if (!fusion_op) { - return std::nullopt; - } - mlir::Value operand = GetHloOperands(fusion_op).front(); - mlir::Value output = GetHloOutputs(fusion_op).front(); - Shape operand_shape = GetShape(operand); - Shape output_shape = GetShape(output); - if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && - GetAllocationSlice(operand, allocations).ok()) { - return std::make_unique(operand, output); - } + if (auto copy_fusion = GetCopyFusion(analysis, allocations, fusion_op)) { + return copy_fusion; } return std::make_unique(analysis); } diff --git a/third_party/xla/xla/tests/multioutput_fusion_test.cc b/third_party/xla/xla/tests/multioutput_fusion_test.cc index 0b14df8dcbfb72..58221ea57aa960 100644 --- a/third_party/xla/xla/tests/multioutput_fusion_test.cc +++ b/third_party/xla/xla/tests/multioutput_fusion_test.cc @@ -190,36 +190,6 @@ XLA_TEST_F(MultiOutputFusionTest, DifferentTypesNoFusion) { } XLA_TEST_F(MultiOutputFusionTest, DifferentTypesFusion) { RunTest1D(true, 8); } -XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { - const char* testcase = R"( - HloModule m, is_scheduled=true - - fused_computation { - x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) - gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 - gte.2 = (s32[]) get-tuple-element(gte.3), index=0 - gte.4 = s32[] get-tuple-element(gte.2), index=0 - copy = s32[] copy(gte.4) - ROOT tuple = (s32[]) tuple(copy) - } - - ENTRY thing.v3 { - x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) - ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation - } - )"; - auto module = ParseAndReturnVerifiedModule(testcase).value(); - auto param = LiteralUtil::MakeTupleOwned( - LiteralUtil::MakeTupleOwned( - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), - LiteralUtil::CreateR0(1.0)), - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), - LiteralUtil::CreateR0(4))); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); -} - XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const char* testcase = R"( HloModule m, is_scheduled=true From 7f62bc67f9bf6aafbc80002423dfcd38d7a6f597 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 4 Dec 2023 02:46:31 -0800 Subject: [PATCH 334/381] Fix constant cast in GPU runtime (MemsetImpl) For example, if the constant is "1.0 : f16", the binary value should be 0x3C00, not 0x0001 (what static_cast yields). PiperOrigin-RevId: 587649352 --- third_party/xla/xla/service/gpu/runtime/BUILD | 2 +- .../xla/xla/service/gpu/runtime/memset.cc | 5 +++-- third_party/xla/xla/tests/reduce_test.cc | 20 +++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index edd51c1e1600c8..a296a45bce2d2f 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -722,7 +722,7 @@ cc_library( "//xla/runtime:custom_call_registry", "//xla/runtime:executable", "//xla/service:executable", - "//xla/service/gpu:io_feed_manager", + "@com_google_absl//absl/base", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/memset.cc b/third_party/xla/xla/service/gpu/runtime/memset.cc index 0e31161604a413..76314e07fc9a3d 100644 --- a/third_party/xla/xla/service/gpu/runtime/memset.cc +++ b/third_party/xla/xla/service/gpu/runtime/memset.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/runtime/memset.h" +#include "absl/base/casts.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/executable.h" #include "xla/service/gpu/runtime/support.h" @@ -95,9 +96,9 @@ static absl::StatusOr ToBitPattern(CustomCall::VariantArg constant) { else if (auto i64 = constant.get(); succeeded(i64)) return truncate(*i64); else if (auto bf16 = constant.get(); succeeded(bf16)) - return extend(static_cast(*bf16)); + return extend(absl::bit_cast(*bf16)); else if (auto f16 = constant.get(); succeeded(f16)) - return extend(static_cast(*f16)); + return extend(absl::bit_cast(*f16)); else if (auto f32 = constant.get(); succeeded(f32)) return truncate(*f32); else if (auto f64 = constant.get(); succeeded(f64)) diff --git a/third_party/xla/xla/tests/reduce_test.cc b/third_party/xla/xla/tests/reduce_test.cc index 843f276d82868c..9885284eb196d3 100644 --- a/third_party/xla/xla/tests/reduce_test.cc +++ b/third_party/xla/xla/tests/reduce_test.cc @@ -1044,6 +1044,26 @@ XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(ReduceHloTest, ReduceAtomicF16) { + absl::string_view hlo_string = R"( +HloModule jit_reduce_axes12 + +region_0.3 { + Arg_0.4 = f16[] parameter(0) + Arg_1.5 = f16[] parameter(1) + ROOT minimum.6 = f16[] minimum(Arg_0.4, Arg_1.5) +} + +ENTRY main.8 { + constant.1 = f16[] constant(1) + Arg_0.1 = f16[2,16385,1]{2,1,0} broadcast(constant.1), dimensions={} + constant.2 = f16[] constant(inf) + ROOT reduce.7 = f16[2]{0} reduce(Arg_0.1, constant.2), dimensions={1,2}, to_apply=region_0.3 +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + class VariadicReduceTest : public HloTestBase {}; XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) { From 5c41c45c964cfaec8f5d583158092448c432990f Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 4 Dec 2023 04:38:20 -0800 Subject: [PATCH 335/381] [TileAnalysis] Use flat_hash_map to store indexing maps for operands. PiperOrigin-RevId: 587673987 --- third_party/xla/xla/service/gpu/model/BUILD | 1 + .../xla/service/gpu/model/tile_analysis.cc | 172 +++++----- .../xla/xla/service/gpu/model/tile_analysis.h | 29 +- .../service/gpu/model/tile_analysis_test.cc | 305 ++++++++---------- 4 files changed, 229 insertions(+), 278 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 497d601a7a0db1..88b4257236b856 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -232,6 +232,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 7fe45637bf2cfa..203de7c8618b5c 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/service/gpu/model/tile_analysis.h" -#include #include #include #include +#include #include #include #include @@ -39,6 +39,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -74,13 +75,12 @@ StatusOr ComputeCwiseOpIndexing( dims.size(), mlir_context), .input_dims_sizes = {}}; - std::vector operand_indexing_maps; + HloInstructionIndexing instr_indexing; int64_t operand_count = instr->operand_count(); - operand_indexing_maps.reserve(operand_count); for (int64_t operand_id = 0; operand_id < operand_count; ++operand_id) { - operand_indexing_maps.push_back({{identity_map}, operand_id}); + instr_indexing.operand_indexing_maps[operand_id].insert(identity_map); } - return HloInstructionIndexing{std::move(operand_indexing_maps)}; + return instr_indexing; } StatusOr ComputeBroadcastOpIndexing( @@ -95,9 +95,7 @@ StatusOr ComputeBroadcastOpIndexing( .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), .input_dims_sizes = {}}; - - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } // Composes affine maps, i.e. consumer_map ∘ producer_map. @@ -165,28 +163,23 @@ IndexingMap ComposeIndexingMaps(const IndexingMap& producer_map, // 2. `consumer_operand_indexing` is the consumer's HloOperandIndexing for the // operand that corresponds to the provided producer. HloInstructionIndexing ComputeFusedProducerConsumerIndexing( - const HloInstructionIndexing& producer_instr_indexing, - const HloOperandIndexing& consumer_operand_indexing) { + const HloInstructionIndexing& producer_indexing, + const absl::flat_hash_set& operand_indexing_maps) { HloInstructionIndexing fused_instr_indexing; // Every operand can be read 1 or more times by the consumer which also can // have 1 or more read accesses to its operands. So, to get the composed // indexing maps we have to compute a "cross product" here. - for (const HloOperandIndexing& producer_operand_indexing : - producer_instr_indexing.operand_indexing_maps) { + for (const auto& [producer_operand_id, producer_operand_indexing] : + producer_indexing.operand_indexing_maps) { auto& composed_operand_indexing = - fused_instr_indexing.operand_indexing_maps.emplace_back(); - composed_operand_indexing.operand_id = producer_operand_indexing.operand_id; - for (const IndexingMap& producer_map : - producer_operand_indexing.indexing_maps) { - for (const IndexingMap& consumer_map : - consumer_operand_indexing.indexing_maps) { - composed_operand_indexing.indexing_maps.insert( + fused_instr_indexing.operand_indexing_maps[producer_operand_id]; + for (const IndexingMap& producer_map : producer_operand_indexing) { + for (const IndexingMap& consumer_map : operand_indexing_maps) { + composed_operand_indexing.insert( ComposeIndexingMaps(producer_map, consumer_map)); } } - fused_instr_indexing.operand_indexing_maps.push_back( - std::move(composed_operand_indexing)); } return fused_instr_indexing; } @@ -209,31 +202,27 @@ StatusOr ComputeFusionOpIndexing( parameter_indexing_maps; while (!bfs.empty()) { const auto& [instr, instr_indexing] = bfs.front(); - for (const auto& operand_indexing : instr_indexing.operand_indexing_maps) { - const HloInstruction* producer_instr = - instr->operand(operand_indexing.operand_id); + for (const auto& [operand_id, operand_indexing_maps] : + instr_indexing.operand_indexing_maps) { + const HloInstruction* producer_instr = instr->operand(operand_id); if (producer_instr->IsConstant()) continue; // If the producer is a fusion op parameter, store the result. if (auto parameter = DynCast(producer_instr)) { parameter_indexing_maps[parameter->parameter_number()].insert( - operand_indexing.indexing_maps.begin(), - operand_indexing.indexing_maps.end()); + operand_indexing_maps.begin(), operand_indexing_maps.end()); continue; } TF_ASSIGN_OR_RETURN(auto producer_instr_indexing, ComputeInstructionIndexing( producer_instr, /*output_id=*/0, mlir_context)); - bfs.push(std::make_pair(producer_instr, - ComputeFusedProducerConsumerIndexing( - producer_instr_indexing, operand_indexing))); + bfs.push(std::make_pair( + producer_instr, ComputeFusedProducerConsumerIndexing( + producer_instr_indexing, operand_indexing_maps))); } bfs.pop(); } - HloInstructionIndexing fusion_indexing; - for (const auto& [operand_id, maps] : parameter_indexing_maps) { - fusion_indexing.operand_indexing_maps.push_back({maps, operand_id}); - } - return fusion_indexing; + return HloInstructionIndexing{.operand_indexing_maps = + std::move(parameter_indexing_maps)}; } StatusOr ComputeDotOpIndexing( @@ -311,12 +300,8 @@ StatusOr ComputeDotOpIndexing( .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), rhs_exprs, mlir_context), .input_dims_sizes = input_dim_sizes}; - - return HloInstructionIndexing{ - {HloOperandIndexing{.indexing_maps = {std::move(lhs_indexing_map)}, - .operand_id = 0}, - HloOperandIndexing{.indexing_maps = {std::move(rhs_indexing_map)}, - .operand_id = 1}}}; + return HloInstructionIndexing::FromIndexingMaps( + {lhs_indexing_map, rhs_indexing_map}); } StatusOr ComputeReduceOpIndexing( @@ -348,14 +333,10 @@ StatusOr ComputeReduceOpIndexing( exprs, mlir_context), .input_dims_sizes = std::move(input_dims_sizes)}; - std::vector operand_indexing_maps; - int64_t input_count = reduce->input_count(); - operand_indexing_maps.reserve(input_count); - for (int64_t input_id = 0; input_id < input_count; ++input_id) { - operand_indexing_maps.push_back({{indexing_map}, input_id}); - } - return HloInstructionIndexing{std::move(operand_indexing_maps)}; + return HloInstructionIndexing::FromIndexingMaps( + std::vector(reduce->input_count(), indexing_map)); } + // Computes strides for a shape. std::vector ComputeStrides(absl::Span dims) { int rank = static_cast(dims.size()); @@ -515,8 +496,8 @@ StatusOr ComputeReshapeOpIndexing( auto output_dims = reshape->shape().dimensions(); IndexingMap reshape_indexing_map = ComputeReshapeIndexingMap(input_dims, output_dims, mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; + + return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } StatusOr ComputeReverseOpIndexing( @@ -544,8 +525,7 @@ StatusOr ComputeReverseOpIndexing( mlir_context), .input_dims_sizes = {}}; - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } StatusOr ComputeSliceOpIndexing( @@ -568,8 +548,7 @@ StatusOr ComputeSliceOpIndexing( .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), .input_dims_sizes = {}}; - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } IndexingMap ComputeTransposeIndexingMap(absl::Span permutation, @@ -584,10 +563,8 @@ IndexingMap ComputeTransposeIndexingMap(absl::Span permutation, StatusOr ComputeTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - IndexingMap transpose_indexing_map = - ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(transpose_indexing_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps( + {ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context)}); } StatusOr ComputeBitcastOpIndexing( @@ -603,18 +580,14 @@ StatusOr ComputeBitcastOpIndexing( input_shape, output_shape); CHECK(permutation.has_value()) << "Failed to deduce permutation for a bitcast."; - IndexingMap transpose_indexing_map = - ComputeTransposeIndexingMap(*permutation, mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(transpose_indexing_map)}, - .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps( + {ComputeTransposeIndexingMap(*permutation, mlir_context)}); } if (std::holds_alternative( decomposed_bitcast)) { IndexingMap reshape_indexing_map = ComputeReshapeIndexingMap( input_shape.dimensions(), output_shape.dimensions(), mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } // `trt` stands for transpose-reshape-transpose decomposition of bitcast. auto trt = std::get(decomposed_bitcast); @@ -627,8 +600,7 @@ StatusOr ComputeBitcastOpIndexing( ComputeTransposeIndexingMap(trt.transpose2_dims, mlir_context); IndexingMap composed_map = ComposeIndexingMaps( ComposeIndexingMaps(transpose_map_1, reshape_map), transpose_map_2); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(composed_map)}, .operand_id = 0}}}; + return HloInstructionIndexing::FromIndexingMaps({composed_map}); } template @@ -779,7 +751,7 @@ struct IndexingMapSimplifier { extracted = getAffineBinaryOpExpr( AffineExprKind::Add, extracted, getAffineBinaryOpExpr(AffineExprKind::Mul, - cast(expr).getLHS(), + mlir::cast(expr).getLHS(), getAffineConstantExpr(factor, mlir_context))); // Remove from dividend. return false; @@ -888,31 +860,28 @@ bool IndexingMap::Simplify(absl::Span dimension_sizes) { return true; } -bool HloOperandIndexing::Simplify(absl::Span dimension_sizes) { - std::vector to_remove; - std::vector to_add; - for (auto map : indexing_maps) { - to_remove.push_back(map); - if (map.Simplify(dimension_sizes)) { - to_add.push_back(map); - } else { - to_remove.pop_back(); - } - } - for (auto& map : to_remove) { - indexing_maps.erase(map); - } - for (auto& map : to_add) { - indexing_maps.insert(map); - } - return !to_remove.empty(); -} - bool HloInstructionIndexing::Simplify( absl::Span dimension_sizes) { bool any_simplified = false; for (auto& operand_indexing : operand_indexing_maps) { - any_simplified |= operand_indexing.Simplify(dimension_sizes); + std::vector to_remove; + std::vector to_add; + absl::flat_hash_set& indexing_maps = operand_indexing.second; + for (IndexingMap map : indexing_maps) { + to_remove.push_back(map); + if (map.Simplify(dimension_sizes)) { + to_add.push_back(map); + } else { + to_remove.pop_back(); + } + } + for (auto& map : to_remove) { + indexing_maps.erase(map); + } + for (auto& map : to_add) { + indexing_maps.insert(map); + } + any_simplified |= !to_remove.empty(); } return any_simplified; } @@ -935,26 +904,29 @@ std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { return out; } -std::ostream& operator<<(std::ostream& out, - const HloOperandIndexing& operand_indexing) { - out << "operand id = " << operand_indexing.operand_id << ' '; - for (const auto& map : operand_indexing.indexing_maps) { - out << map; - } - return out; -} - std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing) { - for (const auto& operand_map : instr_indexing.operand_indexing_maps) { - out << operand_map; + for (const auto& [operand_id, indexing_maps] : + instr_indexing.operand_indexing_maps) { + out << "operand id = " << operand_id << ' '; + for (const auto& indexing_map : indexing_maps) { + out << indexing_map; + } } return out; } std::string IndexingMap::ToString() const { return ToStringImpl(*this); } -std::string HloOperandIndexing::ToString() const { return ToStringImpl(*this); } +HloInstructionIndexing HloInstructionIndexing::FromIndexingMaps( + absl::Span indexing_maps) { + HloInstructionIndexing instr_indexing; + instr_indexing.operand_indexing_maps.reserve(indexing_maps.size()); + for (const auto& [index, map] : llvm::enumerate(indexing_maps)) { + instr_indexing.operand_indexing_maps[index].insert(map); + } + return instr_indexing; +} std::string HloInstructionIndexing::ToString() const { return ToStringImpl(*this); diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.h b/third_party/xla/xla/service/gpu/model/tile_analysis.h index a8bed459116973..734d8e168f0007 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.h @@ -22,7 +22,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" #include "llvm/ADT/Hashing.h" #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -65,10 +67,12 @@ namespace gpu { // could not be expressed via dimensions of the output. struct IndexingMap { std::string ToString() const; + // Returns true if the map was simplified. bool Simplify(absl::Span dimension_sizes); mlir::AffineMap affine_map; + // Upper iteration bounds for dimensions only present in the input. std::vector input_dims_sizes; }; std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); @@ -80,21 +84,6 @@ H AbslHashValue(H h, const IndexingMap& indexing_map) { return H::combine(std::move(h), static_cast(affine_map_hash)); } -// Contains 1 or more indexing maps for the `operand_id`. There are cases, when -// the same input operand is read multiple times in various ways. Especially, it -// happens a lot in fusion ops. -struct HloOperandIndexing { - std::string ToString() const; - - // Returns true if the indexing was simplified. - bool Simplify(absl::Span dimension_sizes); - - absl::flat_hash_set indexing_maps; - int64_t operand_id; -}; -std::ostream& operator<<(std::ostream& out, - const HloOperandIndexing& operand_indexing); - // Contains indexing maps for all N-dimensional tensor input operands that // correspond to a particular output. struct HloInstructionIndexing { @@ -103,7 +92,15 @@ struct HloInstructionIndexing { // Returns true if the indexing was simplified. bool Simplify(absl::Span dimension_sizes); - std::vector operand_indexing_maps; + // Creates a HloInstructionIndexing from a list of indexing maps for all + // operands and sorted w.r.t. operand index, i.e. indexing_maps[i] corresponds + // to operand[i] of the instruction. + static HloInstructionIndexing FromIndexingMaps( + absl::Span indexing_maps); + + // Maps input operand index to the indexing map for one particular output. + absl::flat_hash_map> + operand_indexing_maps; }; std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing); diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 518c0c9e8f3e2d..377453a3d8e4aa 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -37,6 +37,7 @@ using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::ExplainMatchResult; using ::testing::HasSubstr; +using ::testing::Pair; using ::testing::PrintToString; using ::testing::UnorderedElementsAre; @@ -50,7 +51,7 @@ MATCHER_P2(MatchIndexingMap, affine_map_string, input_dims_sizes, arg.input_dims_sizes, result_listener); } -MATCHER_P2(MatchOperandIndexing, operand_id, indexing_map_matchers, "") { +MATCHER_P2(MatchInstrIndexing, operand_id, indexing_map_matchers, "") { return ExplainMatchResult(Eq(operand_id), arg.operand_id, result_listener) && ExplainMatchResult(indexing_map_matchers, arg.indexing_maps, result_listener); @@ -85,14 +86,12 @@ TEST_F(TileAnalysisTest, ElementwiseOp) { ROOT add0 = f32[10, 20] add(p0, p1) } )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))))); + EXPECT_THAT(input_indexing.operand_indexing_maps, + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))), + Pair(1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BitcastIsReshape) { @@ -106,7 +105,7 @@ TEST_F(TileAnalysisTest, BitcastIsReshape) { )")); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( + UnorderedElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1 * 4 + d2)", std::vector{}))))); } @@ -120,11 +119,11 @@ TEST_F(TileAnalysisTest, BitcastIsTranspose) { ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BitcastIsTransposeReshapeTranspose) { @@ -137,10 +136,10 @@ TEST_F(TileAnalysisTest, BitcastIsTransposeReshapeTranspose) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1) -> (d1, d0 floordiv 3, d0 mod 3)", - std::vector{}))))); + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1) -> (d1, d0 floordiv 3, d0 mod 3)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BroadcastOp) { @@ -153,9 +152,9 @@ TEST_F(TileAnalysisTest, BroadcastOp) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", - std::vector{}))))); + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { @@ -175,11 +174,10 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { )")); EXPECT_THAT( input_indexing.operand_indexing_maps, - UnorderedElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))))); + UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))), + Pair(1, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithDot) { @@ -247,29 +245,24 @@ TEST_F(TileAnalysisTest, FusionOpWithDot) { EXPECT_THAT( input_indexing.operand_indexing_maps, UnorderedElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0] -> " - "(d2 + d3, d0 * 768 + s0, d4, d5)", - std::vector{768}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0)", - std::vector{768}))), - MatchOperandIndexing( - 2, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5) -> (d1)", std::vector{}))), - MatchOperandIndexing( - 3, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", - std::vector{768}))), - MatchOperandIndexing( - 4, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", - std::vector{768}))), - MatchOperandIndexing( - 5, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5) -> (d2 + d3, d4, d5)", - std::vector{}))))); + Pair(0, + ElementsAre(MatchIndexingMap("(d0, d1, d2, d3, d4, d5)[s0] -> " + "(d2 + d3, d0 * 768 + s0, d4, d5)", + std::vector{768}))), + Pair(1, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0)", + std::vector{768}))), + Pair(2, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5) -> (d1)", std::vector{}))), + Pair(3, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", + std::vector{768}))), + Pair(4, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0)", + std::vector{768}))), + Pair(5, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5) -> (d2 + d3, d4, d5)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { @@ -288,7 +281,7 @@ TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { )")); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( + UnorderedElementsAre(Pair( 0, UnorderedElementsAre( MatchIndexingMap("(d0, d1) -> (d1, d0)", std::vector{}), @@ -320,17 +313,16 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { EXPECT_THAT( input_indexing.operand_indexing_maps, UnorderedElementsAre( - MatchOperandIndexing( - 0, UnorderedElementsAre( - MatchIndexingMap("(d0) -> (d0)", std::vector{}), - MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), - MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))), - MatchOperandIndexing( - 1, - UnorderedElementsAre( - MatchIndexingMap("(d0) -> (d0)", std::vector{}), - MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), - MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))))); + Pair(0, + UnorderedElementsAre( + MatchIndexingMap("(d0) -> (d0)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))), + Pair(1, + UnorderedElementsAre( + MatchIndexingMap("(d0) -> (d0)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 1)", std::vector{}), + MatchIndexingMap("(d0) -> (d0 + 2)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { @@ -356,11 +348,11 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { ROOT fusion = f32[10] fusion(p0, p0_init), kind=kLoop, calls=f } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", - std::vector{150, 50, 20}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", + std::vector{150, 50, 20}))))); } TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { @@ -387,7 +379,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( + UnorderedElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1)[s0] -> (d0, s0)", std::vector{20}))))); } @@ -418,11 +410,10 @@ TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f } )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d2, d0, d1)", - std::vector{}))))); + EXPECT_THAT(input_indexing.operand_indexing_maps, + UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d2, d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { @@ -449,10 +440,10 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", - std::vector{16, 128}))))); + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap( + "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", + std::vector{16, 128}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { @@ -470,9 +461,8 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0) -> (d0)", - std::vector{}))))); + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0) -> (d0)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { @@ -490,10 +480,10 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { } )")); EXPECT_TRUE(input_indexing.Simplify({8, 16})); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { @@ -511,11 +501,10 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { } )")); EXPECT_TRUE(input_indexing.Simplify({10, 10, 10})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1, d2)", - std::vector{}))))); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0, d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { @@ -536,10 +525,10 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { )")); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", - std::vector{}))))); + ElementsAre( + Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { @@ -552,11 +541,10 @@ TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { } )")); EXPECT_FALSE(input_indexing.Simplify({32})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0) -> (d0 floordiv 8, d0 mod 8)", - std::vector{}))))); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0) -> (d0 floordiv 8, d0 mod 8)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { @@ -570,7 +558,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { )")); EXPECT_FALSE(input_indexing.Simplify({4, 8})); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( + ElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0 * 8 + d1)", std::vector{}))))); } @@ -587,10 +575,10 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandAndCollapseShape) { EXPECT_FALSE(input_indexing.Simplify({32, 3, 4})); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)", - std::vector{}))))); + ElementsAre( + Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { @@ -603,11 +591,10 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { } )")); EXPECT_FALSE(input_indexing.Simplify({4, 4, 8})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0 * 4 + d1, d2)", - std::vector{}))))); + EXPECT_THAT(input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 * 4 + d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { @@ -622,7 +609,7 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { EXPECT_TRUE(input_indexing.Simplify({2, 4, 4})); // TODO(b/313840171): Simplify `(d1 * 4 + d2) floordiv 8` to `d1 floordiv 2`. EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( + ElementsAre(Pair( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, " "(d1 * 4 + d2) mod 8)", @@ -642,12 +629,12 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { // TODO(b/313840171): Simplify `(d0 * 8 + d1) floordiv 16` to `d0 floordiv 2`. // TODO(b/313840171): Simplify `((d0 * 8 + d1) mod 16) floordiv 4` to // `((d0 * 8 + d1) floordiv 4) mod 4` to `(d0 * 2 + d1 floordiv 4) mod 4`. - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1) -> ((d0 * 8 + d1) floordiv 16, " - "((d0 * 8 + d1) mod 16) floordiv 4, d1 mod 4)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1) -> ((d0 * 8 + d1) floordiv 16, " + "((d0 * 8 + d1) mod 16) floordiv 4, d1 mod 4)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReduceOp) { @@ -667,10 +654,9 @@ TEST_F(TileAnalysisTest, ReduceOp) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", - std::vector{20, 50}))))); + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", + std::vector{20, 50}))))); } TEST_F(TileAnalysisTest, VariadicReduceOp) { @@ -703,23 +689,21 @@ TEST_F(TileAnalysisTest, VariadicReduceOp) { ASSERT_IS_OK(input_indexing_0); EXPECT_THAT( input_indexing_0->operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap( - "(d0)[s0] -> (s0, d0)", std::vector{256}))))); + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))), + Pair(1, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))))); auto input_indexing_1 = ComputeInstructionIndexing(root, 1, &mlir_context_); ASSERT_IS_OK(input_indexing_1); EXPECT_THAT( input_indexing_1->operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap( - "(d0)[s0] -> (s0, d0)", std::vector{256}))))); + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))), + Pair(1, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", + std::vector{256}))))); } TEST_F(TileAnalysisTest, ReverseOp) { @@ -732,11 +716,11 @@ TEST_F(TileAnalysisTest, ReverseOp) { } )")); EXPECT_FALSE(input_indexing.Simplify({1, 17, 9, 9})); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReverseReshape) { @@ -756,10 +740,10 @@ TEST_F(TileAnalysisTest, ReverseReshape) { } )")); EXPECT_TRUE(input_indexing.Simplify({10, 11})); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, SliceOp) { @@ -772,11 +756,11 @@ TEST_F(TileAnalysisTest, SliceOp) { slice={[5:10:1], [3:20:7], [0:50:2]} } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", - std::vector{}))))); + EXPECT_THAT( + input_indexing.operand_indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, TransposeOp) { @@ -790,10 +774,9 @@ TEST_F(TileAnalysisTest, TransposeOp) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, TransposeOp4D) { @@ -806,10 +789,9 @@ TEST_F(TileAnalysisTest, TransposeOp4D) { } )")); EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", + std::vector{}))))); } TEST_F(TileAnalysisTest, DotOp) { @@ -826,15 +808,14 @@ TEST_F(TileAnalysisTest, DotOp) { )")); EXPECT_THAT( input_indexing.operand_indexing_maps, - ElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " - "(d2, d1, s1, d3, s0, d0)", - std::vector{18, 17}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " - "(s1, d0, d4, s0, d5, d1)", - std::vector{18, 17}))))); + UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " + "(d2, d1, s1, d3, s0, d0)", + std::vector{18, 17}))), + Pair(1, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " + "(s1, d0, d4, s0, d5, d1)", + std::vector{18, 17}))))); } TEST_F(TileAnalysisTest, UnsupportedOps) { From f85b28136457c42c6841cf4036d58da93cea29b4 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 4 Dec 2023 06:43:16 -0800 Subject: [PATCH 336/381] Priority fusion: fuse bitcasts first. Bitcasts can get in the way of properly analyzing producer/consumer fusions, so just fuse them into their consumers first. PiperOrigin-RevId: 587701582 --- third_party/xla/xla/service/gpu/priority_fusion.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 55205823e2754d..61677ef0b08aed 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -154,6 +154,14 @@ class GpuPriorityFusionQueue : public FusionQueue { continue; } current_consumers_ = current_producer_->users(); + + if (current_producer_->opcode() == HloOpcode::kBitcast) { + // We don't check if bitcasts can be fused with all consumers, so we + // have to do it here. + llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) { + return !CanFuseCached(current_producer_, consumer); + }); + } } auto next_consumer = current_consumers_.back(); @@ -286,6 +294,11 @@ class GpuPriorityFusionQueue : public FusionQueue { // Returns the priority of the producer based on its current operands and // users. Priority CalculateProducerPriority(HloInstruction* producer) { + // Bitcasts should always be fused first, since they are no-ops. + if (producer->opcode() == HloOpcode::kBitcast) { + return std::numeric_limits::max(); + } + // Don't fuse if we can't fuse in all users. if (auto fusion_decision = CanFuseWithAllUsers(producer); !fusion_decision) { From 12a14593a7eeeb3f8752a9e0182b83340358fd25 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 4 Dec 2023 06:44:34 -0800 Subject: [PATCH 337/381] Remove `jit_is_disabled` API since it's not used anywhere and is subsumed by jax.config.disable_jit PiperOrigin-RevId: 587701903 --- third_party/xla/xla/python/jax_jit.cc | 1 - third_party/xla/xla/python/xla_extension/jax_jit.pyi | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index cd9407fce89ccf..fa5f666a0e3b20 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -344,7 +344,6 @@ void BuildJaxjitSubmodule(py::module& m) { }, py::return_value_policy::reference); - jitlib.def("jit_is_disabled", &GetDisableJit); jitlib.def("get_enable_x64", &GetEnableX64); jitlib.def("set_thread_local_state_initialization_callback", [](py::object f) { initialize_local_state = f; }); diff --git a/third_party/xla/xla/python/xla_extension/jax_jit.pyi b/third_party/xla/xla/python/xla_extension/jax_jit.pyi index 9bf5e30d6c8907..3c647d5461b72e 100644 --- a/third_party/xla/xla/python/xla_extension/jax_jit.pyi +++ b/third_party/xla/xla/python/xla_extension/jax_jit.pyi @@ -34,7 +34,6 @@ class JitState: def global_state() -> JitState: ... def thread_local_state() -> JitState: ... -def jit_is_disabled() -> bool: ... def get_enable_x64() -> bool: ... def set_thread_local_state_initialization_callback( function: Callable[[], None]): ... From 7d4e8ab7e086478a759e8eea7a73f60117c9f2d0 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Mon, 4 Dec 2023 06:45:10 -0800 Subject: [PATCH 338/381] PR #7387: [ROCM] build brake fix 231129 Imported from GitHub PR https://github.com/openxla/xla/pull/7387 This is a build brake fix caused by these new changes to gpu_driver: https://github.com/openxla/xla/commit/fede5db53eac81a77e33a70094cf889a8a0810dd https://github.com/openxla/xla/commit/c58bc738fe27187eae50c8aa7243b8bb19e5d26f Besides, we have also renamed hipblaslt_plugin to workaround hipblas-lt runtime loading problems for rocm-6.0 @xla-rotation: could you please have a look ? thanks ! Copybara import of the project: -- 8c3660a625e253dbb19affc135c52ce69cab3e71 by Pavel Emeliyanenko : yet unfinished fix -- 2a09cdedfbb31b64388d1ff4fbdae413be403a1c by Pavel Emeliyanenko : rocm driver fixes -- 04046dc9ed1cce33a0eae65be313836225070264 by Pavel Emeliyanenko : added hip wrapper functions and renamed hipblaslt plugin -- f6b540ad9990826b0f7051b5227aa91e3490e933 by Pavel Emeliyanenko : fixing whitespace Merging this change closes #7387 PiperOrigin-RevId: 587702108 --- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/xla/stream_executor/rocm/BUILD | 4 +- .../xla/stream_executor/rocm/rocm_driver.cc | 247 ++++++++++++++++-- .../rocm/rocm_driver_wrapper.h | 8 +- 4 files changed, 237 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 4f15bf6f8f9c19..4b0f3c02186c61 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1762,7 +1762,7 @@ cc_library( "//xla/stream_executor:host_or_device_scalar", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:hipblas_lt_header", - "//xla/stream_executor/rocm:hipblaslt_plugin", + "//xla/stream_executor/rocm:amdhipblaslt_plugin", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/platform:dso_loader", ]) + if_static([ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 03fbbe60445eb3..7ed801f678d93c 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -444,7 +444,7 @@ cc_library( ) cc_library( - name = "hipblaslt_plugin", + name = "amdhipblaslt_plugin", srcs = if_rocm_is_configured(["hip_blas_lt.cc"]), hdrs = if_rocm_is_configured([ "hip_blas_lt.h", @@ -562,7 +562,7 @@ cc_library( ":rocm_driver", ":rocm_platform", ":rocm_helpers", - ":hipblaslt_plugin", + ":amdhipblaslt_plugin", ]), alwayslink = 1, ) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index bf20011441504d..76af619f6328e6 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -767,30 +767,13 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, "Failed to set shared memory size"); } - RETURN_IF_ROCM_ERROR(hipGraphExecKernelNodeSetParams(exec, node, ¶ms), - "Failed to set HIP graph kernel node params"); + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecKernelNodeSetParams(exec, node, ¶ms), + "Failed to set HIP graph kernel node params"); return ::tsl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( - GpuContext* context, hipGraphNode_t* node, hipGraph_t graph, - absl::Span deps, hipDeviceptr_t gpu_dst, - hipDeviceptr_t gpu_src, uint64_t size) { - VLOG(2) << "Add memcpy d2d node to a graph " << graph - << "; dst: " << reinterpret_cast(gpu_dst) - << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size - << "; context: " << context->context() << "; deps: " << deps.size(); - - RETURN_IF_ROCM_ERROR( - wrap::hipGraphAddMemcpyNode1D( - node, graph, deps.data(), deps.size(), - reinterpret_cast(gpu_dst), reinterpret_cast(gpu_src), - static_cast(size), hipMemcpyDeviceToDevice), - "Failed to add memcpy d2d node to a HIP graph"); - return tsl::OkStatus(); -} - /* static */ tsl::Status GpuDriver::GraphAddChildNode( hipGraphNode_t* node, hipGraph_t graph, absl::Span deps, hipGraph_t child) { @@ -816,6 +799,230 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return tsl::OkStatus(); } +static hipMemAccessFlags ToHipMemAccessFlags( + GpuDriver::MemAccessFlags access_flags) { + switch (access_flags) { + case GpuDriver::MemAccessFlags::kNone: + return hipMemAccessFlagsProtNone; + case GpuDriver::MemAccessFlags::kRead: + return hipMemAccessFlagsProtRead; + case GpuDriver::MemAccessFlags::kReadWrite: + return hipMemAccessFlagsProtReadWrite; + } +} + +static hipMemLocationType ToHipLocationType( + GpuDriver::MemLocationType location_type) { + switch (location_type) { + case GpuDriver::MemLocationType::kInvalid: + return hipMemLocationTypeInvalid; + case GpuDriver::MemLocationType::kDevice: + return hipMemLocationTypeDevice; + case GpuDriver::MemLocationType::kHost: + case GpuDriver::MemLocationType::kHostNuma: + case GpuDriver::MemLocationType::kHostNumaCurrent: + return hipMemLocationTypeInvalid; + } +} + +static hipMemAllocationType ToHipAllocationType( + GpuDriver::MemAllocationType allocation_type) { + switch (allocation_type) { + case GpuDriver::MemAllocationType::kInvalid: + return hipMemAllocationTypeInvalid; + case GpuDriver::MemAllocationType::kPinned: + return hipMemAllocationTypePinned; + } +} + +/*static*/ tsl::Status GpuDriver::GraphAddMemAllocNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, MemAccessFlags access_flags, + MemLocationType location_type, int device_id, + MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, + uint64_t max_pool_size) { + hipMemLocation mem_loc = { + .type = ToHipLocationType(location_type), + .id = device_id, + }; + + hipMemPoolProps props{}; + props.allocType = ToHipAllocationType(allocation_type); + props.handleTypes = hipMemHandleTypeNone; + props.location = mem_loc; + + hipMemAccessDesc mem_desc = { + .location = mem_loc, + .flags = ToHipMemAccessFlags(access_flags), + }; + + hipMemAllocNodeParams params{ + .poolProps = props, + .accessDescs = &mem_desc, + .accessDescCount = 1, + .bytesize = size, + .dptr = nullptr, + }; + + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemAllocNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memory allocation node to a CUDA graph"); + + VLOG(2) << "Add MemAllocNode to a graph " << graph << " size " << size + << " address " << reinterpret_cast(params.dptr); + + *d_ptr = params.dptr; + return ::tsl::OkStatus(); +} + +/*static*/ tsl::StatusOr> +GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) { + hipMemAllocNodeParams params; + RETURN_IF_ROCM_ERROR(wrap::hipGraphMemAllocNodeGetParams(node, ¶ms), + "Failed to get memory allocation node parameter"); + return std::pair{params.dptr, params.bytesize}; +} + +/* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst, + GpuDevicePtr gpu_src, uint64_t size) { + VLOG(2) << "Add memcpy d2d node to a graph " << graph + << "; dst: " << reinterpret_cast(gpu_dst) + << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size + << "; context: " << context->context() << "; deps: " << deps.size(); + + hipMemcpy3DParms params{ + .srcArray = {}, + .srcPos = {}, + .srcPtr = {.ptr = gpu_src}, + .dstArray = {}, + .dstPos = {}, + .dstPtr = {.ptr = gpu_dst}, + .extent = hipExtent{.width = size, .height = 1, .depth = 1}, + .kind = hipMemcpyDeviceToDevice}; + + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memcpy d2d node to a HIP graph"); + + return ::tsl::OkStatus(); +} + +/* static */ tsl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { + VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(gpu_dst) + << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size + << "; context: " << context->context(); + + hipMemcpy3DParms params{ + .srcArray = {}, + .srcPos = {}, + .srcPtr = {.ptr = gpu_src}, + .dstArray = {}, + .dstPos = {}, + .dstPtr = {.ptr = gpu_dst}, + .extent = hipExtent{.width = size, .height = 1, .depth = 1}, + .kind = hipMemcpyDeviceToDevice}; + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecMemcpyNodeSetParams(exec, node, ¶ms), + "Failed to set memcpy d2d node params"); + + return ::tsl::OkStatus(); +} + +namespace { + +struct BitPatternToString { + std::string operator()(uint8_t pattern) { + return absl::StrCat("u8:", pattern); + } + std::string operator()(uint16_t pattern) { + return absl::StrCat("u16:", pattern); + } + std::string operator()(uint32_t pattern) { + return absl::StrCat("u32:", pattern); + } +}; + +// Broadcasts a pattern value of 1/2/4 bytes to a 4 byte value. +struct BitPatternToValue { + std::pair operator()(uint8_t pattern) { + unsigned value = pattern; + return {(value << 24) | (value << 16) | (value << 8) | value, + /*element_size=*/1}; + } + std::pair operator()(uint16_t pattern) { + unsigned value = pattern; + return {(value << 16) | value, /*element_size=*/2}; + } + std::pair operator()(uint32_t pattern) { + return {pattern, /*element_size=*/4}; + } +}; + +} // namespace + +/* static */ tsl::Status GpuDriver::GraphAddMemsetNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr dst, + std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Add memset node to a graph " << graph + << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context() << "; deps: " << deps.size(); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + hipMemsetParams params{ + .dst = dst, + .elementSize = element_size, + .height = 1, + .pitch = 0, // unused if height is 1 + .value = value, + .width = num_elements, + }; + + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemsetNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memset node to a CUDA graph"); + + return ::tsl::OkStatus(); +} + +/* static */ tsl::Status GpuDriver::GraphExecMemsetNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr dst, std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Set memset node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context(); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + hipMemsetParams params{ + .dst = dst, + .elementSize = element_size, + .height = 1, + .pitch = 0, // unused if height is 1 + .value = value, + .width = num_elements, + }; + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecMemsetNodeSetParams(exec, node, ¶ms), + "Failed to set memset node params"); + + return ::tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 36bcfdd33e873e..f3356ffd83ddc2 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -104,15 +104,21 @@ namespace wrap { __macro(hipGetErrorString) \ __macro(hipGraphAddKernelNode) \ __macro(hipGraphAddChildGraphNode) \ + __macro(hipGraphAddMemAllocNode) \ __macro(hipGraphAddMemcpyNode) \ __macro(hipGraphAddMemcpyNode1D) \ - __macro(hipGraphExecChildGraphNodeSetParams) \ + __macro(hipGraphAddMemsetNode) \ __macro(hipGraphCreate) \ __macro(hipGraphDebugDotPrint) \ __macro(hipGraphDestroy) \ + __macro(hipGraphExecChildGraphNodeSetParams) \ __macro(hipGraphExecDestroy) \ + __macro(hipGraphExecKernelNodeSetParams) \ + __macro(hipGraphExecMemcpyNodeSetParams) \ + __macro(hipGraphExecMemsetNodeSetParams) \ __macro(hipGraphExecUpdate) \ __macro(hipGraphInstantiate) \ + __macro(hipGraphMemAllocNodeGetParams) \ __macro(hipGraphLaunch) \ __macro(hipGraphNodeGetType) \ __macro(hipHostFree) \ From 2d3b643affc66f09ea0a328a98d8ccedb00d2778 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 4 Dec 2023 07:06:36 -0800 Subject: [PATCH 339/381] [XLA] Fix `Executable::hlo_proto()` to allow a `nullptr` HLO Proto. PiperOrigin-RevId: 587707587 --- third_party/xla/xla/service/executable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index dcda9507fc0e3b..949b78b0d6f971 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -381,7 +381,7 @@ class Executable { } HloProto const* hlo_proto() const { - if (!hlo_proto_->has_hlo_module()) { + if (hlo_proto_ != nullptr && !hlo_proto_->has_hlo_module()) { *hlo_proto_->mutable_hlo_module() = module().ToProto(); } return hlo_proto_.get(); From 4ed1ceec91fa504ae4f8a56377788f79eb2bfe0e Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 4 Dec 2023 07:31:23 -0800 Subject: [PATCH 340/381] [XLA:GPU] Emit fusions with a single DUS instruction in-place. FusionWrapper was added a few month ago. Before that we wouldn't have a fusion with a single DUS inside. PiperOrigin-RevId: 587714250 --- third_party/xla/xla/service/gpu/fusions/fusions.cc | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 8d370cecbdbc2e..c8b04163b99c0f 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -51,12 +51,6 @@ bool IsParameterOrGteOfParameter(const HloInstruction* instr) { return false; } -bool IsSingleInstructionFusion(const HloFusionAnalysis& analysis) { - return analysis.fusion_roots().size() == 1 && - absl::c_all_of(analysis.fusion_roots()[0]->operands(), - IsParameterOrGteOfParameter); -} - bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) { return absl::c_all_of( analysis.fusion_roots(), [](const HloInstruction* root) { @@ -110,8 +104,7 @@ std::optional> GetFusionEmitter( case HloFusionAnalysis::EmitterFusionKind::kInputSlices: return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { - bool is_single = IsSingleInstructionFusion(analysis); - if (!is_single && IsDynamicUpdateSliceFusion(analysis)) { + if (IsDynamicUpdateSliceFusion(analysis)) { if (allocations.empty() || fusion_op == nullptr) { return std::nullopt; } From 6f0d8ac86c682d84d05e87c3ab4a517b7d5efcd1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Dec 2023 08:09:48 -0800 Subject: [PATCH 341/381] Move xla/python/exceptions.h and xla/python/status_casters.h to pjrt/. Change in preparation for adding some code that uses a third-party library (gloo) that uses exceptions to PJRT. PiperOrigin-RevId: 587724478 --- third_party/xla/xla/pjrt/BUILD | 33 +++ .../xla/xla/{python => pjrt}/exceptions.h | 6 +- third_party/xla/xla/pjrt/status_casters.h | 218 ++++++++++++++++++ third_party/xla/xla/python/BUILD | 90 +++----- .../xla/xla/python/custom_call_sharding.cc | 2 +- third_party/xla/xla/python/jax_jit.cc | 2 +- third_party/xla/xla/python/mlir.cc | 2 +- third_party/xla/xla/python/ops.cc | 2 +- .../xla/xla/python/outfeed_receiver_py.cc | 2 +- third_party/xla/xla/python/pjit.cc | 2 +- third_party/xla/xla/python/pmap_lib.cc | 4 +- third_party/xla/xla/python/profiler.cc | 4 +- third_party/xla/xla/python/py_array.cc | 2 +- third_party/xla/xla/python/py_buffer.cc | 2 +- third_party/xla/xla/python/py_client.cc | 2 +- third_party/xla/xla/python/py_client.h | 2 +- third_party/xla/xla/python/py_client_gpu.cc | 2 +- .../xla/xla/python/py_compile_only_client.cc | 2 +- third_party/xla/xla/python/pytree.cc | 2 +- third_party/xla/xla/python/sharding.h | 2 +- third_party/xla/xla/python/status_casters.h | 203 +--------------- .../xla/xla/python/status_casters_ext.cc | 4 +- third_party/xla/xla/python/traceback.cc | 2 +- .../xla/xla/python/transfer_guard_lib.cc | 2 +- third_party/xla/xla/python/types.cc | 2 +- third_party/xla/xla/python/xla.cc | 2 +- third_party/xla/xla/python/xla_compiler.cc | 4 +- 27 files changed, 315 insertions(+), 287 deletions(-) rename third_party/xla/xla/{python => pjrt}/exceptions.h (94%) create mode 100644 third_party/xla/xla/pjrt/status_casters.h diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 4853b646929a25..b65dd7bfdd421a 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -809,3 +809,36 @@ tf_proto_library( ], visibility = ["//visibility:public"], ) + +cc_library( + name = "exceptions", + hdrs = ["exceptions.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//visibility:public"], + deps = [ + "//xla:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "status_casters", + hdrs = ["status_casters.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//visibility:public"], + deps = [ + ":exceptions", + "//xla:status", + "//xla:statusor", + "@local_tsl//tsl/platform:macros", + ], +) diff --git a/third_party/xla/xla/python/exceptions.h b/third_party/xla/xla/pjrt/exceptions.h similarity index 94% rename from third_party/xla/xla/python/exceptions.h rename to third_party/xla/xla/pjrt/exceptions.h index c5b7e72e61663e..19911e171edb93 100644 --- a/third_party/xla/xla/python/exceptions.h +++ b/third_party/xla/xla/pjrt/exceptions.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_EXCEPTIONS_H_ -#define XLA_PYTHON_EXCEPTIONS_H_ +#ifndef XLA_PJRT_EXCEPTIONS_H_ +#define XLA_PJRT_EXCEPTIONS_H_ #include #include @@ -64,4 +64,4 @@ class XlaRuntimeError : public std::runtime_error { } // namespace xla -#endif // XLA_PYTHON_EXCEPTIONS_H_ +#endif // XLA_PJRT_EXCEPTIONS_H_ diff --git a/third_party/xla/xla/pjrt/status_casters.h b/third_party/xla/xla/pjrt/status_casters.h new file mode 100644 index 00000000000000..9c39d1cbb5153f --- /dev/null +++ b/third_party/xla/xla/pjrt/status_casters.h @@ -0,0 +1,218 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_STATUS_CASTERS_H_ +#define XLA_PJRT_STATUS_CASTERS_H_ + +#include "xla/pjrt/exceptions.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "tsl/platform/macros.h" + +namespace xla { + +// C++ -> Python caster helpers. +// +// Failing statuses become Python exceptions; OK Status() becomes None. +// +// Given there can be only a single global pybind11 type_caster for the +// `absl::Status` type, and given XLA wants a custom exception being raised, +// we use a dedicated helper to implement this feature without relying on a +// global `type_caster`. +// +// For example: +// +// - Functions without arguments: +// m.def("my_func", []() { xla::ThrowIfError(MyFunc()); } +// - Classes with a single argument: +// py_class.def("delete", [](Buffer& self) { +// xla::ThrowIfError(self.Delete()); +// } +// +// For functions with more arguments, you can either inline the arguments, +// or use the `ThrowIfErrorWrapper` wrapper defined below: +// +// m.def("my_func", xla::ThrowIfErrorWrapper(MyFunc)); +// +// Nonstatic member functions can be wrapped by passing a +// pointer-to-member-function: +// xla::ThrowIfErrorWrapper(&MyClass::MyMethod) + +inline void ThrowIfError(xla::Status src) { + if (!src.ok()) { + throw xla::XlaRuntimeError(src); + } +} + +// If one does not want to have to define a lambda specifying the inputs +// arguments, on can use the `ThrowIfErrorWrapper` wrapper. +// +// There are three specializations: +// - For free functions, `Sig` is the function type and `F` is `Sig&`. +// - For callable types, `Sig` is the pointer to member function type +// and `F` is the type of the callable. +// - For a nonstatic member function of a class `C`, `Sig` is the function type +// and `F` is Sig C::*. +// +// In the first two cases, the wrapper returns a callable with signature `Sig`; +// in the third case, the wrapper returns callable with a modified signature +// that takes a C instance as the first argument. +template +struct ThrowIfErrorWrapper; + +// C++17 "deduction guide" that guides class template argument deduction (CTAD) +// For free functions. +template +ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; + +// For callable types (with operator()). +template +ThrowIfErrorWrapper(xla::Status (&)(Args...)) + -> ThrowIfErrorWrapper; + +// For unbound nonstatic member functions. +template +ThrowIfErrorWrapper(xla::Status (C::*)(Args...)) + -> ThrowIfErrorWrapper; + +// Template specializations. + +// For free functions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(xla::Status (&f)(Args...)) : func(f) {} + void operator()(Args... args) { + xla::ThrowIfError(func(std::forward(args)...)); + } + xla::Status (&func)(Args...); +}; + +// For callable types (with operator()), non-const and const versions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) { + xla::ThrowIfError(func(std::forward(args)...)); + } + F func; +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) const { + xla::ThrowIfError(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} + void operator()(C& instance, Args... args) { + xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + xla::Status (C::*ptmf)(Args...); +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + void operator()(const C& instance, Args... args) const { + xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + xla::Status (C::*ptmf)(Args...) const; +}; + +// Utilities for `StatusOr`. +template +T ValueOrThrow(StatusOr v) { + if (!v.ok()) { + throw xla::XlaRuntimeError(v.status()); + } + return std::move(v).value(); +} + +template +struct ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(xla::StatusOr (&)(Args...)) + -> ValueOrThrowWrapper(Args...), + xla::StatusOr (&)(Args...)>; + +template +ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...)) + -> ValueOrThrowWrapper(Args...), C>; + +// Deduction guide for const methods. +template +ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...) const) + -> ValueOrThrowWrapper(Args...) const, C>; + +template +struct ValueOrThrowWrapper(Args...), + xla::StatusOr (&)(Args...)> { + explicit ValueOrThrowWrapper(xla::StatusOr (&f)(Args...)) : func(f) {} + R operator()(Args... args) { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + xla::StatusOr (&func)(Args...); +}; +template +struct ValueOrThrowWrapper (C::*)(Args...), F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; +template +struct ValueOrThrowWrapper (C::*)(Args...) const, F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) const { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ValueOrThrowWrapper(Args...), C> { + explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...)) + : ptmf(ptmf) {} + R operator()(C& instance, Args... args) { + return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + xla::StatusOr (C::*ptmf)(Args...); +}; +template +struct ValueOrThrowWrapper(Args...) const, C> { + explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + R operator()(const C& instance, Args... args) const { + return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + xla::StatusOr (C::*ptmf)(Args...) const; +}; + +} // namespace xla + +#endif // XLA_PJRT_STATUS_CASTERS_H_ diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 96d225404bc605..20b348cce7a339 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -141,32 +141,13 @@ py_strict_test( ] + xla_py_test_deps(), ) -cc_library( - name = "status_casters", - hdrs = ["status_casters.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = ["//visibility:public"], - deps = [ - ":exceptions", - "//xla:status", - "//xla:statusor", - "@local_tsl//tsl/platform:macros", - "@pybind11", - ], -) - tsl_pybind_extension( name = "status_casters_ext", srcs = ["status_casters_ext.cc"], visibility = ["//visibility:private"], deps = [ - ":exceptions", - ":status_casters", + "//xla/pjrt:exceptions", + "//xla/pjrt:status_casters", "@pybind11", ], ) @@ -184,22 +165,6 @@ py_strict_test( ] + xla_py_test_deps(), ) -cc_library( - name = "exceptions", - hdrs = ["exceptions.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = ["//visibility:public"], - deps = [ - "//xla:status", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "types", srcs = ["types.cc"], @@ -212,7 +177,6 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", "//xla:literal", "//xla:shape_util", "//xla:status", @@ -220,6 +184,7 @@ cc_library( "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/pjrt:exceptions", "//xla/python/ifrt", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -277,13 +242,13 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", ":python_ref_manager", # placeholder for index annotation deps "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "//xla/pjrt:exceptions", "@local_tsl//tsl/platform:logging", "@pybind11", ], @@ -348,13 +313,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":callback", - ":exceptions", ":pprof_profile_builder", ":py_client_gpu", ":py_host_callback_proto_cc", ":python_ref_manager", ":python_utils", - ":status_casters", ":traceback", ":transfer_guard_lib", ":types", @@ -377,6 +340,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", "//xla/pjrt:lru_cache", "//xla/pjrt:mlir_to_hlo", @@ -385,6 +349,7 @@ cc_library( "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", + "//xla/pjrt:status_casters", "//xla/pjrt:transpose", "//xla/python/ifrt", "//xla/python/pjrt_ifrt", @@ -457,8 +422,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":callback", - ":exceptions", "//xla:comparison_util", + "//xla/pjrt:exceptions", "//xla/service:custom_call_status", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -513,12 +478,10 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", ":py_client", ":python_ref_manager", ":python_utils", ":pytree", - ":status_casters", ":types", ":util", # placeholder for index annotation deps @@ -534,8 +497,10 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/pjrt:exceptions", "//xla/pjrt:lru_cache", "//xla/pjrt:pjrt_client", + "//xla/pjrt:status_casters", "//xla/python/ifrt", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:traceme", @@ -570,11 +535,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":inspect_sharding", - ":status_casters", # placeholder for index annotation deps "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", + "//xla/pjrt:status_casters", "//xla/service:call_inliner", "//xla/service:custom_call_sharding_helper", "//xla/service:hlo_pass_pipeline", @@ -596,7 +561,6 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":status_casters", ":types", # placeholder for index annotation deps "@com_google_absl//absl/types:span", @@ -612,6 +576,7 @@ cc_library( "//xla/client/lib:self_adjoint_eig", "//xla/client/lib:sorting", "//xla/client/lib:svd", + "//xla/pjrt:status_casters", "@pybind11", ], ) @@ -654,12 +619,12 @@ cc_library( ":py_client", ":python_utils", ":pytree", - ":status_casters", ":transfer_guard_lib", ":util", # placeholder for index annotation deps "@com_google_absl//absl/synchronization", "//xla/pjrt:lru_cache", + "//xla/pjrt:status_casters", "//xla/python/ifrt", "//xla/python/pjrt_ifrt", "@local_tsl//tsl/platform:errors", @@ -680,12 +645,10 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", ":jax_jit", ":py_client", ":python_utils", ":pytree", - ":status_casters", ":types", ":util", # placeholder for index annotation deps @@ -696,7 +659,9 @@ cc_library( "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//xla:xla_data_proto_cc", + "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_client", + "//xla/pjrt:status_casters", "//xla/python/ifrt", "//xla/python/pjrt_ifrt", "@local_tsl//tsl/platform:logging", @@ -740,7 +705,6 @@ cc_library( deps = [ ":outfeed_receiver", ":py_client", - ":status_casters", ":types", # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", @@ -748,6 +712,7 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_builder", "//xla/pjrt:pjrt_client", + "//xla/pjrt:status_casters", "@pybind11", ], ) @@ -786,7 +751,6 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", ":pytree_proto_cc", # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", @@ -795,6 +759,7 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "//xla/pjrt:exceptions", "@local_tsl//tsl/platform:logging", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", @@ -814,7 +779,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":refine_polymorphic_shapes", - ":status_casters", ":types", # placeholder for index annotation deps "//xla:status", @@ -823,6 +787,7 @@ cc_library( "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:status_casters", "//xla/service/llvm_ir:llvm_util", "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_tsl//tsl/platform:errors", @@ -877,8 +842,6 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", - ":status_casters", ":types", ":xplane_to_profile_instructions", # placeholder for index annotation deps @@ -888,6 +851,8 @@ cc_library( "//xla/backends/profiler/cpu:python_tracer", "//xla/backends/profiler/plugin:plugin_tracer", "//xla/backends/profiler/plugin:profiler_c_api_hdrs", + "//xla/pjrt:exceptions", + "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "//xla/python/profiler/internal:traceme_wrapper", @@ -918,11 +883,11 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":status_casters", # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", "//xla:status", "//xla:util", + "//xla/pjrt:status_casters", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", ], @@ -983,9 +948,7 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:public"], deps = [ - ":exceptions", ":py_client", - ":status_casters", ":types", # placeholder for index annotation deps "@com_google_absl//absl/hash", @@ -1004,6 +967,8 @@ cc_library( "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/pjrt:exceptions", + "//xla/pjrt:status_casters", "//xla/service:call_inliner", "//xla/service:computation_placer", "//xla/service:custom_call_target_registry", @@ -1139,7 +1104,6 @@ cc_library( ":python_ref_manager", ":pytree", ":refine_polymorphic_shapes", - ":status_casters", ":traceback", ":transfer_guard_lib", ":types", @@ -1168,6 +1132,7 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/cpu:cpu_client", "//xla/pjrt/distributed", @@ -1241,3 +1206,12 @@ xla_cc_test( "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) + +cc_library( + name = "status_casters", + hdrs = ["status_casters.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla/pjrt:status_casters", + ], +) diff --git a/third_party/xla/xla/python/custom_call_sharding.cc b/third_party/xla/xla/python/custom_call_sharding.cc index 9b5e5e1dfef810..68118f57bb1a42 100644 --- a/third_party/xla/xla/python/custom_call_sharding.cc +++ b/third_party/xla/xla/python/custom_call_sharding.cc @@ -30,8 +30,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/inspect_sharding.h" -#include "xla/python/status_casters.h" #include "xla/service/call_inliner.h" #include "xla/service/custom_call_sharding_helper.h" #include "xla/service/hlo_pass_pipeline.h" diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index fa5f666a0e3b20..301b20180099c9 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -46,10 +46,10 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/py_values.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "tsl/platform/status.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/python/mlir.cc b/third_party/xla/xla/python/mlir.cc index 77dc23bc7b6409..e01c8b078cae22 100644 --- a/third_party/xla/xla/python/mlir.cc +++ b/third_party/xla/xla/python/mlir.cc @@ -37,8 +37,8 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status.h" diff --git a/third_party/xla/xla/python/ops.cc b/third_party/xla/xla/python/ops.cc index 2d44a5bf41523d..8bf0003b65e996 100644 --- a/third_party/xla/xla/python/ops.cc +++ b/third_party/xla/xla/python/ops.cc @@ -35,7 +35,7 @@ limitations under the License. #include "xla/client/lib/svd.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" -#include "xla/python/status_casters.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/python/outfeed_receiver_py.cc b/third_party/xla/xla/python/outfeed_receiver_py.cc index addd479f7bb6e3..55db82ef9f7187 100644 --- a/third_party/xla/xla/python/outfeed_receiver_py.cc +++ b/third_party/xla/xla/python/outfeed_receiver_py.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/client/executable_build_options.h" #include "xla/client/xla_builder.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/outfeed_receiver.h" #include "xla/python/py_client.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" namespace xla { diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index 76e3dfc5a9e99a..2fe408f24e38cd 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" #include "xla/python/jax_jit.h" #include "xla/python/py_array.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/python/python_utils.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/transfer_guard_lib.h" #include "xla/python/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/python/pmap_lib.cc b/third_party/xla/xla/python/pmap_lib.cc index aed4ca042331ea..3481bf47dc1d64 100644 --- a/third_party/xla/xla/python/pmap_lib.cc +++ b/third_party/xla/xla/python/pmap_lib.cc @@ -36,7 +36,8 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/memory.h" @@ -50,7 +51,6 @@ limitations under the License. #include "xla/python/pytree.h" #include "xla/python/sharded_device_array.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "xla/python/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/python/profiler.cc b/third_party/xla/xla/python/profiler.cc index 52b611bca62391..566c14a59d327e 100644 --- a/third_party/xla/xla/python/profiler.cc +++ b/third_party/xla/xla/python/profiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/profiler/internal/traceme_wrapper.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "xla/python/xplane_to_profile_instructions.h" #include "xla/status.h" diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 8986cae185fcba..354eb40051dfda 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -31,6 +31,7 @@ limitations under the License. #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" @@ -41,7 +42,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/python_utils.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/transfer_guard_lib.h" #include "xla/python/util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/python/py_buffer.cc b/third_party/xla/xla/python/py_buffer.cc index ed98c84c81fa20..79f927231313bd 100644 --- a/third_party/xla/xla/python/py_buffer.cc +++ b/third_party/xla/xla/python/py_buffer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "pybind11/pytypes.h" // from @pybind11 #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/python/py_client.h" #include "xla/python/python_ref_manager.h" #include "xla/python/python_utils.h" -#include "xla/python/status_casters.h" #include "xla/python/transfer_guard_lib.h" #include "xla/python/types.h" #include "xla/python/util.h" diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 2caed4eb9562fc..ad67d0640ce586 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -23,12 +23,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/python/callback.h" -#include "xla/python/exceptions.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/executable.h" diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 581801831ffec6..3c15b1057dc868 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -26,9 +26,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "pybind11/pybind11.h" // from @pybind11 #include "xla/client/xla_builder.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" -#include "xla/python/exceptions.h" #include "xla/python/ifrt/client.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/statusor.h" diff --git a/third_party/xla/xla/python/py_client_gpu.cc b/third_party/xla/xla/python/py_client_gpu.cc index 6a5caf680243d0..36b73f3d65ce6f 100644 --- a/third_party/xla/xla/python/py_client_gpu.cc +++ b/third_party/xla/xla/python/py_client_gpu.cc @@ -25,9 +25,9 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #endif #include "pybind11/pybind11.h" // from @pybind11 +#include "xla/pjrt/exceptions.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" -#include "xla/python/exceptions.h" #if TENSORFLOW_USE_ROCM #define gpuSuccess hipSuccess diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index 78908290c64b04..a2031c40a3b0bb 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -26,8 +26,8 @@ limitations under the License. #include "absl/types/span.h" #include "pybind11/stl.h" // from @pybind11 #include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" -#include "xla/python/status_casters.h" #include "tsl/python/lib/core/numpy.h" //NOLINT namespace xla { diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 1a47247bc5a21e..b5d42ec0f7ced5 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -41,7 +41,7 @@ limitations under the License. #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/python/sharding.h b/third_party/xla/xla/python/sharding.h index 0e780b10c37835..5fbd70e61fdcd5 100644 --- a/third_party/xla/xla/python/sharding.h +++ b/third_party/xla/xla/python/sharding.h @@ -30,11 +30,11 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" #include "xla/python/py_client.h" #include "xla/python/py_device_list.h" #include "xla/python/sharded_device_array.h" -#include "xla/python/status_casters.h" #include "xla/xla_data.pb.h" namespace jax { diff --git a/third_party/xla/xla/python/status_casters.h b/third_party/xla/xla/python/status_casters.h index 1b7020005360da..69e0abd6e742a3 100644 --- a/third_party/xla/xla/python/status_casters.h +++ b/third_party/xla/xla/python/status_casters.h @@ -16,205 +16,8 @@ limitations under the License. #ifndef XLA_PYTHON_STATUS_CASTERS_H_ #define XLA_PYTHON_STATUS_CASTERS_H_ -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "tsl/platform/macros.h" - -namespace xla { - -// C++ -> Python caster helpers. -// -// Failing statuses become Python exceptions; OK Status() becomes None. -// -// Given there can be only a single global pybind11 type_caster for the -// `absl::Status` type, and given XLA wants a custom exception being raised, -// we use a dedicated helper to implement this feature without relying on a -// global `type_caster`. -// -// For example: -// -// - Functions without arguments: -// m.def("my_func", []() { xla::ThrowIfError(MyFunc()); } -// - Classes with a single argument: -// py_class.def("delete", [](Buffer& self) { -// xla::ThrowIfError(self.Delete()); -// } -// -// For functions with more arguments, you can either inline the arguments, -// or use the `ThrowIfErrorWrapper` wrapper defined below: -// -// m.def("my_func", xla::ThrowIfErrorWrapper(MyFunc)); -// -// Nonstatic member functions can be wrapped by passing a -// pointer-to-member-function: -// xla::ThrowIfErrorWrapper(&MyClass::MyMethod) - -inline void ThrowIfError(xla::Status src) { - if (!src.ok()) { - throw xla::XlaRuntimeError(src); - } -} - -// If one does not want to have to define a lambda specifying the inputs -// arguments, on can use the `ThrowIfErrorWrapper` wrapper. -// -// There are three specializations: -// - For free functions, `Sig` is the function type and `F` is `Sig&`. -// - For callable types, `Sig` is the pointer to member function type -// and `F` is the type of the callable. -// - For a nonstatic member function of a class `C`, `Sig` is the function type -// and `F` is Sig C::*. -// -// In the first two cases, the wrapper returns a callable with signature `Sig`; -// in the third case, the wrapper returns callable with a modified signature -// that takes a C instance as the first argument. -template -struct ThrowIfErrorWrapper; - -// C++17 "deduction guide" that guides class template argument deduction (CTAD) -// For free functions. -template -ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; - -// For callable types (with operator()). -template -ThrowIfErrorWrapper(xla::Status (&)(Args...)) - -> ThrowIfErrorWrapper; - -// For unbound nonstatic member functions. -template -ThrowIfErrorWrapper(xla::Status (C::*)(Args...)) - -> ThrowIfErrorWrapper; - -// Template specializations. - -// For free functions. -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (&f)(Args...)) : func(f) {} - void operator()(Args... args) { - xla::ThrowIfError(func(std::forward(args)...)); - } - xla::Status (&func)(Args...); -}; - -// For callable types (with operator()), non-const and const versions. -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} - void operator()(Args... args) { - xla::ThrowIfError(func(std::forward(args)...)); - } - F func; -}; -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} - void operator()(Args... args) const { - xla::ThrowIfError(func(std::forward(args)...)); - } - F func; -}; - -// For unbound nonstatic member functions, non-const and const versions. -// `ptmf` stands for "pointer to member function". -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} - void operator()(C& instance, Args... args) { - xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); - } - xla::Status (C::*ptmf)(Args...); -}; -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...) const) - : ptmf(ptmf) {} - void operator()(const C& instance, Args... args) const { - xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); - } - xla::Status (C::*ptmf)(Args...) const; -}; - -// Utilities for `StatusOr`. -template -T ValueOrThrow(StatusOr v) { - if (!v.ok()) { - throw xla::XlaRuntimeError(v.status()); - } - return std::move(v).value(); -} - -template -struct ValueOrThrowWrapper; - -template -ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; - -template -ValueOrThrowWrapper(xla::StatusOr (&)(Args...)) - -> ValueOrThrowWrapper(Args...), - xla::StatusOr (&)(Args...)>; - -template -ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...)) - -> ValueOrThrowWrapper(Args...), C>; - -// Deduction guide for const methods. -template -ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...) const) - -> ValueOrThrowWrapper(Args...) const, C>; - -template -struct ValueOrThrowWrapper(Args...), - xla::StatusOr (&)(Args...)> { - explicit ValueOrThrowWrapper(xla::StatusOr (&f)(Args...)) : func(f) {} - R operator()(Args... args) { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - xla::StatusOr (&func)(Args...); -}; -template -struct ValueOrThrowWrapper (C::*)(Args...), F> { - explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} - R operator()(Args... args) { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - F func; -}; -template -struct ValueOrThrowWrapper (C::*)(Args...) const, F> { - explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} - R operator()(Args... args) const { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - F func; -}; - -// For unbound nonstatic member functions, non-const and const versions. -// `ptmf` stands for "pointer to member function". -template -struct ValueOrThrowWrapper(Args...), C> { - explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...)) - : ptmf(ptmf) {} - R operator()(C& instance, Args... args) { - return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); - } - xla::StatusOr (C::*ptmf)(Args...); -}; -template -struct ValueOrThrowWrapper(Args...) const, C> { - explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...) const) - : ptmf(ptmf) {} - R operator()(const C& instance, Args... args) const { - return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); - } - xla::StatusOr (C::*ptmf)(Args...) const; -}; - -} // namespace xla +// Forwarding header. +// TODO(phawkins): update users to use the new header location. +#include "xla/pjrt/status_casters.h" #endif // XLA_PYTHON_STATUS_CASTERS_H_ diff --git a/third_party/xla/xla/python/status_casters_ext.cc b/third_party/xla/xla/python/status_casters_ext.cc index 8f903da84427e1..df34ca8ce9d511 100644 --- a/third_party/xla/xla/python/status_casters_ext.cc +++ b/third_party/xla/xla/python/status_casters_ext.cc @@ -15,8 +15,8 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" -#include "xla/python/status_casters.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" namespace xla { diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc index 88d6965b56e236..93389420ea573a 100644 --- a/third_party/xla/xla/python/traceback.cc +++ b/third_party/xla/xla/python/traceback.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" #include "xla/python/python_ref_manager.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/python/transfer_guard_lib.cc b/third_party/xla/xla/python/transfer_guard_lib.cc index 94a720cab36317..f00ba07337edad 100644 --- a/third_party/xla/xla/python/transfer_guard_lib.cc +++ b/third_party/xla/xla/python/transfer_guard_lib.cc @@ -26,7 +26,7 @@ limitations under the License. #include "pybind11/cast.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/status_casters.h" +#include "xla/pjrt/status_casters.h" #include "xla/status.h" #include "xla/util.h" diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index 288fcf92543fb6..3b86e3056e15c0 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" #include "xla/python/ifrt/dtype.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 4e02186dfe93b0..a9fc7a31e6b391 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -64,6 +64,7 @@ limitations under the License. #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/custom_call_sharding.h" #include "xla/python/dlpack.h" #include "xla/python/jax_jit.h" @@ -84,7 +85,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" #include "xla/python/types.h" diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index ac0b0933eae6a1..84f8f22e3a7828 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -45,9 +45,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/layout.h" #include "xla/layout_util.h" -#include "xla/python/exceptions.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/py_client.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" From 8d464245b33eaad42c5ea63cc15f182b04290483 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 4 Dec 2023 08:12:05 -0800 Subject: [PATCH 342/381] [TileAnalysis] Add indexing computation based on input dimensions. We will need this for Triton's epilogue fusion. PiperOrigin-RevId: 587725007 --- .../xla/service/gpu/model/tile_analysis.cc | 133 ++++++++--- .../xla/xla/service/gpu/model/tile_analysis.h | 10 +- .../service/gpu/model/tile_analysis_test.cc | 223 +++++++++++------- 3 files changed, 250 insertions(+), 116 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 203de7c8618b5c..1cb1fb63e58739 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -68,7 +68,7 @@ using mlir::getAffineConstantExpr; using mlir::getAffineDimExpr; using mlir::MLIRContext; -StatusOr ComputeCwiseOpIndexing( +StatusOr ComputeOutputToInputCwiseOpIndexing( const HloInstruction* instr, MLIRContext* mlir_context) { auto dims = instr->shape().dimensions(); IndexingMap identity_map{.affine_map = AffineMap::getMultiDimIdentityMap( @@ -78,12 +78,21 @@ StatusOr ComputeCwiseOpIndexing( HloInstructionIndexing instr_indexing; int64_t operand_count = instr->operand_count(); for (int64_t operand_id = 0; operand_id < operand_count; ++operand_id) { - instr_indexing.operand_indexing_maps[operand_id].insert(identity_map); + instr_indexing.indexing_maps[operand_id].insert(identity_map); } return instr_indexing; } -StatusOr ComputeBroadcastOpIndexing( +StatusOr ComputeInputToOutputCwiseOpIndexing( + const HloInstruction* instr, MLIRContext* mlir_context) { + auto dims = instr->shape().dimensions(); + IndexingMap identity_map{.affine_map = AffineMap::getMultiDimIdentityMap( + dims.size(), mlir_context), + .input_dims_sizes = {}}; + return HloInstructionIndexing::FromIndexingMaps({identity_map}); +} + +StatusOr ComputeOutputToInputBroadcastOpIndexing( const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { auto output_dims = bcast->shape().dimensions(); @@ -171,9 +180,9 @@ HloInstructionIndexing ComputeFusedProducerConsumerIndexing( // have 1 or more read accesses to its operands. So, to get the composed // indexing maps we have to compute a "cross product" here. for (const auto& [producer_operand_id, producer_operand_indexing] : - producer_indexing.operand_indexing_maps) { + producer_indexing.indexing_maps) { auto& composed_operand_indexing = - fused_instr_indexing.operand_indexing_maps[producer_operand_id]; + fused_instr_indexing.indexing_maps[producer_operand_id]; for (const IndexingMap& producer_map : producer_operand_indexing) { for (const IndexingMap& consumer_map : operand_indexing_maps) { composed_operand_indexing.insert( @@ -186,7 +195,7 @@ HloInstructionIndexing ComputeFusedProducerConsumerIndexing( // Composes instruction indexing maps starting at the root instruction // until the HloParameterInstruction is found. -StatusOr ComputeFusionOpIndexing( +StatusOr ComputeOutputToInputFusionOpIndexing( const HloFusionInstruction* fusion, int output_id, MLIRContext* mlir_context) { const HloInstruction* root = @@ -194,7 +203,7 @@ StatusOr ComputeFusionOpIndexing( ? fusion->fused_expression_root()->operand(output_id) : fusion->fused_expression_root(); std::queue> bfs; - TF_ASSIGN_OR_RETURN(auto root_indexing, ComputeInstructionIndexing( + TF_ASSIGN_OR_RETURN(auto root_indexing, ComputeOutputToInputIndexing( root, output_id, mlir_context)); bfs.push(std::make_pair(root, root_indexing)); @@ -203,7 +212,7 @@ StatusOr ComputeFusionOpIndexing( while (!bfs.empty()) { const auto& [instr, instr_indexing] = bfs.front(); for (const auto& [operand_id, operand_indexing_maps] : - instr_indexing.operand_indexing_maps) { + instr_indexing.indexing_maps) { const HloInstruction* producer_instr = instr->operand(operand_id); if (producer_instr->IsConstant()) continue; // If the producer is a fusion op parameter, store the result. @@ -213,7 +222,7 @@ StatusOr ComputeFusionOpIndexing( continue; } TF_ASSIGN_OR_RETURN(auto producer_instr_indexing, - ComputeInstructionIndexing( + ComputeOutputToInputIndexing( producer_instr, /*output_id=*/0, mlir_context)); bfs.push(std::make_pair( producer_instr, ComputeFusedProducerConsumerIndexing( @@ -221,11 +230,11 @@ StatusOr ComputeFusionOpIndexing( } bfs.pop(); } - return HloInstructionIndexing{.operand_indexing_maps = + return HloInstructionIndexing{.indexing_maps = std::move(parameter_indexing_maps)}; } -StatusOr ComputeDotOpIndexing( +StatusOr ComputeOutputToInputDotOpIndexing( const HloDotInstruction* dot, MLIRContext* mlir_context) { CHECK_NE(dot, nullptr); const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); @@ -304,7 +313,7 @@ StatusOr ComputeDotOpIndexing( {lhs_indexing_map, rhs_indexing_map}); } -StatusOr ComputeReduceOpIndexing( +StatusOr ComputeOutputToInputReduceOpIndexing( const HloReduceInstruction* reduce, int output_id, MLIRContext* mlir_context) { absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), @@ -490,7 +499,7 @@ IndexingMap ComputeReshapeIndexingMap(absl::Span input_dims, .input_dims_sizes = {}}; } -StatusOr ComputeReshapeOpIndexing( +StatusOr ComputeOutputToInputReshapeOpIndexing( const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { auto input_dims = reshape->operand(0)->shape().dimensions(); auto output_dims = reshape->shape().dimensions(); @@ -500,6 +509,16 @@ StatusOr ComputeReshapeOpIndexing( return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } +StatusOr ComputeInputToOutputReshapeOpIndexing( + const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { + auto input_dims = reshape->operand(0)->shape().dimensions(); + auto output_dims = reshape->shape().dimensions(); + IndexingMap reshape_indexing_map = + ComputeReshapeIndexingMap(output_dims, input_dims, mlir_context); + + return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); +} + StatusOr ComputeReverseOpIndexing( const HloReverseInstruction* reverse, MLIRContext* mlir_context) { absl::flat_hash_set reverse_dims(reverse->dimensions().begin(), @@ -528,7 +547,7 @@ StatusOr ComputeReverseOpIndexing( return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } -StatusOr ComputeSliceOpIndexing( +StatusOr ComputeOutputToInputSliceOpIndexing( const HloSliceInstruction* slice, MLIRContext* mlir_context) { auto output_dims = slice->shape().dimensions(); @@ -561,16 +580,25 @@ IndexingMap ComputeTransposeIndexingMap(absl::Span permutation, .input_dims_sizes = {}}; } -StatusOr ComputeTransposeOpIndexing( +StatusOr ComputeOutputToInputTransposeOpIndexing( const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { return HloInstructionIndexing::FromIndexingMaps( {ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context)}); } -StatusOr ComputeBitcastOpIndexing( - const HloInstruction* bitcast, MLIRContext* mlir_context) { - const Shape& input_shape = bitcast->operand(0)->shape(); - const Shape& output_shape = bitcast->shape(); +StatusOr ComputeInputToOutputTransposeOpIndexing( + const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { + auto forward_permutation = AffineMap::getPermutationMap( + std::vector(transpose->dimensions().begin(), + transpose->dimensions().end()), + mlir_context); + return HloInstructionIndexing::FromIndexingMaps( + {IndexingMap{.affine_map = forward_permutation, .input_dims_sizes = {}}}); +} + +StatusOr ComputeOutputToInputBitcastOpIndexingImpl( + const Shape& input_shape, const Shape& output_shape, + MLIRContext* mlir_context) { ShapeUtil::BitcastDecomposition decomposed_bitcast = ShapeUtil::DecomposeBitcast(input_shape, output_shape); @@ -603,6 +631,22 @@ StatusOr ComputeBitcastOpIndexing( return HloInstructionIndexing::FromIndexingMaps({composed_map}); } +StatusOr ComputeOutputToInputBitcastOpIndexing( + const HloInstruction* bitcast, MLIRContext* mlir_context) { + const Shape& input_shape = bitcast->operand(0)->shape(); + const Shape& output_shape = bitcast->shape(); + return ComputeOutputToInputBitcastOpIndexingImpl(input_shape, output_shape, + mlir_context); +} + +StatusOr ComputeInputToOutputBitcastOpIndexing( + const HloInstruction* bitcast, MLIRContext* mlir_context) { + const Shape& input_shape = bitcast->operand(0)->shape(); + const Shape& output_shape = bitcast->shape(); + return ComputeOutputToInputBitcastOpIndexingImpl(output_shape, input_shape, + mlir_context); +} + template std::string ToStringImpl(const T& value) { std::string s; @@ -863,7 +907,7 @@ bool IndexingMap::Simplify(absl::Span dimension_sizes) { bool HloInstructionIndexing::Simplify( absl::Span dimension_sizes) { bool any_simplified = false; - for (auto& operand_indexing : operand_indexing_maps) { + for (auto& operand_indexing : indexing_maps) { std::vector to_remove; std::vector to_add; absl::flat_hash_set& indexing_maps = operand_indexing.second; @@ -906,8 +950,7 @@ std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing) { - for (const auto& [operand_id, indexing_maps] : - instr_indexing.operand_indexing_maps) { + for (const auto& [operand_id, indexing_maps] : instr_indexing.indexing_maps) { out << "operand id = " << operand_id << ' '; for (const auto& indexing_map : indexing_maps) { out << indexing_map; @@ -921,9 +964,9 @@ std::string IndexingMap::ToString() const { return ToStringImpl(*this); } HloInstructionIndexing HloInstructionIndexing::FromIndexingMaps( absl::Span indexing_maps) { HloInstructionIndexing instr_indexing; - instr_indexing.operand_indexing_maps.reserve(indexing_maps.size()); + instr_indexing.indexing_maps.reserve(indexing_maps.size()); for (const auto& [index, map] : llvm::enumerate(indexing_maps)) { - instr_indexing.operand_indexing_maps[index].insert(map); + instr_indexing.indexing_maps[index].insert(map); } return instr_indexing; } @@ -932,37 +975,59 @@ std::string HloInstructionIndexing::ToString() const { return ToStringImpl(*this); } -StatusOr ComputeInstructionIndexing( +StatusOr ComputeOutputToInputIndexing( const HloInstruction* instr, int output_id, MLIRContext* mlir_context) { if (HloInstruction::IsOpElementwise(instr->opcode())) { - return ComputeCwiseOpIndexing(instr, mlir_context); + return ComputeOutputToInputCwiseOpIndexing(instr, mlir_context); } if (auto bcast = DynCast(instr)) { - return ComputeBroadcastOpIndexing(bcast, mlir_context); + return ComputeOutputToInputBroadcastOpIndexing(bcast, mlir_context); } if (instr->opcode() == HloOpcode::kBitcast) { - return ComputeBitcastOpIndexing(instr, mlir_context); + return ComputeOutputToInputBitcastOpIndexing(instr, mlir_context); } if (auto dot = DynCast(instr)) { - return ComputeDotOpIndexing(dot, mlir_context); + return ComputeOutputToInputDotOpIndexing(dot, mlir_context); } if (auto fusion = DynCast(instr)) { - return ComputeFusionOpIndexing(fusion, output_id, mlir_context); + return ComputeOutputToInputFusionOpIndexing(fusion, output_id, + mlir_context); } if (auto reduce = DynCast(instr)) { - return ComputeReduceOpIndexing(reduce, output_id, mlir_context); + return ComputeOutputToInputReduceOpIndexing(reduce, output_id, + mlir_context); } if (auto reshape = DynCast(instr)) { - return ComputeReshapeOpIndexing(reshape, mlir_context); + return ComputeOutputToInputReshapeOpIndexing(reshape, mlir_context); } if (auto reverse = DynCast(instr)) { return ComputeReverseOpIndexing(reverse, mlir_context); } if (auto slice = DynCast(instr)) { - return ComputeSliceOpIndexing(slice, mlir_context); + return ComputeOutputToInputSliceOpIndexing(slice, mlir_context); + } + if (auto transpose = DynCast(instr)) { + return ComputeOutputToInputTransposeOpIndexing(transpose, mlir_context); + } + return InvalidArgument("Unsupported instruction type"); +} + +StatusOr ComputeInputToOutputIndexing( + const HloInstruction* instr, int input_id, MLIRContext* mlir_context) { + if (HloInstruction::IsOpElementwise(instr->opcode())) { + return ComputeInputToOutputCwiseOpIndexing(instr, mlir_context); + } + if (instr->opcode() == HloOpcode::kBitcast) { + return ComputeInputToOutputBitcastOpIndexing(instr, mlir_context); + } + if (auto reshape = DynCast(instr)) { + return ComputeInputToOutputReshapeOpIndexing(reshape, mlir_context); + } + if (auto reverse = DynCast(instr)) { + return ComputeReverseOpIndexing(reverse, mlir_context); } if (auto transpose = DynCast(instr)) { - return ComputeTransposeOpIndexing(transpose, mlir_context); + return ComputeInputToOutputTransposeOpIndexing(transpose, mlir_context); } return InvalidArgument("Unsupported instruction type"); } diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.h b/third_party/xla/xla/service/gpu/model/tile_analysis.h index 734d8e168f0007..ea356ccb24b7aa 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.h @@ -99,8 +99,7 @@ struct HloInstructionIndexing { absl::Span indexing_maps); // Maps input operand index to the indexing map for one particular output. - absl::flat_hash_map> - operand_indexing_maps; + absl::flat_hash_map> indexing_maps; }; std::ostream& operator<<(std::ostream& out, const HloInstructionIndexing& instr_indexing); @@ -109,10 +108,15 @@ std::string ToString(const mlir::AffineMap& affine_map); // Computes indexing maps for all input operands necessary to compute an element // of the `output_id` instruction output. -StatusOr ComputeInstructionIndexing( +StatusOr ComputeOutputToInputIndexing( const HloInstruction* instr, int output_id, mlir::MLIRContext* mlir_context); +// Computes indexing maps for all output operands that the element of the +// `input_id` instruction input will participate in. +StatusOr ComputeInputToOutputIndexing( + const HloInstruction* instr, int input_id, mlir::MLIRContext* mlir_context); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc index 377453a3d8e4aa..c9388b81a83360 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis_test.cc @@ -59,8 +59,8 @@ MATCHER_P2(MatchInstrIndexing, operand_id, indexing_map_matchers, "") { class TileAnalysisTest : public HloTestBase { public: - StatusOr GetIndexingMapsForEntryComputation( - absl::string_view hlo_string, int operand_id = 0) { + StatusOr GetOutputToInputIndexingForEntryComputation( + absl::string_view hlo_string, int output_id = 0) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -70,33 +70,61 @@ class TileAnalysisTest : public HloTestBase { << "If there are multiple instructions, they need to be wrapped in a " "fusion."; } + return ComputeOutputToInputIndexing(root, output_id, &mlir_context_); + } + + StatusOr GetInputToOutputIndexingForEntryComputation( + absl::string_view hlo_string, int input_id = 0) { + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); - return ComputeInstructionIndexing(root, operand_id, &mlir_context_); + for (auto* operand : root->operands()) { + TF_RET_CHECK(operand->opcode() == HloOpcode::kParameter || + operand->opcode() == HloOpcode::kConstant) + << "If there are multiple instructions, they need to be wrapped in a " + "fusion."; + } + return ComputeInputToOutputIndexing(root, input_id, &mlir_context_); } mlir::MLIRContext mlir_context_; }; TEST_F(TileAnalysisTest, ElementwiseOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + auto ir = R"( HloModule m ENTRY e { p0 = f32[10, 20] parameter(0) p1 = f32[10, 20] parameter(1) ROOT add0 = f32[10, 20] add(p0, p1) } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + )"; + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetOutputToInputIndexingForEntryComputation(ir)); + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))), Pair(1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))))); + TF_ASSERT_OK_AND_ASSIGN( + auto output_indexing, + GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/0)); + EXPECT_THAT(output_indexing.indexing_maps, + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); + TF_ASSERT_OK_AND_ASSIGN( + auto output_indexing1, + GetInputToOutputIndexingForEntryComputation(ir, /*input_id=*/1)); + EXPECT_THAT(output_indexing1.indexing_maps, + UnorderedElementsAre( + Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BitcastIsReshape) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[4, 32] parameter(0) @@ -104,7 +132,7 @@ TEST_F(TileAnalysisTest, BitcastIsReshape) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1 * 4 + d2)", std::vector{}))))); @@ -112,7 +140,7 @@ TEST_F(TileAnalysisTest, BitcastIsReshape) { TEST_F(TileAnalysisTest, BitcastIsTranspose) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[3, 12288, 6, 128] parameter(0) @@ -120,38 +148,46 @@ TEST_F(TileAnalysisTest, BitcastIsTranspose) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", std::vector{}))))); } TEST_F(TileAnalysisTest, BitcastIsTransposeReshapeTranspose) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + auto ir = R"( HloModule m ENTRY e { p0 = f32[16, 17, 3] parameter(0) ROOT bitcast = f32[51, 16] {0, 1} bitcast(p0) } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + )"; + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetOutputToInputIndexingForEntryComputation(ir)); + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1) -> (d1, d0 floordiv 3, d0 mod 3)", std::vector{}))))); + TF_ASSERT_OK_AND_ASSIGN(auto output_indexing, + GetInputToOutputIndexingForEntryComputation(ir)); + EXPECT_THAT( + output_indexing.indexing_maps, + UnorderedElementsAre(Pair( + 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1 * 3 + d2, d0)", + std::vector{}))))); } TEST_F(TileAnalysisTest, BroadcastOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[20] parameter(0) ROOT bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", std::vector{}))))); @@ -159,7 +195,7 @@ TEST_F(TileAnalysisTest, BroadcastOp) { TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[100] parameter(0) @@ -173,7 +209,7 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0) -> (d0)", std::vector{}))), Pair(1, ElementsAre(MatchIndexingMap( @@ -182,7 +218,7 @@ TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { TEST_F(TileAnalysisTest, FusionOpWithDot) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( f { p0 = s8[3,12288,6,128]{3,2,1,0} parameter(0) bitcast1 = s8[3,6,128,12288]{2,1,3,0} bitcast(p0) @@ -243,7 +279,7 @@ TEST_F(TileAnalysisTest, FusionOpWithDot) { EXPECT_TRUE(input_indexing.Simplify({16, 16, 3, 1, 6, 128})); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap("(d0, d1, d2, d3, d4, d5)[s0] -> " @@ -267,7 +303,7 @@ TEST_F(TileAnalysisTest, FusionOpWithDot) { TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[1000, 1000] parameter(0) @@ -280,7 +316,7 @@ TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair( 0, UnorderedElementsAre( @@ -290,7 +326,7 @@ TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { TEST_F(TileAnalysisTest, FusionExponentialDuplication) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule test_module fused_computation { @@ -311,7 +347,7 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { ROOT fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=fused_computation })")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, UnorderedElementsAre( @@ -327,7 +363,7 @@ TEST_F(TileAnalysisTest, FusionExponentialDuplication) { TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -349,7 +385,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", std::vector{150, 50, 20}))))); @@ -357,7 +393,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -378,7 +414,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1)[s0] -> (d0, s0)", std::vector{20}))))); @@ -386,7 +422,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[20, 10, 50] parameter(0) @@ -410,7 +446,7 @@ TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d2, d0, d1)", std::vector{}))))); @@ -418,7 +454,7 @@ TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -439,7 +475,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { ROOT fusion = f32[32] fusion(p0, p0_init), kind=kLoop, calls=f } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap( "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", @@ -448,7 +484,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[128] parameter(0) @@ -460,14 +496,14 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { ROOT fusion = f32[128] fusion(p0), kind=kLoop, calls=f } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0) -> (d0)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[8, 16] parameter(0) @@ -481,14 +517,14 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { )")); EXPECT_TRUE(input_indexing.Simplify({8, 16})); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))))); } TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[10, 10, 10] parameter(0) @@ -501,7 +537,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { } )")); EXPECT_TRUE(input_indexing.Simplify({10, 10, 10})); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0, d1, d2)", std::vector{}))))); @@ -509,7 +545,7 @@ TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m f { p0 = f32[150, 64, 1024] parameter(0) @@ -524,7 +560,7 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre( Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", @@ -533,7 +569,7 @@ TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[4,8] parameter(0) @@ -541,7 +577,7 @@ TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { } )")); EXPECT_FALSE(input_indexing.Simplify({32})); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0) -> (d0 floordiv 8, d0 mod 8)", std::vector{}))))); @@ -549,7 +585,7 @@ TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[32] parameter(0) @@ -557,33 +593,44 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { } )")); EXPECT_FALSE(input_indexing.Simplify({4, 8})); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair( 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0 * 8 + d1)", std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpExpandAndCollapseShape) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + auto ir = R"( HloModule m ENTRY e { p0 = f32[4, 8, 12] parameter(0) ROOT reshape = f32[32, 3, 4] reshape(p0) } - )")); + )"; + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetOutputToInputIndexingForEntryComputation(ir)); EXPECT_FALSE(input_indexing.Simplify({32, 3, 4})); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre( Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)", std::vector{}))))); + + TF_ASSERT_OK_AND_ASSIGN(auto output_indexing, + GetInputToOutputIndexingForEntryComputation(ir)); + EXPECT_FALSE(output_indexing.Simplify({4, 8, 12})); + EXPECT_THAT( + output_indexing.indexing_maps, + ElementsAre( + Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4)", + std::vector{}))))); } TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[16, 8] parameter(0) @@ -591,7 +638,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { } )")); EXPECT_FALSE(input_indexing.Simplify({4, 4, 8})); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 * 4 + d1, d2)", std::vector{}))))); @@ -599,7 +646,7 @@ TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[4,8] parameter(0) @@ -608,7 +655,7 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { )")); EXPECT_TRUE(input_indexing.Simplify({2, 4, 4})); // TODO(b/313840171): Simplify `(d1 * 4 + d2) floordiv 8` to `d1 floordiv 2`. - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair( 0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, " @@ -618,7 +665,7 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[2, 4, 4] parameter(0) @@ -630,7 +677,7 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { // TODO(b/313840171): Simplify `((d0 * 8 + d1) mod 16) floordiv 4` to // `((d0 * 8 + d1) floordiv 4) mod 4` to `(d0 * 2 + d1 floordiv 4) mod 4`. EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1) -> ((d0 * 8 + d1) floordiv 16, " "((d0 * 8 + d1) mod 16) floordiv 4, d1 mod 4)", @@ -639,7 +686,7 @@ TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { TEST_F(TileAnalysisTest, ReduceOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m max { p0 = f32[] parameter(0) @@ -653,7 +700,7 @@ TEST_F(TileAnalysisTest, ReduceOp) { dimensions={3, 1}, to_apply=max } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", std::vector{20, 50}))))); @@ -685,20 +732,20 @@ TEST_F(TileAnalysisTest, VariadicReduceOp) { ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); - auto input_indexing_0 = ComputeInstructionIndexing(root, 0, &mlir_context_); + auto input_indexing_0 = ComputeOutputToInputIndexing(root, 0, &mlir_context_); ASSERT_IS_OK(input_indexing_0); EXPECT_THAT( - input_indexing_0->operand_indexing_maps, + input_indexing_0->indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", std::vector{256}))), Pair(1, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", std::vector{256}))))); - auto input_indexing_1 = ComputeInstructionIndexing(root, 1, &mlir_context_); + auto input_indexing_1 = ComputeOutputToInputIndexing(root, 1, &mlir_context_); ASSERT_IS_OK(input_indexing_1); EXPECT_THAT( - input_indexing_1->operand_indexing_maps, + input_indexing_1->indexing_maps, UnorderedElementsAre( Pair(0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", std::vector{256}))), @@ -707,17 +754,27 @@ TEST_F(TileAnalysisTest, VariadicReduceOp) { } TEST_F(TileAnalysisTest, ReverseOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + auto ir = R"( HloModule m ENTRY e { p0 = f32[1, 17, 9, 9] parameter(0) ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} } - )")); + )"; + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetOutputToInputIndexingForEntryComputation(ir)); EXPECT_FALSE(input_indexing.Simplify({1, 17, 9, 9})); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)", + std::vector{}))))); + + TF_ASSERT_OK_AND_ASSIGN(auto output_indexing, + GetInputToOutputIndexingForEntryComputation(ir)); + EXPECT_FALSE(output_indexing.Simplify({1, 17, 9, 9})); + EXPECT_THAT( + output_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)", std::vector{}))))); @@ -725,7 +782,7 @@ TEST_F(TileAnalysisTest, ReverseOp) { TEST_F(TileAnalysisTest, ReverseReshape) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m fused_computation { p0 = f32[10, 11] parameter(0) @@ -741,14 +798,14 @@ TEST_F(TileAnalysisTest, ReverseReshape) { )")); EXPECT_TRUE(input_indexing.Simplify({10, 11})); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))))); } TEST_F(TileAnalysisTest, SliceOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[10, 20, 50] parameter(0) @@ -757,38 +814,46 @@ TEST_F(TileAnalysisTest, SliceOp) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", std::vector{}))))); } TEST_F(TileAnalysisTest, TransposeOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + auto ir = R"( HloModule m ENTRY e { p0 = f32[3, 12288, 6, 128] parameter(0) ROOT transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1} } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + )"; + TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, + GetOutputToInputIndexingForEntryComputation(ir)); + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", std::vector{}))))); + + TF_ASSERT_OK_AND_ASSIGN(auto output_indexing, + GetInputToOutputIndexingForEntryComputation(ir)); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( + "(d0, d1, d2, d3) -> (d0, d2, d3, d1)", + std::vector{}))))); } TEST_F(TileAnalysisTest, TransposeOp4D) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[3, 12288, 6, 128] parameter(0) ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) } )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, + EXPECT_THAT(input_indexing.indexing_maps, ElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", std::vector{}))))); @@ -796,7 +861,7 @@ TEST_F(TileAnalysisTest, TransposeOp4D) { TEST_F(TileAnalysisTest, DotOp) { TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( + GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[4, 38, 17, 11, 18, 10] parameter(0) @@ -807,7 +872,7 @@ TEST_F(TileAnalysisTest, DotOp) { } )")); EXPECT_THAT( - input_indexing.operand_indexing_maps, + input_indexing.indexing_maps, UnorderedElementsAre(Pair(0, ElementsAre(MatchIndexingMap( "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " "(d2, d1, s1, d3, s0, d0)", @@ -819,7 +884,7 @@ TEST_F(TileAnalysisTest, DotOp) { } TEST_F(TileAnalysisTest, UnsupportedOps) { - ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( + ASSERT_IS_NOT_OK(GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { p0 = f32[1, 17, 9, 9] parameter(0) @@ -827,7 +892,7 @@ TEST_F(TileAnalysisTest, UnsupportedOps) { ROOT concat = f32[6, 17, 9, 9] concatenate(p0, p1) } )")); - ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( + ASSERT_IS_NOT_OK(GetOutputToInputIndexingForEntryComputation(R"( HloModule m ENTRY e { input = s32[1,1,25,1] parameter(0) From 2cadb32cd81811fc0291e8e26ec28f5e27421711 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 08:52:09 -0800 Subject: [PATCH 343/381] Changed Version of Bazel to version 6.4.0 PiperOrigin-RevId: 587735745 --- .bazelversion | 2 +- tensorflow/core/kernels/mlir_generated/build_defs.bzl | 1 - tensorflow/tools/ci_build/release/common.sh | 2 +- third_party/xla/.bazelversion | 2 +- third_party/xla/third_party/tsl/.bazelversion | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.bazelversion b/.bazelversion index 204ac7c926e437..b536fbc5061305 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -6.4.0 +6.1.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index a3535c4ac93080..fe071b2375ec0e 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -217,7 +217,6 @@ _gen_kernel_bin_rule = rule( outputs = {"kernel": "%{name}_kernel.o"}, toolchains = use_cpp_toolchain(), implementation = _gen_kernel_bin_impl, - provides = [CcInfo], ) # Returns the shape string (e.g. "4x4" or "16Bx2") as comma-separated integers. diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index 33cb30c1df1381..e5a490ac177e59 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -17,7 +17,7 @@ # Keeps Bazel versions of the build scripts. # LINT.IfChange -LATEST_BAZEL_VERSION=6.4.0 +LATEST_BAZEL_VERSION=6.1.0 # LINT.ThenChange( # //tensorflow/opensource_only/.bazelversion, # //tensorflow/tools/ci_build/install/install_bazel.sh, diff --git a/third_party/xla/.bazelversion b/third_party/xla/.bazelversion index 204ac7c926e437..b536fbc5061305 100644 --- a/third_party/xla/.bazelversion +++ b/third_party/xla/.bazelversion @@ -1,2 +1,2 @@ -6.4.0 +6.1.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/.bazelversion b/third_party/xla/third_party/tsl/.bazelversion index 204ac7c926e437..b536fbc5061305 100644 --- a/third_party/xla/third_party/tsl/.bazelversion +++ b/third_party/xla/third_party/tsl/.bazelversion @@ -1,2 +1,2 @@ -6.4.0 +6.1.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file From a93f9e96e798b01841d69439eed5889b09638761 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 10:12:30 -0800 Subject: [PATCH 344/381] Deduplicate some code by using the CreateAllStrategiesGroup to generate sharding strategies for iota and broadcast ops. PiperOrigin-RevId: 587762424 --- .../auto_sharding/auto_sharding.cc | 15 ++- .../auto_sharding/auto_sharding.h | 3 +- .../auto_sharding/auto_sharding_strategy.cc | 112 +++++++----------- 3 files changed, 55 insertions(+), 75 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index f33cf2954a7c12..7a9a54ae180bc9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1148,7 +1148,8 @@ StatusOr> CreateAllStrategiesGroup( const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies) { + bool create_replicated_strategies, + bool create_partially_replicated_strategies) { std::unique_ptr strategy_group; if (shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); @@ -1159,7 +1160,8 @@ StatusOr> CreateAllStrategiesGroup( strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, - create_replicated_strategies) + create_replicated_strategies, + create_partially_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); @@ -1167,9 +1169,12 @@ StatusOr> CreateAllStrategiesGroup( } else if (shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, only_allow_divisible, - "", call_graph); + if (create_partially_replicated_strategies || + cluster_env.IsDeviceMesh1D()) { + EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, + strategy_map, strategy_group, + only_allow_divisible, "", call_graph); + } // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 66f12dd37ad95a..f3010e0a7bb8e5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -247,7 +247,8 @@ StatusOr> CreateAllStrategiesGroup( const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies); + bool create_replicated_strategies, + bool create_partially_replicated_strategies); // Enumerates sharding strategies for elementwise operators by following // strategies of an operand of the elementwise op. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 8a2468fd4fb2b9..7dc2c70717afc6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include #include @@ -37,8 +37,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" -#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" @@ -75,7 +75,6 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes) { const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; StrategyMap strategy_map; // This map stores all of the trimmed strategies due to user specified // sharding. The key is the instruction id, the value is the strategies. This @@ -127,11 +126,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kRngBitGenerator: case HloOpcode::kRng: { strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters) + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + option.allow_replicated_parameters, + /* create_partially_replicated_strategies */ true) .value(); break; } @@ -260,28 +260,18 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kBroadcast: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - - if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } else { - EnumerateAllPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - batch_dim_map, only_allow_divisible, call_graph, - /*partitions*/ 2); - if (option.allow_mixed_mesh_shape) { - EnumerateAll1DPartition(ins, ins->shape(), - cluster_env.device_mesh_1d_, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "1d", call_graph); - } - } - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - + // For an unknown reason, we do not generate partially replicated + // strategies for >1D broadcast ops. This can be changed if we find that + // our search isn't exhaustive enough for certain ops. + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ + (ins->shape().rank() == 1)) + .value(); break; } case HloOpcode::kReshape: { @@ -561,38 +551,17 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kIota: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - if (cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } - if (cluster_env.IsDeviceMesh2D()) { - // Split 2 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 2); - } - if (cluster_env.IsDeviceMesh3D()) { - // Split 3 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 3); - } - if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) { - // Split 1 dim, but for 1d flattened version of the 2d mesh - // For example, when the mesh shape is (2, 4), we add strategies for - // mesh shape (1, 8) here in addition. - EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, " 1d", call_graph); - } - - // Replicate - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty * 5); - + // For an unknown reason, we do not generate partially replicated + // strategies for iota ops. This can be changed if we find that our + // search isn't exhaustive enough for certain ops. + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ false) + .value(); break; } case HloOpcode::kTuple: { @@ -649,7 +618,9 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) + batch_dim_map, call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) .value(); } } else { @@ -664,7 +635,9 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) + batch_dim_map, call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) .value(); } } @@ -723,11 +696,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kInfeed: case HloOpcode::kSort: { strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - /*create_replicated_strategies*/ true) + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) .value(); break; } From 4ef66d1427089249cf46a32457effe873c266836 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 4 Dec 2023 10:22:03 -0800 Subject: [PATCH 345/381] [xla:gpu] Add support for CustomFusion matched pattern replacements to replace intermediate instructions PiperOrigin-RevId: 587765924 --- .../xla/service/gpu/custom_fusion_rewriter.cc | 111 ++++++++++++------ third_party/xla/xla/service/gpu/kernels/BUILD | 5 + .../gpu/kernels/custom_fusion_pattern.cc | 36 ++++++ .../gpu/kernels/custom_fusion_pattern.h | 40 ++++++- .../gpu/kernels/cutlass_gemm_fusion.cc | 16 ++- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 59 ++++++++++ 6 files changed, 225 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc index df279f901beb71..622db91c1abea9 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/custom_fusion_rewriter.h" #include +#include #include #include @@ -23,18 +24,20 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" -#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { @@ -43,35 +46,50 @@ CustomFusionRewriter::CustomFusionRewriter( const CustomFusionPatternRegistry* patterns) : device_(device), patterns_(patterns) {} -// Returns instructions that have to become custom fusion parameters. Returns an -// error if matched pattern can't be outlined as a fusion. -static StatusOr> GetPatternCaptures( - const CustomFusionPattern::Match& match) { - HloInstruction* root = match.instructions.back(); - absl::InlinedVector captures; - - // Instruction that will go into the fusion body. +// Returns a set of instruction that have users outside of a matched pattern +// and have a replacement that must be applied after building a new custom +// fusion instruction. Only root instruction can have external users and does +// not require a replacement, as the fusion itself is a replacement. If +// instruction has external users and does not have a replacement returns empty +// optional. +static std::optional> +GetPatternReplacements(const CustomFusionPattern::Match& match) { + absl::flat_hash_set requires_replacement; absl::flat_hash_set instructions_set( - match.instructions.begin(), match.instructions.end()); + match.instructions().begin(), match.instructions().end()); - // Check that intermediate instructions do not have users outside of the - // matched pattern. Only root instruction can have external users. - for (HloInstruction* instr : match.instructions) { + for (HloInstruction* instr : match.instructions()) { for (HloInstruction* user : instr->users()) { - if (instr != root && !instructions_set.contains(user)) { - return absl::InvalidArgumentError(absl::StrCat( - "Custom fusion intermediate result ", instr->name(), - " has users outside of a matched pattern: ", user->name())); + if (instr == match.root() || instructions_set.contains(user)) continue; + + if (match.HasReplacement(instr)) { + requires_replacement.insert(instr); + continue; } + + VLOG(3) << "Custom fusion intermediate result " << instr->name() + << " has users outside of a matched pattern: " << user->name(); + return std::nullopt; } } - // Collect instructions captured by a matched pattern. - for (HloInstruction* instr : match.instructions) { + return requires_replacement; +} + +// Returns instructions that have to become custom fusion parameters. Returns an +// error if matched pattern can't be outlined as a fusion. +static absl::InlinedVector GetPatternCaptures( + const CustomFusionPattern::Match& match) { + absl::InlinedVector captures; + + absl::flat_hash_set instructions_set( + match.instructions().begin(), match.instructions().end()); + + for (HloInstruction* instr : match.instructions()) { for (HloInstruction* operand : instr->operands()) { if (!instructions_set.contains(operand) && absl::c_find(captures, operand) == captures.end()) { - captures.push_back(operand); + captures.emplace_back(operand); } } } @@ -83,7 +101,7 @@ static StatusOr> GetPatternCaptures( static StatusOr CreateFusionBody( HloModule* module, const CustomFusionPattern::Match& match, absl::Span captures) { - HloComputation::Builder builder(match.config.name()); + HloComputation::Builder builder(match.config().name()); // A mapping from original instructions to instructions in the fusion body. absl::flat_hash_map instr_mapping; @@ -107,7 +125,7 @@ static StatusOr CreateFusionBody( // TODO(ezhulenev): Instructions in the pattern must be topologically sorted, // otherwise we'll get a crash! Figure out how to do it! - for (HloInstruction* instr : match.instructions) { + for (HloInstruction* instr : match.instructions()) { instr_mapping[instr] = builder.AddInstruction( instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); } @@ -120,21 +138,20 @@ static StatusOr CreateFusionInstruction( absl::Span captures, HloComputation* body) { // We'll be replacing the root operation of a custom fusion with a fusion // instruction calling fusion computation. - HloInstruction* root = match.instructions.back(); + HloInstruction* root = match.root(); HloComputation* parent = root->parent(); // Add a fusion operation calling outlined fusion computation. HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( root->shape(), HloInstruction::FusionKind::kCustom, captures, body)); - module->SetAndUniquifyInstrName(fusion, match.config.name()); + module->SetAndUniquifyInstrName(fusion, match.config().name()); // Set backends config to a matched custom fusion config. FusionBackendConfig backend_config; backend_config.set_kind("__custom_fusion"); - *backend_config.mutable_custom_fusion_config() = match.config; + *backend_config.mutable_custom_fusion_config() = match.config(); TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(backend_config))); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(root, fusion)); return fusion; } @@ -154,25 +171,43 @@ StatusOr CustomFusionRewriter::Run( if (matches.empty()) return false; for (const CustomFusionPattern::Match& match : matches) { - // Check if pattern can be outlined as a fusion and collect captured - // parameters (instructions defined outside of a fusion). + VLOG(2) << "Matched custom fusion " << match.config().name() + << "; root instruction: " << match.instructions().back()->name(); + + auto replacememts = GetPatternReplacements(match); + if (!replacememts.has_value()) continue; + auto captures = GetPatternCaptures(match); - if (!captures.ok()) { - VLOG(2) << "Skip custom fusion " << match.config.name() << ": " - << captures.status(); - continue; - } TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, match, *captures)); - + CreateFusionBody(module, match, captures)); TF_ASSIGN_OR_RETURN( HloInstruction * fusion, - CreateFusionInstruction(module, match, *captures, fusion_body)); + CreateFusionInstruction(module, match, captures, fusion_body)); VLOG(2) << "Added a fusion instruction: " << fusion->name() - << " for custom fusion " << match.config.name() - << " (instruction count = " << match.instructions.size() << ")"; + << " for custom fusion " << match.config().name() + << " (instruction count = " << match.instructions().size() << ")"; + + for (HloInstruction* instr : *replacememts) { + VLOG(2) << "Replace matched instruction: " << instr->name() + << " with a pattern replacement"; + + TF_ASSIGN_OR_RETURN( + HloInstruction * replacement, + match.BuildReplacement(instr, Cast(fusion))); + + TF_RETURN_IF_ERROR( + instr->ReplaceAllUsesWith(replacement, match.config().name())); + + VLOG(2) << "Replaced instruction: " << instr->name() + << " with: " << replacement->name(); + } + + VLOG(2) << "Replace custom fusion root instruction " << match.root()->name() + << "with " << fusion->name(); + HloComputation* parent = match.root()->parent(); + TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion)); } return true; diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index efef54862054d7..3c26363e03669e 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -42,10 +42,15 @@ cc_library( hdrs = ["custom_fusion_pattern.h"], visibility = ["//visibility:public"], deps = [ + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc index 811b29ffb198e7..6da062c1363b39 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc @@ -19,11 +19,47 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla::gpu { +//===----------------------------------------------------------------------===// +// CustomFusionPattern::Match +//===----------------------------------------------------------------------===// + +CustomFusionPattern::Match::Match(CustomFusionConfig config, + std::vector instructions) + : config_(std::move(config)), instructions_(std::move(instructions)) {} + +void CustomFusionPattern::Match::AddReplacement(HloInstruction* instr, + Replacement replacement) { + replacements_[instr] = std::move(replacement); +} + +bool CustomFusionPattern::Match::HasReplacement(HloInstruction* instr) const { + return replacements_.contains(instr); +} + +StatusOr CustomFusionPattern::Match::BuildReplacement( + HloInstruction* instr, HloFusionInstruction* fusion) const { + if (auto it = replacements_.find(instr); it != replacements_.end()) { + return it->second(fusion); + } + + return absl::InvalidArgumentError( + absl::StrCat("no replacement for instruction: ", instr->name())); +} + +//===----------------------------------------------------------------------===// +// CustomFusionPatternRegistry +//===----------------------------------------------------------------------===// + CustomFusionPatternRegistry* CustomFusionPatternRegistry::Default() { static auto* registry = new CustomFusionPatternRegistry(); return registry; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h index cc90abccfae81c..308960902e823d 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h @@ -16,14 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ #define XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ +#include #include #include #include #include #include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -37,9 +42,38 @@ class CustomFusionPattern { public: virtual ~CustomFusionPattern() = default; - struct Match { - CustomFusionConfig config; - std::vector instructions; + // Matched sequence of instructions that can be handled by a custom fusion. + class Match { + public: + Match(CustomFusionConfig config, + std::vector instructions); + + // If some of operations matched by a pattern have users outside of the + // custom fusion, pattern can optionally provide a replacement that can be + // derived from the fusion instruction result, or from other instructions in + // the parent computation. + using Replacement = + std::function(HloFusionInstruction *)>; + + void AddReplacement(HloInstruction *instr, Replacement replacement); + bool HasReplacement(HloInstruction *instr) const; + + // Builds a replacement for `instr` using a `fusion` instruction constructed + // for a pattern match. + StatusOr BuildReplacement( + HloInstruction *instr, HloFusionInstruction *fusion) const; + + const CustomFusionConfig &config() const { return config_; } + absl::Span instructions() const { + return instructions_; + } + + HloInstruction *root() const { return instructions_.back(); } + + private: + CustomFusionConfig config_; + std::vector instructions_; + absl::flat_hash_map replacements_; }; // Returns custom fusion config and a list of instructions that matched to a diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 6fb374742ada09..1e3e794f871c4f 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -173,7 +173,21 @@ CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( CustomFusionConfig config; config.set_name("cutlass_gemm_with_dynamic_update_slice"); - return Match{config, matched->Instrs()}; + Match match(config, matched->Instrs()); + + // Add an optional replacement for intermediate dot instruction as a + // dynamic-slice from the fusion result. + match.AddReplacement(matched->dot, [=](HloFusionInstruction* fusion) { + HloComputation* parent = fusion->parent(); + auto* dus = Cast(matched->update_slice); + auto* slice = parent->AddInstruction(HloInstruction::CreateDynamicSlice( + matched->bitcast->shape(), fusion, dus->index_operands(), + matched->bitcast->shape().dimensions())); + return parent->AddInstruction( + HloInstruction::CreateBitcast(matched->dot->shape(), slice)); + }); + + return match; } std::optional diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 489917eac04d1c..07818d618f9d89 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -165,6 +165,65 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) { + const char* hlo = R"( + HloModule test + + ENTRY %main { + %p0 = f32[2,2,2]{2,1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + %add = f32[2,2]{1,0} add(%dot, %dot) + + %cast = f32[1,2,2]{2,1,0} bitcast(%dot) + %dus = f32[2,2,2]{2,1,0} dynamic-update-slice(%p0, %cast, %i, %i, %i) + + ROOT %r = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(%add, %dus) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2,2]{2,1,0} parameter + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[CAST:%[^ ]+]] = f32[1,2,2]{2,1,0} bitcast([[DOT]]) + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[2,2,2]{2,1,0} dynamic-update-slice( + ; CHECK: [[P1]], [[CAST]], [[P2]], [[P2]], [[P2]] + ; CHECK: ) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: [[OFFSET:%[^ ]+]] = s32[] parameter(2) + ; CHECK: [[FUSION:%[^ ]+]] = f32[2,2,2]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: [[SLICE:%[^ ]+]] = f32[1,2,2]{2,1,0} dynamic-slice( + ; CHECK: [[FUSION]], [[OFFSET]], [[OFFSET]], [[OFFSET]]), + ; CHECK: dynamic_slice_sizes={1,2,2} + ; CHECK: [[CAST:%[^. ]+]] = f32[2,2]{1,0} bitcast([[SLICE]]) + ; CHECK: [[ADD:%[^. ]+]] = f32[2,2]{1,0} add([[CAST]], [[CAST]]) + ; CHECK: } + )"; + + CustomFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + //===----------------------------------------------------------------------===// // Run And Compare Tests //===----------------------------------------------------------------------===// From 659c9ad359c7b2acc4829abaffb3863767d3cd42 Mon Sep 17 00:00:00 2001 From: Edward Schwartz Date: Mon, 4 Dec 2023 10:22:14 -0800 Subject: [PATCH 346/381] Add PjRtTensorBuffer support to GPUUtil::CopyGPUTensorToCPU A buffer used for GPU-to-CPU copies for the PjRt XLA case now has proper lifetime to avoid a race condition that caused incorrect (stale) numerical results with MultiWorkerMirroredStrategy. PiperOrigin-RevId: 587765995 --- tensorflow/core/common_runtime/gpu/BUILD | 3 +- .../core/common_runtime/gpu/gpu_util.cc | 41 ++++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index b99b41729218df..24e9fd15877e9f 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -193,7 +193,8 @@ tf_cuda_library( "@local_xla//xla/stream_executor/gpu:gpu_init_impl", ] + if_google( # TODO(b/282068262): PJRT pulls in TFRT components that are incompatible with ARM platform. - # Clean up so that PJRT can run on ARM. + # Clean up so that PJRT can run on ARM (and remove "#if defined(PLATFORM_GOOGLE) ..." use + # from gpu_util.cc). # Also it won't build with WeightWatcher which tracks OSS build binaries. # TODO(b/290533709): Clean up this build rule. selects.with_or({ diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index b699239fdb979b..a04d8e50c088c6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -17,6 +17,23 @@ limitations under the License. #include #include +#include +#include + +#include "absl/status/status.h" + +// TODO(b/282059652): Merge google internal and open-source code path once TF +// dependency issue is resolved. +#if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64)) +#define TF_GPU_USE_PJRT +#endif // PLATFORM_GOOGLE && TF_PLATFORM_LINUX_X86_64 + +#ifdef TF_GPU_USE_PJRT +#include "tensorflow/compiler/jit/pjrt_tensor_buffer.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_future.h" +#endif // TF_GPU_USE_PJRT #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" @@ -272,12 +289,27 @@ bool NeedStaging(const Tensor* tensor) { } // namespace -// static void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, const DeviceContext* device_context, const Tensor* gpu_tensor, Tensor* cpu_tensor, StatusCallback done) { - VLOG(1) << "CopyGPUTensorToCPU"; +#ifdef TF_GPU_USE_PJRT + const PjRtTensorBuffer* pjrt_tensor_buffer = + dynamic_cast(DMAHelper::buffer(gpu_tensor)); + if (pjrt_tensor_buffer != nullptr) { + VLOG(1) << "CopyGPUTensorToCPU using PjRtTensorBuffer"; + auto literal = std::make_unique(); + auto status = tensorflow::HostTensorToMutableBorrowingLiteral( + cpu_tensor, literal.get()); + xla::PjRtFuture future = + pjrt_tensor_buffer->pjrt_buffer()->ToLiteral(literal.get()); + future.OnReady([literal = std::move(literal), + done](const tensorflow::Status& status) { done(status); }); + return; + } +#endif // TF_GPU_USE_PJRT + + VLOG(1) << "CopyGPUTensorToCPU using AcceleratorDeviceInfo"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, cpu_tensor, @@ -291,7 +323,7 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, static_cast(device_context) ->device_to_host_stream(); if (send_device_to_host_stream == nullptr) { - done(errors::Internal("No send gpu copy-out-stream is available.")); + done(absl::InternalError("No send gpu copy-out-stream is available.")); return; } // Wait for the sender's main stream to make sure the data are available. @@ -310,14 +342,13 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, send_device_to_host_stream, [send_device_to_host_stream, done, input_ref]() { if (!send_device_to_host_stream->ok()) { - LOG(FATAL) << "GPU->CPU Memcpy failed"; + LOG(FATAL) << "GPU->CPU Memcpy failed"; // Crash OK } input_ref.Unref(); done(OkStatus()); }); } -/* static */ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, const DeviceContext* device_context, Device* gpu_device, Tensor* gpu_tensor, From 97afb56fc164334ae758e080788f9c95d0dcddab Mon Sep 17 00:00:00 2001 From: Robert David Date: Mon, 4 Dec 2023 10:40:45 -0800 Subject: [PATCH 347/381] Reorganize code and have separate `AlignedAlloc`, `AlignedFree`, and `AlignedRealloc` functions. PiperOrigin-RevId: 587772398 --- tensorflow/lite/simple_memory_arena.cc | 74 ++++++++++++++++---------- tensorflow/lite/simple_memory_arena.h | 12 +++-- 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index d2885f4bc8cd84..d8cd854a608177 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -39,10 +39,47 @@ T AlignTo(size_t alignment, T offset) { : offset + (alignment - offset % alignment); } -size_t RequiredAllocationSize(size_t data_array_size, size_t alignment) { - return data_array_size + alignment - 1; +// Allocates memory and aligns it to the specified size. Returns a pair of the +// allocation pointer and the aligned pointer. +tflite::PointerAlignedPointerPair AlignedAlloc(size_t size, size_t alignment) { + const size_t allocation_size = size + alignment - 1; + char* pointer = reinterpret_cast(std::malloc(allocation_size)); +#if defined(__clang__) +#if __has_feature(memory_sanitizer) + std::memset(pointer, 0, allocation_size); +#endif +#endif + char* aligned_ptr = reinterpret_cast( + AlignTo(alignment, reinterpret_cast(pointer))); + return {pointer, aligned_ptr}; +} + +// Frees up aligned memory. +void AlignedFree(const tflite::PointerAlignedPointerPair& buffer) { + std::free(buffer.pointer); } +// Reallocates aligned memory +// +// The function either extends the memory allocation in-place, or if that is not +// possible a new allocation is created, the data is copied, and the old buffer +// is deallocated. It is an error to change the alignment during reallocation. +// If the previous allocation is null, this is equivalent to AlignedAlloc. +// Returns pointers to the new allocation. +tflite::PointerAlignedPointerPair AlignedRealloc( + const tflite::PointerAlignedPointerPair& old_buffer, size_t old_size, + size_t new_size, size_t alignment) { + tflite::PointerAlignedPointerPair new_buffer = + AlignedAlloc(new_size, alignment); + if (new_size > 0 && old_size > 0) { + // Copy data when both old and new buffers are bigger than 0 bytes. + const size_t copy_amount = std::min(new_size, old_size); + std::memcpy(new_buffer.aligned_pointer, old_buffer.aligned_pointer, + copy_amount); + } + AlignedFree(old_buffer); + return new_buffer; +} } // namespace namespace tflite { @@ -56,50 +93,33 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { PauseHeapMonitoring(/*pause=*/true); OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_size); -#endif - const size_t new_allocation_size = - RequiredAllocationSize(new_size, alignment_); - char* new_buffer = reinterpret_cast(std::malloc(new_allocation_size)); -#if defined(__clang__) -#if __has_feature(memory_sanitizer) - std::memset(new_buffer, 0, new_allocation_size); -#endif -#endif - char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer))); - if (new_size > 0 && data_size_ > 0) { - // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t copy_amount = std::min(new_size, data_size_); - std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); - } - std::free(buffer_); - buffer_ = new_buffer; - aligned_ptr_ = new_aligned_ptr; -#ifdef TF_LITE_TENSORFLOW_PROFILER if (data_size_ > 0) { OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), data_size_); } #endif + auto new_buffer = AlignedRealloc(buffer_, data_size_, new_size, alignment_); + bool reallocated = (new_buffer.aligned_pointer != buffer_.aligned_pointer); + buffer_ = new_buffer; data_size_ = new_size; #ifdef TF_LITE_TENSORFLOW_PROFILER PauseHeapMonitoring(/*pause=*/false); #endif - return true; + return reallocated; } void ResizableAlignedBuffer::Release() { - if (buffer_ == nullptr) { + if (buffer_.pointer == nullptr) { return; } #ifdef TF_LITE_TENSORFLOW_PROFILER OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), data_size_); #endif - std::free(buffer_); - buffer_ = nullptr; + AlignedFree(buffer_); + buffer_.pointer = nullptr; + buffer_.aligned_pointer = nullptr; data_size_ = 0; - aligned_ptr_ = nullptr; } void SimpleMemoryArena::PurgeAfter(int32_t node) { diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 05a26ccc20d27d..7275b3014f3660 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -53,10 +53,15 @@ struct ArenaAllocWithUsageInterval { } }; +struct PointerAlignedPointerPair { + char* pointer; + char* aligned_pointer; +}; + class ResizableAlignedBuffer { public: ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : buffer_(nullptr), + : buffer_{nullptr, nullptr}, data_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { @@ -75,7 +80,7 @@ class ResizableAlignedBuffer { void Release(); // Pointer to the data array. - char* GetPtr() const { return aligned_ptr_; } + char* GetPtr() const { return buffer_.aligned_pointer; } // Size of the data array. Note: the allocated memory block might be larger // due to excess alignment requirements. size_t GetSize() const { return data_size_; } @@ -88,10 +93,9 @@ class ResizableAlignedBuffer { ResizableAlignedBuffer(ResizableAlignedBuffer&&) = delete; ResizableAlignedBuffer& operator=(ResizableAlignedBuffer&&) = delete; - char* buffer_; + PointerAlignedPointerPair buffer_; size_t data_size_; size_t alignment_; - char* aligned_ptr_; int subgraph_index_; }; From 5c04562a352c0087dc1270f5ac4c6936a526cd95 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 4 Dec 2023 11:45:05 -0800 Subject: [PATCH 348/381] Enable all tests on windows PiperOrigin-RevId: 587792813 --- .../xla/third_party/tsl/.kokoro/windows/windows_build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh b/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh index 331efa186fb87e..4f4b0a0fdf9d31 100644 --- a/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh +++ b/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh @@ -50,7 +50,7 @@ export PATH="$PATH:/c/Python38" -- //tsl/... \ || { echo "Bazel Build Failed" && exit 1; } -# Test TSL TODO(ddunleavy) enable all tests +# Test TSL /c/tools/bazel.exe test \ --output_filter="" \ --flaky_test_attempts=3 \ @@ -60,7 +60,7 @@ export PATH="$PATH:/c/Python38" --build_tag_filters=$TAGS_FILTER \ --test_tag_filters=$TAGS_FILTER \ --keep_going \ - -- //tsl/... -//tsl/platform:subprocess_test -//tsl/platform/cloud:google_auth_provider_test -//tsl/platform/cloud:oauth_client_test \ + -- //tsl/... \ || { echo "Bazel Test Failed" && exit 1; } exit 0 From 147708254e68177bcb431939ce5758f18caf1306 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 4 Dec 2023 11:49:15 -0800 Subject: [PATCH 349/381] Use `str | None` rather than `Optional[str]` FORCE_TEST_ACTIONS PiperOrigin-RevId: 587793944 --- third_party/xla/build_tools/lint/check_contents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/build_tools/lint/check_contents.py b/third_party/xla/build_tools/lint/check_contents.py index 1649152148d1a4..5d09ec074b3b1e 100644 --- a/third_party/xla/build_tools/lint/check_contents.py +++ b/third_party/xla/build_tools/lint/check_contents.py @@ -22,7 +22,7 @@ import logging # Intended to run on vanilla Github Actions runner import re import sys -from typing import Iterable, Optional, Sequence +from typing import Iterable, Sequence from xla.build_tools.lint import diff_parser @@ -92,7 +92,7 @@ def check_diffs( hunks: Iterable[diff_parser.Hunk], *, prohibited_regex: str, - suppression_regex: Optional[str] = None, # TODO(ddunleavy): CI not on 3.10 + suppression_regex: str | None = None, ) -> list[RegexLocation]: """Checks FileDiffs for prohibited regexes. From 1e6ad9a4fb5dad0229c7e30546d7bff04cdd4bee Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Mon, 4 Dec 2023 12:08:02 -0800 Subject: [PATCH 350/381] Add logging for DTensorDevice::Execute inputs. PiperOrigin-RevId: 587799882 --- tensorflow/dtensor/cc/dtensor_device.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index 81d03d14ace24e..b9ddf25fde6caf 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -2384,6 +2384,9 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, absl::flat_hash_set input_meshes; std::vector single_device_input_indices; + VLOG(4) << "DTensorOperation: " << dtensor_operation.name + << " num_inputs are " << num_inputs; + typed_inputs.resize(num_inputs); for (int j = 0; j < num_inputs; ++j) { TFE_TensorHandle* input = inputs[j]; @@ -2392,6 +2395,8 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, if (name_ != input_device) { single_device_input_indices.push_back(j); typed_inputs[j] = nullptr; + VLOG(5) << "Input " << j << ": " + << tensorflow::unwrap(input)->DebugString(); continue; } // Handle input which is on DTensor device already. @@ -2404,10 +2409,15 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, input_meshes.insert(t->layout().mesh()); } typed_inputs[j] = t; + VLOG(5) << "Input " << j << ": " << typed_inputs[j]->DebugString(); } const std::optional mesh = ChooseBroadcastingMesh(input_meshes, dtypes); + VLOG(4) << "Execution DTensorOperation: " << dtensor_operation.name + << " with broadcast mesh " + << (mesh.has_value() ? mesh->ToString() : "no broadcast mesh"); + // TODO(feyu): This short circuit only allows running unsupported op // via DTensorDevice in eager mode. for tf.function and its graph, we will // need to build single device mesh placement rules in mesh propagation. From 00b81b7ad37984933a2e6ef2c294b67da0c063a5 Mon Sep 17 00:00:00 2001 From: Wilsin Gosti Date: Mon, 4 Dec 2023 12:17:50 -0800 Subject: [PATCH 351/381] #tf-data Annotate IO read statistics to xprof to aid user debugging. PiperOrigin-RevId: 587803069 --- tensorflow/core/data/root_dataset.cc | 24 +++++++++++++++++++ tensorflow/core/platform/host_info.h | 1 + .../tsl/tsl/platform/default/port.cc | 6 +++-- .../third_party/tsl/tsl/platform/host_info.h | 18 ++++++++++++++ .../tsl/tsl/platform/windows/port.cc | 5 ++-- 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 55ff2bc8122213..bba8a426366329 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" +#include "tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -46,6 +47,8 @@ constexpr char kDatasetType[] = "Root"; constexpr char kAlgorithm[] = "algorithm"; constexpr char kCpuBudget[] = "cpu_budget"; constexpr char kExperiments[] = "experiments"; +constexpr char kReadRoundtripLatency[] = "read_latency_usec"; +constexpr char kReadResponseBytes[] = "read_bytes"; constexpr char kIntraOpParallelism[] = "intra_op_parallelism"; constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec"; constexpr char kPrivateThreadpoolSize[] = "threadpool_size"; @@ -277,6 +280,27 @@ class RootDataset::Iterator : public DatasetIterator { "%lld", static_cast( model_node()->TotalMaximumBufferedBytes() / 1.0e6)))); } + const auto io_statistics = tsl::port::GetIOStatistics(); + if (io_statistics.roundtrip_latency_usec.count > 0) { + traceme_metadata.push_back(std::make_pair( + kReadRoundtripLatency, + strings::Printf( + "(count: %lld, mean: %lld, std dev: %lld)", + static_cast( + io_statistics.roundtrip_latency_usec.count), + static_cast(io_statistics.roundtrip_latency_usec.mean), + static_cast( + io_statistics.roundtrip_latency_usec.std_dev)))); + } + if (io_statistics.response_bytes.count > 0) { + traceme_metadata.push_back(std::make_pair( + kReadResponseBytes, + strings::Printf( + "(count: %lld, mean: %lld, std dev: %lld)", + static_cast(io_statistics.response_bytes.count), + static_cast(io_statistics.response_bytes.mean), + static_cast(io_statistics.response_bytes.std_dev)))); + } return traceme_metadata; } diff --git a/tensorflow/core/platform/host_info.h b/tensorflow/core/platform/host_info.h index 89d495d6e41229..caab7ae380b31b 100644 --- a/tensorflow/core/platform/host_info.h +++ b/tensorflow/core/platform/host_info.h @@ -22,6 +22,7 @@ limitations under the License. namespace tensorflow { namespace port { using tsl::port::Hostname; +using tsl::port::IOStatistics; using tsl::port::JobName; using tsl::port::JobUid; } // namespace port diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc index c2151c78ec5330..868fb35f887dab 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc @@ -15,6 +15,7 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "tsl/platform/cpu_info.h" +#include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" @@ -256,7 +257,6 @@ int NUMAGetThreadNodeAffinity() { return node_index; } - void* NUMAMalloc(int node, size_t size, int minimum_alignment) { #ifdef TENSORFLOW_USE_NUMA if (HaveHWLocTopology()) { @@ -307,7 +307,6 @@ int NUMAGetMemAffinity(const void* addr) { return node; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -447,5 +446,8 @@ MemoryBandwidthInfo GetMemoryBandwidthInfo() { MemoryBandwidthInfo membw_info = {INT64_MAX}; return membw_info; } + +IOStatistics GetIOStatistics() { return IOStatistics(); } + } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/host_info.h b/third_party/xla/third_party/tsl/tsl/platform/host_info.h index 189f3be2934ce3..630f9424525e04 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/host_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/host_info.h @@ -16,11 +16,26 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ #define TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ +#include + #include "tsl/platform/types.h" namespace tsl { namespace port { +// Statistical data of IO operations performed by the job. +struct IOStatistics { + struct Distribution { + uint64_t count = 0; + double mean = 0.0; + double std_dev = 0.0; + }; + // Distribution of round trip IO latency in microseconds. + Distribution roundtrip_latency_usec; + // Distribution of data received by IO reads in bytes. + Distribution response_bytes; +}; + // Return the hostname of the machine on which this process is running. string Hostname(); @@ -34,6 +49,9 @@ int64_t JobUid(); // Returns the Borg task ID as an int64_t if it exists. Otherwise return -1. int64_t TaskId(); +// Retrieves the host file read statistics. +IOStatistics GetIOStatistics(); + } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc index 9b5692650dbb5c..f8e19503edb305 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc @@ -61,6 +61,8 @@ int64_t JobUid() { return -1; } int64_t TaskId() { return -1; } +IOStatistics GetIOStatistics() { return IOStatistics(); } + int NumSchedulableCPUs() { SYSTEM_INFO system_info; GetSystemInfo(&system_info); @@ -122,7 +124,6 @@ void NUMAFree(void* ptr, size_t size) { tsl::port::Free(ptr); } int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -183,7 +184,7 @@ string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { DWORD data; DWORD data_size = sizeof(data); - #pragma comment(lib, "shlwapi.lib") // For SHGetValue(). +#pragma comment(lib, "shlwapi.lib") // For SHGetValue(). if (SUCCEEDED( SHGetValueA(HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", From 62cc349b3ca3087ab71c82ca0dd6f3af4ef84eeb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 12:19:05 -0800 Subject: [PATCH 352/381] [XLA] Make PropagateLivenessThroughTuple O(N). PiperOrigin-RevId: 587803560 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/hlo_liveness_analysis.cc | 45 +++++++++++-------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 18c5db752de9de..90c79e4f532142 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4484,6 +4484,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/service/hlo_liveness_analysis.cc b/third_party/xla/xla/service/hlo_liveness_analysis.cc index 91ec0a6f8d7470..b24dd6678e905f 100644 --- a/third_party/xla/xla/service/hlo_liveness_analysis.cc +++ b/third_party/xla/xla/service/hlo_liveness_analysis.cc @@ -15,18 +15,22 @@ limitations under the License. #include "xla/service/hlo_liveness_analysis.h" +#include +#include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/call_graph.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/types.h" @@ -116,25 +120,28 @@ void PropagateLivenessThroughTuple( HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, Workset* workset) { CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); - for (int64_t operand_index = 0; operand_index < instruction->operand_count(); - ++operand_index) { - const ShapeTree& index_tree = *live_index_map->at(instruction); - ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { - if (shape_index.empty() || shape_index[0] != operand_index) { - return; - } - // Mark top-level index of operand at 'operand_index'. - MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, - worklist, workset); - // Mark sub-shape index of operand at 'operand_index'. - ShapeIndex operand_shape_index; - for (int i = 1; i < shape_index.size(); ++i) { - operand_shape_index.push_back(shape_index[i]); - } - MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, - live_index_map, worklist, workset); - }); - } + const ShapeTree& index_tree = *live_index_map->at(instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + const size_t size = shape_index.size(); + if (size == 0) { + return; + } + const int64_t operand_index = shape_index[0]; + if (operand_index >= instruction->operand_count()) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index(size - 1); + for (int i = 1; i < size; ++i) { + operand_shape_index[i - 1] = shape_index[i]; + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); } // Propagates liveness through GetTupleElement instructions. From ced6245c372463b2101c262c9231349feff23b66 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 4 Dec 2023 12:31:45 -0800 Subject: [PATCH 353/381] Create new translate package for graphdef import/export PiperOrigin-RevId: 587807107 --- tensorflow/compiler/mlir/tensorflow/BUILD | 359 ++---------------- .../compiler/mlir/tensorflow/translate/BUILD | 339 +++++++++++++++++ 2 files changed, 366 insertions(+), 332 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/translate/BUILD diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index dc2722f5d1096b..77530c113b9be2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -764,133 +764,6 @@ cc_library( ], ) -cc_library( - name = "upgrade_graph", - srcs = ["translate/upgrade_graph.cc"], - hdrs = ["translate/upgrade_graph.h"], - deps = [ - ":attribute_utils", - "//tensorflow/compiler/tf2xla:functionalize_control_flow", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core/common_runtime:device", - "//tensorflow/core/common_runtime:device_factory", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:grappler_item_builder", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - "//tensorflow/core/protobuf:for_core_protos_cc", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "export_graphdef", - srcs = [ - "translate/export_graphdef.cc", - ], - hdrs = [ - "translate/export_graphdef.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":convert_type", - ":error_util", - ":export_tf_dialect_op", - ":export_utils", - ":mlir_roundtrip_flags", - ":tensorflow", - ":translate_utils", - ":verify_suitable_for_graph_export", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/utils:name_utils", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/graph/regularization:util", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "import_model", - srcs = [ - "translate/import_model.cc", - ], - hdrs = [ - "translate/export_graphdef.h", - "translate/import_model.h", - ], - deps = [ - ":attribute_utils", - ":convert_attr", - ":convert_tensor", - ":convert_type", - ":dump_mlir_util", - ":dynamic_shape_utils", - ":error_util", - ":mangling_util", - ":mlir_import_options", - ":mlir_roundtrip_flags", - ":tensorflow", - ":tensorflow_attributes", - ":tensorflow_types", - ":translate_utils", - ":upgrade_graph", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/cc/saved_model:constants", - "//tensorflow/cc/saved_model:loader_lite", - "//tensorflow/cc/saved_model:loader_util", - "//tensorflow/compiler/jit:shape_inference_helpers", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:initialize_variables_in_session_init", - "//tensorflow/compiler/mlir/tensorflow/transforms:lift_variables_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:mark_initialized_variables_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/utils:transitive_fanin", - "//tensorflow/core/platform:crash_analysis", - "//tensorflow/core/platform:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_xla//xla:status_macros", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_parser", - ], -) - cc_library( name = "parse_text_proto", srcs = ["utils/parse_text_proto.cc"], @@ -916,20 +789,6 @@ cc_library( ], ) -tf_cc_test( - name = "tf_mlir_translate_registration_test", - size = "small", - srcs = ["translate/tf_mlir_translate_registration_test.cc"], - deps = [ - ":translate_registration", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:TranslateLib", - ], -) - cc_library( name = "export_utils", srcs = [ @@ -1006,92 +865,6 @@ cc_library( ], ) -cc_library( - name = "export_tf_dialect_op", - srcs = [ - "translate/export_tf_dialect_op.cc", - ], - hdrs = [ - "translate/export_tf_dialect_op.h", - ], - deps = [ - ":convert_type", - ":export_utils", - ":tensorflow", - "//tensorflow/compiler/mlir/utils:string_container_utils", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:DerivedAttributeOpInterface", - "@llvm-project//mlir:IR", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "translate_tf_dialect_op", - srcs = ["translate/translate_tf_dialect_op.cc"], - deps = [ - ":export_tf_dialect_op", - ":tensorflow", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - ], - alwayslink = 1, -) - -cc_library( - name = "mlir_roundtrip_pass", - srcs = ["translate/mlir_roundtrip_pass.cc"], - hdrs = ["translate/mlir_roundtrip_pass.h"], - deps = [ - ":error_util", - ":export_graphdef", - ":import_model", - ":mlir_roundtrip_flags", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "mlir_roundtrip_pass_registration", - srcs = ["translate/mlir_roundtrip_pass_registration.cc"], - deps = [ - ":mlir_roundtrip_pass", - ], - alwayslink = 1, -) - -cc_library( - name = "mlir_roundtrip_flags", - srcs = ["translate/mlir_roundtrip_flags.cc"], - hdrs = ["translate/mlir_roundtrip_flags.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@local_xla//xla:status_macros", - ], -) - cc_library( name = "convert_attr", srcs = ["utils/convert_attr.cc"], @@ -1253,90 +1026,6 @@ cc_library( ], ) -cc_library( - name = "mlir_import_options", - hdrs = ["translate/mlir_import_options.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "translate_lib", - srcs = ["translate/tf_mlir_translate.cc"], - hdrs = ["translate/tf_mlir_translate.h"], - visibility = ["//visibility:public"], - deps = [ - ":error_util", - ":import_model", - ":import_utils", - ":mangling_util", - ":mlir_import_options", - ":mlir_roundtrip_flags", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/cc/saved_model:loader_lite", - "//tensorflow/cc/saved_model:reader", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/utils:transitive_fanin", - "//tensorflow/core/util/tensor_bundle:byteswaptensor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - ], -) - -cc_library( - name = "translate_cl_options", - srcs = [ - "translate/tf_mlir_translate_cl.cc", - ], - hdrs = [ - "translate/tf_mlir_translate_cl.h", - ], - deps = [ - "@llvm-project//llvm:Support", - ], - alwayslink = 1, -) - -cc_library( - name = "translate_registration", - srcs = [ - "translate/tf_mlir_translate_registration.cc", - ], - deps = [ - ":export_graphdef", - ":mlir_roundtrip_flags", - ":tensorflow", - ":translate_cl_options", - ":translate_lib", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/service/cpu:cpu_compiler", - "@local_xla//xla/service/cpu:cpu_transfer_manager", - "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor/host:host_platform", - "@local_xla//xla/stream_executor/host:host_platform_id", - ], - alwayslink = 1, -) - tf_cc_test( name = "error_util_test", srcs = ["utils/error_util_test.cc"], @@ -1943,27 +1632,6 @@ cc_library( ], ) -cc_library( - name = "split_into_island_per_op_pass", - srcs = ["translate/split_into_island_per_op_pass.cc"], - hdrs = [ - "ir/tf_executor.h", - "translate/split_into_island_per_op_pass.h", - ], - deps = [ - ":tensorflow", - ":tensorflow_executor_inc_gen", - ":tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:Pass", - ], -) - tf_cc_test( name = "xla_rewrite_util_test", size = "small", @@ -2031,3 +1699,30 @@ build_test( # ) # # copybara:uncomment_end(google-only) + +# Required as we created the transforms subpackage and need to update +# these BUILD targets in a follow up. +aliased_targets = [ + "export_graphdef", + "import_model", + "export_tf_dialect_op", + "translate_tf_dialect_op", + "mlir_roundtrip_pass", + "mlir_roundtrip_pass_registration", + "mlir_roundtrip_flags", + "mlir_import_options", + "translate_lib", + "translate_cl_options", + "translate_registration", + "split_into_island_per_op_pass", + "upgrade_graph", +] + +[ + alias( + name = target, + actual = "//tensorflow/compiler/mlir/tensorflow/translate:%s" % target, + visibility = ["//visibility:public"], + ) + for target in aliased_targets +] diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD new file mode 100644 index 00000000000000..46af8590c8108e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -0,0 +1,339 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +cc_library( + name = "import_model", + srcs = [ + "import_model.cc", + ], + hdrs = [ + "export_graphdef.h", + "import_model.h", + ], + deps = [ + ":mlir_roundtrip_flags", + ":upgrade_graph", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:loader_util", + "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:convert_attr", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:initialize_variables_in_session_init", + "//tensorflow/compiler/mlir/tensorflow/transforms:lift_variables_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:mark_initialized_variables_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/platform:crash_analysis", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla:status_macros", + "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_parser", + ], +) + +tf_cc_test( + name = "tf_mlir_translate_registration_test", + size = "small", + srcs = ["tf_mlir_translate_registration_test.cc"], + deps = [ + ":translate_registration", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:TranslateLib", + ], +) + +cc_library( + name = "export_tf_dialect_op", + srcs = [ + "export_tf_dialect_op.cc", + ], + hdrs = [ + "export_tf_dialect_op.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/utils:string_container_utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:IR", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "translate_tf_dialect_op", + srcs = ["translate_tf_dialect_op.cc"], + deps = [ + ":export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_roundtrip_pass", + srcs = ["mlir_roundtrip_pass.cc"], + hdrs = ["mlir_roundtrip_pass.h"], + deps = [ + ":export_graphdef", + ":import_model", + ":mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "mlir_roundtrip_pass_registration", + srcs = ["mlir_roundtrip_pass_registration.cc"], + deps = [ + ":mlir_roundtrip_pass", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_roundtrip_flags", + srcs = ["mlir_roundtrip_flags.cc"], + hdrs = ["mlir_roundtrip_flags.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "mlir_import_options", + hdrs = ["mlir_import_options.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "translate_lib", + srcs = ["tf_mlir_translate.cc"], + hdrs = ["tf_mlir_translate.h"], + visibility = ["//visibility:public"], + deps = [ + ":import_model", + ":mlir_roundtrip_flags", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_utils", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/util/tensor_bundle:byteswaptensor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "translate_cl_options", + srcs = [ + "tf_mlir_translate_cl.cc", + ], + hdrs = [ + "tf_mlir_translate_cl.h", + ], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "export_graphdef", + srcs = [ + "export_graphdef.cc", + ], + hdrs = [ + "export_graphdef.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":export_tf_dialect_op", + ":mlir_roundtrip_flags", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow:verify_suitable_for_graph_export", + "//tensorflow/compiler/mlir/utils:name_utils", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/graph/regularization:util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "translate_registration", + srcs = [ + "tf_mlir_translate_registration.cc", + ], + deps = [ + ":export_graphdef", + ":mlir_roundtrip_flags", + ":translate_cl_options", + ":translate_lib", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) + +cc_library( + name = "split_into_island_per_op_pass", + srcs = ["split_into_island_per_op_pass.cc"], + hdrs = [ + "split_into_island_per_op_pass.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_executor_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "upgrade_graph", + srcs = ["upgrade_graph.cc"], + hdrs = ["upgrade_graph.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/tf2xla:functionalize_control_flow", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@llvm-project//llvm:Support", + ], +) From b650fbd014813f63bee01f7330888d5cd85117dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 12:47:07 -0800 Subject: [PATCH 354/381] Various minor cleanups in the test file * Fixed typos * Removed unused headers * Switched from CHECK(p) to CHECK(p != nullptr) according to best practice * Small code refactoring PiperOrigin-RevId: 587811152 --- tensorflow/core/runtime_fallback/test/BUILD | 4 +- .../batch_function_fallback_benchmark_test.cc | 62 ++++++++++--------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/runtime_fallback/test/BUILD b/tensorflow/core/runtime_fallback/test/BUILD index d8fd4850f1357f..58accf3aa46a65 100644 --- a/tensorflow/core/runtime_fallback/test/BUILD +++ b/tensorflow/core/runtime_fallback/test/BUILD @@ -1,5 +1,5 @@ -load("@tf_runtime//tools:mlir_to_bef.bzl", "mlir_to_bef") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_test", "tf_cc_test") +load("@tf_runtime//tools:mlir_to_bef.bzl", "mlir_to_bef") # copybara:uncomment load("//third_party/tf_runtime_google/cpp_tests:gen_tests.bzl", "tfrt_cc_test_and_strict_benchmark") package( @@ -157,7 +157,9 @@ cc_library( # deps = [ # "//base", # "//devtools/build/runtime:get_runfiles_dir", +# "@com_google_absl//absl/log:check", # "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", +# "//tensorflow/core:framework", # "//tensorflow/core/platform:env", # "//tensorflow/core/platform:resource_loader", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", diff --git a/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc b/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc index 11bc0b6ecbf4f5..6031aeb4726c67 100644 --- a/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc +++ b/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc @@ -12,46 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include "base/logging.h" -#include "devtools/build/runtime/get_runfiles_dir.h" #include "testing/base/public/benchmark.h" -#include #include +#include "absl/log/check.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.h" -#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h" #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h" -#include "tensorflow/core/runtime_fallback/util/tensor_util.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime -#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime +#include "tfrt/host_context/async_value.h" // from @tf_runtime #include "tfrt/host_context/chain.h" // from @tf_runtime #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/function.h" // from @tf_runtime +#include "tfrt/host_context/host_allocator.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/support/aligned_buffer.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime #include "tfrt/support/rc_array.h" // from @tf_runtime #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime -#include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime namespace tensorflow { namespace tfd { namespace { // Creates a BEF file with a program that runs -// tfrt_fallback_async.batch_function with a empty function forwarding inputs or -// outputs. +// tfrt_fallback_async.batch_function with an "identity" function (i.e. forward +// inputs to outputs). std::pair> CreateBefFile( tfrt::HostContext* host) { std::string file_path = GetDataDependencyFilepath( @@ -65,14 +64,14 @@ std::pair> CreateBefFile( auto bef_file = tfrt::BEFFile::Open(bef_buffer, host->GetKernelRegistry(), host->diag_handler(), host->allocator()); - CHECK(bef_file); + CHECK(bef_file != nullptr); return std::make_pair(std::move(bef_buffer), std::move(bef_file)); } std::unique_ptr CreateTestCoreRuntime() { auto corert = tfrt::CoreRuntime::Create( - /*diag_handler=*/[](const tfrt::DecodedDiagnostic& - diag) { LOG(ERROR) << diag.message(); }, + /*diag_handler=*/ + [](const tfrt::DecodedDiagnostic& diag) { LOG(ERROR) << diag.message(); }, tfrt::CreateMallocAllocator(), tfrt::CreateMultiThreadedWorkQueue(16, 16)); CHECK(corert); @@ -83,13 +82,14 @@ std::unique_ptr CreateTestCoreRuntime() { return std::move(corert.get()); } -tfrt::RCArray CreateTestArguments(const tfrt::Function* func, - tfrt::HostContext* host) { - Tensor tensor(DataType::DT_INT32, TensorShape({1})); +tfrt::RCArray CreateTestArguments( + const tfrt::Function* func) { + size_t num_args = func->num_arguments(); std::vector> arguments; - arguments.reserve(func->argument_types().size()); + arguments.reserve(num_args); arguments.push_back(tfrt::GetReadyChain()); - for (int i = 1, e = func->argument_types().size(); i < e; ++i) { + Tensor tensor(DataType::DT_INT32, TensorShape({1})); + for (int i = 1; i < num_args; ++i) { arguments.push_back( tfrt::MakeAvailableAsyncValueRef(tensor)); } @@ -102,19 +102,21 @@ TEST(BatchFunctionTest, Basic) { auto* host = corert->GetHostContext(); auto [bef_buffer, bef_file] = CreateBefFile(host); auto* func = bef_file->GetFunction("main"); - CHECK(func); - CHECK_EQ(func->result_types().size(), 113); - CHECK_EQ(func->argument_types().size(), 113); + CHECK(func != nullptr); + size_t num_args = func->num_arguments(); + size_t num_results = func->num_results(); + CHECK_EQ(num_args, 113); + CHECK_EQ(num_results, 113); - auto arguments = CreateTestArguments(func, host); + auto arguments = CreateTestArguments(func); tfrt::ResourceContext resource_ctx; auto exec_ctx = CreateFallbackTestExecutionContext(host, &resource_ctx); std::vector> results; - results.resize(func->result_types().size()); + results.resize(num_results); std::vector> result_tensors; - result_tensors.resize(func->result_types().size() - 1); + result_tensors.resize(num_results - 1); func->Execute(exec_ctx, arguments.values(), results); host->Await(results); @@ -134,17 +136,19 @@ void BM_BatchFunctionFallbackWithLargeAttributesAndManyInputsOutputs( auto* host = corert->GetHostContext(); auto [bef_buffer, bef_file] = CreateBefFile(host); auto* func = bef_file->GetFunction("main"); - CHECK(func); - CHECK_EQ(func->result_types().size(), 113); - CHECK_EQ(func->argument_types().size(), 113); + CHECK(func != nullptr); + size_t num_args = func->num_arguments(); + size_t num_results = func->num_results(); + CHECK_EQ(num_args, 113); + CHECK_EQ(num_results, 113); - auto arguments = CreateTestArguments(func, host); + auto arguments = CreateTestArguments(func); tfrt::ResourceContext resource_ctx; auto exec_ctx = CreateFallbackTestExecutionContext(host, &resource_ctx); std::vector> results; - results.resize(func->result_types().size()); + results.resize(num_results); for (auto _ : state) { func->Execute(exec_ctx, arguments.values(), results); From 11b18328315eb090915837c43c361212f43ef653 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Mon, 4 Dec 2023 13:05:32 -0800 Subject: [PATCH 355/381] PR #7511: [ROCM] build brake fix 23/12/04 Imported from GitHub PR https://github.com/openxla/xla/pull/7511 This is an ongoing build brake fix due to gpu_driver/gpu_kernel changes. @xla-rotation: would you have a look, please ? Copybara import of the project: -- 5b8a3107e18e888a96260d816e05e064011210b7 by Pavel Emeliyanenko : rocm fix Merging this change closes #7511 PiperOrigin-RevId: 587816441 --- .../xla/xla/stream_executor/rocm/rocm_driver.cc | 9 +++++++++ .../xla/stream_executor/rocm/rocm_driver_wrapper.h | 1 + .../xla/xla/stream_executor/rocm/rocm_kernel.cc | 12 ++++++++++++ 3 files changed, 22 insertions(+) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 76af619f6328e6..41d5fe6a95e2e2 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -835,6 +835,15 @@ static hipMemAllocationType ToHipAllocationType( } } +/*static*/ tsl::Status GpuDriver::GraphAddMemFreeNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst) { + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemFreeNode(node, graph, deps.data(), + deps.size(), gpu_dst), + "Failed to add memory free node to a ROCM graph"); + return ::tsl::OkStatus(); +} + /*static*/ tsl::Status GpuDriver::GraphAddMemAllocNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, absl::Span deps, MemAccessFlags access_flags, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index f3356ffd83ddc2..020b5d706a03e7 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -108,6 +108,7 @@ namespace wrap { __macro(hipGraphAddMemcpyNode) \ __macro(hipGraphAddMemcpyNode1D) \ __macro(hipGraphAddMemsetNode) \ + __macro(hipGraphAddMemFreeNode) \ __macro(hipGraphCreate) \ __macro(hipGraphDebugDotPrint) \ __macro(hipGraphDestroy) \ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc index 5ebad6db18bb5a..c091fdd5f28ef3 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc @@ -34,5 +34,17 @@ hipFuncCache_t GpuKernel::GetGpuCacheConfig() const { } } +tsl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + int32_t threads_per_block = threads.x * threads.y * threads.z; + VLOG(0) << "Get kernel block occupancy: " << name_ + << "; threads_per_block: " << threads_per_block + << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + + return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, + threads_per_block, + dynamic_shared_memory_bytes); +} + } // namespace gpu } // namespace stream_executor From fffb030931fdc5484cab121b6ad8d9398a7c04e5 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Mon, 4 Dec 2023 13:44:46 -0800 Subject: [PATCH 356/381] #tf-data remove the cap and directly use double-typed gauge cell instead of int64-typed PiperOrigin-RevId: 587828209 --- tensorflow/core/framework/metrics.cc | 15 +++++++-------- tensorflow/core/framework/metrics.h | 2 +- tensorflow/core/framework/model.cc | 12 +++--------- .../third_party/tsl/tsl/lib/monitoring/gauge.h | 6 ++++-- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index eee1dc5e4ecdb1..e6f94a8e444b4a 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -266,12 +266,11 @@ auto* tf_data_model_gauge = tsl::monitoring::Gauge, 1>::New( "/tensorflow/data/model", "tf.data autotuning model proto.", "id"); -auto* tf_data_pipeline_processing_time = - tsl::monitoring::Gauge::New( - "/tensorflow/data/pipeline_processing_time", - "The total processing time of the slowest stage in the input pipeline " - "in microseconds", - "id"); +auto* tf_data_pipeline_processing_time = tsl::monitoring::Gauge::New( + "/tensorflow/data/pipeline_processing_time", + "The total processing time of the slowest stage in the input pipeline " + "in microseconds", + "id"); auto* tf_data_auto_shard = tsl::monitoring::Gauge::New( "/tensorflow/data/autoshard", "tf.data autoshard statistics.", "id", @@ -474,7 +473,7 @@ tsl::monitoring::GaugeCell>* GetTFDataModelGauge( return tf_data_model_gauge->GetCell(id); } -tsl::monitoring::GaugeCell* GetTFDataPipelineProcessingTimeGauge( +tsl::monitoring::GaugeCell* GetTFDataPipelineProcessingTimeGauge( const string& id) { return tf_data_pipeline_processing_time->GetCell(id); } @@ -835,7 +834,7 @@ void RecordUnusedOutput(const string& op_name) { } void RecordPipelineProcessingTime(const string& id, - int64_t pipeline_processing_time_usec) { + double pipeline_processing_time_usec) { GetTFDataPipelineProcessingTimeGauge(id)->Set(pipeline_processing_time_usec); } diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index ee42fae8c14fe9..bcc5808cf7f8e8 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -245,7 +245,7 @@ void RecordUnusedOutput(const string& op_name); // Records the pipeline processing time in microseconds void RecordPipelineProcessingTime(const string& id, - int64_t pipeline_processing_time_usec); + double pipeline_processing_time_usec); // Updates the metrics stored about time spent building graphs. // diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 449520172bbd15..3165ee33231739 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -2363,7 +2363,7 @@ void Model::Optimize(AutotuneAlgorithm algorithm, optimization_params_ = optimization_params; if (snapshot_) { - int64_t pipeline_processing_usec = 0; + double pipeline_processing_usec = 0; ModelTiming model_timing(snapshot_); auto bfs_stage_roots = model_timing.GetStageRoots(); for (const auto& root : bfs_stage_roots) { @@ -2382,14 +2382,8 @@ void Model::Optimize(AutotuneAlgorithm algorithm, root_timing->pipeline_ratio / EnvTime::kMicrosToNanos; - // Cap the computed value to ensure that there is no integer overflow. - if (root_total_time_usec < std::numeric_limits::max()) { - pipeline_processing_usec = - std::max(pipeline_processing_usec, - static_cast(root_total_time_usec)); - } else { - pipeline_processing_usec = std::numeric_limits::max(); - } + pipeline_processing_usec = + std::max(pipeline_processing_usec, root_total_time_usec); } // Only updates the pipeline processing time when it is greater than 0. // If it is zero, we assume the pipeline processing time is the same diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h index 2552013f3d7150..0b69383b5f2d13 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h @@ -65,8 +65,10 @@ class Gauge { std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double and string types."); return new Gauge(); } From d699fc4be2ae586e777425fecc97843cde0041cd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Dec 2023 13:46:01 -0800 Subject: [PATCH 357/381] [XLA:CPU] Make an experimental dependency on the MLIR GPU dialects optional. This code path isn't used by default, and linking in the MLIR GPU dialects increases binary size. PiperOrigin-RevId: 587828526 --- .../xla/xla/mlir/runtime/transforms/BUILD | 17 +++++++++-- .../transforms/compilation_pipeline_cpu.cc | 13 +++++++-- third_party/xla/xla/service/cpu/BUILD | 28 +++++++++++++++++-- .../service/cpu/hlo_xla_runtime_pipeline.cc | 12 ++++++-- 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD index d9151d90a5c200..b56d4930bfbd59 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD @@ -100,6 +100,12 @@ cc_library( srcs = ["compilation_pipeline_cpu.cc"], hdrs = ["compilation_pipeline_cpu.h"], compatible_with = get_compatible_with_portable(), + local_defines = select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "EXPERIMENTAL_MLIR_GPU=1", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -126,8 +132,6 @@ cc_library( "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUToGPURuntimeTransforms", - "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", @@ -142,7 +146,14 @@ cc_library( "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:X86VectorToLLVMIRTranslation", - ], + "@local_tsl//tsl/platform:logging", + ] + select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUTransforms", + ], + "//conditions:default": [], + }), alwayslink = 1, # has pipeline registration ) diff --git a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 50488cad7f13ff..22920be516e5d4 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project @@ -32,7 +31,6 @@ limitations under the License. #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project @@ -57,6 +55,12 @@ limitations under the License. #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir/runtime/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" +#include "tsl/platform/logging.h" + +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU namespace xla { namespace runtime { @@ -146,6 +150,7 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, llvm_options.enableAvx2 = opts.math_avx2; pm.addPass(mlir::hlo::createGenericHostToLLVMPass(llvm_options)); const bool gpuCodegen = opts.xla_cpu_sparse_cuda_threads > 0; +#ifdef EXPERIMENTAL_MLIR_GPU if (gpuCodegen) { #ifdef MLIR_GPU_TO_CUBIN_PASS_ENABLE pm.addNestedPass( @@ -154,6 +159,10 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, #endif pm.addPass(mlir::createGpuToLLVMConversionPass()); } +#else // EXPERIMENTAL_MLIR_GPU + CHECK(!gpuCodegen) + << "Experimental MLIR GPU code generation was not enabled at build time"; +#endif // EXPERIMENTAL_MLIR_GPU pm.addPass(mlir::createReconcileUnrealizedCastsPass()); // Prepare module for translation to LLVM. diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 6f782102a843b6..448ecbc8e1bf6c 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -2,6 +2,7 @@ # LLVM-based CPU backend for XLA. load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load( "//xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS", @@ -45,6 +46,19 @@ filegroup( visibility = ["//visibility:public"], ) +bool_flag( + name = "experimental_mlir_gpu", + build_setting_default = False, +) + +config_setting( + name = "experimental_mlir_gpu_enabled", + flag_values = { + ":experimental_mlir_gpu": "True", + }, + visibility = ["//visibility:public"], +) + cc_library( name = "test_header_helper", testonly = True, @@ -468,6 +482,10 @@ cc_library( name = "hlo_xla_runtime_pipeline", srcs = ["hlo_xla_runtime_pipeline.cc"], hdrs = ["hlo_xla_runtime_pipeline.h"], + local_defines = select({ + ":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ "//xla:status", @@ -483,8 +501,6 @@ cc_library( "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", @@ -502,7 +518,13 @@ cc_library( "@llvm-project//mlir:VectorTransforms", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - ], + ] + select({ + ":experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + ], + "//conditions:default": [], + }), alwayslink = 1, # has pipeline registration ) diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index caff8ee13e90f7..0dafc4c9235d3e 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project @@ -33,7 +32,6 @@ limitations under the License. #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project @@ -56,6 +54,11 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU + namespace xla { namespace cpu { namespace { @@ -111,6 +114,7 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass( mlir::bufferization::createFinalizingBufferizePass()); +#ifdef EXPERIMENTAL_MLIR_GPU // Sparse GPU acceleration lowers to GPU dialect. if (gpu_codegen) { pm.addPass( @@ -120,6 +124,10 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, pm.addNestedPass( mlir::createConvertGpuOpsToNVVMOps()); } +#else // EXPERIMENTAL_MLIR_GPU + CHECK(!gpu_codegen) + << "Experimental MLIR GPU code generation was not enabled at build time"; +#endif // EXPERIMENTAL_MLIR_GPU } void AddSparsificationPassPipeline(mlir::OpPassManager& pm) { From a0f6253af843e0453d70de8a743da147332d85ca Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 4 Dec 2023 14:48:55 -0800 Subject: [PATCH 358/381] [xla:gpu] Add support for BF16 CUTLASS gemm fusions PiperOrigin-RevId: 587848029 --- third_party/xla/xla/service/gpu/kernels/BUILD | 1 + .../gpu/kernels/cutlass_gemm_fusion.cc | 25 +++++++++++-------- .../gpu/kernels/cutlass_gemm_kernels.cu.h | 5 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 3c26363e03669e..f3ffcb803211f3 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -92,6 +92,7 @@ cc_library( # ":cutlass_gemm", # ":cutlass_gemm_custom_kernel", # "@com_google_absl//absl/status", +# "@com_google_absl//absl/types:span", # "//xla:shape_util", # "//xla:status", # "//xla:statusor", diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 1e3e794f871c4f..b3bc09a9b148dc 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -94,16 +95,19 @@ static Status MatchRowMajorGemm(HloDotInstruction* dot) { // Return OK if dot instruction is a simple gemm with all operands and result // having the same data type. -static Status MatchSimpleGemm(HloDotInstruction* dot, PrimitiveType dtype) { +static Status MatchSimpleGemm(HloDotInstruction* dot, + absl::Span support_dtypes) { TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); - if (dot->operand(0)->shape().element_type() != dtype || - dot->operand(1)->shape().element_type() != dtype || - dot->shape().element_type() != dtype) { - return absl::InternalError("operands and result must have the same type"); + for (PrimitiveType dtype : support_dtypes) { + if (dot->operand(0)->shape().element_type() == dtype && + dot->operand(1)->shape().element_type() == dtype && + dot->shape().element_type() == dtype) { + return OkStatus(); + } } - return OkStatus(); + return absl::InternalError("unsupported operands type"); } // Returns matched GEMM with one of the operands upcasted to the accumulator @@ -153,7 +157,7 @@ std::optional CutlassGemmPattern::TryMatch( auto* dot = DynCast(instr); if (!dot) return std::nullopt; - auto matched = MatchSimpleGemm(dot, PrimitiveType::F32); + auto matched = MatchSimpleGemm(dot, {PrimitiveType::F32}); if (!matched.ok()) return std::nullopt; CustomFusionConfig config; @@ -224,7 +228,7 @@ class CutlassGemmFusion : public CustomFusion { "cutlass_gemm requires ROOT operation to be a dot"); } - TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, PrimitiveType::F32)); + TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, {PrimitiveType::F32})); auto dtype = dot->shape().element_type(); @@ -293,8 +297,9 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomFusion { } TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithDynamicUpdateSlice(dus)); - TF_RETURN_IF_ERROR(MatchSimpleGemm(Cast(matched.dot), - PrimitiveType::F32)); + TF_RETURN_IF_ERROR( + MatchSimpleGemm(Cast(matched.dot), + {PrimitiveType::F32, PrimitiveType::BF16})); auto dtype = matched.dot->shape().element_type(); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h index fbd5715b564a75..3e25873e386358 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernels.cu.h @@ -52,8 +52,9 @@ __global__ void Kernel(typename Gemm::Params params, char* ptr_c = reinterpret_cast(params.ptr_C); char* ptr_d = reinterpret_cast(params.ptr_D); - params.ptr_C = ptr_c + 4 * out_offset * (m * n); - params.ptr_D = ptr_d + 4 * out_offset * (m * n); + using ElementC = typename Gemm::ElementC; + params.ptr_C = ptr_c + sizeof(ElementC) * out_offset * (m * n); + params.ptr_D = ptr_d + sizeof(ElementC) * out_offset * (m * n); } Gemm::invoke(params, *shared_storage); From f837e0d34df9fa19e85157bda8f6fa5a2111f409 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 14:50:15 -0800 Subject: [PATCH 359/381] Get rid of a couple unused fields in AutoShardingOption. PiperOrigin-RevId: 587848438 --- .../auto_sharding/auto_sharding.cc | 30 ++++++++----------- .../auto_sharding/auto_sharding_option.cc | 7 ----- .../auto_sharding/auto_sharding_option.h | 7 +---- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 7a9a54ae180bc9..0f0b92b07ca2d8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3592,24 +3592,20 @@ StatusOr AutoShardingImplementation::RunAutoSharding( std::vector s_val; std::vector e_val; double objective = -1.0; - if (!option_.load_solution_vector) { - auto solver_result = - Solve(*module, *hlo_live_range, liveness_node_set, strategy_map, - strategy_groups, cost_graph, alias_set, option_, - sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } else { - TF_ASSIGN_OR_RETURN(auto solution, solver_result.status); - std::tie(s_val, e_val, objective) = solution; - if (mesh_idx == partial_mesh_shapes.size() - 1) { - this->solver_optimal_objective_value_ = objective; - } - } + auto solver_result = + Solve(*module, *hlo_live_range, liveness_node_set, strategy_map, + strategy_groups, cost_graph, alias_set, option_, + sharding_propagation_solution); + if (solver_result.skip_auto_sharding) { + return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; + } else if (!solver_result.status.ok()) { + return AutoShardingResult::kModuleUnchanged; } else { - s_val = option_.strategy_vector; + TF_ASSIGN_OR_RETURN(auto solution, solver_result.status); + std::tie(s_val, e_val, objective) = solution; + if (mesh_idx == partial_mesh_shapes.size() - 1) { + this->solver_optimal_objective_value_ = objective; + } } XLA_VLOG_LINES(5, PrintAutoShardingSolution(sequence, liveness_set, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index 661535c9db027c..1f32867e164a05 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -85,7 +85,6 @@ std::string AutoShardingOption::ToString() const { absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape)); lines.push_back( absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches)); - lines.push_back(absl::StrCat("load_solution_vector: ", load_solution_vector)); lines.push_back( absl::StrCat("force_simple_heuristic: ", force_simple_heuristic)); lines.push_back(absl::StrCat("force_strategy: ", force_strategy)); @@ -118,12 +117,6 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("device_mesh_beta: [", absl::StrJoin(device_mesh_beta, ","), "]")); - lines.push_back(absl::StrCat("load_strategy: ", load_strategy)); - if (load_strategy) { - lines.push_back(absl::StrCat("strategy_vector: [", - absl::StrJoin(strategy_vector, ","), "]")); - } - return absl::StrJoin(lines, "\n"); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 775c4a59bc301e..fe950b190df245 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -116,9 +116,6 @@ struct AutoShardingOption { // is divided by this number. int grad_acc_num_micro_batches = 1; - // If true, load solution vector from PassContext - bool load_solution_vector = false; - // If true, N-D sharding (e.g., N maybe be 2 or 3) will be solved in N // iterations, where one iteration chooses one tensor dimension to shard. If // false, solve N-D sharding directly, i.e., generating all possible sharding @@ -161,8 +158,7 @@ struct AutoShardingOption { // element models the communication performance along each mesh dimension. std::vector device_mesh_alpha; std::vector device_mesh_beta; - // Load the strategy vector instead of solving one. - bool load_strategy = false; + // Explore other mesh shapes with the same number of devices as the provided // one for a potentially better auto-sharding solution. bool try_multiple_mesh_shapes = false; @@ -180,7 +176,6 @@ struct AutoShardingOption { // smaller Mixed ILP). bool allow_alias_to_follower_conversion = true; - std::vector strategy_vector; // If greater than zero, tensors with size smaller than or equal to this limit // will always be replicated if they don't have a different user-specified // sharding. From f65a35d1a2c1a2f0bf5833431071afc878b62830 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 14:52:25 -0800 Subject: [PATCH 360/381] #tf-data Annotate IO read statistics to xprof to aid user debugging. PiperOrigin-RevId: 587849184 --- tensorflow/core/data/root_dataset.cc | 24 ------------------- tensorflow/core/platform/host_info.h | 1 - .../tsl/tsl/platform/default/port.cc | 6 ++--- .../third_party/tsl/tsl/platform/host_info.h | 18 -------------- .../tsl/tsl/platform/windows/port.cc | 5 ++-- 5 files changed, 4 insertions(+), 50 deletions(-) diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index bba8a426366329..55ff2bc8122213 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" -#include "tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -47,8 +46,6 @@ constexpr char kDatasetType[] = "Root"; constexpr char kAlgorithm[] = "algorithm"; constexpr char kCpuBudget[] = "cpu_budget"; constexpr char kExperiments[] = "experiments"; -constexpr char kReadRoundtripLatency[] = "read_latency_usec"; -constexpr char kReadResponseBytes[] = "read_bytes"; constexpr char kIntraOpParallelism[] = "intra_op_parallelism"; constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec"; constexpr char kPrivateThreadpoolSize[] = "threadpool_size"; @@ -280,27 +277,6 @@ class RootDataset::Iterator : public DatasetIterator { "%lld", static_cast( model_node()->TotalMaximumBufferedBytes() / 1.0e6)))); } - const auto io_statistics = tsl::port::GetIOStatistics(); - if (io_statistics.roundtrip_latency_usec.count > 0) { - traceme_metadata.push_back(std::make_pair( - kReadRoundtripLatency, - strings::Printf( - "(count: %lld, mean: %lld, std dev: %lld)", - static_cast( - io_statistics.roundtrip_latency_usec.count), - static_cast(io_statistics.roundtrip_latency_usec.mean), - static_cast( - io_statistics.roundtrip_latency_usec.std_dev)))); - } - if (io_statistics.response_bytes.count > 0) { - traceme_metadata.push_back(std::make_pair( - kReadResponseBytes, - strings::Printf( - "(count: %lld, mean: %lld, std dev: %lld)", - static_cast(io_statistics.response_bytes.count), - static_cast(io_statistics.response_bytes.mean), - static_cast(io_statistics.response_bytes.std_dev)))); - } return traceme_metadata; } diff --git a/tensorflow/core/platform/host_info.h b/tensorflow/core/platform/host_info.h index caab7ae380b31b..89d495d6e41229 100644 --- a/tensorflow/core/platform/host_info.h +++ b/tensorflow/core/platform/host_info.h @@ -22,7 +22,6 @@ limitations under the License. namespace tensorflow { namespace port { using tsl::port::Hostname; -using tsl::port::IOStatistics; using tsl::port::JobName; using tsl::port::JobUid; } // namespace port diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc index 868fb35f887dab..c2151c78ec5330 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc @@ -15,7 +15,6 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" @@ -257,6 +256,7 @@ int NUMAGetThreadNodeAffinity() { return node_index; } + void* NUMAMalloc(int node, size_t size, int minimum_alignment) { #ifdef TENSORFLOW_USE_NUMA if (HaveHWLocTopology()) { @@ -307,6 +307,7 @@ int NUMAGetMemAffinity(const void* addr) { return node; } + bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -446,8 +447,5 @@ MemoryBandwidthInfo GetMemoryBandwidthInfo() { MemoryBandwidthInfo membw_info = {INT64_MAX}; return membw_info; } - -IOStatistics GetIOStatistics() { return IOStatistics(); } - } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/host_info.h b/third_party/xla/third_party/tsl/tsl/platform/host_info.h index 630f9424525e04..189f3be2934ce3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/host_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/host_info.h @@ -16,26 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ #define TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ -#include - #include "tsl/platform/types.h" namespace tsl { namespace port { -// Statistical data of IO operations performed by the job. -struct IOStatistics { - struct Distribution { - uint64_t count = 0; - double mean = 0.0; - double std_dev = 0.0; - }; - // Distribution of round trip IO latency in microseconds. - Distribution roundtrip_latency_usec; - // Distribution of data received by IO reads in bytes. - Distribution response_bytes; -}; - // Return the hostname of the machine on which this process is running. string Hostname(); @@ -49,9 +34,6 @@ int64_t JobUid(); // Returns the Borg task ID as an int64_t if it exists. Otherwise return -1. int64_t TaskId(); -// Retrieves the host file read statistics. -IOStatistics GetIOStatistics(); - } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc index f8e19503edb305..9b5692650dbb5c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc @@ -61,8 +61,6 @@ int64_t JobUid() { return -1; } int64_t TaskId() { return -1; } -IOStatistics GetIOStatistics() { return IOStatistics(); } - int NumSchedulableCPUs() { SYSTEM_INFO system_info; GetSystemInfo(&system_info); @@ -124,6 +122,7 @@ void NUMAFree(void* ptr, size_t size) { tsl::port::Free(ptr); } int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; } + bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -184,7 +183,7 @@ string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { DWORD data; DWORD data_size = sizeof(data); -#pragma comment(lib, "shlwapi.lib") // For SHGetValue(). + #pragma comment(lib, "shlwapi.lib") // For SHGetValue(). if (SUCCEEDED( SHGetValueA(HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", From f2a46836cfad96f9ab8c12e429fe6e0570483f5f Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Mon, 4 Dec 2023 14:54:36 -0800 Subject: [PATCH 361/381] [stream_executor] Remove GetSubBuffer PiperOrigin-RevId: 587849841 --- .../stream_executor/stream_executor.cc | 4 --- .../xla/xla/backends/interpreter/executor.cc | 6 ---- .../xla/xla/backends/interpreter/executor.h | 2 -- .../xla/stream_executor/cuda/cuda_executor.cc | 6 ---- .../xla/stream_executor/gpu/gpu_executor.h | 3 -- .../stream_executor/host/host_gpu_executor.cc | 5 --- .../stream_executor/host/host_gpu_executor.h | 2 -- .../stream_executor/rocm/rocm_gpu_executor.cc | 6 ---- .../stream_executor_internal.h | 2 -- .../stream_executor/stream_executor_pimpl.cc | 5 --- .../stream_executor/stream_executor_pimpl.h | 34 ------------------- .../xla/stream_executor/tpu/tpu_executor.h | 4 --- 12 files changed, 79 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 3fcd255a2248ab..12391143a4d9e0 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -230,10 +230,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase Allocate(uint64 size) { return Allocate(size, /*memory_space=*/0); } - void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, - uint64 size) override { - LOG(FATAL) << "GetSubBuffer is not supported by pluggable device."; - } void Deallocate(DeviceMemoryBase* mem) override { SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem); diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index 3766f7cb7af82c..1095c71a86b226 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -34,12 +34,6 @@ DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64_t size, return DeviceMemoryBase(new char[size], size); } -void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent, - uint64_t offset_bytes, - uint64_t /*size_bytes*/) { - return parent + offset_bytes; -} - void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { delete[] static_cast(mem->opaque()); } diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 683d83a15d58e0..5d866462950072 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -64,8 +64,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes, - uint64_t size_bytes) override; void Deallocate(DeviceMemoryBase *mem) override; void *HostMemoryAllocate(uint64_t size) override { return new char[size]; } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 419590c30a08be..e29fd1b1806483 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -573,12 +573,6 @@ DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { return DeviceMemoryBase(GpuDriver::DeviceAllocate(context_, size), size); } -void* GpuExecutor::GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) { - // offset and size are in bytes, so char* works as the pointer type. - return reinterpret_cast(mem->opaque()) + offset_bytes; -} - void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 1055c167c9d946..5e3f46f901d681 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -140,9 +140,6 @@ class GpuExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void* GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) override; - void Deallocate(DeviceMemoryBase* mem) override; void* UnifiedMemoryAllocate(uint64_t size) override { diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc b/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc index 254d2ac02c4d39..906cad0c7f1aa6 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc @@ -73,11 +73,6 @@ DeviceMemoryBase HostExecutor::Allocate(uint64_t size, int64_t memory_space) { tsl::port::AlignedMalloc(size, /*minimum_alignment=*/64), size); } -void* HostExecutor::GetSubBuffer(DeviceMemoryBase* parent, - uint64_t offset_bytes, uint64_t size_bytes) { - return reinterpret_cast(parent->opaque()) + offset_bytes; -} - void HostExecutor::Deallocate(DeviceMemoryBase* mem) { tsl::port::AlignedFree(mem->opaque()); } diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h index 6ca6d0bb6594c7..51c97e6b59ddc2 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h @@ -60,8 +60,6 @@ class HostExecutor : public internal::StreamExecutorInterface { } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset_bytes, - uint64_t size_bytes) override; void Deallocate(DeviceMemoryBase* mem) override; void* HostMemoryAllocate(uint64_t size) override { return new char[size]; } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index 82e1c71186bd71..2a13df20ab5dc6 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -415,12 +415,6 @@ DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { return DeviceMemoryBase(GpuDriver::DeviceAllocate(context_, size), size); } -void* GpuExecutor::GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) { - // offset and size are in bytes, so char* works as the pointer type. - return reinterpret_cast(mem->opaque()) + offset_bytes; -} - void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index c7bd1736a8314b..70cef69edc7508 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -331,8 +331,6 @@ class StreamExecutorInterface { DeviceMemoryBase Allocate(uint64_t size) { return Allocate(size, /*memory_space=*/0); } - virtual void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size) = 0; virtual void Deallocate(DeviceMemoryBase* mem) = 0; // Allocates unified memory space of the given size, if supported. // See diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index a58beb89c6bab2..aeadfa10f03dff 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -482,11 +482,6 @@ DeviceMemoryBase StreamExecutor::Allocate(uint64_t size, int64_t memory_space) { return buf; } -void* StreamExecutor::GetUntypedSubBuffer(DeviceMemoryBase* parent, - uint64_t offset, uint64_t size) { - return implementation_->GetSubBuffer(parent, offset, size); -} - tsl::StatusOr StreamExecutor::GetUntypedSymbol( const std::string& symbol_name, ModuleHandle module_handle) { // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index d7235077d44f43..f6a655c6580b8e 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -155,18 +155,6 @@ class StreamExecutor { return AllocateOwnedArray(1); } - // Allocate a memory region inside another allocated memory region. - // Offset and size are specified in terms of T elements. - // Warning: Do not free a parent buffer before its sub-buffers; this may cause - // use-after-free issues (the specific behavior is not consistent across - // platforms). - // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a - // sub-buffer after parent deallocation is expected to be safe. This will - // render your code non-platform-portable, however. - template - DeviceMemory GetSubBuffer(DeviceMemory* parent, uint64_t element_offset, - uint64_t element_count); - // An untyped version of GetSymbol. tsl::StatusOr GetUntypedSymbol( const std::string& symbol_name, ModuleHandle module_handle); @@ -493,9 +481,6 @@ class StreamExecutor { // nullptr is returned. DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space); - void* GetUntypedSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size); - // Causes the host code to synchronously wait for operations entrained // onto stream to complete. Effectively a join on the asynchronous device // operations enqueued on the stream before this program point. @@ -750,25 +735,6 @@ ScopedDeviceMemory::ScopedDeviceMemory( } } -template -DeviceMemory StreamExecutor::GetSubBuffer(DeviceMemory* parent, - uint64_t element_offset, - uint64_t element_count) { - if (element_offset + element_count > parent->ElementCount()) { - LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater " - << "than parent allocation size: (" << element_offset << " + " - << element_count << ") vs. (" << parent->ElementCount() << ")"; - return DeviceMemory{}; - } - - void* opaque = GetUntypedSubBuffer(parent, sizeof(T) * element_offset, - sizeof(T) * element_count); - if (opaque == nullptr) { - return DeviceMemory{}; - } - return DeviceMemory(DeviceMemoryBase(opaque, sizeof(T) * element_count)); -} - } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 46ebe431a49fad..c4cd3ed11c1b9c 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -171,10 +171,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { LOG(FATAL) << "Not yet implemented"; } - void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size) override { - LOG(FATAL) << "not yet implemented"; - } tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, uint64_t size) override { LOG(FATAL) << "not yet implemented"; From 57c33a2cb286a865994d4a8cacb8808d6f69be8b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 4 Dec 2023 15:01:07 -0800 Subject: [PATCH 362/381] [xla:ffi] Add XLA_FFI_Error_GetMessage API PiperOrigin-RevId: 587851943 --- third_party/xla/xla/ffi/api/c_api.h | 12 ++++++++++++ third_party/xla/xla/ffi/api/ffi.h | 26 ++++++++++++++++++-------- third_party/xla/xla/ffi/ffi_api.cc | 13 +++++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index cd6eb0fe1e7553..3092b7755531ca 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -118,6 +118,17 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Create_Args, errc); typedef XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args); +struct XLA_FFI_Error_GetMessage_Args { + size_t struct_size; + void* priv; + XLA_FFI_Error* error; + const char* message; // out +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_GetMessage_Args, message); + +typedef void XLA_FFI_Error_GetMessage(XLA_FFI_Error_GetMessage_Args* args); + struct XLA_FFI_Error_Destroy_Args { size_t struct_size; void* priv; @@ -296,6 +307,7 @@ struct XLA_FFI_Api { XLA_FFI_InternalApi* internal_api; _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get); diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index f44f6327091b5a..a2c662cabf6d92 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -255,7 +255,7 @@ struct CtxDecoding> { static std::optional Decode(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, - DiagnosticEngine&) { + DiagnosticEngine& diagnostic) { XLA_FFI_Stream_Get_Args args; args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; args.priv = nullptr; @@ -263,6 +263,8 @@ struct CtxDecoding> { args.stream = nullptr; if (XLA_FFI_Error* error = api->XLA_FFI_Stream_Get(&args); error) { + diagnostic.Emit("Failed to get platform stream: ") + << GetErrorMessage(api, error); DestroyError(api, error); return std::nullopt; } @@ -270,14 +272,22 @@ struct CtxDecoding> { return reinterpret_cast(args.stream); } - // TODO(ezhulenev): We need to log error message somewhere, currently we - // silently destroy it. + static const char* GetErrorMessage(const XLA_FFI_Api* api, + XLA_FFI_Error* error) { + XLA_FFI_Error_GetMessage_Args args; + args.struct_size = XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE; + args.priv = nullptr; + args.error = error; + api->XLA_FFI_Error_GetMessage(&args); + return args.message; + } + static void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) { - XLA_FFI_Error_Destroy_Args destroy_args; - destroy_args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; - destroy_args.error = error; - api->XLA_FFI_Error_Destroy(&destroy_args); + XLA_FFI_Error_Destroy_Args args; + args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; + args.priv = nullptr; + args.error = error; + api->XLA_FFI_Error_Destroy(&args); } }; diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 303ba5ea6cdf05..e9d2bdf9b0fbb6 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -183,6 +183,18 @@ static XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args) { return new XLA_FFI_Error{Status(ToStatusCode(args->errc), args->message)}; } +static void XLA_FFI_Error_GetMessage(XLA_FFI_Error_GetMessage_Args* args) { + Status struct_size_check = ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_Error_GetMessage", XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE, + args->struct_size); + if (!struct_size_check.ok()) { + LOG(ERROR) << struct_size_check.message(); + } + // absl::Status owns error message in a std::string which guarantees that + // we'll get a null terminated string. + args->message = args->error->status.message().data(); +} + static void XLA_FFI_Error_Destroy(XLA_FFI_Error_Destroy_Args* args) { Status struct_size_check = ActualStructSizeIsGreaterOrEqual( "XLA_FFI_Error_Destroy", XLA_FFI_Error_Destroy_Args_STRUCT_SIZE, @@ -245,6 +257,7 @@ static XLA_FFI_Api api = { &internal_api, XLA_FFI_Error_Create, // creates error + XLA_FFI_Error_GetMessage, // get error message XLA_FFI_Error_Destroy, // frees error XLA_FFI_Handler_Register, // registers handler XLA_FFI_Stream_Get, // returns platform specific stream From cc8ae1a2d742a2f58e1e2319bd21b3840bd89cc8 Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Mon, 4 Dec 2023 15:14:05 -0800 Subject: [PATCH 363/381] Add dialect verification to dialect to verify_input_dialect_to_executor_pass. PiperOrigin-RevId: 587856220 --- .../compiler/mlir/tf2xla/internal/passes/BUILD | 1 + .../verify_input_dialect_to_executor_pass.cc | 9 +++++++++ ...rify_input_dialect_to_executor_pass_test.mlir | 16 ++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 367733da134745..1ddc3cbfd08d32 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -229,6 +229,7 @@ cc_library( deps = [ ":dialect_to_executor_passes_inc_gen", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:dialect_detection_utils", "//tensorflow/core:framework", "//tensorflow/core/transforms/toposort:Pass", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc index fe08a58b726c00..bfaf67054134c6 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h" namespace tensorflow { namespace tf2xla { @@ -47,6 +48,14 @@ void VerifyInputDialectToExecutorPass::runOnOperation() { Operation* func_op = getOperation(); auto walk_result = func_op->walk([&](Operation* op) { + if (!tensorflow::tf2xla::internal::IsInBridgeAcceptableDialects(op)) { + std::string error = "op is in dialect " + + op->getDialect()->getNamespace().str() + + " which is not an accepted dialect"; + op->emitError() << error; + return WalkResult::interrupt(); + } + if (IsTfDeviceClusterFuncOp(op)) { std::string error = "failed TF functional to executor validation, op " diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir index 121f60b40ec5ee..5a6fda697d23fa 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir @@ -15,4 +15,20 @@ func.func @testClusterFuncOpFails(%arg0: tensor) -> tensor { // expected-error@below {{failed TF functional to executor validation, op tf_device.cluster_func is not allowed}} %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor) -> tensor func.return %cluster : tensor +} + +// ----- + +// CHECK-LABEL: func @testTFDialect +func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> { + %0 = "tf.Identity"(%arg0) : (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> + func.return %0 : tensor<4x2x!tf_type.string> +} + +// ----- + +func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // expected-error@below {{op is in dialect chlo which is not an accepted dialect}} + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> } \ No newline at end of file From 2e289d34ffd1a0aa322db2f18ef8bb98cdd43d87 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Mon, 4 Dec 2023 15:19:35 -0800 Subject: [PATCH 364/381] Refactor the splitting functionality out of XlaSplitNDBaseOp into a standalone utility PiperOrigin-RevId: 587857877 --- tensorflow/core/tpu/kernels/BUILD | 41 ++ .../core/tpu/kernels/sharding_util_ops.cc | 452 ++---------------- tensorflow/core/tpu/kernels/sharding_utils.cc | 237 +++++++++ tensorflow/core/tpu/kernels/sharding_utils.h | 308 ++++++++++++ .../core/tpu/kernels/sharding_utils_test.cc | 281 +++++++++++ 5 files changed, 900 insertions(+), 419 deletions(-) create mode 100644 tensorflow/core/tpu/kernels/sharding_utils.cc create mode 100644 tensorflow/core/tpu/kernels/sharding_utils.h create mode 100644 tensorflow/core/tpu/kernels/sharding_utils_test.cc diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index f2cfb5c75531f5..4f3c8b82373069 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -46,6 +46,7 @@ package_group( packages = [ "//tensorflow/compiler/mlir/quantization/...", "//tensorflow/compiler/mlir/tf2xla/...", + "//tensorflow/core/tfrt/ifrt/...", "//tensorflow/core/tpu/...", "//tensorflow/dtensor/...", "//third_party/py/jax_tpu_embedding/...", @@ -1342,6 +1343,7 @@ cc_library( name = "sharding_util_ops", srcs = ["sharding_util_ops.cc"], deps = [ + ":sharding_utils", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core/framework:op_requires", @@ -1349,6 +1351,7 @@ cc_library( "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1360,6 +1363,44 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "sharding_utils", + srcs = ["sharding_utils.cc"], + hdrs = ["sharding_utils.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + ], +) + +tf_cc_test( + name = "sharding_utils_test", + srcs = ["sharding_utils_test.cc"], + deps = [ + ":sharding_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + ], +) + tf_kernel_library( name = "global_iter_id_op", srcs = ["global_iter_id.cc"], diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops.cc b/tensorflow/core/tpu/kernels/sharding_util_ops.cc index 5513547b6bb67b..beb70bfd8b25b6 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops.cc @@ -15,11 +15,14 @@ limitations under the License. #include #include +#include +#include #include #include #define EIGEN_USE_THREADS +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -34,10 +37,12 @@ limitations under the License. #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tpu/kernels/sharding_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" @@ -129,454 +134,63 @@ Status CreateResourceInvalidDTypeError(const ResourceHandle& handle, DataTypeString(expected_dtype), ".")); } -// Converts flatten index to start indices (subscript scaled with slice shape) -// for determining where to start a slice in the input tensor. -template -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); - -template -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, - const int index) { - return Eigen::DSizes(); -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[0] = index * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[1] = (index % num_partitions[1]) * slice_shape[1]; - subscript[0] = (index / num_partitions[1]) * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[2] = (index % num_partitions[2]) * slice_shape[2]; - subscript[1] = - ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1]; - subscript[0] = - (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[3] = (index % num_partitions[3]) * slice_shape[3]; - subscript[2] = - ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2]; - subscript[1] = - ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[4] = (index % num_partitions[4]) * slice_shape[4]; - subscript[3] = - ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3]; - subscript[2] = - ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = (index / (num_partitions[4] * num_partitions[3] * - num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[5] = (index % num_partitions[5]) * slice_shape[5]; - subscript[4] = - ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4]; - subscript[3] = - ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) * - slice_shape[3]; - subscript[2] = - ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = ((index / (num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] * - num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[6] = (index % num_partitions[6]) * slice_shape[6]; - subscript[5] = - ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5]; - subscript[4] = - ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) * - slice_shape[4]; - subscript[3] = - ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) % - num_partitions[3]) * - slice_shape[3]; - subscript[2] = ((index / (num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[7] = (index % num_partitions[7]) * slice_shape[7]; - subscript[6] = - ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6]; - subscript[5] = - ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) * - slice_shape[5]; - subscript[4] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) % - num_partitions[4]) * - slice_shape[4]; - subscript[3] = ((index / (num_partitions[7] * num_partitions[6] * - num_partitions[5] * num_partitions[4])) % - num_partitions[3]) * - slice_shape[3]; - subscript[2] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3] * num_partitions[2] * - num_partitions[1])) * - slice_shape[0]; - return subscript; -} - constexpr absl::string_view kTensorName = "'input' tensor"; constexpr absl::string_view kResourceName = "'resource' variable tensor"; -template -Eigen::DSizes TF_ATTRIBUTE_NOINLINE -ShapeAsEigenDSizes(const TensorShape& shape); -template -Eigen::DSizes ShapeAsEigenDSizes( - const TensorShape& shape) { - return shape.AsEigenDSizes(); -} - -bool TF_ATTRIBUTE_NOINLINE -ValidateShapesForSlice(OpKernelContext* ctx, bool resource, const Tensor* input, - const std::vector& num_splits, - const std::vector& paddings); - -bool ValidateShapesForSlice(OpKernelContext* ctx, bool resource, - const Tensor* input, - const std::vector& num_splits, - const std::vector& paddings) { - const auto& ishape = input->shape(); - - Status s; - - absl::string_view input_name = resource ? kResourceName : kTensorName; - const int rank = ishape.dims(); - const auto& input_shape = ishape.dim_sizes(); - if (rank <= 0 || rank > 8) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " must have rank in range (0, 8], but got ", rank, ".")); - } else if (rank != num_splits.size()) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " rank must be the same as 'num_splits' length ", - num_splits.size(), ", but got rank ", rank, ".")); - } else { - for (int dim = 0; dim < rank; ++dim) { - const auto input_shape_dim = input_shape[dim]; - const auto paddings_dim = paddings[dim]; - const auto num_splits_dim = num_splits[dim]; - if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " shape dimension ", dim, " (", input_shape_dim, - ") with padding ", paddings_dim, - " must be evenly divisible by 'num_splits' ", num_splits_dim, ".")); - break; - } - } - } - if (!s.ok()) { - ctx->CtxFailure(__FILE__, __LINE__, s); - return false; - } - return true; -} - // Shared base class to save code space +template class XlaSplitNDShared : public OpKernel { public: explicit TF_ATTRIBUTE_NOINLINE XlaSplitNDShared(OpKernelConstruction* ctx) - : OpKernel(ctx), num_slices_(1), has_paddings_(false) { - GetAndValidateAttributes(/*split=*/true, ctx, num_splits_, num_slices_, - paddings_, has_paddings_); + : OpKernel(ctx) { + std::vector num_splits; + int num_slices = 1; + std::vector paddings; + bool has_paddings = false; + + GetAndValidateAttributes(/*split=*/true, ctx, num_splits, num_slices, + paddings, has_paddings); + + auto xla_nd_splitter = XlaNDSplitter::Create( + num_splits, num_slices, paddings, has_paddings); + OP_REQUIRES_OK(ctx, xla_nd_splitter.status()); + splitter_ = *std::move(xla_nd_splitter); } protected: - template - class SliceAndMaybePadState { - public: - int num_complete_pad_dims_; - int num_partial_pad_dims_; - TensorShape non_padded_slice_shape_; - Eigen::array, Rank> slice_paddings_; - Eigen::DSizes slice_indices_; - Eigen::DSizes output_slice_shape_dsizes_; - Eigen::DSizes non_padded_slice_shape_dsizes_; - - TF_ATTRIBUTE_NOINLINE SliceAndMaybePadState( - absl::Span num_splits, - const absl::Span input_shape, - const TensorShape& output_slice_shape, int slice_index) { - output_slice_shape_dsizes_ = ShapeAsEigenDSizes(output_slice_shape); - num_complete_pad_dims_ = 0; - num_partial_pad_dims_ = 0; - slice_indices_ = GetSliceIndices( - num_splits, output_slice_shape_dsizes_, slice_index); - - // Calculate paddings necessary for slice instead of padding input and - // slicing subsequently to reduce temporary memory allocation. - for (int dim = 0; dim < Rank; ++dim) { - const int64_t dim_size = input_shape[dim]; - const int64_t out_dim = output_slice_shape_dsizes_[dim]; - int64_t non_padded_dim = 0; - if (slice_indices_[dim] >= dim_size) { - // Complete padding. - slice_indices_[dim] = dim_size; - non_padded_dim = 0; - slice_paddings_[dim] = {0, out_dim}; - num_complete_pad_dims_++; - } else if (slice_indices_[dim] + out_dim > dim_size) { - // Partial padding. - non_padded_dim = dim_size - slice_indices_[dim]; - slice_paddings_[dim] = {0, out_dim - non_padded_dim}; - num_partial_pad_dims_++; - } else { - non_padded_dim = out_dim; - } - non_padded_slice_shape_.AddDim(non_padded_dim); - } - non_padded_slice_shape_dsizes_ = - ShapeAsEigenDSizes(non_padded_slice_shape_); - } - }; - static void TF_ATTRIBUTE_NOINLINE GetDtypeHelper(OpKernelConstruction* ctx, const char* attr_name, DataType* dtype_ptr) { OP_REQUIRES_OK(ctx, ctx->GetAttr(attr_name, dtype_ptr)); } - std::vector num_splits_; - int num_slices_; - std::vector paddings_; - bool has_paddings_; + std::optional> splitter_; }; template -class XlaSplitNDBaseOp : public XlaSplitNDShared { +class XlaSplitNDBaseOp : public XlaSplitNDShared { public: explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx) - : XlaSplitNDShared(ctx) {} + : XlaSplitNDShared(ctx) {} protected: void ComputeInternal( bool resource, OpKernelContext* ctx, const std::function& assign_or_copy_value_fn, const Tensor* input) { - const int rank = input->shape().dims(); const auto& input_shape = input->shape().dim_sizes(); - if (!ValidateShapesForSlice(ctx, resource, input, num_splits_, paddings_)) { - return; - } - - TensorShape output_slice_shape; - for (int i = 0; i < rank; ++i) { - output_slice_shape.AddDim((input_shape[i] + paddings_[i]) / - ((num_slices_ == 1) ? 1 : num_splits_[i])); - } - if (num_slices_ == 1 && !has_paddings_) { - // Handle simple case first - OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(*input)); - } else { - const Device& device = ctx->eigen_device(); - std::vector output_slices(num_slices_); - for (int i = 0; i < num_slices_; i++) { - OP_REQUIRES_OK(ctx, - ctx->allocate_output( - /*index=*/i, output_slice_shape, &output_slices[i])); - } - - if (rank == 1) { - SliceAndMaybePad<1>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 2) { - SliceAndMaybePad<2>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 3) { - SliceAndMaybePad<3>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 4) { - SliceAndMaybePad<4>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 5) { - SliceAndMaybePad<5>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 6) { - SliceAndMaybePad<6>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 7) { - SliceAndMaybePad<7>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 8) { - SliceAndMaybePad<8>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } - return; - } - } - - private: - void TF_ATTRIBUTE_NOINLINE SetToConstant(Tensor* output_slice, - const Device& device) { - auto output_flat = output_slice->flat(); - output_flat.device(device) = output_flat.constant(T()); - } - - template - void TF_ATTRIBUTE_NOINLINE AssignFromInput( - Tensor* output_slice, const Device& device, const Tensor* input, - const Eigen::DSizes& slice_indices, - const Eigen::DSizes& output_slice_shape_dsizes) { - output_slice->tensor().device(device) = - input->tensor().slice(slice_indices, - output_slice_shape_dsizes); - } + absl::string_view input_name = resource ? kResourceName : kTensorName; + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + return ctx->allocate_output( + /*index=*/i, output_slice_shape, tensor); + }; - template - void TF_ATTRIBUTE_NOINLINE SliceAndMaybePad( - OpKernelContext* ctx, const Device& device, const Tensor* input, - const absl::Span input_shape, - const TensorShape& output_slice_shape, - const std::vector& output_slices) { - const auto& input_tensor = input->tensor(); - // Slice shape with optional padding. - for (int i = 0; i < num_slices_; ++i) { - Tensor* output_slice = output_slices[i]; - SliceAndMaybePadState r(num_splits_, input_shape, - output_slice_shape, i); - if (r.num_complete_pad_dims_ == Rank || - (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) { - // Need to init padding - SetToConstant(output_slice, device); - } - if (r.num_complete_pad_dims_ == Rank) { - // Done - } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { - output_slice->tensor() - .slice(Eigen::DSizes(), - r.non_padded_slice_shape_dsizes_) - .device(device) = input_tensor.slice( - r.slice_indices_, r.non_padded_slice_shape_dsizes_); - } else { - AssignFromInput(output_slice, device, input, r.slice_indices_, - r.output_slice_shape_dsizes_); - } - } + const Device& device = ctx->eigen_device(); + auto status = this->splitter_->Split( + input, input_name, assign_or_copy_value_fn, allocate_output_fn, device); + OP_REQUIRES_OK(ctx, status); } }; @@ -605,7 +219,7 @@ class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp { explicit TF_ATTRIBUTE_NOINLINE ReadVariableXlaSplitNDOp( OpKernelConstruction* ctx) : XlaSplitNDBaseOp(ctx) { - XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_); + XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_); } void Compute(OpKernelContext* ctx) override { diff --git a/tensorflow/core/tpu/kernels/sharding_utils.cc b/tensorflow/core/tpu/kernels/sharding_utils.cc new file mode 100644 index 00000000000000..0f4b9620b347f3 --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils.cc @@ -0,0 +1,237 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/kernels/sharding_utils.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace sharding_internal { +absl::Status ValidateShapesForSlice(absl::string_view input_name, + const Tensor* input, + const std::vector& num_splits, + const std::vector& paddings) { + const auto& ishape = input->shape(); + + Status s; + + const int rank = ishape.dims(); + const auto& input_shape = ishape.dim_sizes(); + if (rank <= 0 || rank > 8) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " must have rank in range (0, 8], but got ", rank, ".")); + } else if (rank != num_splits.size()) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " rank must be the same as 'num_splits' length ", + num_splits.size(), ", but got rank ", rank, ".")); + } else { + for (int dim = 0; dim < rank; ++dim) { + const auto input_shape_dim = input_shape[dim]; + const auto paddings_dim = paddings[dim]; + const auto num_splits_dim = num_splits[dim]; + if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " shape dimension ", dim, " (", input_shape_dim, + ") with padding ", paddings_dim, + " must be evenly divisible by 'num_splits' ", num_splits_dim, ".")); + break; + } + } + } + return s; +} + +} // namespace sharding_internal + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[0] = index * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[1] = (index % num_partitions[1]) * slice_shape[1]; + subscript[0] = (index / num_partitions[1]) * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[2] = (index % num_partitions[2]) * slice_shape[2]; + subscript[1] = + ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1]; + subscript[0] = + (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[3] = (index % num_partitions[3]) * slice_shape[3]; + subscript[2] = + ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2]; + subscript[1] = + ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[4] = (index % num_partitions[4]) * slice_shape[4]; + subscript[3] = + ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3]; + subscript[2] = + ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = (index / (num_partitions[4] * num_partitions[3] * + num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[5] = (index % num_partitions[5]) * slice_shape[5]; + subscript[4] = + ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4]; + subscript[3] = + ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) * + slice_shape[3]; + subscript[2] = + ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = ((index / (num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] * + num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[6] = (index % num_partitions[6]) * slice_shape[6]; + subscript[5] = + ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5]; + subscript[4] = + ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) * + slice_shape[4]; + subscript[3] = + ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) % + num_partitions[3]) * + slice_shape[3]; + subscript[2] = ((index / (num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[7] = (index % num_partitions[7]) * slice_shape[7]; + subscript[6] = + ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6]; + subscript[5] = + ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) * + slice_shape[5]; + subscript[4] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) % + num_partitions[4]) * + slice_shape[4]; + subscript[3] = ((index / (num_partitions[7] * num_partitions[6] * + num_partitions[5] * num_partitions[4])) % + num_partitions[3]) * + slice_shape[3]; + subscript[2] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3] * num_partitions[2] * + num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sharding_utils.h b/tensorflow/core/tpu/kernels/sharding_utils.h new file mode 100644 index 00000000000000..a7388dc269ec06 --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils.h @@ -0,0 +1,308 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace sharding_internal { +absl::Status ValidateShapesForSlice(absl::string_view input_name, + const Tensor* input, + const std::vector& num_splits, + const std::vector& paddings); +template +Eigen::DSizes TF_ATTRIBUTE_NOINLINE +ShapeAsEigenDSizes(const TensorShape& shape); +template +Eigen::DSizes ShapeAsEigenDSizes( + const TensorShape& shape) { + return shape.AsEigenDSizes(); +} + +} // namespace sharding_internal + +// Converts flatten index to start indices (subscript scaled with slice shape) +// for determining where to start a slice in the input tensor. +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); + +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, + const int index) { + return Eigen::DSizes(); +} + +// Shared base class to save code space +template +class XlaNDSplitter { + public: + static absl::StatusOr> Create( + const std::vector& num_splits, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_splits.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int splits_cnt = 1; + for (auto split : num_splits) { + splits_cnt *= split; + } + + if (num_slices != splits_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", splits_cnt, " but got ", num_slices)); + } + + return XlaNDSplitter(num_splits, num_slices, paddings, + has_paddings); + } + + // Split the given input. + // + // The splitted outputs are stored into tensors allocated by + // `allocate_output_fn`. In the simple case of pass through (no split and no + // padding), the output is stored through the fast path by + // `assign_or_copy_value_fn`. + absl::Status Split( + const Tensor* input, absl::string_view input_name, + const std::function& assign_or_copy_value_fn, + const std::function& allocate_output_fn, + const Device& device) { + if (num_splits_.size() != paddings_.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits_.size(), + " mismatch with paddings size ", paddings_.size(), ".")); + } + + const int rank = input->shape().dims(); + const auto& input_shape = input->shape().dim_sizes(); + + TF_RETURN_IF_ERROR(sharding_internal::ValidateShapesForSlice( + input_name, input, num_splits_, paddings_)); + + TensorShape output_slice_shape; + for (int i = 0; i < rank; ++i) { + output_slice_shape.AddDim((input_shape[i] + paddings_[i]) / + ((num_slices_ == 1) ? 1 : num_splits_[i])); + } + if (num_slices_ == 1 && !has_paddings_) { + // Handle simple case first + TF_RETURN_IF_ERROR(assign_or_copy_value_fn(*input)); + } else { + std::vector output_slices(num_slices_); + for (int i = 0; i < num_slices_; i++) { + TF_RETURN_IF_ERROR(allocate_output_fn( + /*index=*/i, output_slice_shape, &output_slices[i])); + } + + if (rank == 1) { + SliceAndMaybePad<1>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 2) { + SliceAndMaybePad<2>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 3) { + SliceAndMaybePad<3>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 4) { + SliceAndMaybePad<4>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 5) { + SliceAndMaybePad<5>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 6) { + SliceAndMaybePad<6>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 7) { + SliceAndMaybePad<7>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 8) { + SliceAndMaybePad<8>(device, input, input_shape, output_slice_shape, + output_slices); + } + } + return absl::OkStatus(); + } + + private: + template + class SliceAndMaybePadState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE SliceAndMaybePadState( + absl::Span num_splits, + const absl::Span input_shape, + const TensorShape& output_slice_shape, int slice_index) { + output_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(output_slice_shape); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + slice_indices_ = GetSliceIndices( + num_splits, output_slice_shape_dsizes_, slice_index); + + // Calculate paddings necessary for slice instead of padding input and + // slicing subsequently to reduce temporary memory allocation. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = input_shape[dim]; + const int64_t out_dim = output_slice_shape_dsizes_[dim]; + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + slice_paddings_[dim] = {0, out_dim}; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + out_dim > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + slice_paddings_[dim] = {0, out_dim - non_padded_dim}; + num_partial_pad_dims_++; + } else { + non_padded_dim = out_dim; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(non_padded_slice_shape_); + } + }; + + std::vector num_splits_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit XlaNDSplitter(const std::vector& num_splits, int num_slices, + const std::vector& paddings, + bool has_paddings) + : num_splits_(num_splits), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + void TF_ATTRIBUTE_NOINLINE SetToConstant(Tensor* output_slice, + const Device& device) { + auto output_flat = output_slice->flat(); + output_flat.device(device) = output_flat.constant(T()); + } + + template + void TF_ATTRIBUTE_NOINLINE AssignFromInput( + Tensor* output_slice, const Device& device, const Tensor* input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& output_slice_shape_dsizes) { + output_slice->tensor().device(device) = + input->tensor().slice(slice_indices, + output_slice_shape_dsizes); + } + + template + void TF_ATTRIBUTE_NOINLINE + SliceAndMaybePad(const Device& device, const Tensor* input, + const absl::Span input_shape, + const TensorShape& output_slice_shape, + const std::vector& output_slices) { + const auto& input_tensor = input->tensor(); + // Slice shape with optional padding. + for (int i = 0; i < num_slices_; ++i) { + Tensor* output_slice = output_slices[i]; + SliceAndMaybePadState r(num_splits_, input_shape, + output_slice_shape, i); + if (r.num_complete_pad_dims_ == Rank || + (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) { + // Need to init padding + SetToConstant(output_slice, device); + } + if (r.num_complete_pad_dims_ == Rank) { + // Done + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output_slice->tensor() + .slice(Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_) + .device(device) = input_tensor.slice( + r.slice_indices_, r.non_padded_slice_shape_dsizes_); + } else { + AssignFromInput(output_slice, device, input, r.slice_indices_, + r.output_slice_shape_dsizes_); + } + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc new file mode 100644 index 00000000000000..df4eacec84455a --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/tpu/kernels/sharding_utils.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace { +Eigen::ThreadPoolDevice CreateThreadPoolDevice() { + constexpr int kMaxParallelism = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); + + Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), + kMaxParallelism); + return device; +} + +TEST(XlaNDSplitterTest, NoSplits) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2, 2}); + const std::vector num_splits = {1, 1, 1}; + const std::vector paddings(num_splits.size(), 0); + const int num_outputs = 1; + auto input_tensor = + test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/false))); + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDSplitterTest, NoSplitsWithPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 1, 1}); + const std::vector num_splits = {1, 1, 1}; + const std::vector paddings = {0, 1, 1}; + const int num_outputs = 1; + auto input_tensor = test::AsTensor({0, 1}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + std::vector expected_values(3 * 3 * 3); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 0, 0, 0, 1, 0, 0, 0}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitNoPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({4, 4}); + const std::vector num_splits = {2, 2}; + const std::vector paddings(num_splits.size(), 0); + const int num_outputs = 4; + auto input_tensor = test::AsTensor( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 4, 5}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({2, 3, 6, 7}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({8, 9, 12, 13}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({10, 11, 14, 15}, TensorShape({2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitPartialPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({3, 3}); + const std::vector num_splits = {2, 2}; + const std::vector paddings = {1, 1}; + const int num_outputs = 4; + auto input_tensor = + test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7, 8}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 3, 4}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({2, 0, 5, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({6, 7, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({8, 0, 0, 0}, TensorShape({2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitCompletePadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 1}); + const std::vector num_splits = {2, 2}; + const std::vector paddings = {2, 3}; + const int num_outputs = 4; + auto input_tensor = test::AsTensor({0, 1}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 0, 1, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); +} + +} // namespace +} // namespace tensorflow From 6acc044ef96a44d3adc7fdb4a4a93cfb8af7389f Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Mon, 4 Dec 2023 15:38:27 -0800 Subject: [PATCH 365/381] [stream_executor] Use GetSlice to create sub-buffer PiperOrigin-RevId: 587862987 --- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 7 ++----- third_party/xla/xla/service/gpu/buffer_allocations.cc | 4 +--- third_party/xla/xla/stream_executor/device_memory.h | 3 ++- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 7070a25f5c67f9..c942e3ea17ab50 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -308,9 +308,7 @@ class AsyncHostToDeviceTransferManager CHECK_LE(offset, buffer_memory.size()); CHECK_LE(transfer_size, buffer_memory.size() - offset); if (transfer_size < buffer_memory.size()) { - sub_buffer = se::DeviceMemoryBase( - reinterpret_cast(buffer_memory.opaque()) + offset, - transfer_size); + sub_buffer = buffer_memory.GetByteSlice(offset, transfer_size); } else { sub_buffer = buffer_memory; } @@ -482,8 +480,7 @@ PjRtFuture StreamExecutorGpuClient::CopyRawSubBufferToHost( std::unique_ptr sub_buffer; if (transfer_size < device_memory.size()) { sub_buffer = std::make_unique( - reinterpret_cast(device_memory.opaque()) + offset, - transfer_size); + device_memory.GetByteSlice(offset, transfer_size)); } else { sub_buffer = std::make_unique(device_memory); } diff --git a/third_party/xla/xla/service/gpu/buffer_allocations.cc b/third_party/xla/xla/service/gpu/buffer_allocations.cc index fc55f0951436df..6d7aff6f74dbaa 100644 --- a/third_party/xla/xla/service/gpu/buffer_allocations.cc +++ b/third_party/xla/xla/service/gpu/buffer_allocations.cc @@ -85,9 +85,7 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); CHECK_LE(buffer_slice.offset(), base.size()); CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); - return se::DeviceMemoryBase( - static_cast(base.opaque()) + buffer_slice.offset(), - buffer_slice.size()); + return base.GetByteSlice(buffer_slice.offset(), buffer_slice.size()); } Status BufferAllocations::AddExternalAllocation( diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index e4f3643402f037..f5548dd9ca1fd9 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -92,7 +92,8 @@ class DeviceMemoryBase { // Creates a memory region (slice) inside another allocated memory region. // Offset and size are in bytes. - DeviceMemoryBase GetByteSlice(uint64_t offset_bytes, uint64_t size_bytes) { + DeviceMemoryBase GetByteSlice(uint64_t offset_bytes, + uint64_t size_bytes) const { DCHECK(offset_bytes + size_bytes <= size_) << "requested slice allocation (offset + size) is greater " << "than parent allocation size: (" << offset_bytes << " + " From b92cbae291265b2084eaa5b3925e0a8579cbaff6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 4 Dec 2023 15:39:27 -0800 Subject: [PATCH 366/381] [stream_executor] Roll back KernelLaunchContext for args packing Block and thread dimensions already available in device kernels, so there should no reason to add extra kernel parameters for them. For CUTLASS gemm args packing we know thread dimensions statically from an operation template. PiperOrigin-RevId: 587863262 --- .../gpu/kernels/cutlass_gemm_kernel.cu.h | 17 ++++++++------- .../cuda/cuda_command_buffer_test.cc | 2 +- .../xla/stream_executor/cuda/cuda_executor.cc | 3 +-- third_party/xla/xla/stream_executor/kernel.cc | 8 ------- third_party/xla/xla/stream_executor/kernel.h | 21 +------------------ .../xla/xla/stream_executor/kernel_spec.h | 4 ++-- 6 files changed, 14 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index dd5cdeecf90d28..33186c25caf782 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -151,10 +151,8 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, static_assert(sizeof(Params) < 512, "Params struct size is unexpectedly large"); - using PackedArgs = StatusOr>; - - return [=](const se::KernelLaunchContext &ctx, - const se::KernelArgs &args) -> PackedArgs { + return [=](const se::Kernel &kernel, const se::KernelArgs &args) + -> StatusOr> { auto *mem_args = Cast(&args); cutlass::Status can_implement = Kernel::can_implement(problem_size); @@ -191,10 +189,13 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, lda, ldb, ldc, ldc // strides ); - // Query kernel API for SM occupancy for the launch dimensions. - TF_ASSIGN_OR_RETURN(int32_t sm_occupancy, - ctx.kernel()->GetMaxOccupiedBlocksPerCore( - ctx.threads(), args.number_of_shared_bytes())); + // We keep max_occupancy in a static variable as currently for all practical + // purposes all stream executors in the process have identical underlying + // devices, and there is no need to repeatedly query this property. + static int32_t shared_mem_bytes = sizeof(typename Kernel::SharedStorage); + static int32_t sm_occupancy = + kernel.GetMaxOccupiedBlocksPerCore(ThreadDim(), shared_mem_bytes) + .value_or(1); // Convert CUTLASS operation arguments to a device kernel parameters. Params params(arguments, device_sms, sm_occupancy); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 2c40db007be26b..1a8a31de2b1e1b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -114,7 +114,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { // Register a kernel with a custom arguments packing function that packs // device memory arguments into a struct with pointers. - MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelLaunchContext&, + MultiKernelLoaderSpec spec(/*arity=*/1, [&](const Kernel& kernel, const KernelArgs& args) { auto bufs = Cast(&args)->device_memory_args(); auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index e29fd1b1806483..97912df3bd6d66 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -469,8 +469,7 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, "Kernel is missing a custom arguments packing function for device " "memory arguments array"); - KernelLaunchContext ctx(&kernel, block_dims, thread_dims); - TF_ASSIGN_OR_RETURN(auto packed, pack(ctx, *device_mem)); + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); return launch(*packed); } diff --git a/third_party/xla/xla/stream_executor/kernel.cc b/third_party/xla/xla/stream_executor/kernel.cc index 4bf8f3987df9b6..078f7142c40000 100644 --- a/third_party/xla/xla/stream_executor/kernel.cc +++ b/third_party/xla/xla/stream_executor/kernel.cc @@ -47,14 +47,6 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { shared_memory_bytes_ = shared_memory_bytes; } -//===----------------------------------------------------------------------===// -// KernelLaunchContext -//===----------------------------------------------------------------------===// - -KernelLaunchContext::KernelLaunchContext(const Kernel *kernel, BlockDim blocks, - ThreadDim threads) - : kernel_(kernel), blocks_(blocks), threads_(threads) {} - //===----------------------------------------------------------------------===// // Kernel //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 1a3ad3b4775cb3..c0cd10454a3280 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -211,25 +211,6 @@ class KernelArgsPackedArrayBase : public KernelArgs { Kind kind() const final { return Kind::kPackedArray; } }; -//===----------------------------------------------------------------------===// -// KernelLaunchContext -//===----------------------------------------------------------------------===// - -// Properties of a kernel launch that might impact kernel arguments packing. -class KernelLaunchContext { - public: - KernelLaunchContext(const Kernel *kernel, BlockDim blocks, ThreadDim threads); - - const Kernel *kernel() const { return kernel_; } - BlockDim blocks() const { return blocks_; } - ThreadDim threads() const { return threads_; } - - private: - const Kernel *kernel_; - BlockDim blocks_; - ThreadDim threads_; -}; - //===----------------------------------------------------------------------===// // Kernel //===----------------------------------------------------------------------===// @@ -247,7 +228,7 @@ class Kernel { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelLaunchContext &ctx, const KernelArgs &args)>; + const Kernel &kernel, const KernelArgs &args)>; Kernel(Kernel &&from); diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index e49eae4ec322b4..742cfa476381be 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -61,8 +61,8 @@ limitations under the License. namespace stream_executor { +class Kernel; // defined in kernel.h class KernelArgs; // defined in kernel.h -class KernelLaunchContext; // defined in kernel.h class KernelArgsPackedArrayBase; // defined in kernel.h // Describes how to load a kernel on a target platform. @@ -263,7 +263,7 @@ class MultiKernelLoaderSpec { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelLaunchContext &ctx, const KernelArgs &args)>; + const Kernel &kernel, const KernelArgs &args)>; explicit MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing = nullptr); From 50c4460a9fef149d563de4476e28dc22aaa2229b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 4 Dec 2023 15:48:24 -0800 Subject: [PATCH 367/381] [xla:gpu] Add CUTLASS gemm benchmarks PiperOrigin-RevId: 587865806 --- third_party/xla/xla/service/gpu/kernels/BUILD | 19 +++++ .../cutlass_gemm_custom_kernel_benchmarks.cc | 85 +++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index f3ffcb803211f3..b98e7bb7aec733 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -168,6 +168,25 @@ cc_library( # "@local_tsl//tsl/lib/core:status_test_util", # "@local_tsl//tsl/platform:status", # "@local_tsl//tsl/platform:test", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) +# +# xla_test( +# name = "cutlass_gemm_custom_kernel_benchmarks", +# srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_benchmarks.cc"]), +# backends = ["gpu"], +# deps = [ +# ":cutlass_gemm_custom_kernel", +# "//xla:types", +# "//xla:xla_data_proto_cc", +# "//xla/stream_executor", +# "//xla/stream_executor:multi_platform_manager", +# "//xla/stream_executor:platform", +# "//xla/stream_executor/cuda:cuda_platform", +# "@local_tsl//tsl/lib/core:status_test_util", +# "@local_tsl//tsl/platform:status", +# "@local_tsl//tsl/platform:test", # "@local_tsl//tsl/platform:test_benchmark", # "@local_tsl//tsl/platform:test_main", # ], diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc new file mode 100644 index 00000000000000..027b8abaa9b25c --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -0,0 +1,85 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::gpu::kernel::gemm_universal { + +static uint32_t BitPattern(float value) { + uint32_t pattern; + std::memcpy(&pattern, &value, sizeof(float)); + return pattern; +} + +static void BM_RowMajorGemm(benchmark::State& state) { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + const se::DeviceDescription& device = executor->GetDeviceDescription(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + se::Kernel gemm(executor); + + // GEMM: 8192x4096 * 4096x16384 -> 8192x16384 + int32_t m = 8192; + int32_t n = 16384; + int32_t k = 4096; + + auto custom_kernel = + GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::BF16, m, n, k, + /*indices=*/{0, 1, 2}, /*slices=*/{}, device); + TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); + + // Prepare arguments: a=1.1, b=1.2, c=0.0 + se::DeviceMemory a = executor->AllocateArray(m * k, 0); + se::DeviceMemory b = executor->AllocateArray(k * n, 0); + se::DeviceMemory c = executor->AllocateArray(m * n, 0); + + stream.ThenMemset32(&a, BitPattern(1.1f), a.size()); + stream.ThenMemset32(&b, BitPattern(1.2f), b.size()); + stream.ThenMemZero(&c, c.size()); + + se::KernelArgsDeviceMemoryArray args( + std::vector({a, b, c}), + custom_kernel->shared_memory_bytes()); + + for (auto s : state) { + TF_CHECK_OK(executor->Launch(&stream, custom_kernel->thread_dims(), + custom_kernel->block_dims(), gemm, args)); + TF_CHECK_OK(stream.BlockHostUntilDone()); + } +} + +BENCHMARK(BM_RowMajorGemm); + +} // namespace xla::gpu::kernel::gemm_universal From 363c14b79a8db6cdde9b06abb0e42387a3ab3926 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 15:59:47 -0800 Subject: [PATCH 368/381] Some more cleanup: 1. Deduplicate the postprocessing code for dots and convs. 2. Combine the InferInputShardingForTopK function with GetInputSharding function, and get rid of an unused parameter in the later. PiperOrigin-RevId: 587868586 --- .../auto_sharding/auto_sharding.cc | 96 ++++++------------- .../auto_sharding/auto_sharding_util.cc | 16 +++- .../auto_sharding/auto_sharding_util.h | 1 - 3 files changed, 40 insertions(+), 73 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 0f0b92b07ca2d8..da425db30813b9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -143,13 +143,6 @@ std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id) { return strategy_group; } -// ShardingPropagation::GetShardingFromUser does not handle TopK custom -// calls. Mirroring that function's handling of kSort, we handle TopK below. -HloSharding InferInputShardingForTopK(const HloInstruction* ins, - const HloSharding& output_sharding) { - return output_sharding; -} - // Compute the resharding costs as well as input shardings (when missing) for // all operands of a given instruction, and an output sharding for that // instruction. @@ -173,13 +166,12 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( } } else { std::optional cur_input_sharding; + CHECK_EQ(input_shardings.size(), ins->operand_count()); if (input_shardings[k].has_value()) { - CHECK_EQ(input_shardings.size(), ins->operand_count()); cur_input_sharding = input_shardings[k]; } else { - cur_input_sharding = - GetInputSharding(ins, operand, k, output_sharding, call_graph, - cluster_env.NumDevices()); + cur_input_sharding = GetInputSharding( + ins, k, output_sharding, call_graph, cluster_env.NumDevices()); } bool is_sharding_default_replicated = false; if (!cur_input_sharding.has_value()) { @@ -187,8 +179,6 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( (ins->opcode() == HloOpcode::kScatter && k != 0)) { is_sharding_default_replicated = true; cur_input_sharding = HloSharding::Replicate(); - } else if (IsTopKCustomCall(ins)) { - cur_input_sharding = InferInputShardingForTopK(ins, output_sharding); } else if (ins->opcode() == HloOpcode::kCustomCall) { is_sharding_default_replicated = true; cur_input_sharding = HloSharding::Replicate(); @@ -2031,23 +2021,33 @@ Status SetHloShardingPostProcessing( // Here we insert some extra annotated identity instructions to help the // spmd partitioner generate correct code. - if (inst->opcode() == HloOpcode::kDot) { + if (inst->opcode() == HloOpcode::kDot || + inst->opcode() == HloOpcode::kConvolution) { const ShardingStrategy& stra = GetShardingStrategy(inst, strategy_map, cost_graph, s_val); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); const HloSharding& lhs_sharding = lhs->sharding(); const HloSharding& rhs_sharding = rhs->sharding(); - const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); - const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions(); - const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions(); + std::vector lhs_con_dims; + std::vector rhs_con_dims; + if (inst->opcode() == HloOpcode::kDot) { + const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); + lhs_con_dims.push_back(dot_dnums.lhs_contracting_dimensions()[0]); + rhs_con_dims.push_back(dot_dnums.rhs_contracting_dimensions()[0]); + } else { + const ConvolutionDimensionNumbers& conv_dnums = + inst->convolution_dimension_numbers(); + lhs_con_dims.push_back(conv_dnums.input_feature_dimension()); + rhs_con_dims.push_back(conv_dnums.kernel_input_feature_dimension()); + } - const auto& lhs_tensor_dim_to_mesh_dim = + const std::vector& lhs_tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper( lhs->shape(), lhs_sharding, /* consider_reverse_device_meshes */ true, /* crash_at_error */ crash_at_error); - const auto& rhs_tensor_dim_to_mesh_dim = + const std::vector& rhs_tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper( rhs->shape(), rhs_sharding, /* consider_reverse_device_meshes */ true, @@ -2058,10 +2058,17 @@ Status SetHloShardingPostProcessing( return absl::InvalidArgumentError( "Cannot generate tensor dim to mesh dim mapping"); } + if (absl::StrContains(stra.name, "allreduce") && - lhs_tensor_dim_to_mesh_dim[lhs_con_dims[0]] == -1 && - rhs_tensor_dim_to_mesh_dim[rhs_con_dims[0]] == -1) { - // Allow duplicatd dot computation in this case to reduce + std::any_of(lhs_con_dims.begin(), lhs_con_dims.end(), + [&lhs_tensor_dim_to_mesh_dim](int64_t dim) { + return lhs_tensor_dim_to_mesh_dim[dim] == -1; + }) && + std::any_of(rhs_con_dims.begin(), rhs_con_dims.end(), + [&rhs_tensor_dim_to_mesh_dim](int64_t dim) { + return rhs_tensor_dim_to_mesh_dim[dim] == -1; + })) { + // Allow duplicated dot computation in this case to reduce // communication } else { CHECK(stra.input_shardings.size() == 2) @@ -2077,51 +2084,6 @@ Status SetHloShardingPostProcessing( device_mesh, resharding_cache); } } - } else if (inst->opcode() == HloOpcode::kConvolution) { - const ShardingStrategy& stra = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - const HloSharding& lhs_sharding = lhs->sharding(); - const HloSharding& rhs_sharding = rhs->sharding(); - const ConvolutionDimensionNumbers& conv_dnums = - inst->convolution_dimension_numbers(); - const int lhs_in_channel_dim = conv_dnums.input_feature_dimension(); - const int rhs_in_channel_dim = - conv_dnums.kernel_input_feature_dimension(); - - const auto& lhs_tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper( - lhs->shape(), lhs_sharding, - /* consider_reverse_device_meshes */ true, - /* crash_at_error */ crash_at_error); - const auto& rhs_tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper( - rhs->shape(), rhs_sharding, - /* consider_reverse_device_meshes */ true, - /* crash_at_error */ crash_at_error); - - if (lhs_tensor_dim_to_mesh_dim.size() != lhs->shape().rank() || - rhs_tensor_dim_to_mesh_dim.size() != rhs->shape().rank()) { - return absl::InvalidArgumentError( - "Cannot generate tensor dim to mesh dim mapping"); - } - - if (absl::StrContains(stra.name, "allreduce") && - lhs_tensor_dim_to_mesh_dim[lhs_in_channel_dim] == -1 && - rhs_tensor_dim_to_mesh_dim[rhs_in_channel_dim] == -1) { - // Allow duplicatd conv computation in this case to reduce - // communication - } else { - if (stra.input_shardings[0].has_value()) { - FixMixedMeshShapeResharding(inst, 0, stra.input_shardings[0].value(), - device_mesh, resharding_cache); - } - if (stra.input_shardings[1].has_value()) { - FixMixedMeshShapeResharding(inst, 1, stra.input_shardings[1].value(), - device_mesh, resharding_cache); - } - } } else if (inst->opcode() == HloOpcode::kOutfeed) { // Outfeed operand shardings are handled in downstream passes and so we // ignore outfeed ops here. However, we need to ensure that outfeed ops diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index d5965d057576a0..0acb7a2c9868c8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -68,12 +68,11 @@ inline HloInstruction* PassThroughCustomCallMarkerUser( HloInstruction* raw_user, const HloInstruction* inst); std::optional GetInputSharding(const HloInstruction* ins, - const HloInstruction* operand, int64_t op_index, const HloSharding& output_sharding, const CallGraph& call_graph, int64_t num_devices) { - auto ins_clone = ins->Clone(); + std::unique_ptr ins_clone = ins->Clone(); ins_clone->set_sharding(output_sharding); std::vector> operands; @@ -95,9 +94,16 @@ std::optional GetInputSharding(const HloInstruction* ins, operands.push_back(std::move(operand_clone)); } - auto result = ShardingPropagation::GetShardingFromUser( - *ins_clone->operand(op_index), *ins_clone, 10, true, call_graph); - return result; + std::optional inferred_sharding = + ShardingPropagation::GetShardingFromUser( + *ins_clone->operand(op_index), *ins_clone, 10, true, call_graph); + + if (!inferred_sharding.has_value() && IsTopKCustomCall(ins)) { + // ShardingPropagation::GetShardingFromUser does not handle TopK custom + // calls. Mirroring that function's handling of kSort, we handle TopK below. + inferred_sharding = output_sharding; + } + return inferred_sharding; } // Return whether the instruction is an activation from another pipeline stage. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 290d92f45a86a0..b6aca1199eb1ee 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -357,7 +357,6 @@ inline std::vector Argsort(const std::vector& scores) { // Given the sharding for an instruction, invoke the sharding propagation pass // to infer appropriate shardings for its operands. std::optional GetInputSharding(const HloInstruction* ins, - const HloInstruction* operand, int64_t op_index, const HloSharding& output_sharding, const xla::CallGraph& call_graph, From c909884fcb756f0ea1d2e392d1829c11d9312848 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Mon, 4 Dec 2023 16:02:17 -0800 Subject: [PATCH 369/381] Refactor the concatenate functionality out of XlaConcatNDBaseOp into a standalone utility PiperOrigin-RevId: 587869254 --- tensorflow/core/tpu/kernels/BUILD | 2 + .../core/tpu/kernels/sharding_util_ops.cc | 123 ++---------- tensorflow/core/tpu/kernels/sharding_utils.h | 148 +++++++++++++++ .../core/tpu/kernels/sharding_utils_test.cc | 175 ++++++++++++++++++ 4 files changed, 339 insertions(+), 109 deletions(-) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 4f3c8b82373069..48b4711d37cf75 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -1379,6 +1379,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1394,6 +1395,7 @@ tf_cc_test( "//tensorflow/core/platform:status", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops.cc b/tensorflow/core/tpu/kernels/sharding_util_ops.cc index beb70bfd8b25b6..fe726011527165 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops.cc @@ -285,12 +285,18 @@ TF_CALL_uint4(REGISTER_READ_VARIABLE_XLA_SPLIT_ND); #undef REGISTER_READ_VARIABLE_XLA_SPLIT_ND // Shared base class to save code space +template class XlaConcatNDShared : public OpKernel { public: explicit TF_ATTRIBUTE_NOINLINE XlaConcatNDShared(OpKernelConstruction* ctx) : OpKernel(ctx), num_slices_(1), has_paddings_(false) { GetAndValidateAttributes(/*split=*/false, ctx, num_concats_, num_slices_, paddings_, has_paddings_); + + auto xla_nd_concatenator = XlaNDConcatenator::Create( + num_concats_, num_slices_, paddings_, has_paddings_); + OP_REQUIRES_OK(ctx, xla_nd_concatenator.status()); + concatenator_ = *std::move(xla_nd_concatenator); } protected: @@ -328,132 +334,31 @@ class XlaConcatNDShared : public OpKernel { return absl::OkStatus(); } - void ApplyAssignOrCopyShared( - OpKernelContext* ctx, - const std::function& assign_or_copy_value_fn, - const Tensor& input) { - OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(input)); - } - - template - class MaybeUnpadAndAssignState { - public: - int num_complete_pad_dims_; - int num_partial_pad_dims_; - TensorShape non_padded_slice_shape_; - Eigen::DSizes slice_shape_dsizes_; - Eigen::array, Rank> slice_paddings_; - Eigen::DSizes slice_indices_; - Eigen::DSizes output_slice_shape_dsizes_; - Eigen::DSizes non_padded_slice_shape_dsizes_; - - TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssignState( - absl::Span num_concats, const Tensor& input0, - Tensor* output, int slice_index) { - slice_shape_dsizes_ = input0.shape().AsEigenDSizes(); - slice_indices_ = - GetSliceIndices(num_concats, slice_shape_dsizes_, slice_index); - num_complete_pad_dims_ = 0; - num_partial_pad_dims_ = 0; - // Calculate paddings necessary to strip from slice. - for (int dim = 0; dim < Rank; ++dim) { - const int64_t dim_size = output->shape().dim_size(dim); - int64_t non_padded_dim = 0; - if (slice_indices_[dim] >= dim_size) { - // Complete padding. - slice_indices_[dim] = dim_size; - non_padded_dim = 0; - num_complete_pad_dims_++; - } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) { - // Partial padding. - non_padded_dim = dim_size - slice_indices_[dim]; - num_partial_pad_dims_++; - } else { - non_padded_dim = slice_shape_dsizes_[dim]; - } - non_padded_slice_shape_.AddDim(non_padded_dim); - } - non_padded_slice_shape_dsizes_ = - non_padded_slice_shape_.AsEigenDSizes(); - } - }; std::vector num_concats_; int num_slices_; std::vector paddings_; bool has_paddings_; + std::optional> concatenator_; }; template -class XlaConcatNDBaseOp : public XlaConcatNDShared { +class XlaConcatNDBaseOp : public XlaConcatNDShared { public: explicit TF_ATTRIBUTE_NOINLINE XlaConcatNDBaseOp(OpKernelConstruction* ctx) - : XlaConcatNDShared(ctx) {} + : XlaConcatNDShared(ctx) {} protected: void ComputeInternal( bool resource, OpKernelContext* ctx, const OpInputList& inputs, const std::function& assign_or_copy_value_fn, const std::function()>& get_output_fn) { - const int rank = inputs[0].shape().dims(); - - OP_REQUIRES(ctx, rank > 0 && rank <= 8, - absl::InvalidArgumentError(absl::StrCat( - "'inputs' tensors must have rank in range (0, 8], but got ", - rank, "."))); - - if (num_slices_ == 1 && !has_paddings_) { - // Simple case - ApplyAssignOrCopyShared(ctx, assign_or_copy_value_fn, inputs[0]); - return; - } - const Device& device = ctx->eigen_device(); - auto status_or_output = get_output_fn(); - OP_REQUIRES_OK(ctx, status_or_output.status()); - Tensor* output = std::move(status_or_output).value(); - - if (rank == 1) { - MaybeUnpadAndAssign<1>(ctx, device, inputs, output); - } else if (rank == 2) { - MaybeUnpadAndAssign<2>(ctx, device, inputs, output); - } else if (rank == 3) { - MaybeUnpadAndAssign<3>(ctx, device, inputs, output); - } else if (rank == 4) { - MaybeUnpadAndAssign<4>(ctx, device, inputs, output); - } else if (rank == 5) { - MaybeUnpadAndAssign<5>(ctx, device, inputs, output); - } else if (rank == 6) { - MaybeUnpadAndAssign<6>(ctx, device, inputs, output); - } else if (rank == 7) { - MaybeUnpadAndAssign<7>(ctx, device, inputs, output); - } else if (rank == 8) { - MaybeUnpadAndAssign<8>(ctx, device, inputs, output); - } - } - - private: - template - void TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssign(OpKernelContext* ctx, - const Device& device, - const OpInputList& inputs, - Tensor* output) { - for (int i = 0; i < num_slices_; ++i) { - MaybeUnpadAndAssignState r(num_concats_, inputs[0], output, i); - if (r.num_complete_pad_dims_ == Rank) { - continue; - } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { - output->tensor() - .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_) - .device(device) = inputs[i].tensor().slice( - Eigen::DSizes(), - r.non_padded_slice_shape_dsizes_); - } else { - output->tensor() - .slice(r.slice_indices_, r.slice_shape_dsizes_) - .device(device) = inputs[i].tensor(); - } - } + std::vector input_tensors(inputs.begin(), inputs.end()); + auto status = this->concatenator_->ComputeInternal( + absl::MakeSpan(input_tensors), assign_or_copy_value_fn, get_output_fn, + device); + OP_REQUIRES_OK(ctx, status); } }; diff --git a/tensorflow/core/tpu/kernels/sharding_utils.h b/tensorflow/core/tpu/kernels/sharding_utils.h index a7388dc269ec06..429e327462ad74 100644 --- a/tensorflow/core/tpu/kernels/sharding_utils.h +++ b/tensorflow/core/tpu/kernels/sharding_utils.h @@ -26,11 +26,13 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/status.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace sharding_internal { @@ -303,6 +305,152 @@ class XlaNDSplitter { } }; +// Shared base class to save code space +template +class XlaNDConcatenator { + public: + static absl::StatusOr> Create( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_concats.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_concats size ", num_concats.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int concats_cnt = 1; + for (auto concat : num_concats) { + concats_cnt *= concat; + } + + if (num_slices != concats_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", concats_cnt, " but got ", num_slices)); + } + + return XlaNDConcatenator(num_concats, num_slices, paddings, + has_paddings); + } + absl::Status ComputeInternal( + absl::Span inputs, + const std::function& assign_or_copy_value_fn, + const std::function()>& get_output_fn, + const Device& device) { + const int rank = inputs[0].shape().dims(); + + if (rank < 1 || rank > 8) { + return absl::InvalidArgumentError(absl::StrCat( + "'inputs' tensors must have rank in range (0, 8], but got ", rank, + ".")); + } + + if (num_slices_ == 1 && !has_paddings_) { + // Simple case + return assign_or_copy_value_fn(inputs[0]); + } + + TF_ASSIGN_OR_RETURN(Tensor * output, get_output_fn()); + + if (rank == 1) { + MaybeUnpadAndAssign<1>(device, inputs, output); + } else if (rank == 2) { + MaybeUnpadAndAssign<2>(device, inputs, output); + } else if (rank == 3) { + MaybeUnpadAndAssign<3>(device, inputs, output); + } else if (rank == 4) { + MaybeUnpadAndAssign<4>(device, inputs, output); + } else if (rank == 5) { + MaybeUnpadAndAssign<5>(device, inputs, output); + } else if (rank == 6) { + MaybeUnpadAndAssign<6>(device, inputs, output); + } else if (rank == 7) { + MaybeUnpadAndAssign<7>(device, inputs, output); + } else if (rank == 8) { + MaybeUnpadAndAssign<8>(device, inputs, output); + } + return absl::OkStatus(); + } + + private: + template + class MaybeUnpadAndAssignState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::DSizes slice_shape_dsizes_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssignState( + absl::Span num_concats, const Tensor& input0, + Tensor* output, int slice_index) { + slice_shape_dsizes_ = input0.shape().AsEigenDSizes(); + slice_indices_ = + GetSliceIndices(num_concats, slice_shape_dsizes_, slice_index); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + // Calculate paddings necessary to strip from slice. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = output->shape().dim_size(dim); + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + num_partial_pad_dims_++; + } else { + non_padded_dim = slice_shape_dsizes_[dim]; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + non_padded_slice_shape_.AsEigenDSizes(); + } + }; + + std::vector num_concats_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit TF_ATTRIBUTE_NOINLINE XlaNDConcatenator( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) + : num_concats_(num_concats), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + template + void TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssign(const Device& device, + absl::Span inputs, + Tensor* output) { + for (int i = 0; i < num_slices_; ++i) { + MaybeUnpadAndAssignState r(num_concats_, inputs[0], output, i); + if (r.num_complete_pad_dims_ == Rank) { + continue; + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output->tensor() + .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_) + .device(device) = inputs[i].tensor().slice( + Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_); + } else { + output->tensor() + .slice(r.slice_indices_, r.slice_shape_dsizes_) + .device(device) = inputs[i].tensor(); + } + } + } +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc index df4eacec84455a..cd583df8a57bef 100644 --- a/tensorflow/core/tpu/kernels/sharding_utils_test.cc +++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -277,5 +278,179 @@ TEST(XlaNDSplitterTest, SplitCompletePadding) { test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); } +TEST(XlaNDConcatenatorTest, NoConcats) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2, 2}); + const TensorShape output_shape({2, 2, 2}); + const std::vector num_concats = {1, 1, 1}; + const std::vector paddings(num_concats.size(), 0); + int num_slices = 1; + auto tensor0 = test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDConcatenatorTest, ConcatNoPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({4, 4}); + const std::vector num_concats = {2, 2}; + const std::vector paddings(num_concats.size(), 0); + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 4, 5, 2, 3, 6, 7, 8, 9, + 12, 13, 10, 11, 14, 15}, + TensorShape({4, 4}))); +} + +TEST(XlaNDConcatenatorTest, ConcatPartialPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({3, 3}); + const std::vector num_concats = {2, 2}; + const std::vector paddings = {1, 1}; + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 4, 2, 3, 6, 8, 9, 12}, + TensorShape({3, 3}))); +} + +TEST(XlaNDConcatenatorTest, ConcatCompletePadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({2, 2}); + const std::vector num_concats = {2, 2}; + const std::vector paddings = {2, 2}; + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 2, 3}, TensorShape({2, 2}))); +} + } // namespace } // namespace tensorflow From dbf1465fd2bfc2c1c49be3a30128b4389b05e86f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 16:21:19 -0800 Subject: [PATCH 370/381] When printing a sharding group, also print the index of the corresponding HloValue in the producer instruction, if the producer instruction is a tuple. PiperOrigin-RevId: 587874472 --- .../hlo/experimental/auto_sharding/auto_sharding_strategy.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 24fa988df57299..a37ceca17e47a5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -160,6 +160,10 @@ struct StrategyGroup { absl::StrAppend(&str, indent, "node_idx: ", node_idx, "\n"); absl::StrAppend(&str, indent, "instruction id: ", instruction_id, "\n"); absl::StrAppend(&str, indent, "is_tuple: ", is_tuple, "\n"); + if (tuple_element_idx.has_value()) { + absl::StrAppend(&str, indent, + "index in producer inst.: ", *tuple_element_idx, "\n"); + } if (following != nullptr) { absl::StrAppend(&str, indent, "following instruction: ", following->instruction_id, From 1c2e66aca8e6968716c85ee4a3f99933e216262d Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 4 Dec 2023 16:26:49 -0800 Subject: [PATCH 371/381] Integrate StableHLO at openxla/stablehlo@57e5a4a5 PiperOrigin-RevId: 587875856 --- third_party/stablehlo/temporary.patch | 59 +++---------------- third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 59 +++---------------- .../xla/third_party/stablehlo/workspace.bzl | 4 +- .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 35 +++++++---- .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td | 36 +++++++++++ .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td | 5 +- .../mhlo/transforms/map_stablehlo_to_hlo_op.h | 1 + .../mhlo/hlo-legalize-to-stablehlo.mlir | 13 ++++ .../mhlo/stablehlo-legalize-to-hlo.mlir | 13 ++++ .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 5 ++ 11 files changed, 115 insertions(+), 119 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index b9fde3ceefd71a..2aa71c79d13891 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -2291,29 +2291,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -340,6 +340,19 @@ - %1 = stablehlo.constant dense<2> : tensor - %2 = stablehlo.multiply %0, %1 : tensor - func.return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_or -+func.func @eval_or() -> tensor { -+ // CHECK-NOT: stablehlo.or -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense : tensor -+ %1 = stablehlo.constant dense : tensor -+ %2 = stablehlo.or %0, %1 : tensor -+ func.return %2 : tensor - } - - // ----- diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h @@ -2592,32 +2569,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl namespace { // DenseElementsAttr can be constructed from ArrayRef but not from -@@ -304,6 +493,20 @@ - } - }; - -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ - struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, -@@ -407,245 +610,6 @@ - // In a nutshell, they use the upstream type inference infrastructure and a +@@ -422,245 +611,6 @@ // StableHLO-specific extension to refine return types based on potentially // refined operands. -- + -// Refines the values using the given types. -// Tricky implementation details: -// 1) Need to support partial shape refinements, e.g. if just a single @@ -2856,10 +2811,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl - return rewriter.notifyMatchFailure(op, "expected constant output shape"); - return refineReturnShape(rewriter, op, shape); -} - +- struct RefineAllGatherOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; -@@ -1105,39 +1069,8 @@ + LogicalResult matchAndRewrite(AllGatherOp op, +@@ -1119,39 +1069,8 @@ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; void runOnOperation() override { @@ -2901,7 +2857,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl // The algorithm behind this pass consists of a single traversal of the // function. This is sufficient because we only support one function per -@@ -1153,43 +1086,7 @@ +@@ -1167,44 +1086,7 @@ config.strictMode = GreedyRewriteStrictness::AnyOp; RewritePatternSet patterns(&getContext()); @@ -2917,6 +2873,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); +- patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); @@ -2946,7 +2903,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl if (failed( applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { return signalPassFailure(); -@@ -1198,5 +1095,86 @@ +@@ -1213,5 +1095,86 @@ }; } // namespace diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index a60b5db8e74b5d..c02e3754069f67 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "83f095e7217c897f1eccac5652600ceb944cb0e0" - STABLEHLO_SHA256 = "00e442f7e9c8a52a1ac774ce997f8b5a99d12450c4dfe1594df816dcbad5126f" + STABLEHLO_COMMIT = "57e5a4a528a7e999f53c3719c1e68587efb9f0e6" + STABLEHLO_SHA256 = "50cd2240766e11f042c508c1155bc6da8afc4fb56070f04d7c11b4a9e344e94e" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index b9fde3ceefd71a..2aa71c79d13891 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -2291,29 +2291,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -340,6 +340,19 @@ - %1 = stablehlo.constant dense<2> : tensor - %2 = stablehlo.multiply %0, %1 : tensor - func.return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_or -+func.func @eval_or() -> tensor { -+ // CHECK-NOT: stablehlo.or -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense : tensor -+ %1 = stablehlo.constant dense : tensor -+ %2 = stablehlo.or %0, %1 : tensor -+ func.return %2 : tensor - } - - // ----- diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h @@ -2592,32 +2569,10 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl namespace { // DenseElementsAttr can be constructed from ArrayRef but not from -@@ -304,6 +493,20 @@ - } - }; - -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ - struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, -@@ -407,245 +610,6 @@ - // In a nutshell, they use the upstream type inference infrastructure and a +@@ -422,245 +611,6 @@ // StableHLO-specific extension to refine return types based on potentially // refined operands. -- + -// Refines the values using the given types. -// Tricky implementation details: -// 1) Need to support partial shape refinements, e.g. if just a single @@ -2856,10 +2811,11 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl - return rewriter.notifyMatchFailure(op, "expected constant output shape"); - return refineReturnShape(rewriter, op, shape); -} - +- struct RefineAllGatherOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; -@@ -1105,39 +1069,8 @@ + LogicalResult matchAndRewrite(AllGatherOp op, +@@ -1119,39 +1069,8 @@ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; void runOnOperation() override { @@ -2901,7 +2857,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl // The algorithm behind this pass consists of a single traversal of the // function. This is sufficient because we only support one function per -@@ -1153,43 +1086,7 @@ +@@ -1167,44 +1086,7 @@ config.strictMode = GreedyRewriteStrictness::AnyOp; RewritePatternSet patterns(&getContext()); @@ -2917,6 +2873,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); +- patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); @@ -2946,7 +2903,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl if (failed( applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { return signalPassFailure(); -@@ -1198,5 +1095,86 @@ +@@ -1213,5 +1095,86 @@ }; } // namespace diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index a60b5db8e74b5d..c02e3754069f67 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "83f095e7217c897f1eccac5652600ceb944cb0e0" - STABLEHLO_SHA256 = "00e442f7e9c8a52a1ac774ce997f8b5a99d12450c4dfe1594df816dcbad5126f" + STABLEHLO_COMMIT = "57e5a4a528a7e999f53c3719c1e68587efb9f0e6" + STABLEHLO_SHA256 = "50cd2240766e11f042c508c1155bc6da8afc4fb56070f04d7c11b4a9e344e94e" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 50aa7b33341bf7..c518a9025c42ee 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -334,17 +334,6 @@ LogicalResult TypeExtensionsAttr::verifyEncoding( getBounds(), RankedTensorType::get(shape, elementType), emitError); } -//===----------------------------------------------------------------------===// -// CollectivePermuteOp -//===----------------------------------------------------------------------===// - -void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, - Type resultType, Value operand, - DenseIntElementsAttr sourceTargetPairs) { - CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, - sourceTargetPairs, /*channel_handle=*/nullptr); -} - //===----------------------------------------------------------------------===// // ReduceScatterOp //===----------------------------------------------------------------------===// @@ -392,6 +381,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectiveBroadcastOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) @@ -1488,10 +1478,33 @@ LogicalResult AbsOp::inferReturnTypes( return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes); } +//===----------------------------------------------------------------------===// +// CollectiveBroadcastOp +//===----------------------------------------------------------------------===// + +void CollectiveBroadcastOp::build(OpBuilder& odsBuilder, + OperationState& odsState, Type resultType, + Value operand, + DenseIntElementsAttr replicaGroups) { + CollectiveBroadcastOp::build(odsBuilder, odsState, resultType, operand, + replicaGroups, /*channel_handle=*/nullptr); +} + +LogicalResult CollectiveBroadcastOp::verify() { + return hlo::verifyCollectiveBroadcastOp(getLoc(), getReplicaGroups()); +} + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// +void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, + Type resultType, Value operand, + DenseIntElementsAttr sourceTargetPairs) { + CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, + sourceTargetPairs, /*channel_handle=*/nullptr); +} + LogicalResult CollectivePermuteOp::verify() { return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs()); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index e566a7463a64b4..b93368e08446db 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2243,6 +2243,42 @@ def MHLO_ConcatenateOp : MHLO_ShapedInterfaceOp<"concatenate", let hasFolder = 1; } +def MHLO_CollectiveBroadcastOp: MHLO_Op<"collective_broadcast", + [HLO_CompatibleOperandsAndResultType]> { + let summary = "CollectiveBroadcast operation"; + let description = [{ + Within each process group in the process grid, send the value of the + `operand` tensor from the source process to the target processes and produce a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast + + Example: + ```mlir + %result = "mhlo.collective_broadcast"(%operand) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + ``` + }]; + + let arguments = (ins + MHLO_Tensor:$operand, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle + ); + let results = (outs MHLO_Tensor); + let hasCustomHLOConverter = 1; + let hasVerifier = 1; + // channel_handle is only used for the SPMD partitioner, so we add a + // simplified builder method for convenience. + let builders = [ + OpBuilder<(ins + "::mlir::Type":$result_type, "::mlir::Value":$operand, + "::mlir::DenseIntElementsAttr":$replica_groups)>]; +} + def MHLO_CollectivePermuteOp: MHLO_Op<"collective_permute", [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "CollectivePermute operation"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index 08c4619591525a..70572e8807cf93 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -157,8 +157,9 @@ def MHLO_ArgResultAlias : AttrDef { } // Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +// optionally for collective instructions (AllToAll, AllReduce, +// CollectiveBroadcast, and CollectivePermute). Non-positive channel_id +// handle is equivalent to no channel id. def MHLO_ChannelHandle : AttrDef { let mnemonic = "channel_handle"; let parameters = (ins "int64_t":$handle, "int64_t":$type); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h index 3667563ac078e7..bed720482d0588 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h @@ -64,6 +64,7 @@ MAP_STABLEHLO_TO_HLO(CeilOp) MAP_STABLEHLO_TO_HLO(CholeskyOp) MAP_STABLEHLO_TO_HLO(ClampOp) MAP_STABLEHLO_TO_HLO(ClzOp) +MAP_STABLEHLO_TO_HLO(CollectiveBroadcastOp) MAP_STABLEHLO_TO_HLO(CollectivePermuteOp) MAP_STABLEHLO_TO_HLO(CompareOp) MAP_STABLEHLO_TO_HLO(ComplexOp) diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index fa860a960abe29..f763e6fd89a613 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -559,6 +559,19 @@ func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_collective_broadcast" +func.func @op_collective_broadcast(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> { + // CHECK: "stablehlo.collective_broadcast"(%arg0) { + // CHECK-SAME: channel_handle = #stablehlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // CHECK-SAME: } : (tensor<1x2xi64>) -> tensor<1x2xi64> + %0 = "mhlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> +} + // CHECK-LABEL: "op_collective_permute" func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { // CHECK: "stablehlo.collective_permute"(%arg0) { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index ce40b9d6e19d13..901e7bbfeaa558 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -545,6 +545,19 @@ func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_collective_broadcast" +func.func @op_collective_broadcast(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> { + // CHECK: "mhlo.collective_broadcast"(%arg0) { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // CHECK-SAME: } : (tensor<1x2xi64>) -> tensor<1x2xi64> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> +} + // CHECK-LABEL: "op_collective_permute" func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { // CHECK: "mhlo.collective_permute"(%arg0) { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 4b8b0d25699002..f774b4e38c923e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -835,6 +835,11 @@ namespace mlir { namespace mhlo { namespace { +LogicalResult ExportXlaOp(CollectiveBroadcastOp, OpLoweringContext) { + // TODO: b/314330871 - Implement MHLO export for CollectiveBroadcastOp. + return failure(); +} + LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) { // This op should've been removed during PrepareForExport. return failure(); From e97f5e430fda8268dbf7eefa5ab31e108bcc742c Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Mon, 4 Dec 2023 17:05:32 -0800 Subject: [PATCH 372/381] Change seed for a test We got unlucky and hit a seed which happens to fail the KS test. PiperOrigin-RevId: 587885112 --- .../compiler/tests/stateless_random_ops_test.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 01142082ae24f5..166ea0be43b3f5 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -15,7 +15,6 @@ """Tests for stateless random-number generation ops.""" import functools -import os from absl.testing import parameterized import numpy as np @@ -268,10 +267,11 @@ def testDistributionOfStatelessRandomUniform(self, alg, dtype, seed): # maxval != 1. y = y.astype(float) / maxval # Tests that the values are distributed amongst 10 bins with equal - # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with - # p=0.05. This test is probabilistic and would be flaky if the random + # probability. 27.88 is the Chi^2 value for 9 degrees of freedom with + # p=0.001. This test is probabilistic and would be flaky if the random # seed were not fixed. - self.assertLess(random_test_util.chi_squared(y, 10), 16.92) + bins = 10 + self.assertLess(random_test_util.chi_squared(y, bins), 27.88) def testRandomNormalIsFinite(self): with self.session() as sess, self.test_scope(): @@ -308,16 +308,12 @@ def testTruncatedNormal(self, dtype): x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) - is_megacore = 'megacore' in os.environ.get('TEST_TARGET', '').lower() if dtype == dtypes.float16: - if is_megacore: - mean_atol = 2e-3 - else: - mean_atol = 7e-4 + mean_atol = 2e-3 else: mean_atol = 5e-4 - if dtype == dtypes.float16 and is_megacore: + if dtype == dtypes.float16: median_atol = 2e-3 else: median_atol = 8e-4 From 227a72e54dbab121ed34021364ae040974180184 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Mon, 4 Dec 2023 17:25:41 -0800 Subject: [PATCH 373/381] Import ragged_tensor.py in the ragged __init__.py file. PiperOrigin-RevId: 587889341 --- tensorflow/python/ops/ragged/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py index c9d9a79dad753f..457e54641c6953 100644 --- a/tensorflow/python/ops/ragged/__init__.py +++ b/tensorflow/python/ops/ragged/__init__.py @@ -25,3 +25,4 @@ API docstring: tensorflow.ragged """ +from tensorflow.python.ops.ragged import ragged_tensor From 7e85121a698d1cded8fd777437a93f1f0185cfc8 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 4 Dec 2023 17:48:20 -0800 Subject: [PATCH 374/381] Add BufferDonor support to cpu_compiler.cc. PiperOrigin-RevId: 587893584 --- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/cpu_compiler.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 448ecbc8e1bf6c..427a3a47c08878 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -317,6 +317,7 @@ cc_library( "//xla/service:map_inliner", "//xla/service:operand_upcaster", "//xla/service:optimization_barrier_expander", + "//xla/service:optimize_input_output_buffer_alias", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", "//xla/service:reshape_decomposer", diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 63eae07969e96a..88f1cf9f50f648 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -189,6 +189,7 @@ limitations under the License. #include "xla/service/map_inliner.h" #include "xla/service/operand_upcaster.h" #include "xla/service/optimization_barrier_expander.h" +#include "xla/service/optimize_input_output_buffer_alias.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" #include "xla/service/reshape_decomposer.h" @@ -939,6 +940,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(true); return pipeline.Run(module).status(); } From 1308427362536c6a1e6e5a8ad26018dfdfe84801 Mon Sep 17 00:00:00 2001 From: Xinyi Wang Date: Mon, 4 Dec 2023 18:44:57 -0800 Subject: [PATCH 375/381] Let multi process runner re-raise SkipTest from sub-process. PiperOrigin-RevId: 587903563 --- tensorflow/python/distribute/multi_process_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 69b22392903a03..a07df8e337cb0c 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -929,10 +929,13 @@ def shutdown(self): if self._runner is not None: try: self._runner.join() + except unittest.SkipTest: + raise except Exception as e: # pylint: disable=broad-except - logging.error( + logging.exception( 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', - e) + e, + ) self._runner = None def _start(self): From 05c310739b806c3f2ef173428d3764d4599bfb56 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Dec 2023 19:02:29 -0800 Subject: [PATCH 376/381] Clears the contents of 'solve_info' before computing the solution's fingerprint; this contains information (like solver wall time) that can vary between runs. PiperOrigin-RevId: 587906345 --- .../hlo/experimental/auto_sharding/auto_sharding_solver.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 38865d3454bdeb..ebe4563a320df6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -584,6 +584,7 @@ AutoShardingSolverResult SolveAndExtractSolution( absl::Time end_time = absl::Now(); auto duration = end_time - start_time; LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms"; + LOG(INFO) << "Solver Status: " << status; if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; @@ -623,7 +624,6 @@ AutoShardingSolverResult SolveAndExtractSolution( "likely a bug and should be reported."; } else if (status != operations_research::MPSolver::OPTIMAL) { auto err_msg = "Solver timed out."; - LOG(WARNING) << err_msg << " Solver status " << status; return AutoShardingSolverResult(absl::InternalError(err_msg), true); } @@ -634,10 +634,10 @@ AutoShardingSolverResult SolveAndExtractSolution( uint64_t model_fprint = tsl::Fingerprint64(model_proto.SerializeAsString()); operations_research::MPSolutionResponse response; solver.FillSolutionResponseProto(&response); + response.clear_solve_info(); // Remove for fingerprint; can vary between runs uint64_t solution_fprint = tsl::Fingerprint64(response.SerializeAsString()); - LOG(INFO) << "Solver Status: " << status - << " Objective value: " << solver.Objective().Value() + LOG(INFO) << "Objective value: " << solver.Objective().Value() << " Model fingerprint: " << model_fprint << " Solution fingerprint: " << solution_fprint; if (solver.Objective().Value() >= kInfinityCost) { From 136624245327f83493daa0e354252b06a042c21c Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Mon, 4 Dec 2023 19:55:47 -0800 Subject: [PATCH 377/381] [XLA] Improve the compile time and memory usage of while_loop_fusible_sinking by doing a prepass to detect whether to construct a fusion. PiperOrigin-RevId: 587914972 --- third_party/xla/xla/service/BUILD | 3 +- .../xla/service/while_loop_fusible_sinking.cc | 142 +++++++++++------- .../xla/service/while_loop_fusible_sinking.h | 10 +- 3 files changed, 91 insertions(+), 64 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 90c79e4f532142..4c7ceeb55f3873 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6124,14 +6124,13 @@ cc_library( hdrs = ["while_loop_fusible_sinking.h"], visibility = ["//visibility:public"], deps = [ - ":call_graph", ":hlo_pass", ":while_util", - "//xla:literal_util", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index 37d5f4d229fa27..5f39d71e5a208d 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/while_loop_fusible_sinking.h" +#include #include #include "absl/algorithm/container.h" @@ -22,8 +23,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal_util.h" -#include "xla/service/call_graph.h" #include "xla/service/while_util.h" #include "xla/statusor.h" #include "xla/util.h" @@ -31,64 +30,87 @@ limitations under the License. namespace xla { -HloInstruction* WhileLoopFusibleSinking::GetSinkableFusion( - HloInstruction* while_operand) { - std::vector worklist; +namespace { +// Constant and Iota have no operands and an output and broadcasts add +// dimensions to the output so we are looking fusions that have much smaller +// operand sizes compared to output sizes to avoid materialization +bool IsPurelyExpanding(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBroadcast || + (instr->opcode() == HloOpcode::kConstant && + instr->shape().rank() == 0) || + instr->opcode() == HloOpcode::kIota; +} + +bool IsFusionCandidate(const HloInstruction* instr) { + return instr->IsElementwise() || instr->opcode() == HloOpcode::kReshape || + instr->opcode() == HloOpcode::kTranspose; +} +} // namespace + +bool WhileLoopFusibleSinking::IsSinkableFusion(HloInstruction* while_operand) { + absl::InlinedVector worklist; + absl::flat_hash_set visited; worklist.push_back(while_operand); - HloInstruction* fusion = nullptr; - auto fuse = [&](HloInstruction* instr) -> bool { - if (!instr->IsFusible()) { + while (!worklist.empty()) { + HloInstruction* to_process = worklist.back(); + worklist.pop_back(); + if (!to_process->IsFusible()) { return false; } - if (!fusion) { - fusion = instr->AddInstruction(instr->CreateFusion( - instr->shape(), HloInstruction::FusionKind::kLoop, instr)); - return true; + if (!visited.insert(to_process->unique_id()).second) { + // Do not sink extremely large subgraphs as they will be expensive to + // recompute in the loop. + if (visited.size() > 100) { + return false; + } + continue; } - // The instruction has already been visited, just skip it. - if (!fusion->IsUserOf(instr)) { - return false; + if (IsPurelyExpanding(to_process)) { + continue; } - fusion->FuseInstruction(instr); - return true; - }; - std::vector new_operands; - while (!worklist.empty()) { - HloInstruction* to_process = worklist.back(); - worklist.pop_back(); - if (to_process->IsElementwise() && fuse(to_process)) { + if (IsFusionCandidate(to_process)) { for (auto* op : to_process->operands()) { worklist.push_back(op); } continue; } - switch (to_process->opcode()) { - case HloOpcode::kBroadcast: { - HloInstruction* op = to_process->mutable_operand(0); - if (fuse(to_process) && (op->opcode() == HloOpcode::kConstant || - op->opcode() == HloOpcode::kIota)) { - fuse(op); - } - break; - } - case HloOpcode::kConstant: - case HloOpcode::kIota: { - fuse(to_process); - break; + return false; + } + return true; +} + +HloInstruction* WhileLoopFusibleSinking::CreateSinkableFusion( + HloInstruction* while_operand) { + HloInstruction* fusion = + while_operand->AddInstruction(while_operand->CreateFusion( + while_operand->shape(), HloInstruction::FusionKind::kLoop, + while_operand)); + bool did_fuse = IsFusionCandidate(while_operand); + // Fuse up to broadcasts, this function expects that IsSinkableFusion is true + // and does not verify that + while (did_fuse) { + did_fuse = false; + for (int64_t i = fusion->operand_count() - 1; i >= 0; --i) { + HloInstruction* op = fusion->mutable_operand(i); + if (IsPurelyExpanding(op)) { + continue; } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: { - HloInstruction* op = to_process->mutable_operand(0); - if (fuse(to_process)) { - worklist.push_back(op); - } + fusion->FuseInstruction(op); + did_fuse = true; + break; + } + } + // Fuse the broadcasts, constants and iota at the terminals. + did_fuse = true; + while (did_fuse) { + did_fuse = false; + for (int64_t i = fusion->operand_count() - 1; i >= 0; --i) { + HloInstruction* op = fusion->mutable_operand(i); + if (IsPurelyExpanding(op)) { + fusion->FuseInstruction(op); + did_fuse = true; break; } - default: - if (fusion) { - fusion->parent()->RemoveInstruction(fusion).IgnoreError(); - } - return nullptr; } } return fusion; @@ -100,8 +122,7 @@ StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( HloComputation* while_body = while_instr->while_body(); // Don't try to mutate unflattened while loop computations. - if (call_graph_->GetNode(while_cond).callers().size() > 1 || - call_graph_->GetNode(while_body).callers().size() > 1) { + if (call_counts_[while_body] > 1 || call_counts_[while_cond] > 1) { return false; } HloInstruction* init_value = while_instr->mutable_operand(0); @@ -116,6 +137,8 @@ StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( WhileUtil::GetGTEsMapForWhileConditional(*while_cond); std::vector invariant_body_gtes = WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + std::vector tuple_indices; + std::vector new_operands; for (HloInstruction* invariant_body_gte : invariant_body_gtes) { int64_t index = invariant_body_gte->tuple_index(); @@ -126,17 +149,19 @@ StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(0, init_value)); } // Original value should be a fusible subgraph. - HloInstruction* fusion = GetSinkableFusion(invariant_value); - if (fusion == nullptr) { + if (!IsSinkableFusion(invariant_value)) { continue; } + HloInstruction* fusion = CreateSinkableFusion(invariant_value); changed = true; - auto uses = while_instr->users(); if (fusion->operand_count() > 0 && (while_instr->IsRoot() || - absl::c_any_of(uses, [&](HloInstruction* use) { + absl::c_any_of(while_instr->users(), [&](HloInstruction* use) { return use->opcode() != HloOpcode::kGetTupleElement; }))) { + // This really only occurs in unit tests or toy programs. Copy the current + // users for later replacement. + auto uses = while_instr->users(); std::vector gtes(init_value->operand_count()); for (int64_t i = 0; i < gtes.size(); ++i) { gtes[i] = while_instr->AddInstruction( @@ -161,9 +186,9 @@ StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( HloInstruction* root = while_body->root_instruction(); HloInstruction* parameter = while_body->parameter_instruction(0); - std::vector tuple_indices(fusion->operand_count()); + tuple_indices.resize(fusion->operand_count()); int64_t next_index = init_value->operand_count(); - std::vector new_operands(fusion->operand_count()); + new_operands.resize(fusion->operand_count()); for (int64_t i = 0; i < fusion->operand_count(); ++i) { init_value->AppendOperand(fusion->mutable_operand(i)); parameter->mutable_shape()->mutable_tuple_shapes()->push_back( @@ -191,8 +216,6 @@ StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( StatusOr WhileLoopFusibleSinking::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - auto call_graph = CallGraph::Build(module, execution_threads); - call_graph_ = call_graph.get(); bool changed = false; std::vector while_instrs; for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { @@ -221,6 +244,11 @@ StatusOr WhileLoopFusibleSinking::Run( HloPredicateIsOp); } + for (HloInstruction* while_instr : while_instrs) { + call_counts_[while_instr->while_body()]++; + call_counts_[while_instr->while_condition()]++; + } + for (HloInstruction* while_instr : while_instrs) { TF_ASSIGN_OR_RETURN(bool result, TrySinkingFusiblesIntoWhileLoop(while_instr)); diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.h b/third_party/xla/xla/service/while_loop_fusible_sinking.h index 0f6be3a40d51a7..26caf689b20465 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.h +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.h @@ -16,12 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ #define XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ -#include - +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/call_graph.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" @@ -73,9 +72,10 @@ class WhileLoopFusibleSinking : public HloModulePass { // Creates a loop fusion instruction containing the computation to move into // the while loop to avoid conflicts with actual instruction fusion, the loop // fusion will be defused. - HloInstruction* GetSinkableFusion(HloInstruction* while_operand); + bool IsSinkableFusion(HloInstruction* while_operand); + HloInstruction* CreateSinkableFusion(HloInstruction* while_operand); - CallGraph* call_graph_; + absl::flat_hash_map call_counts_; }; } // namespace xla From 2d016bee446499a65bfbab04f77d6d9170114eea Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 4 Dec 2023 20:15:36 -0800 Subject: [PATCH 378/381] [Distributed Eager] Register component functions with a private function library. This change fixes a rare issue where two component functions are registered on a remote eager context, and their function libraries contain a function with the same name but a different body. When this happens the second registration fails due to a duplicate function upon adding it to the context-wide `FunctionLibraryDefinition`. To avoid this problem, when registering a component function, we use the `FunctionDefLibrary` shipped to create a private `FunctionLibraryDefinition` for running that function. We can do this relatively easily because the eager `ClusterFunctionLibraryRuntime` ships all reachable functions along with the root component function; and we have long-standing support for instantiating a function with an "overlay" `FunctionLibraryDefinition`. The behavior matches the TF1 `ClusterFunctionLibraryRuntime`, which ships an entire private library as part of the subgraph it registers with a remote worker, and creates a new `FunctionLibraryDefinition` and `ProcessFunctionLibraryRuntime` for that subgraph. Note that support for removing a component function via the `ClusterFunctionLibraryRuntime` was previously unsupported. We rely on this to simplify the ownership of the private `FunctionLibraryDefinition` objects, which are owned by the `EagerContext` and never deleted. Future support for remove would likely require using refcounted or otherwise-shared `FunctionLibraryDefinition` objects in the FLR stack. (In our experience, the issue is the result of an MLIR rewrite that canonicalizes the same source function in two different ways, so e.g. the choice of retained node for common subexpression elimination is different, but the two versions are functionally equivalent. In principle, making that rewrite deterministic, or making it choose a new name for the rewritten function would also solve the problem. However, I prefer this approach because it is robust to less-than-perfect rewrite passes, and we have a lot of rewrite passes.) PiperOrigin-RevId: 587920362 --- tensorflow/core/common_runtime/eager/BUILD | 8 +- .../core/common_runtime/eager/context.cc | 38 ++++ .../core/common_runtime/eager/context.h | 21 ++ .../common_runtime/eager/eager_operation.cc | 29 ++- .../common_runtime/eager/eager_operation.h | 18 ++ .../core/common_runtime/eager/execute.cc | 25 +-- .../common_runtime/eager/kernel_and_device.cc | 64 +++--- .../common_runtime/eager/kernel_and_device.h | 35 ++- .../eager/kernel_and_device_test.cc | 4 +- .../core/distributed_runtime/eager/BUILD | 5 +- .../eager/eager_service_impl.cc | 37 +++- .../eager/eager_service_impl_test.cc | 207 +++++++++++++++++- .../eager/remote_copy_node.cc | 3 +- .../eager/remote_execute_node.h | 3 +- 14 files changed, 419 insertions(+), 78 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index edeca472fa9b4a..376b6d81351458 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -1,5 +1,3 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_zendnn", @@ -9,6 +7,8 @@ load( "tf_cuda_library", "tf_mkl_kernel_library", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -119,6 +119,8 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -301,6 +303,8 @@ tf_cuda_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 1ba7291a3e07be..a7306be3b8b431 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -26,6 +26,8 @@ limitations under the License. // clang-format off // Required for IS_MOBILE_PLATFORM +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -1008,6 +1010,42 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef, return OkStatus(); } +Status EagerContext::AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library) { + { + mutex_lock l(cache_mu_); + auto iter = component_function_libraries_.find(fdef.signature().name()); + if (iter == component_function_libraries_.end()) { + // TODO(mrry): For any functions in the main function library, consider + // deduplicating them here. + auto component_func_lib_def = std::make_unique( + OpRegistry::Global(), library); + TF_RETURN_IF_ERROR(component_func_lib_def->AddFunctionDef(fdef, {})); + component_function_libraries_.insert( + {fdef.signature().name(), std::move(component_func_lib_def)}); + } else { + // The function has been registered before. If the function is different, + // we error out. + const FunctionDef* prev_fdef = + iter->second->Find(fdef.signature().name()); + if (prev_fdef == nullptr) { + return absl::InternalError( + absl::StrCat("Component function: ", fdef.signature().name(), + " is in the cache but not in the library")); + } + if (!FunctionDefsEqual(fdef, *prev_fdef)) { + return absl::InvalidArgumentError(absl::StrCat( + "Attempting to add a duplicate function with name: ", + fdef.signature().name(), " where the previous and current ", + "definitions differ. Previous definition: ", + prev_fdef->DebugString(), + " and current definition: ", fdef.DebugString())); + } + } + } + return OkStatus(); +} + const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) { return func_lib_def_.Find(function_name); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3aa9a5a3d03890..075849fae3304b 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -251,6 +251,14 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { bool add_to_local_only = false, const StackTracesMap& stack_traces = {}); + // Adds a component function (i.e. containing a subgraph of a multi-process + // function) implemented as `fdef`. + // + // REQUIRES: `library` must contain all functions reachable from `fdef`. It + // should not contain `fdef` itself. + Status AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library); + const FunctionDef* GetFunctionDef(const string& function_name); std::vector ListFunctionNames() override; @@ -385,6 +393,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { FunctionLibraryDefinition* FuncLibDef() override { return &func_lib_def_; } + FunctionLibraryDefinition* GetComponentFunctionFunctionLibraryDefinition( + const string& function_name) { + tf_shared_lock lock(cache_mu_); + auto iter = component_function_libraries_.find(function_name); + if (iter != component_function_libraries_.end()) { + return iter->second.get(); + } + return nullptr; + } + #if !defined(IS_MOBILE_PLATFORM) // Assign the EagerClient pointer to `client` based on the given device / task // name, and increment the refcount of the client. The reference ownership is @@ -756,6 +774,9 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { kernel_cache_ TF_GUARDED_BY(cache_mu_); std::unordered_map registered_functions_ TF_GUARDED_BY(cache_mu_); + + std::unordered_map> + component_function_libraries_ TF_GUARDED_BY(cache_mu_); absl::flat_hash_map device_cache_ TF_GUARDED_BY(device_cache_mu_); std::unordered_map>> diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 545585750b6abb..58888afece8bd1 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" @@ -333,16 +336,24 @@ Status EagerOperation::Reset( if (!is_function) { const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get(); colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end(); - TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_)); - } else if (!remote && !ctx_.FindFunctionByName(op)) { - return errors::NotFound( - "'", op, - "' is neither a type of a primitive operation nor a name " - "of a function registered in binary running on ", - port::Hostname(), - ". Make sure the operation or function is " - "registered in the binary running in this process."); + } else if (!remote) { + const FunctionLibraryDefinition* func_lib_def; + if (eager_func_params.has_value() && + eager_func_params.value().func_lib_def_override != nullptr) { + func_lib_def = eager_func_params.value().func_lib_def_override; + } else { + func_lib_def = ctx_.FuncLibDef(); + } + if (func_lib_def->Find(op) == nullptr) { + return absl::NotFoundError(absl::StrCat( + "'", op, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process.")); + } } attrs_.Reset(op); stack_trace_.reset(); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index ccde391e8dc53d..fd34a709ca2cae 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/managed_stack_trace.h" @@ -153,6 +154,23 @@ class EagerOperation : public ImmediateExecutionOperation { tensorflow::EagerContext& EagerContext() const { return ctx_; } + const FunctionLibraryDefinition* FuncLibDef() const { + if (eager_func_params_.has_value() && + eager_func_params_.value().func_lib_def_override) { + return eager_func_params_.value().func_lib_def_override; + } else { + return ctx_.FuncLibDef(); + } + } + + const FunctionDef* FunctionDef() const { + if (is_function_) { + return FuncLibDef()->Find(attrs_.op_name()); + } else { + return nullptr; + } + } + AttrBuilder* MutableAttrs() { return &attrs_; } const AttrBuilder& Attrs() const { return attrs_; } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index daaab604d2a01d..d6fe6bd12b40c7 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -291,8 +291,7 @@ Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { const auto& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def = nullptr; - const FunctionDef* function_def = - op->EagerContext().FuncLibDef()->Find(op->Name()); + const FunctionDef* function_def = op->FunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -420,8 +419,7 @@ Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name()); + const FunctionDef* function_def = op->FunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op->Name(), "'"); } @@ -445,8 +443,7 @@ Status HasTPUReplication(const EagerOperation& op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(op.Name()); + const FunctionDef* function_def = op.FunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op.Name(), "'"); } @@ -513,11 +510,12 @@ Status HasNestedJitCompile(const EagerOperation& op, const EagerContext& ctx, std::queue function_names; function_names.push(op.Name()); + const FunctionLibraryDefinition* func_lib_def = op.FuncLibDef(); + while (!function_names.empty()) { const string& function_name = function_names.front(); - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(function_name); + const FunctionDef* function_def = func_lib_def->Find(function_name); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", function_name, "'"); } @@ -1537,8 +1535,8 @@ Status GetOrCreateKernelAndDevice( ctx.GetCollectiveExecutorHandle(), ctx.HostCPU())); } - TF_RETURN_IF_ERROR( - kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector)); + TF_RETURN_IF_ERROR(kernel->Init(ctx.LogDevicePlacement(), ndef, + graph_collector, op->eager_func_params())); // Exclude tf.data op kernels from being cached. The reason for this is // that tf.data op kernels that accept a user-defined function will have a @@ -1548,8 +1546,7 @@ Status GetOrCreateKernelAndDevice( // programs that build input pipeline graphs in a loop. const OpDef* op_def; if (op->is_function()) { - const FunctionDef* function_def = - op->EagerContext().FuncLibDef()->Find(op->Name()); + const FunctionDef* function_def = op->FunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -1976,8 +1973,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, std::unique_ptr node(new eager::RemoteExecuteNode( &op->EagerContext(), std::move(request), op_device, ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(), - op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), - *inputs, {retvals, num_outputs})); + op->MutableAttrs()->BuildNodeDef(), op->FuncLibDef(), *inputs, + {retvals, num_outputs})); if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { string msg = strings::StrCat( diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 7b3b383b3ddb44..460fab04252ece 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -103,9 +104,14 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() { } } -Status KernelAndDeviceOp::Init(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { +Status KernelAndDeviceOp::Init( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collecto, + const absl::optional& eager_func_params) { + if (eager_func_params.has_value()) { + return absl::InternalError( + "KernelAndDeviceOp does not support EagerFunctionParams."); + } OpKernel* k = nullptr; if (flr_ == nullptr) { return errors::Internal( @@ -141,22 +147,31 @@ Status KernelAndDeviceOp::Init(const bool log_device_placement, return OkStatus(); } -Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { +Status KernelAndDeviceFunc::InstantiateFunc( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) { const OpDef* op_def = nullptr; - const FunctionDef* function_def; - if (flr_ == nullptr) { - // If function is being executed without an explicit device request, - // lookup the FunctionDef in the CPU's FLR. All FLRs share the same - // library. - function_def = pflr_->GetFLR(host_cpu_device_->name()) - ->GetFunctionLibraryDefinition() - ->Find(ndef.op()); + const FunctionLibraryDefinition* func_lib_def; + FunctionLibraryRuntime::InstantiateOptions options; + + if (eager_func_params.has_value() && + eager_func_params.value().func_lib_def_override != nullptr) { + func_lib_def = eager_func_params.value().func_lib_def_override; + options.lib_def = func_lib_def; } else { - function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op()); + if (flr_ == nullptr) { + // If function is being executed without an explicit device request, + // lookup the FunctionDef in the CPU's FLR. All FLRs share the same + // library. + func_lib_def = pflr_->GetFLR(host_cpu_device_->name()) + ->GetFunctionLibraryDefinition(); + } else { + func_lib_def = flr_->GetFunctionLibraryDefinition(); + } } + const FunctionDef* function_def = func_lib_def->Find(ndef.op()); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -165,7 +180,6 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, TF_RETURN_IF_ERROR( InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_)); - FunctionLibraryRuntime::InstantiateOptions options; options.target = device_ == nullptr ? "" : device_->name(); options.is_multi_device_function = true; for (const Device* device : input_devices_) { @@ -174,13 +188,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, options.composite_devices = composite_devices_; options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_; if (outputs_on_op_device_) { - const FunctionLibraryDefinition* lib_def = - pflr_->GetFunctionLibraryDefinition(); - const FunctionDef* fdef = lib_def->Find(ndef.op()); - if (fdef == nullptr) { + if (function_def == nullptr) { return errors::InvalidArgument("Failed to find function ", ndef.op()); } - for (int i = 0; i < fdef->signature().output_arg_size(); ++i) { + for (int i = 0; i < function_def->signature().output_arg_size(); ++i) { options.output_devices.push_back(options.target); } } @@ -248,11 +259,12 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, return pflr_->IsCrossProcess(handle_, &is_cross_process_); } -Status KernelAndDeviceFunc::Init(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { - TF_RETURN_IF_ERROR( - InstantiateFunc(log_device_placement, ndef, graph_collector)); +Status KernelAndDeviceFunc::Init( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) { + TF_RETURN_IF_ERROR(InstantiateFunc(log_device_placement, ndef, + graph_collector, eager_func_params)); return pflr_->GetOutputDevices(handle_, &output_devices_); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index a98427a9e04d27..7a800f9b2a15d1 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -60,14 +60,19 @@ class FunctionLibraryRuntime; const int64_t kInvalidOpId = -1; -// This struc is used for: -// 1. setting op_id and step_id, is_component_function for single-client +// This struct is used for: +// 1. Setting `op_id` and `step_id`, `is_component_function` for single-client // remote function scenario, -// 2. setting step_id for multi-client parallel_device scenario. +// 2. Setting `step_id` for multi-client parallel_device scenario. +// 3. Supplying an overriding, private `FunctionLibraryDefinition` for component +// functions. struct EagerFunctionParams { int64_t op_id = kInvalidOpId; bool is_component_function; std::optional step_id = std::nullopt; + FunctionLibraryDefinition* func_lib_def_override = + nullptr; // Not owned (owned by `EagerContext`). If not null, functions + // called by the function will be looked up in this library. }; class EagerKernelArgs : public FunctionArgsInterface { @@ -113,8 +118,10 @@ class KernelAndDevice : public core::RefCounted { // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. - virtual Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) = 0; + virtual Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) = 0; // Non-multi-device functions are run using regular CallOp and look like // primitive operations from KernelAndDevice perspective. @@ -215,8 +222,10 @@ class KernelAndDeviceOp final : public KernelAndDevice { ~KernelAndDeviceOp() override = default; - Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) override; + Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, @@ -316,11 +325,15 @@ class KernelAndDeviceFunc : public KernelAndDevice { bool IsCrossProcess() override { return is_cross_process_; } - Status InstantiateFunc(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector); + Status InstantiateFunc( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params); - Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) override; + Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 33122bc4c38105..bda3e5f582fc05 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -118,7 +118,7 @@ void BM_KernelAndDeviceInit(::testing::benchmark::State& state) { KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr, nullptr, env.cpu_device()); for (auto s : state) { - TF_CHECK_OK(k.Init({}, ndef, nullptr)); + TF_CHECK_OK(k.Init({}, ndef, nullptr, std::nullopt)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -138,7 +138,7 @@ void BM_KernelAndDeviceRun(::testing::benchmark::State& state) { TestEnv env; KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr, nullptr, env.cpu_device()); - TF_CHECK_OK(k.Init({}, ndef, nullptr)); + TF_CHECK_OK(k.Init({}, ndef, nullptr, std::nullopt)); const EagerKernelArgs args(std::move(inputs)); for (auto s : state) { TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, std::nullopt, diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 46e86a42be6734..a3a1c8ae937db7 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_grpc_cc_dependencies") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_grpc_cc_dependencies") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -102,6 +102,7 @@ cc_library( ":remote_tensor_handle", "//tensorflow/c/eager:immediate_execution_distributed_manager", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index f2fd43ca853156..f6f3bf1ee1668c 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/nccl/collective_communicator.h" #include "tensorflow/core/platform/errors.h" @@ -48,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tsl/distributed_runtime/preemption/preemption_notifier.h" #include "tsl/protobuf/coordination_config.pb.h" @@ -55,13 +57,14 @@ namespace tensorflow { namespace eager { namespace { -Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, +Status GetNumRetvals(FunctionLibraryDefinition* func_lib_def, + const string& op_name, const google::protobuf::Map& attrs, int* num_retvals) { const tensorflow::OpRegistrationData* op_reg_data = nullptr; auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); if (absl::IsNotFound(status)) { - status = context->FindFunctionOpData(op_name, &op_reg_data); + status = func_lib_def->LookUp(op_name, &op_reg_data); } TF_RETURN_IF_ERROR(status); @@ -100,14 +103,27 @@ Status GetEagerOperationAndNumRetvals(const Operation& operation, const char* name = operation.name().c_str(); // Shorthand std::optional remote_func_params = std::nullopt; + FunctionLibraryDefinition* func_lib_def; if (operation.is_function()) { if (operation.is_component_function()) { + func_lib_def = + eager_context->GetComponentFunctionFunctionLibraryDefinition( + operation.name()); + if (func_lib_def == nullptr) { + return absl::InternalError( + absl::StrCat("Could not find function library for registered " + "component function: ", + operation.name())); + } remote_func_params = {operation.id(), /*is_component_function=*/true, - operation.func_step_id()}; + operation.func_step_id(), func_lib_def}; } else { + func_lib_def = eager_context->FuncLibDef(); remote_func_params = {operation.id(), /*is_component_function=*/false, - std::nullopt}; + std::nullopt, /*func_lib_def=*/nullptr}; } + } else { + func_lib_def = eager_context->FuncLibDef(); } TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false, eager_executor, remote_func_params)); @@ -143,7 +159,7 @@ Status GetEagerOperationAndNumRetvals(const Operation& operation, } // TODO(nareshmodi): Consider caching this. - return GetNumRetvals(eager_context, operation.name(), operation.attrs(), + return GetNumRetvals(func_lib_def, operation.name(), operation.attrs(), num_retvals); } @@ -770,9 +786,14 @@ Status EagerServiceImpl::RegisterFunction( const RegisterFunctionOp& register_function, EagerContext* eager_context) { // If the function is a component of a multi-device function, we only need to // register it locally. - return eager_context->AddFunctionDef( - register_function.function_def(), register_function.library(), - register_function.is_component_function()); + if (register_function.is_component_function()) { + return eager_context->AddComponentFunction(register_function.function_def(), + register_function.library()); + } else { + return eager_context->AddFunctionDef(register_function.function_def(), + register_function.library(), + /*add_to_local_only=*/false); + } } Status EagerServiceImpl::RemoveFunction(const RemoveFunctionOp& remove_function, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 79f8ccb21d934a..2ab6631de71d9b 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -309,6 +309,46 @@ tensorflow::FunctionDef MatMulFunction() { return def; } +tensorflow::FunctionDef MatMulTransposeFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'a'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: 'transpose_a'" + " value {" + " b: true" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }", + &def)); + return def; +} + tensorflow::FunctionDef MatMulNestedFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( @@ -710,15 +750,178 @@ TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) { TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) { RegisterFunctionOp register_op; *register_op.mutable_function_def() = MatMulFunction(); + register_op.set_is_component_function(true); TestComponentFunction(register_op, "MatMulFunction", false); } TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) { RegisterFunctionOp register_op; *register_op.mutable_function_def() = SingleRecvNodeFunction(); + register_op.set_is_component_function(true); TestComponentFunction(register_op, "SingleRecvNodeFunction", true); } +TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionTest) { + RegisterFunctionOp register_op; + *register_op.mutable_function_def() = MatMulNestedFunction(); + *register_op.mutable_library()->add_function() = MatMulFunction(); + register_op.set_is_component_function(true); + TestComponentFunction(register_op, "MatMulNestedFunction", false); +} + +TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionWithNameClashTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); + + // Create context. + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_context_id(context_id); + CreateContextResponse response; + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + // Register first function. + { + EnqueueRequest enqueue_request; + enqueue_request.set_context_id(context_id); + RegisterFunctionOp* register_op = + enqueue_request.add_queue()->mutable_register_function(); + *register_op->mutable_function_def() = MatMulNestedFunction(); + *register_op->mutable_library()->add_function() = MatMulFunction(); + register_op->set_is_component_function(true); + EnqueueResponse enqueue_response; + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request, + &enqueue_response)); + } + + // Register second function. + // In the second registration, the library contains a function named + // "MatMulFunction" but a different body. + { + EnqueueRequest enqueue_request; + enqueue_request.set_context_id(context_id); + RegisterFunctionOp* register_op = + enqueue_request.add_queue()->mutable_register_function(); + + *register_op->mutable_function_def() = MatMulNestedFunction(); + register_op->mutable_function_def()->mutable_signature()->set_name( + "MatMulNestedTransposeFunction"); + *register_op->mutable_library()->add_function() = MatMulTransposeFunction(); + register_op->set_is_component_function(true); + EnqueueResponse enqueue_response; + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request, + &enqueue_response)); + } + + // First run an op to generate input for the functions. + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + + std::unordered_map const_attrs; + AttrValue val; + val.set_type(tensorflow::DataType::DT_FLOAT); + const_attrs.insert({"dtype", val}); + val.Clear(); + SetTensorProto(val.mutable_tensor()); + const_attrs.insert({"value", val}); + AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, + "/job:localhost/replica:0/task:0/device:CPU:0", + &remote_enqueue_request); + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, + &remote_enqueue_response)); + + { + // Run first function with input from the previous op. + RunComponentFunctionRequest run_comp_func_request; + run_comp_func_request.set_context_id(context_id); + RunComponentFunctionResponse run_comp_func_response; + const int output_num = 5; + AddOperationToRunComponentFunctionRequest( + 2, "MatMulNestedFunction", {std::make_pair(1, 0)}, + std::unordered_map(), + "/job:localhost/replica:0/task:0/device:CPU:0", output_num, + &run_comp_func_request); + + CallOptions call_opts; + Notification n; + Status status; + eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, + &run_comp_func_response, + [&status, &n](const Status& s) { + status.Update(s); + n.Notify(); + }); + n.WaitForNotification(); + + TF_ASSERT_OK(status); + // Retrieve the output. + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(2, output_num), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + auto actual = t->flat(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(7, actual(0)); + EXPECT_EQ(10, actual(1)); + EXPECT_EQ(15, actual(2)); + EXPECT_EQ(22, actual(3)); + } + + { + // Run second function with input from the constant op. The result should + // be different, because we are using the transposed implementation of + // MatMulFunction in the second function's library. + RunComponentFunctionRequest run_comp_func_request; + run_comp_func_request.set_context_id(context_id); + RunComponentFunctionResponse run_comp_func_response; + const int output_num = 5; + AddOperationToRunComponentFunctionRequest( + 3, "MatMulNestedTransposeFunction", {std::make_pair(1, 0)}, + std::unordered_map(), + "/job:localhost/replica:0/task:0/device:CPU:0", output_num, + &run_comp_func_request); + + CallOptions call_opts; + Notification n; + Status status; + eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, + &run_comp_func_response, + [&status, &n](const Status& s) { + status.Update(s); + n.Notify(); + }); + n.WaitForNotification(); + + TF_ASSERT_OK(status); + // Retrieve the output. + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(3, output_num), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + auto actual = t->flat(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(10, actual(0)); + EXPECT_EQ(14, actual(1)); + EXPECT_EQ(14, actual(2)); + EXPECT_EQ(20, actual(3)); + } + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + close_context_request.set_context_view_id(0); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + class FunctionWithRemoteInputsTest : public EagerServiceImplTest { public: FunctionWithRemoteInputsTest() @@ -987,7 +1190,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { // Instantiate MatMulFunction on remote_device. const NodeDef node_def = MatMulFunctionNodeDef(); - TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr)); + TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. gtl::InlinedVector input_tensors = {TensorValue()}; @@ -1042,7 +1245,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { // Instantiate MatMulFunction on remote_device. const NodeDef node_def = MatMulFunctionNodeDef(); - TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr)); + TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. gtl::InlinedVector input_tensors = {TensorValue()}; diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index 4cefc9433c2556..bd5bc39622b9d6 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -60,7 +60,8 @@ Status CreateUncachedKernelAndDeviceOp( const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); return kernel->get()->Init(ctx.LogDevicePlacement(), ndef, - /*graph_collector=*/nullptr); + /*graph_collector=*/nullptr, + /*eager_func_params=*/std::nullopt); } // This gets a unique wire ID. We add a random identifier so that if the diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 6aabf3ce209d7d..148e58a5b008c5 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -42,7 +42,8 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { std::unique_ptr request, Device* device, uint64 context_view_id, EagerClient* eager_client, CancellationManager* cancellation_manager, - const NodeDef& ndef, FunctionLibraryDefinition* lib_def, + const NodeDef& ndef, + const FunctionLibraryDefinition* lib_def, const gtl::InlinedVector& inputs, absl::Span retvals) : AsyncRemoteExecuteNode(), From db579439eef970657f5ddbf05dc9b798cb748c51 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 4 Dec 2023 21:16:15 -0800 Subject: [PATCH 379/381] Migrate experimental macOS x86 nightly builds to the new CI folder We need to install Bazelisk and Pyenv manually as these are not present on the x86 Mac VMs. Note that the uploads from these new jobs are disabled as they are not yet ready. However, the old Mac x86 nightly builds will still be running and upload to tf-nightly so there won't be any missing nightly packages while we are doing this migration. PiperOrigin-RevId: 587930871 --- .bazelrc | 17 +++++++++++++++++ .../envs/nightly_libtensorflow_macos_x86 | 7 +++++++ ci/official/envs/nightly_macos_x86_py310 | 16 ++++++++++++++++ ci/official/envs/nightly_macos_x86_py311 | 16 ++++++++++++++++ ci/official/envs/nightly_macos_x86_py312 | 16 ++++++++++++++++ ci/official/envs/nightly_macos_x86_py39 | 14 ++++++++++++++ ci/official/utilities/setup_macos.sh | 9 +++++++++ third_party/xla/.bazelrc | 17 +++++++++++++++++ third_party/xla/third_party/tsl/.bazelrc | 17 +++++++++++++++++ 9 files changed, 129 insertions(+) create mode 100644 ci/official/envs/nightly_libtensorflow_macos_x86 create mode 100644 ci/official/envs/nightly_macos_x86_py310 create mode 100644 ci/official/envs/nightly_macos_x86_py311 create mode 100644 ci/official/envs/nightly_macos_x86_py312 create mode 100644 ci/official/envs/nightly_macos_x86_py39 diff --git a/.bazelrc b/.bazelrc index 42330a369c9f6c..9de6b6e0c2bd54 100644 --- a/.bazelrc +++ b/.bazelrc @@ -671,6 +671,15 @@ test:release_cpu_macos --config=release_base build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + # Build configs for macOS Arm64 build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 @@ -685,6 +694,9 @@ test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors test:release_macos_base --build_tests_only --keep_going test:release_macos_base --flaky_test_attempts=3 +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base @@ -746,6 +758,11 @@ test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-os test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/ci/official/envs/nightly_libtensorflow_macos_x86 b/ci/official/envs/nightly_libtensorflow_macos_x86 new file mode 100644 index 00000000000000..98af50bc06691b --- /dev/null +++ b/ci/official/envs/nightly_libtensorflow_macos_x86 @@ -0,0 +1,7 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_DOCKER_ENABLE=0 +TFCI_LIB_SUFFIX="-cpu-darwin-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_macos_x86_py310 b/ci/official/envs/nightly_macos_x86_py310 new file mode 100644 index 00000000000000..9577841dea84ec --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py310 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_x86_py311 b/ci/official/envs/nightly_macos_x86_py311 new file mode 100644 index 00000000000000..4fe9bad43f89f6 --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py311 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_x86_py312 b/ci/official/envs/nightly_macos_x86_py312 new file mode 100644 index 00000000000000..a4397de120d90c --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py312 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" \ No newline at end of file diff --git a/ci/official/envs/nightly_macos_x86_py39 b/ci/official/envs/nightly_macos_x86_py39 new file mode 100644 index 00000000000000..58c570c5d10507 --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py39 @@ -0,0 +1,14 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" \ No newline at end of file diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index 9dcc28406907d6..3957a919ca1903 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -83,3 +83,12 @@ if [[ "$TFCI_PYTHON_VERSION" == "3.12" ]]; then # Once the wheels are added, this should be removed - b/308399490. brew install cmake fi + +# Scheduled nightly and release builds upload build artifacts (Pip packages, +# Libtensorflow archives) to GCS buckets. TFCI Mac VMs need to authenticate as +# a service account that has the right permissions to be able to do so. +set +x +if [[ -n "${GOOGLE_APPLICATION_CREDENTIALS:-}" ]]; then + gcloud auth activate-service-account +fi +set -x \ No newline at end of file diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 42330a369c9f6c..9de6b6e0c2bd54 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -671,6 +671,15 @@ test:release_cpu_macos --config=release_base build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + # Build configs for macOS Arm64 build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 @@ -685,6 +694,9 @@ test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors test:release_macos_base --build_tests_only --keep_going test:release_macos_base --flaky_test_attempts=3 +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base @@ -746,6 +758,11 @@ test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-os test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 42330a369c9f6c..9de6b6e0c2bd54 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -671,6 +671,15 @@ test:release_cpu_macos --config=release_base build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + # Build configs for macOS Arm64 build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 @@ -685,6 +694,9 @@ test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors test:release_macos_base --build_tests_only --keep_going test:release_macos_base --flaky_test_attempts=3 +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base @@ -746,6 +758,11 @@ test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-os test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. From 4d21b572b7488e8f3d3c2d9873e7954045b03810 Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Wed, 6 Dec 2023 18:29:12 +0000 Subject: [PATCH 380/381] Initial commit to resolve merge conflicts --- third_party/xla/xla/service/gpu/BUILD | 4 -- .../xla/service/gpu/buffer_comparator_test.cc | 6 --- .../xla/stream_executor/device_description.h | 42 ------------------- .../xla/stream_executor/rocm/hip_blas_lt.cc | 19 --------- .../xla/stream_executor/rocm/hip_blas_lt.h | 5 --- third_party/xla/xla/tests/BUILD | 4 -- 6 files changed, 80 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index cb5e399b6e5937..a010fba0a0c1cc 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -5035,14 +5035,10 @@ xla_cc_test( xla_cc_test( name = "determinism_test", srcs = ["determinism_test.cc"], -<<<<<<< HEAD - tags = tf_cuda_tests_tags() + ["no_rocm"], -======= local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), tags = tf_gpu_tests_tags(), ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 deps = [ ":autotuner_util", "//xla:literal", diff --git a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc index fbd2163c587d0f..77d1b5d783ba44 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator_test.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator_test.cc @@ -39,11 +39,6 @@ namespace { class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() -<<<<<<< HEAD - : platform_( - se::MultiPlatformManager::PlatformWithName(PLATFORM).ValueOrDie()), - stream_exec_(platform_->ExecutorForDevice(0).value()) {} -======= #if GOOGLE_CUDA : platform_(se::MultiPlatformManager::PlatformWithName("CUDA").value()), #elif TENSORFLOW_USE_ROCM @@ -51,7 +46,6 @@ class BufferComparatorTest : public testing::Test { #endif stream_exec_(platform_->ExecutorForDevice(0).value()) { } ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 // Take floats only for convenience. Still uses ElementType internally. template diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 2cba18f555360b..85cac5a9ee075d 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -159,25 +159,15 @@ class RocmComputeCapability { return absl::StrJoin(kSupportedGfxVersions, ", "); } -<<<<<<< HEAD - bool has_nhwc_layout_support() const { -======= bool gfx9_mi100_or_later() const { ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } -<<<<<<< HEAD - bool has_bf16_dtype_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", - "gfx941", "gfx942"}; -======= bool gfx9_mi200_or_later() const { static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941", "gfx942"}; ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 return absl::c_count(kList, gfx_version()) != 0; } @@ -190,25 +180,6 @@ class RocmComputeCapability { bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } bool has_fast_fp16_support() const { -<<<<<<< HEAD - static constexpr absl::string_view kList[] = {"gfx906", "gfx908", "gfx90a", - "gfx940", "gfx941", "gfx942", - "gfx1030", "gfx1100"}; - return absl::c_count(kList, gfx_version()) != 0; - } - - bool has_mfma_instr_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", - "gfx941", "gfx942"}; - return absl::c_count(kList, gfx_version()) != 0; - } - - bool has_fp16_atomics_support() const { - // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). - static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941", - "gfx942"}; - return absl::c_count(kList, gfx_version()) != 0; -======= return gfx9_mi100_or_later() || navi21() || navi31(); } @@ -217,7 +188,6 @@ class RocmComputeCapability { bool has_fp16_atomics_support() const { // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). return gfx9_mi200_or_later(); ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 } bool fence_before_barrier() const { @@ -240,17 +210,6 @@ class RocmComputeCapability { std::string gcn_arch_name_ = "gfx000"; // default to invalid arch. static constexpr absl::string_view kSupportedGfxVersions[]{ -<<<<<<< HEAD - "gfx900", // MI25 - "gfx906", // MI50 / MI60 - "gfx908", // MI100 - "gfx90a", // MI200 - "gfx940", // MI300 - "gfx941", // MI300 - "gfx942", // MI300 - "gfx1030", // Navi21 - "gfx1100" // Navi31 -======= "gfx900", // MI25 "gfx906", // MI50 / MI60 "gfx908", // MI100 @@ -258,7 +217,6 @@ class RocmComputeCapability { "gfx940", "gfx941", "gfx942", "gfx1030", // Navi21 "gfx1100" // Navi31 ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 }; }; diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 925c4082caa572..262a3a5c3122f9 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -478,24 +478,6 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( } // Other data types: -<<<<<<< HEAD - TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, - HIP_R_16BF) - TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, - HIP_R_16F) - TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, - HIP_R_32F) - TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_32F, - HIP_R_32F) - TYPED_MATMUL(float, HIP_R_32F, HIP_R_32F, HIP_R_32F, - HIP_R_32F) - TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F, - HIP_R_64F) - TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F, - HIP_C_32F) - TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F, - HIP_C_64F) -======= TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, HIP_R_32F) @@ -504,7 +486,6 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F, HIP_R_64F) TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F, HIP_C_32F) TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F, HIP_C_64F) ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 #undef TYPED_MATMUL diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 6f9530720f279f..0e253c7d8062e2 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -72,12 +72,7 @@ class BlasLt : public gpu::BlasLt { hipblasLtMatmulDesc_t get() const { return handle_.get(); } private: -<<<<<<< HEAD - MatmulDesc(hipblasLtMatmulDesc_t handle, - hipblasComputeType_t compute_type, -======= MatmulDesc(hipblasLtMatmulDesc_t handle, hipblasComputeType_t compute_type, ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 hipDataType datatype) : handle_(handle, wrap::hipblasLtMatmulDescDestroy), compute_type_(compute_type), diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 51aa59abd76466..6ee3b50a90fdea 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -777,13 +777,9 @@ xla_test( xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], -<<<<<<< HEAD - tags = ["no_rocm"], -======= local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), ->>>>>>> db579439eef970657f5ddbf05dc9b798cb748c51 shard_count = 25, deps = [ ":client_library_test_base", From d24bcedf96e797c62305750d86db90df7cba648e Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 5 Dec 2023 23:49:06 -0800 Subject: [PATCH 381/381] Rename `EagerOperation::FunctionDef()` to `EagerOperation::GetFunctionDef()`. Some compilers do not like using the name of a class as a method, which is fair enough. PiperOrigin-RevId: 588312567 --- tensorflow/core/common_runtime/eager/eager_operation.h | 2 +- tensorflow/core/common_runtime/eager/execute.cc | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index fd34a709ca2cae..3ddf91c5ed5f52 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -163,7 +163,7 @@ class EagerOperation : public ImmediateExecutionOperation { } } - const FunctionDef* FunctionDef() const { + const FunctionDef* GetFunctionDef() const { if (is_function_) { return FuncLibDef()->Find(attrs_.op_name()); } else { diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index d6fe6bd12b40c7..0d68aac0cff554 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -291,7 +291,7 @@ Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { const auto& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def = nullptr; - const FunctionDef* function_def = op->FunctionDef(); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -419,7 +419,7 @@ Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = op->FunctionDef(); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op->Name(), "'"); } @@ -443,7 +443,7 @@ Status HasTPUReplication(const EagerOperation& op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = op.FunctionDef(); + const FunctionDef* function_def = op.GetFunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op.Name(), "'"); } @@ -1546,7 +1546,7 @@ Status GetOrCreateKernelAndDevice( // programs that build input pipeline graphs in a loop. const OpDef* op_def; if (op->is_function()) { - const FunctionDef* function_def = op->FunctionDef(); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else {