Skip to content

Commit

Permalink
[xla:scatterExpander] Extend ScatterToLoop transformation to generate…
Browse files Browse the repository at this point in the history
… correct

code in the presence of explicit batch dimensions.

Explicit batch dimensions were recently added to scatter instructions in
openxla/stablehlo#2084.

This CL extends the pass to consider explicit operand batch dimensions when
computing the while-loop init-value shape and the operand indices.

PiperOrigin-RevId: 686266428
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Oct 24, 2024
1 parent 5d58a09 commit d4fbb67
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 46 deletions.
5 changes: 4 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2082,15 +2082,17 @@ cc_library(
srcs = ["scatter_expander.cc"],
hdrs = ["scatter_expander.h"],
deps = [
":call_inliner",
":gather_scatter_utils",
":hlo_creation_utils",
":scatter_utils",
":while_util",
"//xla:literal_util",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)

Expand Down Expand Up @@ -2127,6 +2129,7 @@ xla_cc_test(
"//xla:test",
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down
78 changes: 34 additions & 44 deletions xla/service/scatter_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,62 +15,28 @@ limitations under the License.

#include "xla/service/scatter_expander.h"

#include <cstdint>
#include <iterator>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal_util.h"
#include "xla/service/gather_scatter_utils.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/scatter_utils.h"
#include "xla/service/while_util.h"
#include "xla/shape.h"

namespace xla {

// Expands an index vector from the scatter_indices tensor into a vector that
// can be used to dynamic-update-slice to perform the scatter update.
static absl::StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
int64_t operand_rank) {
HloComputation* computation = index_vector->parent();
const Shape& index_shape = index_vector->shape();

// Scatter of a scalar. Return a zero-sized vector of indices.
if (operand_rank == 0) {
return computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
}

HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));

// We extract out individual components from the smaller index and concatenate
// them (interspersing zeros as needed) into the larger index.
std::vector<HloInstruction*> expanded_index_components;

for (int i = 0; i < operand_rank; i++) {
int64_t index_vector_dim_index =
FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
if (index_vector_dim_index !=
dim_numbers.scatter_dims_to_operand_dims_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
/*limit_indices=*/{index_vector_dim_index + 1},
/*strides=*/{1}));
expanded_index_components.push_back(component_to_concat);
} else {
expanded_index_components.push_back(zero);
}
}

return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}

static absl::StatusOr<HloInstruction*> CheckIndexValidity(
HloComputation* computation, HloInstruction* index,
absl::Span<const int64_t> operand_dims,
Expand Down Expand Up @@ -117,6 +83,23 @@ static absl::StatusOr<HloInstruction*> CheckIndexValidity(
return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
}

// Returns the sorted dimensions in a slice that are either collapsed or
// corresponding to an explicit batching dimension.
std::vector<int64_t> GetDegeneratedSliceDims(
const ScatterDimensionNumbers& dim_numbers) {
absl::Span<const int64_t> input_batching_dims =
dim_numbers.input_batching_dims();
absl::Span<const int64_t> inserted_window_dims =
dim_numbers.inserted_window_dims();
std::vector<int64_t> degenerated_dims;
degenerated_dims.reserve(inserted_window_dims.size() +
input_batching_dims.size());
absl::c_copy(inserted_window_dims, std::back_inserter(degenerated_dims));
absl::c_copy(input_batching_dims, std::back_inserter(degenerated_dims));
absl::c_sort(degenerated_dims);
return degenerated_dims;
}

// Body of the while loop that performs the scatter operation using other HLOs.
static absl::StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
HloScatterInstruction* scatter, HloInstruction* induction_var,
Expand Down Expand Up @@ -158,7 +141,12 @@ static absl::StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
TF_ASSIGN_OR_RETURN(
HloInstruction * scatter_slice_start,
ExpandIndexVectorIntoOperandSpace(
index_vector, dim_numbers, operands[0]->shape().dimensions_size()));
scatter->scatter_indices()->shape(),
operands[0]->shape().dimensions_size(),
dim_numbers.index_vector_dim(),
dim_numbers.scatter_dims_to_operand_dims(),
dim_numbers.scatter_indices_batching_dims(),
dim_numbers.input_batching_dims(), index_vector, induction_var));

// Extract the slice to be used to update from `updates` tensor for the
// induction_var corresponding to this iteration of the while loop.
Expand All @@ -179,6 +167,9 @@ static absl::StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
auto update_slices_with_dims_inserted =
absl::MakeSpan(map_operands).last(updates.size());
absl::Span<const int64_t> actual_update_slice_dims;

std::vector<int64_t> degenerated_dims = GetDegeneratedSliceDims(dim_numbers);

