Skip to content

Commit

Permalink
Add StableHLO CAPI and get_version_from_compat_requirement API (#2494)
Browse files Browse the repository at this point in the history
This PR introduces proper CAPI bindings for popular APIs (not all, will
get to that). Testing internally before external PR, wanted to share for
knowledge's sake.

This also introduces a new StableHLO API and python binding for getting
target versions based on compatibility requirements. Currently
supporting the following values:

```
CompatRequirement ::= None | 1mo | 3mo | Max

Version fromCompatibilityRequirement(CompatRequirement);
```

Anything more fine-grained didn't work well because of our integrate
cadence: Merge to StableHLO, then export to openxla/xla / TF /
elsewhere. Potentially 2 different dates with the same meaning depending
on how a project depends on StableHLO.

More compat requirements can be added at any time, on a per-use-case
basis and versions that compat requirements map to can be modified as
needed, as long as the updated version satisfies the requirement
constraint.

Closes #2170
Closes #2350
  • Loading branch information
GleasonK authored Aug 19, 2024
1 parent 591f2e3 commit 691c676
Show file tree
Hide file tree
Showing 17 changed files with 637 additions and 188 deletions.
23 changes: 20 additions & 3 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ package(
exports_files([
"LICENSE",
"stablehlo/integrations/python/ChloModule.cpp",
"stablehlo/integrations/python/PortableApi.cpp",
"stablehlo/integrations/python/PortableApi.h",
"stablehlo/integrations/python/StablehloApi.cpp",
"stablehlo/integrations/python/StablehloApi.h",
"stablehlo/integrations/python/StablehloModule.cpp",
"stablehlo/integrations/python/VhloModule.cpp",
])
Expand Down Expand Up @@ -864,13 +864,15 @@ STABLEHLO_CAPI_SOURCES = [
"stablehlo/integrations/c/StablehloAttributes.cpp",
"stablehlo/integrations/c/StablehloDialect.cpp",
"stablehlo/integrations/c/StablehloPasses.cpp",
"stablehlo/integrations/c/StablehloApi.cpp",
"stablehlo/integrations/c/StablehloTypes.cpp",
]

STABLEHLO_CAPI_HEADERS = [
"stablehlo/integrations/c/StablehloAttributes.h",
"stablehlo/integrations/c/StablehloDialect.h",
"stablehlo/integrations/c/StablehloPasses.h",
"stablehlo/integrations/c/StablehloApi.h",
"stablehlo/integrations/c/StablehloTypes.h",
]

Expand All @@ -880,10 +882,17 @@ cc_library(
hdrs = STABLEHLO_CAPI_HEADERS,
strip_include_prefix = ".",
deps = [
":reference_api",
":reference_configuration",
":stablehlo_ops",
":stablehlo_passes",
":stablehlo_portable_api",
":stablehlo_serialization",
":version",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

Expand All @@ -904,10 +913,17 @@ cc_library(
hdrs = STABLEHLO_CAPI_HEADERS,
strip_include_prefix = ".",
deps = [
":reference_api",
":reference_configuration",
":stablehlo_ops",
":stablehlo_passes",
":stablehlo_portable_api",
":stablehlo_serialization",
":version",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = True,
)
Expand Down Expand Up @@ -1283,6 +1299,7 @@ cc_binary(
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:TosaDialect",
],
)
Expand Down
40 changes: 38 additions & 2 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ std::string getCurrentVersion();
// `serializePortableArtifact`.
std::string getMinimumVersion();

// From: #include "stablehlo/dialect/Version.h"

// CompatibilityRequirement is used to get a viable target version to use for
// `serializePortableArtifact` given a compatibility requirement specified as
// a duration.
//
// New enum values can be added per use case.
//
// Values represent a minimum requirement, i.e. WEEK_4 will return a >=4w
// old version, the specific implementation detail can be updated at any time
// by the community as long as it satisfies the requirement.
//
// Given that integration into XLA is not immediate, coarse intervals work
// better than providing a specific date.
enum class CompatibilityRequirement {
NONE = 0, // No compat requirement, use latest version.
WEEK_4 = 1, // 1 month requirement
WEEK_12 = 2, // 3 month requirement
MAX = 3, // Maximum compat, use minimum supported version
};

// Get a viable target version to use for `serializePortableArtifact` for a
// given compatibility requirement. See `CompatibilityRequirement` for
// details.
Version::fromCompatibilityRequirement(CompatibilityRequirement requirement);

// From: #include "stablehlo/dialect/Serialization.h"

// Write a StableHLO program to a portable artifact
Expand Down Expand Up @@ -112,12 +138,22 @@ for example usage of these APIs.
StableHLO also provides Python bindings to the C++ compatibility APIs:
```python
class StablehloCompatibilityRequirement(enum.Enum):
NONE, # No compat, same as get_current_version
WEEK_4, # 1mo compat
WEEK_12, # 3mo compat
MAX # Max compat, same as get_minimum_version
def get_version_from_compatibility_requirement(requirement : StablehloCompatibilityRequirement) -> str: ...
def get_current_version() -> str: ...
def get_minimum_version() -> str: ...
def get_smaller_version(v1 : str, v2 : str) -> str: ...
def get_api_version() -> int: ...
def serialize_portable_artifact(module: ir.Module, target_version: str) -> bytes: ...
def serialize_portable_artifact(module: str, target_version: str) -> bytes: ...
def serialize_portable_artifact_str(module: str, target_version: str) -> bytes: ...
def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ...
def deserialize_portable_artifact(artifact: bytes) -> str: ...
def deserialize_portable_artifact_str(artifact: bytes) -> str: ...
def eval_module(module : ir.Module, args : List[ir.Attribute])
```

See [`StablehloModule.cpp`](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/StablehloModule.cpp)
Expand Down
13 changes: 4 additions & 9 deletions stablehlo/api/PortableApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,13 @@ void loadSerializationDialects(MLIRContext& context) {
}
} // namespace

LogicalResult getSmallerVersion(const std::string& version1,
const std::string& version2,
std::string& result) {
FailureOr<std::string> getSmallerVersion(llvm::StringRef version1,
llvm::StringRef version2) {
auto v1 = mlir::vhlo::Version::fromString(version1);
auto v2 = mlir::vhlo::Version::fromString(version2);
if (failed(v1) || failed(v2)) return failure();

if (*v1 < *v2)
result = (*v1).toString();
else
result = (*v2).toString();
return success();
if (*v1 < *v2) return (*v1).toString();
return (*v2).toString();
}

std::string getCurrentVersion() {
Expand Down
7 changes: 3 additions & 4 deletions stablehlo/api/PortableApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ namespace stablehlo {

/// Return the current version for portable API.
/// Increments on all meaningful changes to this file.
inline int64_t getApiVersion() { return 8; }
inline int64_t getApiVersion() { return 9; }

// Get the smaller version between version1 and version2.
LogicalResult getSmallerVersion(const std::string& version1,
const std::string& version2,
std::string& result);
FailureOr<std::string> getSmallerVersion(llvm::StringRef version1,
llvm::StringRef version2);

// Get the current StableHLO version.
//
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/dialect/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ FailureOr<int64_t> Version::getBytecodeVersion() const {
return failure();
}

Version Version::fromCompatibilityRequirement(
CompatibilityRequirement requirement) {
// Compatibility requirement versions can be updated as needed, as long as the
// version satisifies the requirement.
switch (requirement) {
case CompatibilityRequirement::NONE:
return Version::getCurrentVersion();
case CompatibilityRequirement::WEEK_4:
return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024
case CompatibilityRequirement::WEEK_12:
return Version(1, 0, 0); // v1.0.0 - May 14, 2024
case CompatibilityRequirement::MAX:
return Version::getMinimumVersion();
}
}

mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) {
return diag << version.toString();
}
Expand Down
25 changes: 25 additions & 0 deletions stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,31 @@ class Version {
/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }

// CompatibilityRequirement is used to get a viable target version to use for
// `serializePortableArtifact` given a compatibility requirement specified as
// a duration.
//
// New enum values can be added per use case.
//
// Values represent a minimum requirement, i.e. WEEK_4 will return a >=4w
// old version, the specific implementation detail can be updated at any time
// by the community as long as it satisfies the requirement.
//
// Given that integration into XLA is not immediate, coarse intervals work
// better than providing a specific date.
enum class CompatibilityRequirement {
NONE = 0, // No compat requirement, use latest version.
WEEK_4 = 1, // 1 month requirement
WEEK_12 = 2, // 3 month requirement
MAX = 3, // Maximum compat, use minimum supported version
};

// Get a viable target version to use for `serializePortableArtifact` for a
// given compatibility requirement. See `CompatibilityRequirement` for
// details.
static Version fromCompatibilityRequirement(
CompatibilityRequirement requirement);

/// Return the MLIR Bytecode Format associated with the version instance.
/// Returns failure if version is not in compatibility window.
FailureOr<int64_t> getBytecodeVersion() const;
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def VHLO_Dialect : Dialect {
0.19.0: Introduce `composite` operation.
0.20.0: Remove `padding` attribute from `dynamic_conv`.
1.0.0: Increase compatibility guarantees to 5 years backward, 2 years forward (no functional changes relative to 0.20.0).
1.1.0: Add gather/scatter batching dimensions.
1.2.0: Introduce `si2` and `ui2` types.
1.3.0: Extend `custom_call` op `backend_config` to support `DictionaryAttr`.
1.4.0: Add `tan` op to StableHLO opset.
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/integrations/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,24 @@ add_mlir_public_c_api_library(ChloCAPI

add_mlir_public_c_api_library(StablehloCAPI
PARTIAL_SOURCES_INTENDED
StablehloApi.cpp
StablehloAttributes.cpp
StablehloDialect.cpp
StablehloPasses.cpp
StablehloTypes.cpp

LINK_LIBS PUBLIC
LLVMSupport
MLIRCAPIIR
MLIRIR
MLIRSupport
StablehloOps
StablehloPasses
StablehloPortableApi
StablehloReferenceApi
StablehloReferenceConfiguration
StablehloSerialization
Version
)

add_mlir_public_c_api_library(VhloCAPI
Expand Down
137 changes: 137 additions & 0 deletions stablehlo/integrations/c/StablehloApi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "stablehlo/integrations/c/StablehloApi.h"

#include <vector>

#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/api/PortableApi.h"
#include "stablehlo/dialect/Serialization.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/reference/Api.h"
#include "stablehlo/reference/Configuration.h"

int stablehloGetApiVersion() { return mlir::stablehlo::getApiVersion(); }

mlir::vhlo::Version::CompatibilityRequirement unwrapCompatibilityRequirement(
MlirStablehloCompatibilityRequirement requirement) {
switch (requirement) {
case MlirStablehloCompatibilityRequirement::NONE:
return mlir::vhlo::Version::CompatibilityRequirement::NONE;
case MlirStablehloCompatibilityRequirement::WEEK_4:
return mlir::vhlo::Version::CompatibilityRequirement::WEEK_4;
case MlirStablehloCompatibilityRequirement::WEEK_12:
return mlir::vhlo::Version::CompatibilityRequirement::WEEK_12;
case MlirStablehloCompatibilityRequirement::MAX:
return mlir::vhlo::Version::CompatibilityRequirement::MAX;
}
llvm::report_fatal_error("unhandled compatibility requirement");
}

void stablehloVersionFromCompatibilityRequirement(
MlirStablehloCompatibilityRequirement requirement,
MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
stream << mlir::vhlo::Version::fromCompatibilityRequirement(
unwrapCompatibilityRequirement(requirement));
}

void stablehloGetCurrentVersion(MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
stream << mlir::stablehlo::getCurrentVersion();
}

void stablehloGetMinimumVersion(MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
stream << mlir::stablehlo::getMinimumVersion();
}

MlirLogicalResult stablehloGetSmallerVersion(MlirStringRef version1,
MlirStringRef version2,
MlirStringCallback callback,
void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
auto result =
mlir::stablehlo::getSmallerVersion(unwrap(version1), unwrap(version2));
if (mlir::failed(result)) return mlirLogicalResultFailure();
stream << result.value();
return mlirLogicalResultSuccess();
}

MlirLogicalResult stablehloSerializePortableArtifact(
MlirModule moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
if (failed(mlir::stablehlo::serializePortableArtifact(
unwrap(moduleStr), unwrap(targetVersion), stream)))
return mlirLogicalResultFailure();
return mlirLogicalResultSuccess();
}

MlirLogicalResult stablehloSerializePortableArtifact(
MlirStringRef moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
if (failed(mlir::stablehlo::serializePortableArtifact(
unwrap(moduleStr), unwrap(targetVersion), stream)))
return mlirLogicalResultFailure();
return mlirLogicalResultSuccess();
}

MlirLogicalResult stablehloDeserializePortableArtifact(
MlirStringRef artifactStr, MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
if (failed(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr),
stream)))
return mlirLogicalResultFailure();
return mlirLogicalResultSuccess();
}

MlirModule stablehloDeserializePortableArtifact(MlirStringRef artifactStr,
MlirContext ctx) {
return wrap(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr),
unwrap(ctx))
.release());
}

MlirAttribute stablehloEvalModule(MlirModule module, int nArgs,
MlirAttribute const *args, int *errorCode) {
std::vector<mlir::DenseElementsAttr> inputs;
inputs.reserve(nArgs);
for (int i = 0; i < nArgs; ++i) {
inputs.push_back(llvm::cast<mlir::DenseElementsAttr>(unwrap(args[i])));
}
mlir::stablehlo::InterpreterConfiguration config;
mlir::FailureOr<llvm::SmallVector<mlir::DenseElementsAttr>> results =
mlir::stablehlo::evalModule(unwrap(module), inputs, config);
if (mlir::failed(results)) {
*errorCode = 1;
return MlirAttribute{nullptr};
}
std::vector<MlirAttribute> resultsVec;
for (const auto &result : results.value()) {
resultsVec.push_back(wrap(result));
}
return mlirArrayAttrGet(mlirModuleGetContext(module), resultsVec.size(),
resultsVec.data());
}
Loading

0 comments on commit 691c676

Please sign in to comment.