Skip to content

Commit

Permalink
Encapsulate NewHloTestBase test/reference runners.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688760248
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Oct 23, 2024
1 parent 002e8d9 commit 464d1fd
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 27 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ xla_test(
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)

Expand Down
9 changes: 4 additions & 5 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,19 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) {
true /*run_hlo_passes*/, true /*use-threads*/));
ASSERT_EQ(results.size(), kNumReplicas);

auto& ref_runner = HloTestBase::reference_runner_;
TF_ASSERT_OK_AND_ASSIGN(
auto ref_module, ParseAndReturnVerifiedModule(kModuleSingleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true));
TF_ASSERT_OK_AND_ASSIGN(auto ref_exec, reference_runner().CreateExecutable(
std::move(ref_module), true));

ErrorSpec error_spec{1e-5, 1e-5};
fake_ptrs.push_back(nullptr);
for (int i = 0; i < kNumReplicas; i++) {
auto replica_id =
LiteralUtil::CreateFullWithDescendingLayout<uint32_t>({}, i);
fake_ptrs.back() = &replica_id;
TF_ASSERT_OK_AND_ASSIGN(
auto res, ref_runner.ExecuteWithExecutable(ref_exec.get(), fake_ptrs));
TF_ASSERT_OK_AND_ASSIGN(auto res, reference_runner().ExecuteWithExecutable(
ref_exec.get(), fake_ptrs));
EXPECT_TRUE(LiteralTestUtil::Near(res, results[i], error_spec));
}
}
Expand Down
7 changes: 3 additions & 4 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,9 @@ XLA_TEST_F(CollectiveOpsTest,

HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/2);
auto executable =
test_runner_
.CreateExecutable(MakeCrsModule(input_literal.shape(),
/*replica_groups=*/{}, config),
/*run_hlo_passes=*/true)
CreateExecutable(MakeCrsModule(input_literal.shape(),
/*replica_groups=*/{}, config),
/*run_hlo_passes=*/true)
.value();
std::vector<int64_t> devices = {0, 1};
auto device_assn = MakeDeviceAssn(devices);
Expand Down
2 changes: 0 additions & 2 deletions xla/tests/hlo_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
/*reference_runner=*/
GetHloRunnerForReference(reference_platform).value(),
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier),
test_runner_(test_runner()),
reference_runner_(reference_runner()),
test_platform_(test_platform) {}

/*static*/ se::Platform* HloTestBase::GetReferencePlatform() {
Expand Down
11 changes: 0 additions & 11 deletions xla/tests/hlo_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,6 @@ class [[deprecated("Use NewHloTestBase instead.")]] HloTestBase

ErrorSpec error_spec_{0.0001};

// DO NOT USE: These are temporary fields to help migrate to NewHloTestBase's
// accessors.
[[deprecated(
"Use test_runner() instead. This is a temporary field to help migrate to "
"the accessors in NewHloTestBase. Please do not introduce new "
"uses.")]] HloRunnerInterface& test_runner_;
[[deprecated(
"Use reference_runner() instead. This is a temporary field to help "
"migrate to the accessors in NewHloTestBase. Please do not introduce new "
"uses.")]] HloRunnerInterface& reference_runner_;

private:
se::Platform* test_platform_;
std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
Expand Down
5 changes: 2 additions & 3 deletions xla/tests/multithreaded_compilation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ XLA_TEST_F(MultithreadedCompilation, EightModuleCompilation) {
absl::Mutex mu;
std::vector<std::unique_ptr<Executable>> executables;
auto do_compilation = [&](int iteration) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
test_runner_.CreateExecutable(std::move(modules[iteration]), true));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(modules[iteration]), true));
absl::MutexLock lock(&mu);
executables.push_back(std::move(executable));
VLOG(2) << "Adding executable obtained from thread: " << iteration;
Expand Down
3 changes: 1 addition & 2 deletions xla/tests/replicated_io_feed_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) {
std::unique_ptr<HloModule> module =
ParseAndReturnVerifiedModule(hlo_text, config).value();
auto executable =
test_runner_.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
.value();
CreateExecutable(std::move(module), /*run_hlo_passes=*/true).value();

auto device_assn = MakeDeviceAssn(kNumReplicas);

Expand Down

0 comments on commit 464d1fd

Please sign in to comment.