Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to shape functions enabling reuse from MHLO #1918

Merged
merged 2 commits into from
Jan 19, 2024

Conversation

sdasgup3
Copy link
Member

The upstream change #1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions in hlo_ops.cc uses those APIs. The PR updates the visibility and signature of those API for a clearer integration.

Specifically, the PR does the followings:

  1. updates getAccumulatorTypes to return a error status when the input regions is empty: This function is used in type inference of various reduction based operations (eg). This functions enables infering type based on the reduction block of the operation, which is something proposed in RFC. However, there could be instances when type inference can be called with empty region in which case we would like to report a meaningful diagnostic message.

  2. Allow hlo::inferAllReduceOp to accept multiple operands information: In stableHLO, all_reduce op have a single operand (e.g.), whereas in MHLO the op can take multiple operand (e.g.. The hlo::inferAllReduceOp signature is updated to accommodate both cases.

  3. Remove unused arguments to functions verifyReduceOpInputsAndInferShape and inferReduceOp.

Copy link
Member

@ghpvnist ghpvnist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM with minor comment.

stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
@sdasgup3 sdasgup3 merged commit 3bd2fad into openxla:main Jan 19, 2024
9 checks passed
stellaraccident pushed a commit to llvm/torch-mlir that referenced this pull request Jan 31, 2024
With the recent LLVM integrate and changes from
llvm/llvm-project#78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
    rewriter.startRootUpdate(op);
    ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
      rewriter.finalizeRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
      rewriter.cancelRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
    rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
    ~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```

I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to openxla/stablehlo#1918 fixes it.

It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test 

...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference                                                                               
  %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>             
       ^                                                                                                                                                                                            
LLVM ERROR: Failed to infer result type(s).               
```

Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants