diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 725532e893b38..0883e5a0431ca 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1056,67 +1056,75 @@ ENTRY main { expect_custom_kernel_fusion_rewriter_has_run); } -struct PassRunIndex { - int first_run = std::numeric_limits::max(); - int last_run = std::numeric_limits::min(); -}; +class PassOrderTest : public GpuCompilerTest { + public: + void SetDebugOptions(const DebugOptions& options) { + HloModuleConfig config = GetModuleConfigForTest(); + config.set_debug_options(options); + CompileModule(config); + } -// Checks that both passes have actually run and that the first run of the -// `after` pass is after the last run of the `before` pass. -void VerifyPassOrder( - const absl::flat_hash_map& passes, - absl::string_view before, absl::string_view after) { - ASSERT_TRUE(passes.contains(before)) - << "Expected pass did not run: " << before; - ASSERT_TRUE(passes.contains(after)) << "Expected pass did not run: " << after; - EXPECT_LT(passes.at(before).last_run, passes.at(after).first_run) - << "Pass " << before << " ran after " << after; -} + // Fails if any of the passes with names matching the regular expression + // first_pass_regex run after any of the passes matching last_pass_regex or if + // none of the executed passes matches first_pass_regex or last_pass_regex. + void VerifyPassOrder(absl::string_view first_pass_regex, + absl::string_view last_pass_regex) { + if (!optimized_module_) { + CompileModule(GetModuleConfigForTest()); + } + int first_pass_latest_run = -1; + int last_pass_earliest_run = std::numeric_limits::max(); + int run_index = 0; + for (const HloPassMetadata& pass_metadata : + optimized_module_->metadata()->proto().pass_metadata()) { + if (RE2::FullMatch(pass_metadata.pass_name(), first_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches first_pass_regex." << std::endl; + first_pass_latest_run = std::max(first_pass_latest_run, run_index); + } + if (RE2::FullMatch(pass_metadata.pass_name(), last_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches last_pass_regex." << std::endl; + last_pass_earliest_run = std::min(last_pass_earliest_run, run_index); + } + ++run_index; + } -// Traverses the module's pass metadata and gathers, for each pass, its smallest -// and largest run index. If a pass p0's run index is smaller than another pass -// p1's run index, then p0 ran before p1. -absl::flat_hash_map GatherPassOrderInformation( - const HloModule& module) { - // Maps a pass name to its first and last index. - absl::flat_hash_map passes; - int run_index = 0; - for (const HloPassMetadata& pass_metadata : - module.metadata().proto().pass_metadata()) { - auto& pass = passes[pass_metadata.pass_name()]; - pass.first_run = std::min(pass.first_run, run_index); - pass.last_run = std::max(pass.last_run, run_index); - ++run_index; + EXPECT_GT(first_pass_latest_run, -1) + << "Did not run a pass matching " << first_pass_regex; + EXPECT_LT(last_pass_earliest_run, std::numeric_limits::max()) + << "Did not run a pass matching " << last_pass_regex; + EXPECT_LE(first_pass_latest_run, last_pass_earliest_run) + << "One or more passes matching " << first_pass_regex + << " ran after passes matching " << last_pass_regex; } - return passes; -} - -TEST_F(GpuCompilerPassTest, PassesAreRunInCorrectOrder) { - constexpr absl::string_view constant_module = R"( + private: + void CompileModule(const HloModuleConfig& config) { + constexpr absl::string_view constant_module = R"( ENTRY main { ROOT constant = f32[] constant(0) })"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(constant_module, config)); + TF_ASSERT_OK_AND_ASSIGN(optimized_module_, + GetOptimizedModule(std::move(module))); + } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(constant_module)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(std::move(module))); + std::unique_ptr optimized_module_; +}; - // Maps a pass name to its first and last index. - absl::flat_hash_map passes = - GatherPassOrderInformation(*optimized_module); - - // This test captures known dependencies between passes. - VerifyPassOrder(passes, /*before=*/"layout-assignment", - /*after=*/"priority-fusion"); - VerifyPassOrder(passes, /*before=*/"layout-assignment", - /*after=*/"layout_normalization"); - VerifyPassOrder(passes, /*before=*/"host-offload-legalize", - /*after=*/"layout_normalization"); +TEST_F(PassOrderTest, PassesAreRunInCorrectOrder) { + VerifyPassOrder(/*first_pass_regex=*/"layout-assignment", + /*last_pass_regex=*/"priority-fusion"); + VerifyPassOrder(/*first_pass_regex=*/"layout-assignment", + /*last_pass_regex=*/"layout_normalization"); + VerifyPassOrder(/*first_pass_regex=*/"host-offload-legalize", + /*last_pass_regex=*/"layout_normalization"); } -TEST_F(GpuCompilerPassTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) { +TEST_F(PassOrderTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) { auto cc = backend() .default_stream_executor() ->GetDeviceDescription() @@ -1125,38 +1133,22 @@ TEST_F(GpuCompilerPassTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) { GTEST_SKIP() << "FusionBlockLevelRewriter requires Ampere+ to run."; } - constexpr absl::string_view constant_module = R"( -ENTRY main { - ROOT constant = f32[] constant(0) -})"; - - HloModuleConfig config; DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_experimental_enable_fusion_block_level_rewriter( true); - config.set_debug_options(debug_options); + SetDebugOptions(debug_options); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnVerifiedModule(constant_module, config)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(std::move(module))); - - absl::flat_hash_map passes = - GatherPassOrderInformation(*optimized_module); + VerifyPassOrder(/*first_pass_regex=*/".*fusion.*", + /*last_pass_regex=*/"fusion-block-level-rewriter"); +} - absl::string_view kFusionBlockLevelRewriterName = - "fusion-block-level-rewriter"; +TEST_F(PassOrderTest, CollectivePipelinerRunsAfterCollectiveQuantizer) { + DebugOptions options = GetDebugOptionsForTest(); + options.set_xla_gpu_enable_pipelined_collectives(true); + SetDebugOptions(options); - for (const auto& [pass_name, _] : passes) { - if (pass_name != kFusionBlockLevelRewriterName && - absl::StrContains(pass_name, "fusion")) { - VerifyPassOrder(passes, /*before=*/pass_name, - /*after=*/kFusionBlockLevelRewriterName); - VLOG(2) << "Verified pass order: " << pass_name << " -> " - << kFusionBlockLevelRewriterName; - } - } + VerifyPassOrder(/*first_pass_regex=*/"collective-quantizer", + /*last_pass_regex=*/"collective-pipeliner.*"); } } // namespace