Skip to content

Commit

Permalink
[xla:WhileLoopUnroller] Fix MatchShapeCoveringDynamicIndexInstruction.
Browse files Browse the repository at this point in the history
For DynamicSlice, ensure that the slice shape is of size 1 on the dimension
where the induction variable is used to index into the input and matches the
input shapes on all other dimensions.

Add a test.

PiperOrigin-RevId: 688706740
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Oct 22, 2024
1 parent fd5f05f commit 3fc7b97
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
24 changes: 21 additions & 3 deletions xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,31 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
return std::nullopt;
}

// The shape's broadcast_dim must be exactly equal to the loop trip count.
if (operand->shape().dimensions(dynamic_index) != config.trip_count) {
VLOG(3) << "The shape's broadcast_dim must be exactly equal to the loop "
"trip count.";
VLOG(3) << "The dynamic_index dimension size of the operand must be equal "
"to the loop trip count.";
return std::nullopt;
}

if (opcode == HloOpcode::kDynamicSlice) {
const Shape& result_shape = instr->shape();
if (result_shape.dimensions(dynamic_index) != 1) {
VLOG(3) << "The slice size on the dynamic_index dimension must be 1.";
return std::nullopt;
}

const Shape& operand_shape = operand->shape();
CHECK_EQ(result_shape.dimensions_size(), operand_shape.dimensions_size());
for (int64_t i = 0; i < result_shape.dimensions_size(); ++i) {
if (i != dynamic_index &&
result_shape.dimensions(i) != operand_shape.dimensions(i)) {
VLOG(3) << "The slice sizes must match the operand-shape on "
"non-dynamic-index dimensions.";
return std::nullopt;
}
}
}

return dynamic_index;
}

Expand Down
5 changes: 4 additions & 1 deletion xla/service/while_loop_unroller.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ struct UnrollResult {
// 1. All start indices must be constant zero except only a single dimension.
// 2. The start index of that dimension should be equal to the enclosing loop
// induction variable.
// 3. And, the size of that dimension must match the loop trip count.
// 3. The size of that dimension must match the loop trip count.
// 4. For dynamic-slice, the slice size for the induction variable dimension is
// 1, and the size of all other dimensions is the same as the shape of the
// input.
// If so, it returns the dynamic index.
std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode,
Expand Down
59 changes: 51 additions & 8 deletions xla/service/while_loop_unroller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ WhileLoopUnrollerTest::MakeModuleWithNestedLoopBodyIndirectInc(int num_iters) {
constant.3 = s32[] constant(0)
tuple.1 = (s32[], s32[], s32[3]{0}) tuple(constant.3, constant.1, get-tuple-element.22)
inner-while = (s32[], s32[], s32[3]{0}) while(tuple.1), condition=
SimpleLoop.condition, body=SimpleLoop.body
SimpleLoop.condition, body=SimpleLoop.body
get-tuple-element.6 = s32[3]{0} get-tuple-element(inner-while), index=2
inc = s32[] add(get-tuple-element.1, get-tuple-element.2)
ROOT tuple = (s32[], s32[], s32[3]{0}, s32[10]{0}) tuple(inc, get-tuple-element.2, get-tuple-element.6, output)
Expand Down Expand Up @@ -269,22 +269,22 @@ std::unique_ptr<VerifiedHloModule>
WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) {
std::string hlo_string_template = R"(
HloModule SimpleLoop
%reduction {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
SimpleLoop.body {
loop_var.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
get-tuple-element.2 = f32[1024, 1024] get-tuple-element(loop_var.1), index=1
get-tuple-element.3 = f32[1024, 1024] get-tuple-element(loop_var.1), index=2
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] get-tuple-element.2), channel_id=1, replica_groups={{0}}, to_apply=%reduction
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] get-tuple-element.3)
constant.1 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.1)
ROOT tuple = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(add, get-tuple-element.2, %accumulation)
Expand All @@ -298,10 +298,10 @@ WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) {
ENTRY SimpleLoop {
%param.1 = f32[1024, 1024] parameter(0)
constant.3 = s32[] constant(0)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
tuple.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(constant.3, %param.1, %accumulation_buffer)
ROOT while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(tuple.1), condition=SimpleLoop.condition, body=SimpleLoop.body
}
Expand Down Expand Up @@ -987,6 +987,49 @@ TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDS) {
.has_value());
}

TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSShapeMismatch) {
const std::string hlo_string = R"(
HloModule SimpleLoop
body {
param = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) parameter(0)
idx = s32[]{:T(128)} get-tuple-element(param), index=0
constant1 = s32[]{:T(128)} constant(1)
new-idx = s32[]{:T(128)} add(idx, constant1)
update = s32[3,10]{1,0} get-tuple-element(param), index=1
input = s32[3,11]{1,0} get-tuple-element(param), index=2
zero = s32[] constant(0)
slice = s32[1,10] dynamic-slice(input, idx, zero), dynamic_slice_sizes={1,10}
new-update = s32[3,10]{1,0} dynamic-update-slice(update, slice, idx, zero)
ROOT tuple = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) tuple(new-idx, new-update, input)
}
condition {
param = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) parameter(0)
idx = s32[] get-tuple-element(param), index=0
constant3 = s32[]{:T(128)} constant(3)
ROOT less-than = pred[] compare(idx, constant3), direction=LT
}
ENTRY main {
constant0 = s32[]{:T(128)} constant(0)
init-update = s32[3,10]{1,0} constant({...})
init-input = s32[3,11]{1,0} constant({...})
init-while = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) tuple(constant0, init-update, init-input)
ROOT while = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) while(init-while), condition=
condition, body=body
}
)";

auto module = ParseAndReturnVerifiedModule(hlo_string).value();
HloInstruction* loop = module->entry_computation()->root_instruction();
auto config = WhileLoopUnroller::IsLoopUnrollable(loop);
EXPECT_TRUE(config.has_value());
HloComputation* body = module->GetComputationWithName("body");
HloInstruction* input = body->GetInstructionWithName("input");
HloInstruction* instr = body->GetInstructionWithName("slice");
EXPECT_FALSE(MatchShapeCoveringDynamicIndexInstruction(
instr, input, HloOpcode::kDynamicSlice, config.value())
.has_value());
}

TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSNested) {
std::string hlo_string_template = R"(
HloModule SimpleLoop
Expand Down Expand Up @@ -1127,7 +1170,7 @@ TEST_F(WhileLoopUnrollerTest, IsEffectivelyStaticDynamicSlice) {
%dynamic-slice.static = s8[1,128,128] dynamic-slice(s8[6,128,128] %param_0.51117, static.p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128}
ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.static)
}
%fused_computation.slice.2 (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] {
%param_0.51117 = s8[6,128,128] parameter(0)
dynamic.p1 = s32[] parameter(1)
Expand Down

0 comments on commit 3fc7b97

Please sign in to comment.