Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT] Use the least common StableHLO version between the framework and the plugin. #18686

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ cc_library(
"@stablehlo//:register",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:stablehlo_portable_api",
"@stablehlo//:stablehlo_serialization",
"@stablehlo//:version",
"@tsl//tsl/platform:statusor",
Expand Down
20 changes: 16 additions & 4 deletions xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ limitations under the License.
#include "mlir/Transforms/Passes.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "stablehlo/api/PortableApi.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Serialization.h"
Expand Down Expand Up @@ -204,10 +205,21 @@ absl::StatusOr<std::string> SerializeUsingNativeBytecode(
}

absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::ModuleOp mlir_module, absl::string_view target, bool inplace) {
mlir::ModuleOp mlir_module, absl::string_view requested_target,
bool inplace) {
mlir::MLIRContext* context = mlir_module->getContext();
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);

// Usually the plugin is older than the framework, but occasionally a plugin's
// nightly build will use the latest public release of a framework. Serialize
// using the framework's version in these cases.
auto target = mlir::stablehlo::getSmallerVersion(
requested_target, mlir::stablehlo::getCurrentVersion());
if (mlir::failed(target)) {
return absl::InvalidArgumentError(
"Invalid StableHLO target version requested.");
}

// Legalize CHLO -> [StableHLO+Shape] -> StableHLO
// Preserve higher-level ops with XLA support. To be replaced by composites.
mlir::PassManager pm(context);
Expand All @@ -218,7 +230,7 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloCompatibilityExpanderPass(
{std::string(target)}));
{target.value()}));
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
Expand All @@ -243,8 +255,8 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
// Serialize portable artifact
std::string buffer;
llvm::raw_string_ostream os(buffer);
if (failed(mlir::stablehlo::serializePortableArtifact(mlir_module, target,
os))) {
if (mlir::failed(mlir::stablehlo::serializePortableArtifact(
mlir_module, target.value(), os))) {
const absl::Status status = diagnostic_handler.ConsumeStatus();
return absl::InvalidArgumentError(absl::StrCat(
"Failed to serialize StableHLO;\n\nDetailed error from MLIR: ",
Expand Down
16 changes: 9 additions & 7 deletions xla/pjrt/mlir_to_hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ absl::StatusOr<std::string> Serialize(mlir::ModuleOp mlir_module,

// Serializes an MLIR module to a portable artifact with forward and backward
// compatibility. Supports modules using StableHLO/MHLO/CHLO/Func dialects.
// Target parameter is a StableHLO version string ("0.9.0") which can be used
// for forward compatibility to specify the target downgrade version.
// Most commonly should use:
// The `requested_target` parameter is a StableHLO version string ("0.9.0")
// which can be used for forward compatibility to specify the target downgrade
// version. Most commonly should use:
// `mlir::stablehlo::getCurrentVersion()` for backward compat but not forward.
// `mlir::stablehlo::getMinimumVersion()` for maximum forward compatibility.
// Ideally should be the `mlir::stablehlo::getCurrentVersion()` of the plugin.
// If program contains dialects that aren't supposed in StableHLO portable
// artifacts, use SerializeUsingNativeBytecode.
// In PJRT, the `requested_target` should be the current version of the PJRT
// plugin. Serialize will use `min(framework_version, plugin_version)` to
// serialize. If program contains dialects that aren't supported in StableHLO
// portable artifacts, use SerializeUsingNativeBytecode.
absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
mlir::ModuleOp mlir_module, absl::string_view target, bool inplace = false);
mlir::ModuleOp mlir_module, absl::string_view requested_target,
bool inplace = false);

// Given a module that might be a portable artifact, deserialize and upgrade it
// back to StableHLO.
Expand Down
19 changes: 19 additions & 0 deletions xla/pjrt/mlir_to_hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ TEST(MlirToHloTest, StablehloTest) {
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
}

TEST(MlirToHloTest, StablehloPluginNewerThanFramework) {
constexpr char kProgram[] =
R"(
func.func @add(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<1x2xf32>
%0 = stablehlo.add %arg0, %cst : tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
)";
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));

// Request version v100.99.88, newer than the framework version.
// Serialize uses frameworks version when plugin requests a newer version.
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "100.99.98"));
EXPECT_THAT(blob, IsVhloArtifact(mlir::stablehlo::getCurrentVersion()));
}

TEST(MlirToHloTest, ChloTest) {
constexpr char kProgram[] =
R"(
Expand Down
Loading