-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Specification of quantized reduction operation #1664
Conversation
f1f7244
to
91ad66f
Compare
569ccbd
to
9133dda
Compare
4c59aa8
to
07538c8
Compare
sgtm |
696570b
to
fe2b9c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the discussion on OpenXLA Discuss this RFC is now approved! Thanks to all who provided feedback on this design!
From the discussion emerged in openxla#1538 about the specification of quantized `reduce` op, we zeroed in to a proposal which need some structural changes in the op (adding additional blocks). This proposed RFC is to further explore the proposal in detail. Please let me know your feedback. See also the [OpenXLA Discuss post](https://groups.google.com/a/openxla.org/g/openxla-discuss/c/iwE9is49SS4).
From the discussion emerged in openxla#1538 about the specification of quantized `reduce` op, we zeroed in to a proposal which need some structural changes in the op (adding additional blocks). This proposed RFC is to further explore the proposal in detail. Please let me know your feedback. See also the [OpenXLA Discuss post](https://groups.google.com/a/openxla.org/g/openxla-discuss/c/iwE9is49SS4).
The PR implements the approved [RFC for reduce op](#1664) by proposing the specification related changes for `reduce`, `reduce_window` and `select_and_scatter` ops. In #1647, we talked about some of the other ops which will depend on the quantized specification of reduce op. Initially, I thought about creating separate PRs for them, but for the interest of time and the fact that their handling is going to be very similar to how `reduce` op is handled, I propose to include their PR in the current PR. Here are the additional ops (other than `reduce`, `reduce_window`, `select_and_scatter`) whose specification is added - Ops with explicit computation regions: `all_reduce`, `reduce_scatter`, `scatter`: They are handled similar to how `reduce` op is handled. - Ops w/o explicit computation region: `batch_norm_grad`, `batch_norm_training`: For these ops, the semantics of the operation implicitly does reduction with a custom computation function. As there is no explicit computation function in the IR, the proposal in the RFC cannot be applied. I propose (A) to handle these ops similar to how `batch_norm_inference` is handled with `dequant-op-quant` strategy, (B) we can revisit this op later if there is use case to do that implicit reduction using higher accumulation type. One implementation detail: The fact that `batch_norm_grad`, `batch_norm_training` ops returns multiple outputs and current `dequantize_op_quantize` returns a single output, is handled using a dedicated meta function for these two ops. Next steps: Once the PR is approved the plan is to propose the corresponding changes in the verifier/shape functions for these ops. **Only support promotion of reduction operands to reduction-block arguments** The specific question bought up in the discussion is: should we promote demotion as well? It seems that there aren't any use-cases for that. For example, consider the following cases of implicit conversion for reduction operations. The examples are shown using integer type but can be generalized to other types as well. | input element type | accumulator element type | result element type | use-cases | |:--:|:--:|:--:|:--| | i8| i8 |i8| Reduction using min/max| | i8| i32 |i32| Reduction using average/sum| | i8| i32| i8 | Demotion from accumulation type to result type. (No use-case)| | i8| i32| i64 | Extra promotion from accumulation type to result type. (No use-case) | | i32 | i8 | i8 | Demotion from input value type to accumulation type. (No use case) | With that in mind, we propose to develop the specification based on the scenario we have the use-case for (just for promotion: _none of the cases of demotion or extra promotion will be supported in the specification_.
…le semantics [RFC](openxla/stablehlo#1664), to MHLO. The CL does two things: 1. The upstream change openxla/stablehlo#1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions uses those APIs. The CL fixes the invocation of those APIs in MHLO codebase. 2. There exists canonicalization passes like `group-reduction-dimensions` and `hlo-canonicalize-reduction` which create reduce operation using builder methods that calls type inference of reduce op with empty reduction region [example](https://github.com/openxla/xla/blob/a91877b9c9aa1edf307c5927782111b1a81cd81d/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc#L228). This is problematic as, with the [change](openxla/stablehlo#1869), the type inference of reduce op is now dependent on the reduction body. The CL updates all the calls sites of the problematic builder (the one which calls type inference with empty reduction block) with the invocation of a new custom builder method introduced for mhlo::Reduce operation. Note that at the moment we do not need similar custom builder for other reduction based operations (like scatter, reduce_scatter, all_reduce, select_and_scatter, reduce_window) as they are presently created using a builder version take result type as an input and hence does not call inference from within. Also, the CL adds verification tests for the operations with promotable semantics. PiperOrigin-RevId: 597407271
…le semantics [RFC](openxla/stablehlo#1664), to MHLO. The CL does two things: 1. The upstream change openxla/stablehlo#1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions uses those APIs. The CL fixes the invocation of those APIs in MHLO codebase. 2. There exists canonicalization passes like `group-reduction-dimensions` and `hlo-canonicalize-reduction` which create reduce operation using builder methods that calls type inference of reduce op with empty reduction region [example](https://github.com/openxla/xla/blob/a91877b9c9aa1edf307c5927782111b1a81cd81d/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc#L228). This is problematic as, with the [change](openxla/stablehlo#1869), the type inference of reduce op is now dependent on the reduction body. The CL updates all the calls sites of the problematic builder (the one which calls type inference with empty reduction block) with the invocation of a new custom builder method introduced for mhlo::Reduce operation. Note that at the moment we do not need similar custom builder for other reduction based operations (like scatter, reduce_scatter, all_reduce, select_and_scatter, reduce_window) as they are presently created using a builder version take result type as an input and hence does not call inference from within. Also, the CL adds verification tests for the operations with promotable semantics. PiperOrigin-RevId: 597407271
The upstream change #1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions in [hlo_ops.cc](https://github.com/openxla/xla/blob/main/xla/mlir_hlo/mhlo/IR/hlo_ops.cc) uses those APIs. The PR updates the visibility and signature of those API for a clearer integration. Specifically, the PR does the followings: 1. **updates `getAccumulatorTypes` to return a error status when the input regions is empty**: This function is used in type inference of various reduction based operations ([eg](https://github.com/openxla/stablehlo/blob/d5b464925371092095ac934b46ba93ebd4284223/stablehlo/dialect/TypeInference.cpp#L2589)). This functions enables infering type based on the reduction block of the operation, which is something proposed in [RFC](#1664). However, there could be [instances](https://github.com/openxla/xla/blob/a91877b9c9aa1edf307c5927782111b1a81cd81d/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc#L228) when type inference can be called with empty region in which case we would like to report a meaningful diagnostic message. 2. **Allow `hlo::inferAllReduceOp` to accept multiple operands information**: In stableHLO, `all_reduce` op have a single operand ([e.g.](https://github.com/openxla/stablehlo/blob/d5b464925371092095ac934b46ba93ebd4284223/stablehlo/dialect/StablehloOps.td#L1355)), whereas in MHLO the op can take multiple operand ([e.g.](https://github.com/openxla/xla/blob/79aba0801ef75c1c2dffbb4ecc506a0d8144c9ac/xla/mlir_hlo/mhlo/IR/hlo_ops.td#L1528). The `hlo::inferAllReduceOp` signature is updated to accommodate both cases. 3. Remove unused arguments to functions `verifyReduceOpInputsAndInferShape` and `inferReduceOp`.
From the discussion emerged in #1538 about the specification of quantized
reduce
op, we zeroed in to a proposal which need some structural changes in the op (adding additional blocks). This proposed RFC is to further explore the proposal in detail.Please let me know your feedback.