Skip to content

Commit

Permalink
PR #18065: Pass Ordering Test for GPU Compiler
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18065

Adds a class for testing the order of passes in the GPU compiler. The names of the passes expected to run first and last can be described by regular expressions.

Also adds a test for verifying the order of the collective quantizer and collective pipeliner passes.
Copybara import of the project:

--
2e3624d by Philipp Hack <phack@nvidia.com>:

Adds a class for testing the order of passes in the GPU compiler.

--
3243bf2 by Philipp Hack <phack@nvidia.com>:

Adds a class for testing the order of passes in the GPU compiler.

--
1b13ef0 by Philipp Hack <phack@nvidia.com>:

Adds a class for testing the order of passes in the GPU compiler.

Merging this change closes #18065

COPYBARA_INTEGRATE_REVIEW=#18065 from philipphack:u_ordering_test_xla 1b13ef0
PiperOrigin-RevId: 684470066
  • Loading branch information
philipphack authored and Google-ML-Automation committed Oct 10, 2024
1 parent 477eaba commit e1be920
Showing 1 changed file with 68 additions and 76 deletions.
144 changes: 68 additions & 76 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1056,67 +1056,75 @@ ENTRY main {
expect_custom_kernel_fusion_rewriter_has_run);
}

struct PassRunIndex {
int first_run = std::numeric_limits<int>::max();
int last_run = std::numeric_limits<int>::min();
};
class PassOrderTest : public GpuCompilerTest {
public:
void SetDebugOptions(const DebugOptions& options) {
HloModuleConfig config = GetModuleConfigForTest();
config.set_debug_options(options);
CompileModule(config);
}

// Checks that both passes have actually run and that the first run of the
// `after` pass is after the last run of the `before` pass.
void VerifyPassOrder(
const absl::flat_hash_map<std::string, PassRunIndex>& passes,
absl::string_view before, absl::string_view after) {
ASSERT_TRUE(passes.contains(before))
<< "Expected pass did not run: " << before;
ASSERT_TRUE(passes.contains(after)) << "Expected pass did not run: " << after;
EXPECT_LT(passes.at(before).last_run, passes.at(after).first_run)
<< "Pass " << before << " ran after " << after;
}
// Fails if any of the passes with names matching the regular expression
// first_pass_regex run after any of the passes matching last_pass_regex or if
// none of the executed passes matches first_pass_regex or last_pass_regex.
void VerifyPassOrder(absl::string_view first_pass_regex,
absl::string_view last_pass_regex) {
if (!optimized_module_) {
CompileModule(GetModuleConfigForTest());
}
int first_pass_latest_run = -1;
int last_pass_earliest_run = std::numeric_limits<int>::max();
int run_index = 0;
for (const HloPassMetadata& pass_metadata :
optimized_module_->metadata()->proto().pass_metadata()) {
if (RE2::FullMatch(pass_metadata.pass_name(), first_pass_regex)) {
VLOG(2) << "Pass " << pass_metadata.pass_name()
<< " matches first_pass_regex." << std::endl;
first_pass_latest_run = std::max(first_pass_latest_run, run_index);
}
if (RE2::FullMatch(pass_metadata.pass_name(), last_pass_regex)) {
VLOG(2) << "Pass " << pass_metadata.pass_name()
<< " matches last_pass_regex." << std::endl;
last_pass_earliest_run = std::min(last_pass_earliest_run, run_index);
}
++run_index;
}

// Traverses the module's pass metadata and gathers, for each pass, its smallest
// and largest run index. If a pass p0's run index is smaller than another pass
// p1's run index, then p0 ran before p1.
absl::flat_hash_map<std::string, PassRunIndex> GatherPassOrderInformation(
const HloModule& module) {
// Maps a pass name to its first and last index.
absl::flat_hash_map<std::string, PassRunIndex> passes;
int run_index = 0;
for (const HloPassMetadata& pass_metadata :
module.metadata().proto().pass_metadata()) {
auto& pass = passes[pass_metadata.pass_name()];
pass.first_run = std::min(pass.first_run, run_index);
pass.last_run = std::max(pass.last_run, run_index);
++run_index;
EXPECT_GT(first_pass_latest_run, -1)
<< "Did not run a pass matching " << first_pass_regex;
EXPECT_LT(last_pass_earliest_run, std::numeric_limits<int>::max())
<< "Did not run a pass matching " << last_pass_regex;
EXPECT_LE(first_pass_latest_run, last_pass_earliest_run)
<< "One or more passes matching " << first_pass_regex
<< " ran after passes matching " << last_pass_regex;
}

return passes;
}

TEST_F(GpuCompilerPassTest, PassesAreRunInCorrectOrder) {
constexpr absl::string_view constant_module = R"(
private:
void CompileModule(const HloModuleConfig& config) {
constexpr absl::string_view constant_module = R"(
ENTRY main {
ROOT constant = f32[] constant(0)
})";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(constant_module, config));
TF_ASSERT_OK_AND_ASSIGN(optimized_module_,
GetOptimizedModule(std::move(module)));
}

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(constant_module));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
GetOptimizedModule(std::move(module)));
std::unique_ptr<HloModule> optimized_module_;
};

