From 2f99455cdf99e844ddad17de9f4714997023d243 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 24 Oct 2024 09:48:55 -0700 Subject: [PATCH] [PJRT] Use the least common StableHLO version between the framework and the plugin. PiperOrigin-RevId: 689416599 --- xla/pjrt/BUILD | 1 + xla/pjrt/mlir_to_hlo.cc | 20 ++++++++++++++++---- xla/pjrt/mlir_to_hlo.h | 16 +++++++++------- xla/pjrt/mlir_to_hlo_test.cc | 19 +++++++++++++++++++ 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index eed9eb0fe53d7..1d00186fbb49c 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -632,6 +632,7 @@ cc_library( "@stablehlo//:register", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", "@stablehlo//:version", "@tsl//tsl/platform:statusor", diff --git a/xla/pjrt/mlir_to_hlo.cc b/xla/pjrt/mlir_to_hlo.cc index cb021b4427cc4..830e10f450209 100644 --- a/xla/pjrt/mlir_to_hlo.cc +++ b/xla/pjrt/mlir_to_hlo.cc @@ -54,6 +54,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/register.h" +#include "stablehlo/api/PortableApi.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Serialization.h" @@ -204,10 +205,21 @@ absl::StatusOr SerializeUsingNativeBytecode( } absl::StatusOr SerializeUsingVersionedStablehlo( - mlir::ModuleOp mlir_module, absl::string_view target, bool inplace) { + mlir::ModuleOp mlir_module, absl::string_view requested_target, + bool inplace) { mlir::MLIRContext* context = mlir_module->getContext(); mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); + // Usually the plugin is older than the framework, but occasionally a plugin's + // nightly build will use the latest public release of a framework. Serialize + // using the framework's version in these cases. + auto target = mlir::stablehlo::getSmallerVersion( + requested_target, mlir::stablehlo::getCurrentVersion()); + if (mlir::failed(target)) { + return absl::InvalidArgumentError( + "Invalid StableHLO target version requested."); + } + // Legalize CHLO -> [StableHLO+Shape] -> StableHLO // Preserve higher-level ops with XLA support. To be replaced by composites. mlir::PassManager pm(context); @@ -218,7 +230,7 @@ absl::StatusOr SerializeUsingVersionedStablehlo( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( mlir::stablehlo::createStablehloCompatibilityExpanderPass( - {std::string(target)})); + {target.value()})); pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( @@ -243,8 +255,8 @@ absl::StatusOr SerializeUsingVersionedStablehlo( // Serialize portable artifact std::string buffer; llvm::raw_string_ostream os(buffer); - if (failed(mlir::stablehlo::serializePortableArtifact(mlir_module, target, - os))) { + if (mlir::failed(mlir::stablehlo::serializePortableArtifact( + mlir_module, target.value(), os))) { const absl::Status status = diagnostic_handler.ConsumeStatus(); return absl::InvalidArgumentError(absl::StrCat( "Failed to serialize StableHLO;\n\nDetailed error from MLIR: ", diff --git a/xla/pjrt/mlir_to_hlo.h b/xla/pjrt/mlir_to_hlo.h index 90efcf772919d..2413851c386fd 100644 --- a/xla/pjrt/mlir_to_hlo.h +++ b/xla/pjrt/mlir_to_hlo.h @@ -63,16 +63,18 @@ absl::StatusOr Serialize(mlir::ModuleOp mlir_module, // Serializes an MLIR module to a portable artifact with forward and backward // compatibility. Supports modules using StableHLO/MHLO/CHLO/Func dialects. -// Target parameter is a StableHLO version string ("0.9.0") which can be used -// for forward compatibility to specify the target downgrade version. -// Most commonly should use: +// The `requested_target` parameter is a StableHLO version string ("0.9.0") +// which can be used for forward compatibility to specify the target downgrade +// version. Most commonly should use: // `mlir::stablehlo::getCurrentVersion()` for backward compat but not forward. // `mlir::stablehlo::getMinimumVersion()` for maximum forward compatibility. -// Ideally should be the `mlir::stablehlo::getCurrentVersion()` of the plugin. -// If program contains dialects that aren't supposed in StableHLO portable -// artifacts, use SerializeUsingNativeBytecode. +// In PJRT, the `requested_target` should be the current version of the PJRT +// plugin. Serialize will use `min(framework_version, plugin_version)` to +// serialize. If program contains dialects that aren't supported in StableHLO +// portable artifacts, use SerializeUsingNativeBytecode. absl::StatusOr SerializeUsingVersionedStablehlo( - mlir::ModuleOp mlir_module, absl::string_view target, bool inplace = false); + mlir::ModuleOp mlir_module, absl::string_view requested_target, + bool inplace = false); // Given a module that might be a portable artifact, deserialize and upgrade it // back to StableHLO. diff --git a/xla/pjrt/mlir_to_hlo_test.cc b/xla/pjrt/mlir_to_hlo_test.cc index 7f411f962815d..4e7b2610f4bcb 100644 --- a/xla/pjrt/mlir_to_hlo_test.cc +++ b/xla/pjrt/mlir_to_hlo_test.cc @@ -57,6 +57,25 @@ TEST(MlirToHloTest, StablehloTest) { EXPECT_THAT(blob, IsVhloArtifact("1.0.0")); } +TEST(MlirToHloTest, StablehloPluginNewerThanFramework) { + constexpr char kProgram[] = + R"( + func.func @add(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %cst = stablehlo.constant dense<1.0> : tensor<1x2xf32> + %0 = stablehlo.add %arg0, %cst : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } + )"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + ParseMlirModuleString(kProgram, context)); + + // Request version v100.99.88, newer than the framework version. + // Serialize uses frameworks version when plugin requests a newer version. + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "100.99.98")); + EXPECT_THAT(blob, IsVhloArtifact(mlir::stablehlo::getCurrentVersion())); +} + TEST(MlirToHloTest, ChloTest) { constexpr char kProgram[] = R"(