Skip to content

Commit

Permalink
[XLA] Fix a miscompilation during loop constant sinking if the same b…
Browse files Browse the repository at this point in the history
…ody computation was shared by multiple while loops.

The current code assumes that a while loop's body is not shared with any other while loop. It is incorrect to sink a constant from one while loop into a body shared with multiple loops (or indeed, with other non-while instructions).

It is hard to construct such a while loop with the current XlaBuilder API, which is why we have not noticed this bug until now, but it is certainly possible to do so.

Fix the problem by cloning the body/conditional computations if they need to be modified. To correctly traverse the module in the presence of cloning, also fix a long-standing TODO and traverse the module recursively from the entry computation to the leaves.

PiperOrigin-RevId: 689376769
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 24, 2024
1 parent 086726f commit a6eeef2
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 52 deletions.
11 changes: 10 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4979,8 +4979,15 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -4989,11 +4996,13 @@ xla_cc_test(
srcs = ["while_loop_constant_sinking_test.cc"],
deps = [
":while_loop_constant_sinking",
"//xla:literal_util",
"//xla:test",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
123 changes: 83 additions & 40 deletions xla/service/while_loop_constant_sinking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@ limitations under the License.

#include "xla/service/while_loop_constant_sinking.h"

#include <cstdint>
#include <iterator>
#include <stack>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/while_util.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -65,7 +80,7 @@ HloInstruction* CloneHelper(const HloInstruction* instruction,
} // namespace

absl::StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
HloInstruction* while_instr) {
HloModule* module, HloInstruction* while_instr) {
HloComputation* while_cond = while_instr->while_condition();
HloComputation* while_body = while_instr->while_body();

Expand All @@ -74,14 +89,16 @@ absl::StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
return false;
}

bool changed = false;

absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
conditional_gte_index_to_insts =
WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
std::vector<HloInstruction*> invariant_body_gtes =
WhileUtil::GetInvariantGTEsForWhileBody(*while_body);

HloCloneContext body_clone_context(module);
HloCloneContext cond_clone_context(module);
HloComputation* body_clone = nullptr;
HloComputation* cond_clone = nullptr;
for (HloInstruction* invariant_body_gte : invariant_body_gtes) {
int64_t index = invariant_body_gte->tuple_index();
const HloInstruction& invariant_value = *init_value.operand(index);
Expand All @@ -103,12 +120,18 @@ absl::StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
// Sink into the while_body.
// Should have at least one user that's not while_body_root.
if (invariant_body_gte->user_count() > 1) {
if (!body_clone) {
body_clone = module->AddEmbeddedComputation(
while_body->Clone("sunk", &body_clone_context));
while_instr->set_while_body(body_clone);
}
HloInstruction* constant_instr =
CloneHelper(&invariant_value, while_body);
CloneHelper(&invariant_value, body_clone);
TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance(
invariant_body_gte, constant_instr, while_body->root_instruction(),
body_clone_context.FindInstruction(invariant_body_gte),
constant_instr,
body_clone_context.FindInstruction(while_body->root_instruction()),
index));
changed = true;
}

// Check if there is a corresponding GTE in while_conditional.
Expand All @@ -120,16 +143,22 @@ absl::StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
for (HloInstruction* invariant_cond_gte : it->second) {
// Should have at least one user.
if (invariant_cond_gte->user_count() > 0) {
if (!cond_clone) {
cond_clone = module->AddEmbeddedComputation(
while_cond->Clone("sunk", &cond_clone_context));
while_instr->set_while_condition(cond_clone);
}
HloInstruction* constant_instr =
CloneHelper(&invariant_value, while_cond);
TF_RETURN_IF_ERROR(
invariant_cond_gte->ReplaceAllUsesWith(constant_instr));
changed = true;
CloneHelper(&invariant_value, cond_clone);
HloInstruction* cond_gte =
cond_clone_context.FindInstruction(invariant_cond_gte);
TF_RETURN_IF_ERROR(cond_gte->ReplaceAllUsesWith(constant_instr));
TF_RETURN_IF_ERROR(cond_clone->RemoveInstruction(cond_gte));
}
}
}

return changed;
return body_clone || cond_clone;
}

absl::StatusOr<bool> WhileLoopConstantSinking::Run(
Expand All @@ -140,37 +169,51 @@ absl::StatusOr<bool> WhileLoopConstantSinking::Run(

bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
// Right now we don't particularly care about optimizing while-of-while
// patterns. If/When we do, we'll want to visit the outer while (while_0)
// before we visit the inner while (while_1):
//
// while_1_body(state) {
// val = gte(state, 0) // Loop invariant
// use(val)
// }
//
// while_0_body(state) {
// val = gte(state, 0) // Loop invariant
// while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
// ...
// }
//
// main {
// while_0 = while(init=(constant, ...), body=while_0_body, ...)
// }
//
// This will let us sink the constant into the outer while first and then
// into the inner while in a single run of this pass.
absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
HloPredicateIsOp<HloOpcode::kWhile>);
}

