Skip to content

Commit

Permalink
[PJRT] Use the least common StableHLO version between the framework a…
Browse files Browse the repository at this point in the history
…nd the plugin.

PiperOrigin-RevId: 689155251
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Oct 24, 2024
1 parent 450b61f commit 859da41
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 11 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ cc_library(
"@stablehlo//:register",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:stablehlo_portable_api",
"@stablehlo//:stablehlo_serialization",
"@stablehlo//:version",
"@tsl//tsl/platform:statusor",
Expand Down
28 changes: 24 additions & 4 deletions xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ limitations under the License.
#include "mlir/Transforms/Passes.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "stablehlo/api/PortableApi.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Serialization.h"
Expand Down Expand Up @@ -204,10 +205,21 @@ absl::StatusOr<std::string> SerializeUsingNativeBytecode(
}

absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::ModuleOp mlir_module, absl::string_view target, bool inplace) {
mlir::ModuleOp mlir_module, absl::string_view requested_target,
bool inplace) {
mlir::MLIRContext* context = mlir_module->getContext();
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);

// Usually the plugin is older than the framework, but occasionally a plugin's
// nightly build will use the latest public release of a framework. Serialize
// using the frameworks version in these cases.
auto target = mlir::stablehlo::getSmallerVersion(
requested_target, mlir::stablehlo::getCurrentVersion());
if (mlir::failed(target)) {
return absl::InvalidArgumentError(
"Invalid StableHLO target version requested.");
}

// Legalize CHLO -> [StableHLO+Shape] -> StableHLO
// Preserve higher-level ops with XLA support. To be replaced by composites.
mlir::PassManager pm(context);
Expand All @@ -218,7 +230,7 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloCompatibilityExpanderPass(
{std::string(target)}));
{target.value()}));
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
Expand All @@ -243,8 +255,8 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
// Serialize portable artifact
std::string buffer;
llvm::raw_string_ostream os(buffer);
if (failed(mlir::stablehlo::serializePortableArtifact(mlir_module, target,
os))) {
if (mlir::failed(mlir::stablehlo::serializePortableArtifact(
mlir_module, target.value(), os))) {
const absl::Status status = diagnostic_handler.ConsumeStatus();
return absl::InvalidArgumentError(absl::StrCat(
"Failed to serialize StableHLO;\n\nDetailed error from MLIR: ",
Expand All @@ -262,6 +274,14 @@ absl::Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) {
return absl::OkStatus();
}

std::string GetLeastCommonStablehloVersion(std::vector<int64_t>& plugin_attr) {
auto framework_version = mlir::vhlo::Version::getCurrentVersion();
auto plugin_version =
mlir::vhlo::Version(plugin_attr[0], plugin_attr[1], plugin_attr[2]);
if (plugin_version < framework_version) return plugin_version.toString();
return framework_version.toString();
}

std::string GetDefaultStablehloVersion(std::optional<int64_t> plugin_version) {
// TODO: (b/370803410) Use WEEK_12 in PJRT, some plugins were not up to date,
// so temporarily using 1.0.0 to allow them time for a new release.
Expand Down
16 changes: 9 additions & 7 deletions xla/pjrt/mlir_to_hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ absl::StatusOr<std::string> Serialize(mlir::ModuleOp mlir_module,

// Serializes an MLIR module to a portable artifact with forward and backward
// compatibility. Supports modules using StableHLO/MHLO/CHLO/Func dialects.
// Target parameter is a StableHLO version string ("0.9.0") which can be used
// for forward compatibility to specify the target downgrade version.
// Most commonly should use:
// The `requested_target` parameter is a StableHLO version string ("0.9.0")
// which can be used for forward compatibility to specify the target downgrade
// version. Most commonly should use:
// `mlir::stablehlo::getCurrentVersion()` for backward compat but not forward.
// `mlir::stablehlo::getMinimumVersion()` for maximum forward compatibility.
// Ideally should be the `mlir::stablehlo::getCurrentVersion()` of the plugin.
// If program contains dialects that aren't supposed in StableHLO portable
// artifacts, use SerializeUsingNativeBytecode.
// In PJRT, the `requested_target` should be the current version of the PJRT
// plugin. Serialize will use `min(framework_version, plugin_version)` to
// serialize. If program contains dialects that aren't supposed in StableHLO
// portable artifacts, use SerializeUsingNativeBytecode.
absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::ModuleOp mlir_module, absl::string_view target, bool inplace = false);
mlir::ModuleOp mlir_module, absl::string_view requested_target,
bool inplace = false);

// Given a module that might be a portable artifact, deserialize and upgrade it
// back to StableHLO.
Expand Down
19 changes: 19 additions & 0 deletions xla/pjrt/mlir_to_hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ TEST(MlirToHloTest, StablehloTest) {
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
}

TEST(MlirToHloTest, StablehloPluginNewerThanFramework) {
constexpr char kProgram[] =
R"(
func.func @add(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<1x2xf32>
%0 = stablehlo.add %arg0, %cst : tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
)";
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));

// Request version v100.99.88, newer than the framework version.
// Serialize uses frameworks version when plugin requests a newer version.
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "100.99.98"));
EXPECT_THAT(blob, IsVhloArtifact(mlir::stablehlo::getCurrentVersion()));
}

TEST(MlirToHloTest, ChloTest) {
constexpr char kProgram[] =
R"(
Expand Down

0 comments on commit 859da41

Please sign in to comment.