From a6eeef2f0361bcad287b473c8f109d3a9fc6cc68 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 24 Oct 2024 07:41:57 -0700 Subject: [PATCH] [XLA] Fix a miscompilation during loop constant sinking if the same body computation was shared by multiple while loops. The current code assumes that a while loop's body is not shared with any other while loop. It is incorrect to sink a constant from one while loop into a body shared with multiple loops (or indeed, with other non-while instructions). It is hard to construct such a while loop with the current XlaBuilder API, which is why we have not noticed this bug until now, but it is certainly possible to do so. Fix the problem by cloning the body/conditional computations if they need to be modified. To correctly traverse the module in the presence of cloning, also fix a long-standing TODO and traverse the module recursively from the entry computation to the leaves. PiperOrigin-RevId: 689376769 --- xla/service/BUILD | 11 +- xla/service/while_loop_constant_sinking.cc | 123 ++++++++++++------ xla/service/while_loop_constant_sinking.h | 2 +- .../while_loop_constant_sinking_test.cc | 76 +++++++++-- 4 files changed, 160 insertions(+), 52 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index cf02b14f2080c..c1198fc3571af 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4979,8 +4979,15 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -4989,11 +4996,13 @@ xla_cc_test( srcs = ["while_loop_constant_sinking_test.cc"], deps = [ ":while_loop_constant_sinking", + "//xla:literal_util", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/while_loop_constant_sinking.cc b/xla/service/while_loop_constant_sinking.cc index 83bd7f056ae6a..49dfe2a5f7e4d 100644 --- a/xla/service/while_loop_constant_sinking.cc +++ b/xla/service/while_loop_constant_sinking.cc @@ -15,11 +15,26 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" +#include +#include +#include +#include + #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/while_util.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -65,7 +80,7 @@ HloInstruction* CloneHelper(const HloInstruction* instruction, } // namespace absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( - HloInstruction* while_instr) { + HloModule* module, HloInstruction* while_instr) { HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); @@ -74,14 +89,16 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( return false; } - bool changed = false; - absl::flat_hash_map> conditional_gte_index_to_insts = WhileUtil::GetGTEsMapForWhileConditional(*while_cond); std::vector invariant_body_gtes = WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + HloCloneContext body_clone_context(module); + HloCloneContext cond_clone_context(module); + HloComputation* body_clone = nullptr; + HloComputation* cond_clone = nullptr; for (HloInstruction* invariant_body_gte : invariant_body_gtes) { int64_t index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); @@ -103,12 +120,18 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( // Sink into the while_body. // Should have at least one user that's not while_body_root. if (invariant_body_gte->user_count() > 1) { + if (!body_clone) { + body_clone = module->AddEmbeddedComputation( + while_body->Clone("sunk", &body_clone_context)); + while_instr->set_while_body(body_clone); + } HloInstruction* constant_instr = - CloneHelper(&invariant_value, while_body); + CloneHelper(&invariant_value, body_clone); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_body_gte, constant_instr, while_body->root_instruction(), + body_clone_context.FindInstruction(invariant_body_gte), + constant_instr, + body_clone_context.FindInstruction(while_body->root_instruction()), index)); - changed = true; } // Check if there is a corresponding GTE in while_conditional. @@ -120,16 +143,22 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( for (HloInstruction* invariant_cond_gte : it->second) { // Should have at least one user. if (invariant_cond_gte->user_count() > 0) { + if (!cond_clone) { + cond_clone = module->AddEmbeddedComputation( + while_cond->Clone("sunk", &cond_clone_context)); + while_instr->set_while_condition(cond_clone); + } HloInstruction* constant_instr = - CloneHelper(&invariant_value, while_cond); - TF_RETURN_IF_ERROR( - invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); - changed = true; + CloneHelper(&invariant_value, cond_clone); + HloInstruction* cond_gte = + cond_clone_context.FindInstruction(invariant_cond_gte); + TF_RETURN_IF_ERROR(cond_gte->ReplaceAllUsesWith(constant_instr)); + TF_RETURN_IF_ERROR(cond_clone->RemoveInstruction(cond_gte)); } } } - return changed; + return body_clone || cond_clone; } absl::StatusOr WhileLoopConstantSinking::Run( @@ -140,37 +169,51 @@ absl::StatusOr WhileLoopConstantSinking::Run( bool changed = false; std::vector while_instrs; - for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { - // Right now we don't particularly care about optimizing while-of-while - // patterns. If/When we do, we'll want to visit the outer while (while_0) - // before we visit the inner while (while_1): - // - // while_1_body(state) { - // val = gte(state, 0) // Loop invariant - // use(val) - // } - // - // while_0_body(state) { - // val = gte(state, 0) // Loop invariant - // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) - // ... - // } - // - // main { - // while_0 = while(init=(constant, ...), body=while_0_body, ...) - // } - // - // This will let us sink the constant into the outer while first and then - // into the inner while in a single run of this pass. - absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - HloPredicateIsOp); - } - for (HloInstruction* while_instr : while_instrs) { - TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileLoop(while_instr)); - changed |= result; + // Visit computations in order, from outermost to innermost. + // We want to visit the outer while (while_0) before we visit the inner + // while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + std::stack agenda; + agenda.push(module->entry_computation()); + absl::flat_hash_set visited; + while (!agenda.empty()) { + HloComputation* comp = agenda.top(); + agenda.pop(); + if (!visited.insert(comp).second) { + continue; + } + for (auto* instr : comp->instructions()) { + // Sinking constants may change the called computations, so do that first + // if this is a while instruction. + if (instr->opcode() == HloOpcode::kWhile) { + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileLoop(module, instr)); + changed |= result; + } + for (HloComputation* child : instr->called_computations()) { + agenda.push(child); + } + } } + TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); if (changed) { VLOG(2) << "HLO module after WhileLoopConstantSinking:"; diff --git a/xla/service/while_loop_constant_sinking.h b/xla/service/while_loop_constant_sinking.h index 8d1402ff72d29..1ea8e4db0f1b1 100644 --- a/xla/service/while_loop_constant_sinking.h +++ b/xla/service/while_loop_constant_sinking.h @@ -66,7 +66,7 @@ class WhileLoopConstantSinking : public HloModulePass { private: absl::StatusOr TrySinkingConstantsIntoWhileLoop( - HloInstruction* while_instr); + HloModule* module, HloInstruction* while_instr); const bool sink_broadcast_of_constants_; const bool sink_only_scalar_constants_; diff --git a/xla/service/while_loop_constant_sinking_test.cc b/xla/service/while_loop_constant_sinking_test.cc index 3597686e9b9cc..2cfd69a9254e8 100644 --- a/xla/service/while_loop_constant_sinking_test.cc +++ b/xla/service/while_loop_constant_sinking_test.cc @@ -15,9 +15,13 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -68,7 +72,7 @@ ENTRY entry { .Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(_, op::Constant()), _)); } @@ -115,7 +119,7 @@ ENTRY entry { .Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(_, op::Broadcast(op::Constant())), _)); } @@ -155,7 +159,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(op::Constant(), op::Constant()), op::GetTupleElement(op::Parameter(0)), @@ -196,7 +200,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::GetTupleElement(op::Constant(), 0), op::GetTupleElement(op::Parameter(0)))); @@ -244,7 +248,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), op::GetTupleElement(op::Parameter(0)), @@ -286,7 +290,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::GetTupleElement(), op::GetTupleElement(), op::GetTupleElement())); @@ -332,7 +336,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); } @@ -372,7 +376,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::GetTupleElement(op::Constant()))); } @@ -415,7 +419,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); for (const HloInstruction* inst : while_condition->instructions()) { if (inst->opcode() == HloOpcode::kConstant) { @@ -465,9 +469,61 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); } + +TEST_F(WhileLoopConstantSinkingTest, SinkWithSharedBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + while = (f32[2],f32[2]) while(while_init), condition=condition, body=body + while_init2 = (f32[2],f32[2]) tuple(const_1, const_0) + while2 = (f32[2],f32[2]) while(while_init2), condition=condition, body=body + ROOT tuple = ((f32[2],f32[2]),(f32[2],f32[2])) tuple(while, while2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + WhileLoopConstantSinking(/*sink_broadcast_of_constants=*/false, + /*sink_only_scalar_constants=*/false) + .Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body.sunk"); + EXPECT_THAT( + while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1({2, 1}))), + _)); + while_body = module->GetComputationWithName("body.sunk.1"); + EXPECT_THAT( + while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1({1, 2}))), + _)); +} + } // namespace } // namespace xla