From 212b383255e823cb89bfc1fb7af2a3c4ea0a1898 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 24 Oct 2024 11:16:06 -0700 Subject: [PATCH] [xla:scatterExpander] Extend ScatterToLoop transformation to generate correct code in the presence of explicit batch dimensions. Explicit batch dimensions were recently added to scatter instructions in https://github.com/openxla/stablehlo/pull/2084. This CL extends the pass to consider explicit operand batch dimensions when computing the while-loop init-value shape and the operand indices. PiperOrigin-RevId: 689450009 --- xla/service/BUILD | 5 +- xla/service/scatter_expander.cc | 78 ++++++++++++-------------- xla/service/scatter_expander_test.cc | 83 +++++++++++++++++++++++++++- 3 files changed, 120 insertions(+), 46 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index c1198fc3571af..af652b65ccc3a 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2082,15 +2082,17 @@ cc_library( srcs = ["scatter_expander.cc"], hdrs = ["scatter_expander.h"], deps = [ - ":call_inliner", + ":gather_scatter_utils", ":hlo_creation_utils", ":scatter_utils", ":while_util", "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) @@ -2127,6 +2129,7 @@ xla_cc_test( "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/xla/service/scatter_expander.cc b/xla/service/scatter_expander.cc index 1bd4178afcd3c..01ebae5dd533d 100644 --- a/xla/service/scatter_expander.cc +++ b/xla/service/scatter_expander.cc @@ -15,8 +15,13 @@ limitations under the License. #include "xla/service/scatter_expander.h" +#include +#include +#include + #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -24,53 +29,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/gather_scatter_utils.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/scatter_utils.h" #include "xla/service/while_util.h" +#include "xla/shape.h" namespace xla { -// Expands an index vector from the scatter_indices tensor into a vector that -// can be used to dynamic-update-slice to perform the scatter update. -static absl::StatusOr ExpandIndexVectorIntoOperandSpace( - HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, - int64_t operand_rank) { - HloComputation* computation = index_vector->parent(); - const Shape& index_shape = index_vector->shape(); - - // Scatter of a scalar. Return a zero-sized vector of indices. - if (operand_rank == 0) { - return computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); - } - - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); - - // We extract out individual components from the smaller index and concatenate - // them (interspersing zeros as needed) into the larger index. - std::vector expanded_index_components; - - for (int i = 0; i < operand_rank; i++) { - int64_t index_vector_dim_index = - FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.scatter_dims_to_operand_dims_size()) { - TF_ASSIGN_OR_RETURN( - HloInstruction * component_to_concat, - MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, - /*limit_indices=*/{index_vector_dim_index + 1}, - /*strides=*/{1})); - expanded_index_components.push_back(component_to_concat); - } else { - expanded_index_components.push_back(zero); - } - } - - return MakeConcatHlo(expanded_index_components, /*dimension=*/0); -} - static absl::StatusOr CheckIndexValidity( HloComputation* computation, HloInstruction* index, absl::Span operand_dims, @@ -117,6 +83,23 @@ static absl::StatusOr CheckIndexValidity( return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); } +// Returns the sorted dimensions in a slice that are either collapsed or +// corresponding to an explicit batching dimension. +std::vector GetDegeneratedSliceDims( + const ScatterDimensionNumbers& dim_numbers) { + absl::Span input_batching_dims = + dim_numbers.input_batching_dims(); + absl::Span inserted_window_dims = + dim_numbers.inserted_window_dims(); + std::vector degenerated_dims; + degenerated_dims.reserve(inserted_window_dims.size() + + input_batching_dims.size()); + absl::c_copy(inserted_window_dims, std::back_inserter(degenerated_dims)); + absl::c_copy(input_batching_dims, std::back_inserter(degenerated_dims)); + absl::c_sort(degenerated_dims); + return degenerated_dims; +} + // Body of the while loop that performs the scatter operation using other HLOs. static absl::StatusOr> ScatterLoopBody( HloScatterInstruction* scatter, HloInstruction* induction_var, @@ -158,7 +141,12 @@ static absl::StatusOr> ScatterLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * scatter_slice_start, ExpandIndexVectorIntoOperandSpace( - index_vector, dim_numbers, operands[0]->shape().dimensions_size())); + scatter->scatter_indices()->shape(), + operands[0]->shape().dimensions_size(), + dim_numbers.index_vector_dim(), + dim_numbers.scatter_dims_to_operand_dims(), + dim_numbers.scatter_indices_batching_dims(), + dim_numbers.input_batching_dims(), index_vector, induction_var)); // Extract the slice to be used to update from `updates` tensor for the // induction_var corresponding to this iteration of the while loop. @@ -179,6 +167,9 @@ static absl::StatusOr> ScatterLoopBody( auto update_slices_with_dims_inserted = absl::MakeSpan(map_operands).last(updates.size()); absl::Span actual_update_slice_dims; + + std::vector degenerated_dims = GetDegeneratedSliceDims(dim_numbers); + for (int i = 0, n = operands.size(); i < n; ++i) { HloInstruction* update = updates[i]; TF_ASSIGN_OR_RETURN( @@ -188,8 +179,7 @@ static absl::StatusOr> ScatterLoopBody( ElideDegenerateDims(update_slice, {0})); TF_ASSIGN_OR_RETURN( HloInstruction * update_slice_with_dims_inserted, - InsertDegenerateDims(update_slice_for_scatter, - dim_numbers.inserted_window_dims())); + InsertDegenerateDims(update_slice_for_scatter, degenerated_dims)); update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted; // Note that the following transformation assumes that both DynamicSlice and // DynamicUpdateSlice follow the same semantics for OOB indices. For diff --git a/xla/service/scatter_expander_test.cc b/xla/service/scatter_expander_test.cc index 4d135d3bb26da..664f0112068fb 100644 --- a/xla/service/scatter_expander_test.cc +++ b/xla/service/scatter_expander_test.cc @@ -16,11 +16,14 @@ limitations under the License. #include "xla/service/scatter_expander.h" #include +#include #include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/shape_util.h" @@ -65,7 +68,6 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) { ParseAndReturnVerifiedModule(kModuleStr)); ClearInstructionLayout(module.get(), "operand"); - ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&scatter_expander, module.get())); @@ -140,6 +142,85 @@ TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) { EXPECT_FALSE(result); } +TEST_F(ScatterExpanderTest, ScatterToLoopWithBatchDims) { + const char* kModuleStr = R"( +HloModule TensorFlowScatter + func { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT s = s32[] add(x,y) + } + + ENTRY main { + indices = s32[2,3,5]{2,1,0} parameter(0) + update = s32[2,3,2,5]{3,2,1,0} parameter(1) + z = s32[] constant(0) + input = s32[5,3,2,2]{3,2,1,0} broadcast(z), dimensions={} + ROOT s = s32[5,3,2,2]{3,2,1,0} scatter(input, indices, update), + update_window_dims={2}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=3, + input_batching_dims={0,3}, + scatter_indices_batching_dims={2,0}, + to_apply=func + })"; + + // Verify the code that indexes into the operand. + const std::string expected = R"( + //CHECK: (s32[], s32[5,3,2,2], s32[30], s32[30,2])) -> (s32[], s32[5,3,2,2], s32[30], s32[30,2]) { + //CHECK: %[[PARAM:.*]] = (s32[], s32[5,3,2,2], s32[30], s32[30,2]) parameter(0) + //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=0 + //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[5,3,2,2] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=1 + + //CHECK: %[[CONSTANT0:.*]] = s32[] constant(0) + //CHECK: %[[OPERAND_INDICES_LOWER_BOUND:.*]] = s32[4] broadcast(s32[] %[[CONSTANT0]]) + //CHECK: %[[CONSTANT5:.*]] = s32[] constant(5) + //CHECK: %[[REMAINDER:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT5]]) + //CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[REMAINDER]]) + //CHECK: %[[START_INDICES:.*]] = s32[30] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=2 + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) + //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]]) + //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[30] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]]) + + //CHECK: %[[SCATTER_INDEX:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]]) + //CHECK: %[[CONSTANT0_2:.*]] = s32[1] constant({0}) + //CHECK: %[[BD_0_1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT5]]) + //CHECK: %[[CONSTANT3:.*]] = s32[] constant(3) + //CHECK: %[[BD0_RAW:.*]] = s32[] divide(s32[] %[[BD_0_1]], s32[] %[[CONSTANT3]]) + //CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]]) + //CHECK: %[[OPERAND_INDICES:.*]] = s32[4] concatenate(s32[1] %[[BD2]], s32[1] %[[SCATTER_INDEX]], s32[1] %[[CONSTANT0_2]], s32[1] %[[BD0]]) + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[2:3]} + //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]]) + //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[3:4]} + //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]]) + //CHECK: %{{.*}} = s32[1,1,2,1] dynamic-slice(s32[5,3,2,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]]) +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr)); + ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_TRUE(result); + + std::vector while_instructions = + FindInstructions(module.get(), HloOpcode::kWhile); + EXPECT_EQ(while_instructions.size(), 1); + HloComputation* while_body = while_instructions[0]->while_body(); + EXPECT_TRUE( + *RunFileCheck(while_body->ToString( + HloPrintOptions{}.set_include_layout_in_shapes(false)), + expected)); +} + TEST_F(ScatterExpanderTest, EliminateSimpleMultioutpuScattersSkipsNontrivialScatter) { const char* kModuleStr = R"(