for (int i = 0, n = operands.size(); i < n; ++i) {
HloInstruction* update = updates[i];
TF_ASSIGN_OR_RETURN(
Expand All @@ -188,8 +179,7 @@ static absl::StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
ElideDegenerateDims(update_slice, {0}));
TF_ASSIGN_OR_RETURN(
HloInstruction * update_slice_with_dims_inserted,
InsertDegenerateDims(update_slice_for_scatter,
dim_numbers.inserted_window_dims()));
InsertDegenerateDims(update_slice_for_scatter, degenerated_dims));
update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted;
// Note that the following transformation assumes that both DynamicSlice and
// DynamicUpdateSlice follow the same semantics for OOB indices. For
Expand Down
83 changes: 82 additions & 1 deletion xla/service/scatter_expander_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ limitations under the License.
#include "xla/service/scatter_expander.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -65,7 +68,6 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) {
ParseAndReturnVerifiedModule(kModuleStr));

ClearInstructionLayout(module.get(), "operand");

ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters);
TF_ASSERT_OK_AND_ASSIGN(bool result,
RunHloPass(&scatter_expander, module.get()));
Expand Down Expand Up @@ -140,6 +142,85 @@ TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) {
EXPECT_FALSE(result);
}

TEST_F(ScatterExpanderTest, ScatterToLoopWithBatchDims) {
const char* kModuleStr = R"(
HloModule TensorFlowScatter
func {
x = s32[] parameter(0)
y = s32[] parameter(1)
ROOT s = s32[] add(x,y)
}
ENTRY main {
indices = s32[2,3,5]{2,1,0} parameter(0)
update = s32[2,3,2,5]{3,2,1,0} parameter(1)
z = s32[] constant(0)
input = s32[5,3,2,2]{3,2,1,0} broadcast(z), dimensions={}
ROOT s = s32[5,3,2,2]{3,2,1,0} scatter(input, indices, update),
update_window_dims={2},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=3,
input_batching_dims={0,3},
scatter_indices_batching_dims={2,0},
to_apply=func
})";

// Verify the code that indexes into the operand.
const std::string expected = R"(
//CHECK: (s32[], s32[5,3,2,2], s32[30], s32[30,2])) -> (s32[], s32[5,3,2,2], s32[30], s32[30,2]) {
//CHECK: %[[PARAM:.*]] = (s32[], s32[5,3,2,2], s32[30], s32[30,2]) parameter(0)
//CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=0
//CHECK: %[[CONSTANT1:.*]] = s32[] constant(1)
//CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]])
//CHECK: %[[OPERAND:.*]] = s32[5,3,2,2] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=1
//CHECK: %[[CONSTANT0:.*]] = s32[] constant(0)
//CHECK: %[[OPERAND_INDICES_LOWER_BOUND:.*]] = s32[4] broadcast(s32[] %[[CONSTANT0]])
//CHECK: %[[CONSTANT5:.*]] = s32[] constant(5)
//CHECK: %[[REMAINDER:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT5]])
//CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[REMAINDER]])
//CHECK: %[[START_INDICES:.*]] = s32[30] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=2
//CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]])
//CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]])
//CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]])
//CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[30] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]])
//CHECK: %[[SCATTER_INDEX:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]])
//CHECK: %[[CONSTANT0_2:.*]] = s32[1] constant({0})
//CHECK: %[[BD_0_1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT5]])
//CHECK: %[[CONSTANT3:.*]] = s32[] constant(3)
//CHECK: %[[BD0_RAW:.*]] = s32[] divide(s32[] %[[BD_0_1]], s32[] %[[CONSTANT3]])
//CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]])
//CHECK: %[[OPERAND_INDICES:.*]] = s32[4] concatenate(s32[1] %[[BD2]], s32[1] %[[SCATTER_INDEX]], s32[1] %[[CONSTANT0_2]], s32[1] %[[BD0]])
//CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[0:1]}
//CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]])
//CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[1:2]}
//CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]])
//CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[2:3]}
//CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]])
//CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[3:4]}
//CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]])
//CHECK: %{{.*}} = s32[1,1,2,1] dynamic-slice(s32[5,3,2,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]])
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kModuleStr));
ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters);
TF_ASSERT_OK_AND_ASSIGN(bool result,
RunHloPass(&scatter_expander, module.get()));
EXPECT_TRUE(result);

std::vector<HloInstruction*> while_instructions =
FindInstructions(module.get(), HloOpcode::kWhile);
EXPECT_EQ(while_instructions.size(), 1);
HloComputation* while_body = while_instructions[0]->while_body();
EXPECT_TRUE(
*RunFileCheck(while_body->ToString(
HloPrintOptions{}.set_include_layout_in_shapes(false)),
expected));
}

TEST_F(ScatterExpanderTest,
EliminateSimpleMultioutpuScattersSkipsNontrivialScatter) {
const char* kModuleStr = R"(
Expand Down

0 comments on commit d4fbb67

Please sign in to comment.