From c1795e7aad14c7a8f246c2015cfe87587f9dc119 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 10 Oct 2024 16:38:00 -0700 Subject: [PATCH] Add option to `CallInliner` to preserve composites. This is useful for preserving composite ops that hardwares can support. PiperOrigin-RevId: 684619347 --- xla/service/BUILD | 1 - xla/service/call_inliner.cc | 11 ++++++++++- xla/service/call_inliner.h | 12 +++++++++--- xla/service/call_inliner_test.cc | 29 +++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 644604d5b062cc..d98dec548c8566 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1240,7 +1240,6 @@ xla_cc_test( srcs = ["call_inliner_test.cc"], deps = [ ":call_inliner", - "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:test", diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index 6c5550aa34014b..9732be8286c2e4 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -152,6 +152,14 @@ bool InlineUnderShardy(HloInstruction* instruction) { sdy::kManualComputationBodyFuncName.str()))); } +bool InlineComposites( + HloInstruction* instruction, + const absl::flat_hash_set& composites_to_preserve) { + return !instruction->is_composite() || + !composites_to_preserve.contains( + instruction->frontend_attributes().map().at("composite.name")); +} + } // namespace /* static */ absl::StatusOr @@ -204,7 +212,8 @@ bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return instruction->opcode() == HloOpcode::kCall && !instruction->has_backend_config() && !instruction->parent()->IsAsyncComputation() && - InlineUnderShardy(instruction); + InlineUnderShardy(instruction) && + InlineComposites(instruction, composites_to_preserve_); } absl::StatusOr CallInliner::Run( diff --git a/xla/service/call_inliner.h b/xla/service/call_inliner.h index 7fd584ad5eeba8..5f9a20b92e8e1c 100644 --- a/xla/service/call_inliner.h +++ b/xla/service/call_inliner.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_CALL_INLINER_H_ #define XLA_SERVICE_CALL_INLINER_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -41,9 +43,12 @@ class CallInliner : public HloModulePass { // inlined. // If update_domain is true, the exit domains could be updated for calls which // are being inlined if necessary. - explicit CallInliner(bool single_call_site = false, - bool update_domain = false) - : single_call_site_(single_call_site), update_domain_(update_domain) {} + explicit CallInliner( + bool single_call_site = false, bool update_domain = false, + absl::flat_hash_set composites_to_preserve = {}) + : single_call_site_(single_call_site), + update_domain_(update_domain), + composites_to_preserve_(composites_to_preserve) {} ~CallInliner() override = default; absl::string_view name() const override { return "call-inliner"; } @@ -59,6 +64,7 @@ class CallInliner : public HloModulePass { private: bool single_call_site_; bool update_domain_; + absl::flat_hash_set composites_to_preserve_; }; } // namespace xla diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index d56dd643de910a..b41606d1a93e75 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -376,6 +376,35 @@ TEST_F(CallInlinerTest, InlineCompositeCall) { EXPECT_TRUE((*inst)->frontend_attributes().map().empty()); } +TEST_F(CallInlinerTest, PreserveCompositeCall) { + const absl::string_view hlo_string = R"( + HloModule composite + + %add (lhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] constant(2) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) + } + + ENTRY %main () -> f32[] { + %lhs = f32[] constant(42) + ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + })"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + CallInliner call_inliner( + /*single_call_site=*/true, /*update_domain=*/false, + /*composites_to_preserve=*/{"foo.bar"}); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_FALSE(mutated); + + auto inst = module->entry_computation()->instructions().begin(); + EXPECT_THAT(*inst, op::Constant()); + ++inst; + EXPECT_THAT(*inst, op::Call()); + EXPECT_FALSE((*inst)->frontend_attributes().map().empty()); +} + TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) { const char* const hloString = R"( HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}