diff --git a/BUILD.bazel b/BUILD.bazel index 236492df03f..6014550f869 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,8 +22,8 @@ package( exports_files([ "LICENSE", "stablehlo/integrations/python/ChloModule.cpp", - "stablehlo/integrations/python/PortableApi.cpp", - "stablehlo/integrations/python/PortableApi.h", + "stablehlo/integrations/python/StablehloApi.cpp", + "stablehlo/integrations/python/StablehloApi.h", "stablehlo/integrations/python/StablehloModule.cpp", "stablehlo/integrations/python/VhloModule.cpp", ]) @@ -864,6 +864,7 @@ STABLEHLO_CAPI_SOURCES = [ "stablehlo/integrations/c/StablehloAttributes.cpp", "stablehlo/integrations/c/StablehloDialect.cpp", "stablehlo/integrations/c/StablehloPasses.cpp", + "stablehlo/integrations/c/StablehloApi.cpp", "stablehlo/integrations/c/StablehloTypes.cpp", ] @@ -871,6 +872,7 @@ STABLEHLO_CAPI_HEADERS = [ "stablehlo/integrations/c/StablehloAttributes.h", "stablehlo/integrations/c/StablehloDialect.h", "stablehlo/integrations/c/StablehloPasses.h", + "stablehlo/integrations/c/StablehloApi.h", "stablehlo/integrations/c/StablehloTypes.h", ] @@ -880,10 +882,17 @@ cc_library( hdrs = STABLEHLO_CAPI_HEADERS, strip_include_prefix = ".", deps = [ + ":reference_api", + ":reference_configuration", ":stablehlo_ops", ":stablehlo_passes", + ":stablehlo_portable_api", + ":stablehlo_serialization", + ":version", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -904,10 +913,17 @@ cc_library( hdrs = STABLEHLO_CAPI_HEADERS, strip_include_prefix = ".", deps = [ + ":reference_api", + ":reference_configuration", ":stablehlo_ops", ":stablehlo_passes", + ":stablehlo_portable_api", + ":stablehlo_serialization", + ":version", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], alwayslink = True, ) @@ -1283,6 +1299,7 @@ cc_binary( "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:TosaDialect", ], ) diff --git a/docs/compatibility.md b/docs/compatibility.md index 7b8d91f4040..923c6e96381 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -77,6 +77,32 @@ std::string getCurrentVersion(); // `serializePortableArtifact`. std::string getMinimumVersion(); +// From: #include "stablehlo/dialect/Version.h" + +// CompatibilityRequirement is used to get a viable target version to use for +// `serializePortableArtifact` given a compatibility requirement specified as +// a duration. +// +// New enum values can be added per use case. +// +// Values represent a minimum requirement, i.e. WEEK_4 will return a >=4w +// old version, the specific implementation detail can be updated at any time +// by the community as long as it satisfies the requirement. +// +// Given that integration into XLA is not immediate, coarse intervals work +// better than providing a specific date. +enum class CompatibilityRequirement { + NONE = 0, // No compat requirement, use latest version. + WEEK_4 = 1, // 1 month requirement + WEEK_12 = 2, // 3 month requirement + MAX = 3, // Maximum compat, use minimum supported version +}; + +// Get a viable target version to use for `serializePortableArtifact` for a +// given compatibility requirement. See `CompatibilityRequirement` for +// details. +Version::fromCompatibilityRequirement(CompatibilityRequirement requirement); + // From: #include "stablehlo/dialect/Serialization.h" // Write a StableHLO program to a portable artifact @@ -112,12 +138,22 @@ for example usage of these APIs. StableHLO also provides Python bindings to the C++ compatibility APIs: ```python +class StablehloCompatibilityRequirement(enum.Enum): + NONE, # No compat, same as get_current_version + WEEK_4, # 1mo compat + WEEK_12, # 3mo compat + MAX # Max compat, same as get_minimum_version + +def get_version_from_compatibility_requirement(requirement : StablehloCompatibilityRequirement) -> str: ... def get_current_version() -> str: ... def get_minimum_version() -> str: ... +def get_smaller_version(v1 : str, v2 : str) -> str: ... +def get_api_version() -> int: ... def serialize_portable_artifact(module: ir.Module, target_version: str) -> bytes: ... -def serialize_portable_artifact(module: str, target_version: str) -> bytes: ... +def serialize_portable_artifact_str(module: str, target_version: str) -> bytes: ... def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ... -def deserialize_portable_artifact(artifact: bytes) -> str: ... +def deserialize_portable_artifact_str(artifact: bytes) -> str: ... +def eval_module(module : ir.Module, args : List[ir.Attribute]) ``` See [`StablehloModule.cpp`](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/StablehloModule.cpp) diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp index f1c25c8954c..e7005e3f5bc 100644 --- a/stablehlo/api/PortableApi.cpp +++ b/stablehlo/api/PortableApi.cpp @@ -38,18 +38,13 @@ void loadSerializationDialects(MLIRContext& context) { } } // namespace -LogicalResult getSmallerVersion(const std::string& version1, - const std::string& version2, - std::string& result) { +FailureOr getSmallerVersion(llvm::StringRef version1, + llvm::StringRef version2) { auto v1 = mlir::vhlo::Version::fromString(version1); auto v2 = mlir::vhlo::Version::fromString(version2); if (failed(v1) || failed(v2)) return failure(); - - if (*v1 < *v2) - result = (*v1).toString(); - else - result = (*v2).toString(); - return success(); + if (*v1 < *v2) return (*v1).toString(); + return (*v2).toString(); } std::string getCurrentVersion() { diff --git a/stablehlo/api/PortableApi.h b/stablehlo/api/PortableApi.h index 3a4ab580970..7fca3cfd787 100644 --- a/stablehlo/api/PortableApi.h +++ b/stablehlo/api/PortableApi.h @@ -27,12 +27,11 @@ namespace stablehlo { /// Return the current version for portable API. /// Increments on all meaningful changes to this file. -inline int64_t getApiVersion() { return 8; } +inline int64_t getApiVersion() { return 9; } // Get the smaller version between version1 and version2. -LogicalResult getSmallerVersion(const std::string& version1, - const std::string& version2, - std::string& result); +FailureOr getSmallerVersion(llvm::StringRef version1, + llvm::StringRef version2); // Get the current StableHLO version. // diff --git a/stablehlo/dialect/Version.cpp b/stablehlo/dialect/Version.cpp index d572ab969e1..b7f6751e699 100644 --- a/stablehlo/dialect/Version.cpp +++ b/stablehlo/dialect/Version.cpp @@ -72,6 +72,22 @@ FailureOr Version::getBytecodeVersion() const { return failure(); } +Version Version::fromCompatibilityRequirement( + CompatibilityRequirement requirement) { + // Compatibility requirement versions can be updated as needed, as long as the + // version satisifies the requirement. + switch (requirement) { + case CompatibilityRequirement::NONE: + return Version::getCurrentVersion(); + case CompatibilityRequirement::WEEK_4: + return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024 + case CompatibilityRequirement::WEEK_12: + return Version(1, 0, 0); // v1.0.0 - May 14, 2024 + case CompatibilityRequirement::MAX: + return Version::getMinimumVersion(); + } +} + mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) { return diag << version.toString(); } diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 6a1af9fb0b8..e0fda1fcbbe 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -43,6 +43,31 @@ class Version { /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } + // CompatibilityRequirement is used to get a viable target version to use for + // `serializePortableArtifact` given a compatibility requirement specified as + // a duration. + // + // New enum values can be added per use case. + // + // Values represent a minimum requirement, i.e. WEEK_4 will return a >=4w + // old version, the specific implementation detail can be updated at any time + // by the community as long as it satisfies the requirement. + // + // Given that integration into XLA is not immediate, coarse intervals work + // better than providing a specific date. + enum class CompatibilityRequirement { + NONE = 0, // No compat requirement, use latest version. + WEEK_4 = 1, // 1 month requirement + WEEK_12 = 2, // 3 month requirement + MAX = 3, // Maximum compat, use minimum supported version + }; + + // Get a viable target version to use for `serializePortableArtifact` for a + // given compatibility requirement. See `CompatibilityRequirement` for + // details. + static Version fromCompatibilityRequirement( + CompatibilityRequirement requirement); + /// Return the MLIR Bytecode Format associated with the version instance. /// Returns failure if version is not in compatibility window. FailureOr getBytecodeVersion() const; diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index 5df2f3ac999..2da478ec764 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -39,6 +39,7 @@ def VHLO_Dialect : Dialect { 0.19.0: Introduce `composite` operation. 0.20.0: Remove `padding` attribute from `dynamic_conv`. 1.0.0: Increase compatibility guarantees to 5 years backward, 2 years forward (no functional changes relative to 0.20.0). + 1.1.0: Add gather/scatter batching dimensions. 1.2.0: Introduce `si2` and `ui2` types. 1.3.0: Extend `custom_call` op `backend_config` to support `DictionaryAttr`. 1.4.0: Add `tan` op to StableHLO opset. diff --git a/stablehlo/integrations/c/CMakeLists.txt b/stablehlo/integrations/c/CMakeLists.txt index 014a1a8de8c..74595438aa3 100644 --- a/stablehlo/integrations/c/CMakeLists.txt +++ b/stablehlo/integrations/c/CMakeLists.txt @@ -33,14 +33,24 @@ add_mlir_public_c_api_library(ChloCAPI add_mlir_public_c_api_library(StablehloCAPI PARTIAL_SOURCES_INTENDED + StablehloApi.cpp StablehloAttributes.cpp StablehloDialect.cpp StablehloPasses.cpp StablehloTypes.cpp LINK_LIBS PUBLIC + LLVMSupport + MLIRCAPIIR + MLIRIR + MLIRSupport StablehloOps StablehloPasses + StablehloPortableApi + StablehloReferenceApi + StablehloReferenceConfiguration + StablehloSerialization + Version ) add_mlir_public_c_api_library(VhloCAPI diff --git a/stablehlo/integrations/c/StablehloApi.cpp b/stablehlo/integrations/c/StablehloApi.cpp new file mode 100644 index 00000000000..8d9221989a9 --- /dev/null +++ b/stablehlo/integrations/c/StablehloApi.cpp @@ -0,0 +1,137 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/integrations/c/StablehloApi.h" + +#include + +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LLVM.h" +#include "stablehlo/api/PortableApi.h" +#include "stablehlo/dialect/Serialization.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/reference/Api.h" +#include "stablehlo/reference/Configuration.h" + +int stablehloGetApiVersion() { return mlir::stablehlo::getApiVersion(); } + +mlir::vhlo::Version::CompatibilityRequirement unwrapCompatibilityRequirement( + MlirStablehloCompatibilityRequirement requirement) { + switch (requirement) { + case MlirStablehloCompatibilityRequirement::NONE: + return mlir::vhlo::Version::CompatibilityRequirement::NONE; + case MlirStablehloCompatibilityRequirement::WEEK_4: + return mlir::vhlo::Version::CompatibilityRequirement::WEEK_4; + case MlirStablehloCompatibilityRequirement::WEEK_12: + return mlir::vhlo::Version::CompatibilityRequirement::WEEK_12; + case MlirStablehloCompatibilityRequirement::MAX: + return mlir::vhlo::Version::CompatibilityRequirement::MAX; + } + llvm::report_fatal_error("unhandled compatibility requirement"); +} + +void stablehloVersionFromCompatibilityRequirement( + MlirStablehloCompatibilityRequirement requirement, + MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + stream << mlir::vhlo::Version::fromCompatibilityRequirement( + unwrapCompatibilityRequirement(requirement)); +} + +void stablehloGetCurrentVersion(MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + stream << mlir::stablehlo::getCurrentVersion(); +} + +void stablehloGetMinimumVersion(MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + stream << mlir::stablehlo::getMinimumVersion(); +} + +MlirLogicalResult stablehloGetSmallerVersion(MlirStringRef version1, + MlirStringRef version2, + MlirStringCallback callback, + void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + auto result = + mlir::stablehlo::getSmallerVersion(unwrap(version1), unwrap(version2)); + if (mlir::failed(result)) return mlirLogicalResultFailure(); + stream << result.value(); + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult stablehloSerializePortableArtifact( + MlirModule moduleStr, MlirStringRef targetVersion, + MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + if (failed(mlir::stablehlo::serializePortableArtifact( + unwrap(moduleStr), unwrap(targetVersion), stream))) + return mlirLogicalResultFailure(); + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult stablehloSerializePortableArtifact( + MlirStringRef moduleStr, MlirStringRef targetVersion, + MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + if (failed(mlir::stablehlo::serializePortableArtifact( + unwrap(moduleStr), unwrap(targetVersion), stream))) + return mlirLogicalResultFailure(); + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult stablehloDeserializePortableArtifact( + MlirStringRef artifactStr, MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + if (failed(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr), + stream))) + return mlirLogicalResultFailure(); + return mlirLogicalResultSuccess(); +} + +MlirModule stablehloDeserializePortableArtifact(MlirStringRef artifactStr, + MlirContext ctx) { + return wrap(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr), + unwrap(ctx)) + .release()); +} + +MlirAttribute stablehloEvalModule(MlirModule module, int nArgs, + MlirAttribute const *args, int *errorCode) { + std::vector inputs; + inputs.reserve(nArgs); + for (int i = 0; i < nArgs; ++i) { + inputs.push_back(llvm::cast(unwrap(args[i]))); + } + mlir::stablehlo::InterpreterConfiguration config; + mlir::FailureOr> results = + mlir::stablehlo::evalModule(unwrap(module), inputs, config); + if (mlir::failed(results)) { + *errorCode = 1; + return MlirAttribute{nullptr}; + } + std::vector resultsVec; + for (const auto &result : results.value()) { + resultsVec.push_back(wrap(result)); + } + return mlirArrayAttrGet(mlirModuleGetContext(module), resultsVec.size(), + resultsVec.data()); +} diff --git a/stablehlo/integrations/c/StablehloApi.h b/stablehlo/integrations/c/StablehloApi.h new file mode 100644 index 00000000000..77864d6b55a --- /dev/null +++ b/stablehlo/integrations/c/StablehloApi.h @@ -0,0 +1,122 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_INTEGRATIONS_C_STABLEHLOAPI_H_ +#define STABLEHLO_INTEGRATIONS_C_STABLEHLOAPI_H_ + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +// Get the current StableHLO API version. +// +// This value is incremented as needed to help integrate API changes. +MLIR_CAPI_EXPORTED int stablehloGetApiVersion(); + +typedef enum MlirStablehloCompatibilityRequirement { + NONE = 0, + WEEK_4 = 1, + WEEK_12 = 2, + MAX = 3 +} MlirStablehloCompatibilityRequirement; + +// Returns a StringAtt with the version of StableHLO that satisfies the +// compatibility requirement, which is owned by ctx. +MLIR_CAPI_EXPORTED void stablehloVersionFromCompatibilityRequirement( + MlirStablehloCompatibilityRequirement requirement, + MlirStringCallback callback, void* userData); + +// Get the current StableHLO version. +// +// This value can be used as the `targetVersion` argument to +// `serializePortableArtifact`. +MLIR_CAPI_EXPORTED void stablehloGetCurrentVersion(MlirStringCallback callback, + void* userData); + +// Get the minimum supported StableHLO version. +// +// This value can be used as the `targetVersion` argument to +// `serializePortableArtifact`. +// +// Each StableHLO version `producer_version` has a compatibility window, +// i.e. range of versions [`consumer_version_min`, `consumer_version_max`], +// where StableHLO portable artifacts serialized by `producer_version` +// can be deserialized by `consumer_version` within the window. +// See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md +// for the exact extent of these compatibility guarantees. +// +// This function returns `consumer_version_min` for the current StableHLO +// version. It can be used maximize forward compatibility, i.e. to maximize how +// far into the past we can go and still have the payloads produced by +// `serializePortableArtifact` compatible with potential consumers from the past +MLIR_CAPI_EXPORTED void stablehloGetMinimumVersion(MlirStringCallback callback, + void* userData); + +// For two given version strings, return the smaller version. +// Returns failure if either version is not a valid version string. +MlirLogicalResult stablehloGetSmallerVersion(MlirStringRef version1, + MlirStringRef version2, + MlirStringCallback callback, + void* userData); + +// Write a StableHLO program expressed as a string (either prettyprinted MLIR +// module or MLIR bytecode) to a portable artifact. +// Can fail if `moduleStr` cannot be parsed, or if it cannot be expressed in the +// `targetVersion` version of StableHLO, e.g. if it's using new or removed +// features, or if it involves unsupported dialects. +// Returns false on failure. +MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact( + MlirStringRef moduleStr, MlirStringRef targetVersion, + MlirStringCallback callback, void* userData); + +// Write a StableHLO program expressed as a string (either prettyprinted MLIR +// module or MLIR bytecode) to a portable artifact. +// Can fail if `moduleStr` cannot be parsed, or if it cannot be expressed in the +// `targetVersion` version of StableHLO, e.g. if it's using new or removed +// features, or if it involves unsupported dialects. +// Returns false on failure. +MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact( + MlirModule moduleStr, MlirStringRef targetVersion, + MlirStringCallback callback, void* userData); + +// Read a StableHLO program from a portable artifact, returning the module as +// MLIR bytecode. Note, this bytecode returned is not a portable artifact, +// and has the stability of returning textual assembly format. Bytecode is +// returned here since it is more compact and faster to read and write. +// Can fail if `artifactStr` cannot be expressed in the current version of +// StableHLO, e.g. if it's using incompatible features. +// Returns false on failure. +MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifact( + MlirStringRef artifactStr, MlirStringCallback callback, void* userData); + +// Read a StableHLO program from a portable artifact, returning the module as +// MLIR bytecode. Note, this bytecode returned is not a portable artifact, +// and has the stability of returning textual assembly format. Bytecode is +// returned here since it is more compact and faster to read and write. +// Can fail if `artifactStr` cannot be expressed in the current version of +// StableHLO, e.g. if it's using incompatible features. +// +// Returns empty module on failure. +MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact( + MlirStringRef artifactStr, MlirContext ctx); + +// Call the Interpreter, returns MlirArrayAttr of dense element +// MlirAttribute results +MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact( + MlirStringRef artifactStr, MlirContext ctx); + +// Entrypoint for calling the StableHLO reference interpreter. +// Returns an array attribute of dense element attributes for results. +// Sets error code to non-zero on failure. +MlirAttribute stablehloEvalModule(MlirModule module, int nArgs, + MlirAttribute const* args, int* errorCode); + +#endif // STABLEHLO_INTEGRATIONS_C_STABLEHLOAPI_H_ diff --git a/stablehlo/integrations/c/StablehloDialect.h b/stablehlo/integrations/c/StablehloDialect.h index ebac82b69d4..ee8d16c6944 100644 --- a/stablehlo/integrations/c/StablehloDialect.h +++ b/stablehlo/integrations/c/StablehloDialect.h @@ -14,6 +14,7 @@ limitations under the License. #ifndef STABLEHLO_INTEGRATIONS_C_STABLEHLO_DIALECT_H #define STABLEHLO_INTEGRATIONS_C_STABLEHLO_DIALECT_H +#include "mlir-c/IR.h" #include "mlir-c/RegisterEverything.h" #ifdef __cplusplus diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index fac75773a16..43dda81f8fb 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -110,15 +110,12 @@ declare_mlir_python_extension(StablehloPythonExtensions.Main MODULE_NAME _stablehlo ADD_TO_PARENT StablehloPythonExtensions SOURCES + StablehloApi.cpp StablehloModule.cpp - PortableApi.cpp EMBED_CAPI_LINK_LIBS StablehloCAPI PRIVATE_LINK_LIBS - StablehloPasses - StablehloPortableApi - StablehloReferenceApi - StablehloSerialization + StablehloCAPI LLVMSupport ) diff --git a/stablehlo/integrations/python/PortableApi.cpp b/stablehlo/integrations/python/PortableApi.cpp deleted file mode 100644 index 0ea82f9ea27..00000000000 --- a/stablehlo/integrations/python/PortableApi.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "stablehlo/integrations/python/PortableApi.h" - -#include - -#include "stablehlo/api/PortableApi.h" - -namespace py = pybind11; - -namespace mlir { -namespace stablehlo { - -void AddPortableApi(py::module& m) { - // - // Utility APIs. - // - - m.def("get_api_version", []() { return getApiVersion(); }); - - m.def( - "get_smaller_version", - [](std::string version1, std::string version2) -> py::str { - std::string result; - if (failed(getSmallerVersion(version1, version2, result))) { - PyErr_SetString(PyExc_ValueError, - "failed to convert version to stablehlo version"); - return ""; - } - return result; - }, - py::arg("version1"), py::arg("version2")); - - // - // Serialization APIs. - // - - m.def("get_current_version", []() { return getCurrentVersion(); }); - - m.def("get_minimum_version", []() { return getMinimumVersion(); }); - - m.def( - "serialize_portable_artifact", - [](std::string moduleStr, std::string targetVersion) -> py::bytes { - std::string buffer; - llvm::raw_string_ostream os(buffer); - if (failed(serializePortableArtifact(moduleStr, targetVersion, os))) { - PyErr_SetString(PyExc_ValueError, "failed to serialize module"); - return ""; - } - - return py::bytes(buffer); - }, - py::arg("module_str"), py::arg("target_version")); - - m.def( - "deserialize_portable_artifact", - [](std::string artifactStr) -> py::bytes { - std::string buffer; - llvm::raw_string_ostream os(buffer); - if (failed(deserializePortableArtifact(artifactStr, os))) { - PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); - return ""; - } - - return py::bytes(buffer); - }, - py::arg("artifact_str")); -} - -} // namespace stablehlo -} // namespace mlir diff --git a/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/integrations/python/StablehloApi.cpp new file mode 100644 index 00000000000..46a640e103a --- /dev/null +++ b/stablehlo/integrations/python/StablehloApi.cpp @@ -0,0 +1,228 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/integrations/python/StablehloApi.h" + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "stablehlo/integrations/c/StablehloApi.h" + +namespace py = pybind11; + +namespace mlir { +namespace stablehlo { + +// A helper class that implements `MlirStringCallback` by printing parts into a +// C++ string. +class StringWriterHelper { + public: + StringWriterHelper() : ss_(s_) {} + + static MlirStringCallback getMlirStringCallback() { + return [](MlirStringRef string_ref, void *user_data) { + auto *helper = static_cast(user_data); + helper->ss_ << llvm::StringRef(string_ref.data, string_ref.length); + }; + } + + void *getUserData() { return static_cast(this); } + + const std::string &toString() { + ss_.flush(); + return s_; + } + + private: + std::string s_; + llvm::raw_string_ostream ss_; +}; + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +static MlirStringRef toMlirStringRef(std::string_view s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +void AddStablehloApi(py::module &m) { + // Portable API is a subset of StableHLO API + AddPortableApi(m); + + // + // Utility APIs. + // + py::enum_( + m, "StablehloCompatibilityRequirement") + .value("NONE", MlirStablehloCompatibilityRequirement::NONE) + .value("WEEK_4", MlirStablehloCompatibilityRequirement::WEEK_4) + .value("WEEK_12", MlirStablehloCompatibilityRequirement::WEEK_12) + .value("MAX", MlirStablehloCompatibilityRequirement::MAX); + + m.def( + "get_version_from_compatibility_requirement", + [](MlirStablehloCompatibilityRequirement requirement) -> py::str { + StringWriterHelper accumulator; + stablehloVersionFromCompatibilityRequirement( + requirement, accumulator.getMlirStringCallback(), + accumulator.getUserData()); + return accumulator.toString(); + }, + py::arg("requirement")); + + // + // Serialization APIs. + // + m.def( + "serialize_portable_artifact", + [](MlirModule module, std::string_view target) -> py::bytes { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact( + module, toMlirStringRef(target), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + PyErr_SetString(PyExc_ValueError, "failed to serialize module"); + return ""; + } + + return py::bytes(accumulator.toString()); + }, + py::arg("module"), py::arg("target")); + + m.def( + "deserialize_portable_artifact", + [](MlirContext context, std::string_view artifact) -> MlirModule { + auto module = stablehloDeserializePortableArtifact( + toMlirStringRef(artifact), context); + if (mlirModuleIsNull(module)) { + PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); + return {}; + } + return module; + }, + py::arg("context"), py::arg("artifact")); + + // + // Reference APIs + // + m.def( + "eval_module", + [](MlirModule module, + std::vector &args) -> std::vector { + for (auto arg : args) { + if (!mlirAttributeIsADenseElements(arg)) { + PyErr_SetString(PyExc_ValueError, + "input args must be DenseElementsAttr"); + return {}; + } + } + + int errorCode(0); + MlirAttribute resultArrayAttr = + stablehloEvalModule(module, args.size(), args.data(), &errorCode); + + if (errorCode != 0) { + PyErr_SetString(PyExc_ValueError, "interpreter failed"); + return {}; + } + + std::vector pyResults; + for (int i = 0; i < mlirArrayAttrGetNumElements(resultArrayAttr); i++) { + pyResults.push_back(mlirArrayAttrGetElement(resultArrayAttr, i)); + } + return pyResults; + }, + py::arg("module"), py::arg("args")); +} + +void AddPortableApi(py::module &m) { + // + // Utility APIs. + // + m.def("get_api_version", []() { return stablehloGetApiVersion(); }); + + m.def( + "get_smaller_version", + [](const std::string &version1, const std::string &version2) -> py::str { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure(stablehloGetSmallerVersion( + toMlirStringRef(version1), toMlirStringRef(version2), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + PyErr_SetString(PyExc_ValueError, + "failed to convert version to stablehlo version"); + return ""; + } + return accumulator.toString(); + }, + py::arg("version1"), py::arg("version2")); + + m.def("get_current_version", []() -> py::str { + StringWriterHelper accumulator; + stablehloGetCurrentVersion(accumulator.getMlirStringCallback(), + accumulator.getUserData()); + return accumulator.toString(); + }); + + m.def("get_minimum_version", []() -> py::str { + StringWriterHelper accumulator; + stablehloGetMinimumVersion(accumulator.getMlirStringCallback(), + accumulator.getUserData()); + return accumulator.toString(); + }); + + // + // Serialization APIs. + // + m.def( + "serialize_portable_artifact_str", + [](std::string_view moduleStrOrBytecode, + std::string_view targetVersion) -> py::bytes { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact( + toMlirStringRef(moduleStrOrBytecode), + toMlirStringRef(targetVersion), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + PyErr_SetString(PyExc_ValueError, "failed to serialize module"); + return ""; + } + return py::bytes(accumulator.toString()); + }, + py::arg("module_str"), py::arg("target_version")); + + m.def( + "deserialize_portable_artifact_str", + [](std::string_view artifact) -> py::bytes { + StringWriterHelper accumulator; + if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact( + toMlirStringRef(artifact), accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { + PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); + return ""; + } + return py::bytes(accumulator.toString()); + }, + py::arg("artifact_str")); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/integrations/python/PortableApi.h b/stablehlo/integrations/python/StablehloApi.h similarity index 58% rename from stablehlo/integrations/python/PortableApi.h rename to stablehlo/integrations/python/StablehloApi.h index 6eb19bcae88..e0a96a122f9 100644 --- a/stablehlo/integrations/python/PortableApi.h +++ b/stablehlo/integrations/python/StablehloApi.h @@ -13,19 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef STABLEHLO_INTEGRATIONS_PYTHON_API_PORTABLEAPI_H -#define STABLEHLO_INTEGRATIONS_PYTHON_API_PORTABLEAPI_H +#ifndef STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H +#define STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H #include "pybind11/pybind11.h" namespace mlir { namespace stablehlo { -// Add portable API to the pybind11 module. -// Signatures of these APIs have no dependency on MLIR. +// Add StableHLO APIs to the pybind11 module. +// Signatures of these APIs have no dependency on C++ MLIR types and all must +// use C API passthrough. +void AddStablehloApi(pybind11::module& m); + +// Adds a subset of the StableHLO API that doesn't use MLIR in any definitions, +// and is methods only, introducing no new objects / enums to avoid potential +// redefinition issues in complex build environments. void AddPortableApi(pybind11::module& m); } // namespace stablehlo } // namespace mlir -#endif // STABLEHLO_INTEGRATIONS_PYTHON_API_PORTABLEAPI_H +#endif // STABLEHLO_INTEGRATIONS_PYTHON_API_STABLEHLOAPI_H diff --git a/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/integrations/python/StablehloModule.cpp index e4180353a63..5fd995dd950 100644 --- a/stablehlo/integrations/python/StablehloModule.cpp +++ b/stablehlo/integrations/python/StablehloModule.cpp @@ -14,16 +14,13 @@ limitations under the License. #include #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include "mlir/CAPI/IR.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "stablehlo/dialect/Serialization.h" #include "stablehlo/integrations/c/StablehloAttributes.h" #include "stablehlo/integrations/c/StablehloDialect.h" #include "stablehlo/integrations/c/StablehloPasses.h" #include "stablehlo/integrations/c/StablehloTypes.h" -#include "stablehlo/integrations/python/PortableApi.h" -#include "stablehlo/reference/Api.h" +#include "stablehlo/integrations/python/StablehloApi.h" namespace py = pybind11; @@ -533,73 +530,7 @@ PYBIND11_MODULE(_stablehlo, m) { }); // - // Portable APIs + // StableHLO APIs // - mlir::stablehlo::AddPortableApi(m); - - // - // Reference APIs - // - m.def( - "eval_module", - [](MlirModule module, - std::vector &args) -> std::vector { - std::vector inputs; - for (auto arg : args) { - auto attr = llvm::dyn_cast(unwrap(arg)); - if (!attr) { - PyErr_SetString(PyExc_ValueError, - "input args must be DenseElementsAttr"); - return {}; - } - inputs.push_back(attr); - } - - mlir::stablehlo::InterpreterConfiguration config; - auto results = - mlir::stablehlo::evalModule(unwrap(module), inputs, config); - if (failed(results)) { - PyErr_SetString(PyExc_ValueError, "interpreter failed"); - return {}; - } - - std::vector pyResults; - for (auto res : *results) pyResults.push_back(wrap(res)); - return pyResults; - }, - py::arg("module"), py::arg("args")); - - // - // Serialization APIs. - // - - m.def( - "serialize_portable_artifact", - [](MlirModule module, std::string target) -> py::bytes { - std::string buffer; - llvm::raw_string_ostream os(buffer); - if (failed(mlir::stablehlo::serializePortableArtifact(unwrap(module), - target, os))) { - PyErr_SetString(PyExc_ValueError, "failed to serialize module"); - return ""; - } - - return py::bytes(buffer); - }, - py::arg("module"), py::arg("target")); - - m.def( - "deserialize_portable_artifact", - [](MlirContext context, std::string artifact) -> MlirModule { - auto module = mlir::stablehlo::deserializePortableArtifact( - artifact, unwrap(context)); - - if (!module) { - PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); - return {}; - } - - return {module.release()}; - }, - py::arg("context"), py::arg("artifact")); + mlir::stablehlo::AddStablehloApi(m); } diff --git a/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/integrations/python/tests/stablehlo.py index 2d6b0c0811d..df80c2a2c11 100644 --- a/stablehlo/integrations/python/tests/stablehlo.py +++ b/stablehlo/integrations/python/tests/stablehlo.py @@ -247,6 +247,18 @@ def test_minimum_version(): assert is_semver_format(curr_version) +@run +def test_version_requirements(): + for req in ( + stablehlo.StablehloCompatibilityRequirement.NONE, + stablehlo.StablehloCompatibilityRequirement.WEEK_4, + stablehlo.StablehloCompatibilityRequirement.WEEK_12, + stablehlo.StablehloCompatibilityRequirement.MAX, + ): + assert is_semver_format( + stablehlo.get_version_from_compatibility_requirement(req)) + + ASM_FORMAT = """ func.func @test(%arg0: tensor<{0}>) -> tensor<{0}> {{ %0 = stablehlo.add %arg0, %arg0 : (tensor<{0}>, tensor<{0}>) -> tensor<{0}> @@ -289,6 +301,7 @@ def test_reference_api(): def test_get_smaller_version(): curr_version = stablehlo.get_current_version() min_version = stablehlo.get_minimum_version() + print(curr_version) assert stablehlo.get_smaller_version(curr_version, min_version) == min_version @@ -321,8 +334,8 @@ def module_to_bytecode(module: ir.Module) -> bytes: assert m is not None module_str = str(m) bytecode = module_to_bytecode(m) - serialized = stablehlo.serialize_portable_artifact(bytecode, curr_version) - deserialized = stablehlo.deserialize_portable_artifact(serialized) + serialized = stablehlo.serialize_portable_artifact_str(bytecode, curr_version) + deserialized = stablehlo.deserialize_portable_artifact_str(serialized) deserialized_module = ir.Module.parse(deserialized) assert module_str == str(deserialized_module)