Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@20255865
Browse files Browse the repository at this point in the history
Other than the usual integration, the CL does two things:

The upstream change openxla/stablehlo#1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions uses those APIs. The CL fixes the invocation of those APIs in MHLO codebase so as to sync the semantics of StableHLO reduction operation with MHLO.

There exists canonicalization passes like group-reduction-dimensions and hlo-canonicalize-reduction which create reduce operation using builder methods that calls type inference of reduce op with empty reduction region example. This is problematic as, with the change, the type inference of reduce op is now dependent on the reduction body. The CL updates all the calls sites of the problematic builder (the one which calls type inference with empty reduction block) with the invocation of a new custom builder method introduced for mhlo::Reduce operation.

Note that at the moment we do not need similar custom builder for other reduction based operations (like scatter, reduce_scatter, all_reduce, select_and_scatter, reduce_window) as they are presently created using a builder version take result type as an input and hence does not call inference from within.

Also, the CL adds verification tests for the operations with promotable semantics.

PiperOrigin-RevId: 599684600
  • Loading branch information
sdasgup3 authored and copybara-github committed Jan 19, 2024
1 parent fd4c7fe commit 9cb9531
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 407 deletions.
117 changes: 1 addition & 116 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
--- stablehlo/CMakeLists.txt
+++ stablehlo/CMakeLists.txt
@@ -13,135 +13,20 @@
@@ -13,131 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
Expand Down Expand Up @@ -134,37 +134,13 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
-#-------------------------------------------------------------------------------
-
-if(STABLEHLO_ENABLE_BINDINGS_PYTHON)
- if(NOT STABLEHLO_EXTERNAL_PROJECT_BUILD)
- message(WARNING "StableHLO Python bindings are not supported in standalone mode")
- endif()
-
- include(MLIRDetectPythonEnv)
- mlir_configure_python_dev_packages()
-endif()
+set(STABLEHLO_ENABLE_BINDINGS_PYTHON ${MHLO_ENABLE_BINDINGS_PYTHON})

#-------------------------------------------------------------------------------
# Directory setup
diff --ruN a/stablehlo/docs/_toc.yaml b/stablehlo/docs/_toc.yaml
--- stablehlo/docs/_toc.yaml
+++ stablehlo/docs/_toc.yaml
@@ -1,3 +1,16 @@
+# Copyright 2023 The StableHLO Authors.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
toc:
- heading: StableHLO developer guide
- title: Overview
diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt
--- stablehlo/stablehlo/CMakeLists.txt
+++ stablehlo/stablehlo/CMakeLists.txt
Expand All @@ -176,18 +152,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir b/stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir
--- stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir
+++ stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir
@@ -29,7 +29,7 @@
// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty()
// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]]
-// CHECK-PRIMITIVE: linalg.reduce { arith.addi }
+// CHECK-PRIMITIVE: linalg.reduce { arith.addi {overflowFlags = #arith.overflow<none>} }
// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>)
// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
// CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr}
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down Expand Up @@ -2756,83 +2720,4 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -967,16 +967,16 @@
// better error reporting for this case.
// This serves the current use cases well, so the implementation of more
// sophisticated refinement algorithm is left for future work.
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
auto condStatus = refineValues(rewriter, op, op.getCond().getArguments(),
op.getOperandTypes());
auto bodyStatus = refineValues(rewriter, op, op.getBody().getArguments(),
op.getOperandTypes());
if (succeeded(condStatus) || succeeded(bodyStatus)) {
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
} else {
- rewriter.cancelRootUpdate(op);
+ rewriter.cancelOpModification(op);
return failure();
}
}
@@ -1055,7 +1055,7 @@
if (!needsUpdate)
return rewriter.notifyMatchFailure(op, "doesn't need update");

- rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
+ rewriter.modifyOpInPlace(op->getParentOp(), [&]() { return; });
return success();
}
};
diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
--- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp
@@ -426,15 +426,19 @@
return success();
}

-template <typename T, typename Attr>
-SpecialResult convertDenseArray(StringAttr vhloName, Attribute vhloAttr,
- SmallVector<NamedAttribute>& stablehloAttrs) {
+SpecialResult convertDenseI64Array(
+ StringAttr vhloName, Attribute vhloAttr,
+ SmallVector<NamedAttribute>& stablehloAttrs) {
auto tensorAttr = dyn_cast<vhlo::TensorV1Attr>(vhloAttr);
if (!tensorAttr) return specialFailure();

- auto data = SmallVector<T>(
- ArrayRef<T>(reinterpret_cast<const T*>(tensorAttr.getData().data()),
- tensorAttr.getData().size() / sizeof(T)));
+ if (tensorAttr.getData().size() % sizeof(int64_t) != 0)
+ return specialFailure();
+
+ auto data = ArrayRef<int64_t>(
+ reinterpret_cast<const int64_t*>(tensorAttr.getData().data()),
+ tensorAttr.getData().size() / sizeof(int64_t))
+ .vec();

// Handle splats
if (data.size() == 1) {
@@ -445,15 +449,9 @@
data.resize(size, data[0]);
}

- stablehloAttrs.emplace_back(vhloName, Attr::get(vhloAttr.getContext(), data));
+ stablehloAttrs.emplace_back(
+ vhloName, DenseI64ArrayAttr::get(vhloAttr.getContext(), data));
return specialSuccess();
-}
-
-SpecialResult convertDenseI64Array(
- StringAttr vhloName, Attribute vhloAttr,
- SmallVector<NamedAttribute>& stablehloAttrs) {
- return convertDenseArray<int64_t, DenseI64ArrayAttr>(vhloName, vhloAttr,
- stablehloAttrs);
}

template <typename VhloOpTy>

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "9c8a1b7ddd5e4ef5abbdc7ef33b041a56343ae2f"
STABLEHLO_SHA256 = "4f6b6d4e5a96893a5ee6ff53aa017b7e151690f09c8c5eac77bd19b6330faa7a"
STABLEHLO_COMMIT = "20255865ba299ed67bdf6267478c8477aef7a60d"
STABLEHLO_SHA256 = "d703dba8c3f6ed1b5c7ac9772ec24d474644a5f10c0122aaca1401e0b4120471"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
Loading

0 comments on commit 9cb9531

Please sign in to comment.