Skip to content

Commit

Permalink
Temporarily remove stablehlo_current_version from PJRT_GetPluginCAttr…
Browse files Browse the repository at this point in the history
…ibutes.

This can be added back once frameworks release with the fix in:
2f99455

Until then, a plugin that is newer than its framework will error on serialization.

PiperOrigin-RevId: 689524205
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Oct 24, 2024
1 parent e0a385a commit 200efba
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
18 changes: 10 additions & 8 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,14 +634,16 @@ static PJRT_NamedValue StableHloVersion(absl::string_view name,
}

const std::vector<PJRT_NamedValue>& GetXlaPluginCAttributes() {
static const std::vector<PJRT_NamedValue>* c_values =
new std::vector<PJRT_NamedValue>({
XlaVersion("xla_version"),
StableHloVersion<0>("stablehlo_current_version",
mlir::vhlo::Version::getCurrentVersion()),
StableHloVersion<1>("stablehlo_minimum_version",
mlir::vhlo::Version::getMinimumVersion()),
});
static const std::vector<PJRT_NamedValue>* c_values = new std::vector<
PJRT_NamedValue>({
XlaVersion("xla_version"),
// TODO: (b/375454646) Uncomment once frameworks have bugfix:
// https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243
// StableHloVersion<0>("stablehlo_current_version",
// mlir::vhlo::Version::getCurrentVersion()),
StableHloVersion<1>("stablehlo_minimum_version",
mlir::vhlo::Version::getMinimumVersion()),
});
return *c_values;
}

Expand Down
15 changes: 9 additions & 6 deletions xla/pjrt/c/pjrt_c_api_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,15 @@ TEST(PjRtCApiHelperTest, GetXlaPluginCAttributes) {
EXPECT_TRUE(did_not_exist_yet);
}
EXPECT_TRUE(map.find("xla_version") != map.end());
PJRT_NamedValue *current = map["stablehlo_current_version"];
mlir::vhlo::Version current_version =
mlir::vhlo::Version::getCurrentVersion();
EXPECT_TRUE(current->int64_array_value[0] == current_version.getMajor());
EXPECT_TRUE(current->int64_array_value[1] == current_version.getMinor());
EXPECT_TRUE(current->int64_array_value[2] == current_version.getPatch());
// TODO: (b/375454646) Uncomment once frameworks have bugfix:
// https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243
//
// PJRT_NamedValue *current = map["stablehlo_current_version"];
// mlir::vhlo::Version current_version =
// mlir::vhlo::Version::getCurrentVersion();
// EXPECT_TRUE(current->int64_array_value[0] == current_version.getMajor());
// EXPECT_TRUE(current->int64_array_value[1] == current_version.getMinor());
// EXPECT_TRUE(current->int64_array_value[2] == current_version.getPatch());
PJRT_NamedValue *minimum = map["stablehlo_minimum_version"];
mlir::vhlo::Version minimum_version =
mlir::vhlo::Version::getMinimumVersion();
Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/c/pjrt_c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,9 @@ TEST_F(PjrtCApiTest, PluginAttributes) {
EXPECT_TRUE(did_not_exist_yet);
}
EXPECT_TRUE(names.find("xla_version") != names.end());
EXPECT_TRUE(names.find("stablehlo_current_version") != names.end());
// TODO: (b/375454646) Uncomment once frameworks have bugfix:
// https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243
// EXPECT_TRUE(names.find("stablehlo_current_version") != names.end());
EXPECT_TRUE(names.find("stablehlo_minimum_version") != names.end());
}

Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) {
*literal));
}

TEST(PjRtClientTest, CompileUsesStableHloVersion) {
// TODO: (b/375454646) Eanble once frameworks have bugfix:
// https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243
TEST(PjRtClientTest, DISABLED_CompileUsesStableHloVersion) {
SetUpCpuPjRtApi();
TF_ASSERT_OK_AND_ASSIGN(const PJRT_Api* c_api, pjrt::PjrtApi("cpu"));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
Expand Down

0 comments on commit 200efba

Please sign in to comment.