for (HloInstruction* while_instr : while_instrs) {
TF_ASSIGN_OR_RETURN(bool result,
TrySinkingConstantsIntoWhileLoop(while_instr));
changed |= result;
// Visit computations in order, from outermost to innermost.
// We want to visit the outer while (while_0) before we visit the inner
// while (while_1):
//
// while_1_body(state) {
// val = gte(state, 0) // Loop invariant
// use(val)
// }
//
// while_0_body(state) {
// val = gte(state, 0) // Loop invariant
// while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
// ...
// }
//
// main {
// while_0 = while(init=(constant, ...), body=while_0_body, ...)
// }
//
// This will let us sink the constant into the outer while first and then
// into the inner while in a single run of this pass.
std::stack<HloComputation*> agenda;
agenda.push(module->entry_computation());
absl::flat_hash_set<HloComputation*> visited;
while (!agenda.empty()) {
HloComputation* comp = agenda.top();
agenda.pop();
if (!visited.insert(comp).second) {
continue;
}
for (auto* instr : comp->instructions()) {
// Sinking constants may change the called computations, so do that first
// if this is a while instruction.
if (instr->opcode() == HloOpcode::kWhile) {
TF_ASSIGN_OR_RETURN(bool result,
TrySinkingConstantsIntoWhileLoop(module, instr));
changed |= result;
}
for (HloComputation* child : instr->called_computations()) {
agenda.push(child);
}
}
}
TF_RETURN_IF_ERROR(module->RemoveUnusedComputations());

if (changed) {
VLOG(2) << "HLO module after WhileLoopConstantSinking:";
Expand Down
2 changes: 1 addition & 1 deletion xla/service/while_loop_constant_sinking.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class WhileLoopConstantSinking : public HloModulePass {

private:
absl::StatusOr<bool> TrySinkingConstantsIntoWhileLoop(
HloInstruction* while_instr);
HloModule* module, HloInstruction* while_instr);

const bool sink_broadcast_of_constants_;
const bool sink_only_scalar_constants_;
Expand Down
76 changes: 66 additions & 10 deletions xla/service/while_loop_constant_sinking_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ limitations under the License.

#include "xla/service/while_loop_constant_sinking.h"

#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal_util.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -68,7 +72,7 @@ ENTRY entry {
.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::Add(_, op::Constant()), _));
}
Expand Down Expand Up @@ -115,7 +119,7 @@ ENTRY entry {
.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::Add(_, op::Broadcast(op::Constant())), _));
}
Expand Down Expand Up @@ -155,7 +159,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::Add(op::Constant(), op::Constant()),
op::GetTupleElement(op::Parameter(0)),
Expand Down Expand Up @@ -196,7 +200,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::GetTupleElement(op::Constant(), 0),
op::GetTupleElement(op::Parameter(0))));
Expand Down Expand Up @@ -244,7 +248,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())),
op::GetTupleElement(op::Parameter(0)),
Expand Down Expand Up @@ -286,7 +290,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body");
auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
op::GetTupleElement()));
Expand Down Expand Up @@ -332,7 +336,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_condition = module->GetComputationWithName("condition");
auto* while_condition = module->GetComputationWithName("condition.sunk");
EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
}

Expand Down Expand Up @@ -372,7 +376,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_condition = module->GetComputationWithName("condition");
auto* while_condition = module->GetComputationWithName("condition.sunk");
EXPECT_THAT(while_condition->root_instruction(),
op::Lt(_, op::GetTupleElement(op::Constant())));
}
Expand Down Expand Up @@ -415,7 +419,7 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_condition = module->GetComputationWithName("condition");
auto* while_condition = module->GetComputationWithName("condition.sunk");
EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant()));
for (const HloInstruction* inst : while_condition->instructions()) {
if (inst->opcode() == HloOpcode::kConstant) {
Expand Down Expand Up @@ -465,9 +469,61 @@ ENTRY entry {
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_condition = module->GetComputationWithName("condition");
auto* while_condition = module->GetComputationWithName("condition.sunk");
EXPECT_THAT(while_condition->root_instruction(),
op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant())));
}

TEST_F(WhileLoopConstantSinkingTest, SinkWithSharedBody) {
const char* const hlo_string = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2],f32[2]) parameter(0)
p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
add.0 = f32[2] add(p_body.0, p_body.1)
ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1)
}
condition {
p_cond = (f32[2],f32[2]) parameter(0)
ROOT result = pred[] constant(true)
}
ENTRY entry {
const_0 = f32[2] constant({1, 2})
const_1 = f32[2] constant({2, 1})
while_init = (f32[2],f32[2]) tuple(const_0, const_1)
while = (f32[2],f32[2]) while(while_init), condition=condition, body=body
while_init2 = (f32[2],f32[2]) tuple(const_1, const_0)
while2 = (f32[2],f32[2]) while(while_init2), condition=condition, body=body
ROOT tuple = ((f32[2],f32[2]),(f32[2],f32[2])) tuple(while, while2)
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

TF_ASSERT_OK_AND_ASSIGN(
bool changed,
WhileLoopConstantSinking(/*sink_broadcast_of_constants=*/false,
/*sink_only_scalar_constants=*/false)
.Run(module.get()));
ASSERT_TRUE(changed);

auto* while_body = module->GetComputationWithName("body.sunk");
EXPECT_THAT(
while_body->root_instruction(),
op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1<float>({2, 1}))),
_));
while_body = module->GetComputationWithName("body.sunk.1");
EXPECT_THAT(
while_body->root_instruction(),
op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1<float>({1, 2}))),
_));
}

} // namespace
} // namespace xla

0 comments on commit a6eeef2

Please sign in to comment.