// Maps a pass name to its first and last index.
absl::flat_hash_map<std::string, PassRunIndex> passes =
GatherPassOrderInformation(*optimized_module);

// This test captures known dependencies between passes.
VerifyPassOrder(passes, /*before=*/"layout-assignment",
/*after=*/"priority-fusion");
VerifyPassOrder(passes, /*before=*/"layout-assignment",
/*after=*/"layout_normalization");
VerifyPassOrder(passes, /*before=*/"host-offload-legalize",
/*after=*/"layout_normalization");
TEST_F(PassOrderTest, PassesAreRunInCorrectOrder) {
VerifyPassOrder(/*first_pass_regex=*/"layout-assignment",
/*last_pass_regex=*/"priority-fusion");
VerifyPassOrder(/*first_pass_regex=*/"layout-assignment",
/*last_pass_regex=*/"layout_normalization");
VerifyPassOrder(/*first_pass_regex=*/"host-offload-legalize",
/*last_pass_regex=*/"layout_normalization");
}

TEST_F(GpuCompilerPassTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) {
TEST_F(PassOrderTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
Expand All @@ -1125,38 +1133,22 @@ TEST_F(GpuCompilerPassTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) {
GTEST_SKIP() << "FusionBlockLevelRewriter requires Ampere+ to run.";
}

constexpr absl::string_view constant_module = R"(
ENTRY main {
ROOT constant = f32[] constant(0)
})";

HloModuleConfig config;
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_enable_fusion_block_level_rewriter(
true);
config.set_debug_options(debug_options);
SetDebugOptions(debug_options);

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(constant_module, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
GetOptimizedModule(std::move(module)));

absl::flat_hash_map<std::string, PassRunIndex> passes =
GatherPassOrderInformation(*optimized_module);
VerifyPassOrder(/*first_pass_regex=*/".*fusion.*",
/*last_pass_regex=*/"fusion-block-level-rewriter");
}

absl::string_view kFusionBlockLevelRewriterName =
"fusion-block-level-rewriter";
TEST_F(PassOrderTest, CollectivePipelinerRunsAfterCollectiveQuantizer) {
DebugOptions options = GetDebugOptionsForTest();
options.set_xla_gpu_enable_pipelined_collectives(true);
SetDebugOptions(options);

for (const auto& [pass_name, _] : passes) {
if (pass_name != kFusionBlockLevelRewriterName &&
absl::StrContains(pass_name, "fusion")) {
VerifyPassOrder(passes, /*before=*/pass_name,
/*after=*/kFusionBlockLevelRewriterName);
VLOG(2) << "Verified pass order: " << pass_name << " -> "
<< kFusionBlockLevelRewriterName;
}
}
VerifyPassOrder(/*first_pass_regex=*/"collective-quantizer",
/*last_pass_regex=*/"collective-pipeliner.*");
}

} // namespace
Expand Down

0 comments on commit e1be920

Please sign in to comment.