Skip to content

Commit

Permalink
Add option to CallInliner to preserve composites.
Browse files Browse the repository at this point in the history
This is useful for preserving composite ops that hardwares can support.

PiperOrigin-RevId: 684619347
  • Loading branch information
ghpvnist authored and Google-ML-Automation committed Oct 10, 2024
1 parent b9fa72e commit c1795e7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
1 change: 0 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,6 @@ xla_cc_test(
srcs = ["call_inliner_test.cc"],
deps = [
":call_inliner",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:test",
Expand Down
11 changes: 10 additions & 1 deletion xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ bool InlineUnderShardy(HloInstruction* instruction) {
sdy::kManualComputationBodyFuncName.str())));
}

bool InlineComposites(
HloInstruction* instruction,
const absl::flat_hash_set<std::string>& composites_to_preserve) {
return !instruction->is_composite() ||
!composites_to_preserve.contains(
instruction->frontend_attributes().map().at("composite.name"));
}

} // namespace

/* static */ absl::StatusOr<CallInliner::InlinedInstructionMap>
Expand Down Expand Up @@ -204,7 +212,8 @@ bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation() &&
InlineUnderShardy(instruction);
InlineUnderShardy(instruction) &&
InlineComposites(instruction, composites_to_preserve_);
}

absl::StatusOr<bool> CallInliner::Run(
Expand Down
12 changes: 9 additions & 3 deletions xla/service/call_inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef XLA_SERVICE_CALL_INLINER_H_
#define XLA_SERVICE_CALL_INLINER_H_

#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
Expand All @@ -41,9 +43,12 @@ class CallInliner : public HloModulePass {
// inlined.
// If update_domain is true, the exit domains could be updated for calls which
// are being inlined if necessary.
explicit CallInliner(bool single_call_site = false,
bool update_domain = false)
: single_call_site_(single_call_site), update_domain_(update_domain) {}
explicit CallInliner(
bool single_call_site = false, bool update_domain = false,
absl::flat_hash_set<std::string> composites_to_preserve = {})
: single_call_site_(single_call_site),
update_domain_(update_domain),
composites_to_preserve_(composites_to_preserve) {}
~CallInliner() override = default;
absl::string_view name() const override { return "call-inliner"; }

Expand All @@ -59,6 +64,7 @@ class CallInliner : public HloModulePass {
private:
bool single_call_site_;
bool update_domain_;
absl::flat_hash_set<std::string> composites_to_preserve_;
};

} // namespace xla
Expand Down
29 changes: 29 additions & 0 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,35 @@ TEST_F(CallInlinerTest, InlineCompositeCall) {
EXPECT_TRUE((*inst)->frontend_attributes().map().empty());
}

TEST_F(CallInlinerTest, PreserveCompositeCall) {
const absl::string_view hlo_string = R"(
HloModule composite
%add (lhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] constant(2)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %main () -> f32[] {
%lhs = f32[] constant(42)
ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},composite.name="foo.bar",composite.version="1"}
})";

auto module = ParseAndReturnVerifiedModule(hlo_string).value();
CallInliner call_inliner(
/*single_call_site=*/true, /*update_domain=*/false,
/*composites_to_preserve=*/{"foo.bar"});
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
ASSERT_FALSE(mutated);

auto inst = module->entry_computation()->instructions().begin();
EXPECT_THAT(*inst, op::Constant());
++inst;
EXPECT_THAT(*inst, op::Call());
EXPECT_FALSE((*inst)->frontend_attributes().map().empty());
}

TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
Expand Down

0 comments on commit c1795e7

Please sign in to comment.