diff --git a/rfcs/20230622-quantized-reduction.md b/rfcs/20230622-quantized-reduction.md
index 2152a5cc9ea..248d97c097a 100644
--- a/rfcs/20230622-quantized-reduction.md
+++ b/rfcs/20230622-quantized-reduction.md
@@ -2,8 +2,8 @@
Status: Review
Initial version: 06/22/2023
-updated: 07/13/2023
: Minor refactoring of the examples.
-Last updated: 08/11/2023
: Revision of the proposal to introduce an
+updated: 07/13/2023: Minor refactoring of the examples.
+Last updated: 08/11/2023: Revision of the proposal to introduce an
attribute to capture accumulation type.
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664)
@@ -11,6 +11,9 @@ Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664)
* 06/22/2023: Initial version.
* 07/13/2023: Fixed typo in code blocks, header indentation.
+* 08/11/2023: Revision of the proposal to introduce an attribute to capture
+ accumulation type.
+* 08/25/2023: The additional attribute is redundant.
## Introduction
@@ -24,44 +27,32 @@ op, for non-quantized types, has constraints like
which constrained the signature of reduce op and its associated reducer function
`body` to have the same element types for `inputs`, `results` and arguments and
-return for `body`. For reducer function performing an accumulative operation like
-add, this means that the the result of accumulation can overflow in which case
-the result will be implementation defined (e.g.,
-[saturated](https://en.wikipedia.org/wiki/Saturation_arithmetic) or
-[wrap around](https://en.wikipedia.org/wiki/Integer_overflow)).
-From the conversation with customers it seems a reasonable behavior for non
-quantized data types. However, with quantized data types, such loss in precision
-is not acceptable and hence the motivation is to perform the accumulation in
-some higher data type.
-
-The RFC highlights some of the options emerged out of discussion in the
+return for `body`. For reducer function performing an accumulative operation
+like add, this means that the the result of accumulation can overflow in which
+case the result will be implementation defined (e.g.,
+ [saturated](https://en.wikipedia.org/wiki/Saturation_arithmetic) or
+ [wrap around](https://en.wikipedia.org/wiki/Integer_overflow)). From
+the conversation with customers it seems a reasonable behavior for non quantized
+data types. However, with quantized data types, such loss in precision is not
+acceptable and hence the motivation is to perform the accumulation in some
+higher data type.
+
+The RFC introduces the following proposal, emerged out of discussion in the
[thread](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906)
-along with their tradeoffs. The proposal option #1 looks promising at this
-point, but we are open to further discussion on this.
-
-## Option 1: Introduce additional conversion functions
-
-[The thread](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906)
-discuses an option, proposed by @loganchien, on how to achieve the structural
-changes as mentioned above. We note that some of the examples/diagrams presented
-here are borrowed from an internal doc @loganchien authored.
-
-The proposed options introduces on-the-fly type conversions, which (1) convert
-the input type to the type of the `body` function argument and (2) convert the
-result type of the `body` function to the output type. Following is the code
-snippet with the proposed syntax of reduce op:
+, along with their tradeoffs.
+
+The proposal allows the reducer block to express the computation in a different
+element type (preferably higher accumulation type) than the one used in reduce
+op's ops arguments and return type. For illustrative purposes, in the following
+example, the operand element type `tensor>` is different from the element type for
+ reduction region's block arguments. Similarly, the element type of the
+ reduce op's result `!quant.uniform>` is
+ different from that of block return (`tensor>`).
```mlir
%result = "stablehlo.reduce"(%input, %init_value) ({
- ^input_conversion(
- %input: tensor>):
- %input_rescaled = "stablehlo.uniform_quantize"(%input)
- : (tensor>)
- -> tensor>
- "stablehlo.return"(%input_rescaled)
- : (tensor>) -> ()
-
- }, {
^reduce_computation(
%lhs: tensor>,
%rhs: tensor>):
@@ -71,310 +62,63 @@ snippet with the proposed syntax of reduce op:
-> tensor>
"stablehlo.return"(%add)
: (tensor>) -> ()
- }, {
- ^output_conversion(
- %intermediate_result: tensor>):
- %output_rescaled = "stablehlo.uniform_quantize"(%intermediate_result)
- : (tensor>)
- -> tensor>
- "stablehlo.return"(%output_rescaled)
- : (tensor>) -> ()
}) {
- dimensions = dense<...> : tensor<1xi64>
- } : (tensor<... x !quant.uniform>,
- tensor<... x !quant.uniform>)
- -> tensor<... x !quant.uniform>
+ dimensions = dense<1> : tensor
+ } : (tensor<5 x 1 x !quant.uniform>,
+ tensor>)
+ -> tensor<5 x !quant.uniform>
```
### Semantics
-Here we will informally propose the semantics of the additional functions
-`input_conversion` and `output_conversion` introduced.
-
-```python
-+----------+ +--------+ +--------+ +----------+ +--------+ +--------+
-|init_value| |input[0]| |input[1]| |init_value| |input[2]| |input[3]|
-+----------+ +--------+ +--------+ +----------+ +--------+ +--------+
- | | | | | |
-+----------+ +--------+ +--------+ +----------+ +--------+ +--------+
-|input | |input | |input | |input | |input | |input |
-|convert | |convert | |convert | |convert | |convert | |convert |
-+----------+ +--------+ +--------+ +----------+ +--------+ +--------+
- \ / / \ / /
- +-------+ / +-------+ /
- |compute| / |compute| /
- +-------+ / +-------+ /
- \ / \ /
- +-------+ +-------+
- |compute| |compute|
- +-------+ +-------+
- \___________ ___________/
- \ /
- +-------+
- |compute|
- +-------+
- |
- +-------+
- |output |
- |convert|
- +-------+
-```
-
-### Semantics of `input_conversion` block
-
-The `input_conversion` block is applied selectively to the leaf nodes of a
-schedule tree as shown in above diagram. Note that the `input_conversion` cannot
-be applied to the non-leaf nodes of the schedule tree.
-
-### Semantics of `output_conversion` block
-
-The `output_conversion` block is applied just after the `result` for a particular
-index is computed as shown in the above diagram.
-
-Please refer to the [formal spec](#revised-specification-of-reduce-op) of the proposed
-reduce op.
-
-### Implementation details
-
-From the implementation POV of the proposed spec, we note that
-`input_conversion` or `output_conversion` can very well be optional with
-default values as identity functions. For example, the following code snippet
-
-```mlir
-%result = "stablehlo.reduce"(%input, %init_value) ({
- ^reduce_computation(
- %lhs: tensor>,
- %rhs: tensor>):
- %add = "stablehlo.add"(%lhs, %rhs)
- : (tensor>,
- tensor>)
- -> tensor>
- "stablehlo.return"(%add)
- : (tensor>) -> ()
- }) {
- dimensions = dense<...> : tensor<1xi64>
- } : (tensor<... x !quant.uniform>,
- tensor<... x !quant.uniform>)
- -> tensor<... x !quant.uniform>
-```
-
-should be interpreted as
-
-```mlir
-%result = "stablehlo.reduce"(%input, %init_value) ({
- ^input_conversion(
- %input: tensor>):
- "stablehlo.return"(%input)
- : (tensor>) -> ()
-
- }, {
- ^reduce_computation(
- %lhs: tensor>,
- %rhs: tensor>):
- %add = "stablehlo.add"(%lhs, %rhs)
- : (tensor>,
- tensor>)
- -> tensor>
- "stablehlo.return"(%add)
- : (tensor>) -> ()
- }, {
- ^output_conversion(
- %intermediate_result: tensor>):
- "stablehlo.return"(%intermediate_result)
- : (tensor>) -> ()
- }) {
- dimensions = dense<...> : tensor<1xi64>
- } : (tensor<... x !quant.uniform>,
- tensor<... x !quant.uniform>)
- -> tensor<... x !quant.uniform>
-```
-
-Note that with default values, the input/result type of `reduce` op matches
-with the argument or the result type of the `reduce_computation`, including the
-quantization parameters.
-
-Also, note that the relative order of `input_conversion` or `output_conversion`
-w.r.t the `reduce_computation` can be used to identify the appropriate
-conversion function when any one of `input_conversion` or `output_conversion` is
-missing.
-
-The existing pretty printing is currently producing the following output
-`stablehlo.reduce(%input init: %init_value) applies stablehlo.add across
-dimensions = [1] : (tensor<1x6xi64>, tensor) -> tensor<1xi64>`. IMO,
-modifying the above format, with the default conversion function, will create
-clutter. My proposal here is to follow the existing pretty printing when the
-conversion functions are "not provided". In the event, the conversion functions
-are explicitly provided, then the pretty printers will fall back to default
-generic printing,
-**even if the explicitly provided conversion functions are identity function**:
-To avoid identification of identity functions which could be tricky in general.
-
-### Tradeoffs
-
-* (+) Enables programmers to program at (almost) baremetal. If the hardware
- can support reduction computation in wider type (e.g. in the SIMD
- instruction set, we typically do widening/compute/narrowing within the
- kernel to save the memory bandwidth), the programmer can explicitly request
- for that.
-* (-) The disadvantage of this representation is that the syntax is more
- verbose and requires significant changes to the specification.
-
-## Option 2: re-scale input to accumulation type
-
-This option is the simplest from the POV for specification of quantized `reduce`
-op. This is adding `stablehlo.uniform_quantize`ops before and after reduce op
-which operates on the "accumulator" type.
-
-```mlir
-%widen = "stablehlo.uniform_quantize"(%input)
- : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
-
-%reduce = "stablehlo.reduce"(%widen) {
- ^reduce_computation(%lhs: !quant.uniform, %rhs: !qunat.uniform):
- // reduce_computation_block
- }
- : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
-
-%narrowed = "stablehlo.uniform_quantize"(%reduce)
- : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
-```
-
-### Tradeoffs
-
-* (+) An advantage of this option is that we only need minor changes to the
- specification (i.e. to allow quantized types).
-* (-) The compiler must pattern match 3 operations and map them into some
- internal representation before their compilation or execution.
-* (-) The compiler must ensure that the `stablehlo.uniform_quantize` (or
- `stablehlo.convert` in the case of `bf16` or `f16`) is not folded before the
- backend matches the pattern.
- [for more information](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906)
-
-## Option 3: allow accumulator type to be different from input type
-
-This is another option we considered which does not fly well because of limited
-expressibility. Adding it just for completeness purposes.
-The idea here is to convey the accumulator type using the `init_value` operand
-of `reduce` op. The code snippet for `reduce` looks like:
-
-```mlir
-%result = "stablehlo.reduce"(%input, %init_value) ({
- ^reduce_computation(
- %elem: tensor>,
- %acc: tensor>):
- %elem_rescaled = "stablehlo.uniform_quantize"(%elem)
- : (tensor>)
- -> tensor>
- %add = "stablehlo.add"(%elem_rescaled, %acc)
- : (tensor>,
- tensor>)
- -> tensor>
- "stablehlo.return"(%0)
- : (tensor>) -> ()
- }) {
- dimensions = dense<1> : tensor<1xi64>
- } : (tensor<... x !quant.uniform>,
- tensor<... x !quant.uniform>)
- -> tensor<... x !quant.uniform>
-```
-
-In this option, the `init_value` type and the `result` type can be different
-from the input type. The first argument of the compute block is fixed for the
-traversed element and the second argument is fixed for the intermediate
-(accumulation) result.
-
-### Tradeoffs
-
-* (+) Make the accumulation type explicit in the IR.
-* (-) This representation imposes a limitation on the evaluation order.
- Since we can’t express the computation between two intermediate (accumulation)
- results, we can not arbitrarily insert `init_value` and start the
- computation at an arbitrary location. The following shows the restricted
- evaluation order with the method.
-
-```python
-+----------+ +--------+ +--------+ +--------+ +--------+
-|init_value| |input[0]| |input[1]| |input[2]| |input[3]|
-+----------+ +--------+ +--------+ +--------+ +--------+
- \ / / / /
- +-------+ / / /
- |compute| / / /
- +-------+ / / /
- \ / / /
- +-------+ / /
- |compute| / /
- +-------+ / /
- \ / /
- +-------+ /
- |compute| /
- +-------+ /
- \ /
- +-------+
- |compute|
- +-------+
-```
-
-## Open Question
-
-### Should we restrict the proposal #1 to quantized types only?
-
-The above proposal #1 of introducing the additional functions is theoretically
-not limited to quantized `reduce` op, but also can be applied to `reduce` op with
-non-quantized types. For example,
-
-```mlir
-%result = "stablehlo.reduce"(%input, %init_value) ({
- ^input_conversion(%arg0: tensor):
- %0 = "stablehlo.convert"(%arg0): (tensor) -> (tensor)
- "stablehlo.return"(%0) : (tensor) -> (tensor)
- }, {
- ^bb0(%arg0: tensor, %arg1: tensor):
- %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) ->
- tensor
- "stablehlo.return"(%0) : (tensor) -> ()
- }, {
- ^output_conversion(%arg0: tensor):
- %0 = "stablehlo.convert"(%arg0): (tensor) -> (tensor)
- "stablehlo.return"(%0) : (tensor) -> (tensor)
- }) {
- dimensions = dense<1> : tensor<1xbf16>
-} : (tensor<1x6xbf16>, tensor) -> tensor<1xbf16>
-```
-
-However, it is not clear how such operations will be lowered to other IR
-representations, like HLO, which does not support such additional computation
-blocks. IMO there is no additional benefit to support such conversion
-functions for regular type given that there already exists infrastructure
-(backend support, lowering passes) to support regular types w/o conversion
-functions. My proposal here would be to restrict the support to only quantized
-types.
+Depending on (1) the input operand type is different from the reduction block
+argument type or (2) the op result type is different from the reduction block
+return type, there will be implicit type conversion defined by either
+`stablehlo.convert`, `stablehlo.uniform_quantize`, or
+`stablehlo.uniform_dequantize`. For example,
+
+ | Implicit type conversion op | element type of operand or result type | element type of block argument or block return type |
+ |-----------------------------------|----------------------------------------|-----------------------------------------------------|
+ | (A) `stablehlo.uniorm_quantize` | quantized tensor | quantized tensor |
+ | (B) `stablehlo.uniorm_quantize` | floating point | quantized tensor |
+ | (C) `stablehlo.uniorm_dequantize` | quantized tensor | floating point |
+ | (D) `stablehlo.convert` | floating-point | integer |
+ | (E) `stablehlo.convert` | integer | floating-point |
+ | (F) `stablehlo.convert` | floating-point | floating-point |
+ | (G) `stablehlo.convert` | integer | integer |
+ | (G) `stablehlo.convert` | complex | complex |
+
+At this point there is no use for cases other than (A), (F), and (G). My
+proposal here would be to address (A), (F), and (G) only. Note that the (F)
+ partially addresses [Decide on mixed
+ precision](https://github.com/openxla/stablehlo/issues/369) for reduce op in
+ that it allows the the input or init value to differ from the corresponding
+ block arguments w.r.t the precision of floating-point types. However, the
+ mixed precision implementation in HLO seems more detailed in the sense that
+ even allows `inputs` and `init_values` to differ in floating-point
+ precision. My proposal would be to treat the above ticket separately.
## Appendix
-To provide an estimate of specification changes needed to implement option #1
-I have attempted to provide the blueprint here.
+To provide an estimate of specification changes needed to implement the
+proposal, I have attempted to provide the blueprint here.
### Revised specification of reduce op
-#### Semantics
+Here we include only the relevant portions of the spec with the proposed update.
-Applies a reduction functions `input_conversion`, `body`, and
-`output_conversion` to `inputs` and `init_values` along the `dimensions` and
-produces `results` tensors.
+#### Semantics
-The order of reductions is implementation-defined, which means that `body` and
-`init_values` must form a monoid to guarantee that the operation produces the
-same results for all inputs on all implementations. However, this condition
-doesn't hold for many popular reductions. E.g. floating-point addition for
-`body` and zero for `init_values` don't actually form a monoid because
-floating-point addition is not associative.
+...
More formally, `results...[j0, ..., jR-1] =
-map(output_conversion, reduce(input_slices_converted))` where:
+reduce_implicit_convert(reduce(input_slices_converted),
+ type(func_outputs(body)...), type(results...)))` where:
* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
-* `input_slices_converted = map(input_conversion, input_slices...)`.
+* `input_slices_converted = reduce_implicit_convert(input_slices...,
+ type(inputs...), type(func_inputs(body)...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node) = body(exec(node.left), exec(node.right))`.
@@ -384,89 +128,52 @@ map(output_conversion, reduce(input_slices_converted))` where:
* `input_slices_converted...[index]` values, for all `index` in
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
- * Interspersed with an implementation-defined amount of `init_values`
+ * Interspersed with an implementation-defined amount of
+ `reduce_implicit_convert(init_values..., type(init_values...), type(func_inputs(body)[:len(func_inputs(body)//2)])...)`
at implementation-defined positions.
-#### Inputs
-
-| Label | Name | Type | Constraints |
-|-------|---------------------|----------------------------------------------|-------------|
-| (I?) | `inputs` | variadic number of tensors | |
-| (I?) | `init_values` | variadic number of 0-dimensional tensors | |
-| (I?) | `dimensions` | 1-dimensional tensor constant of type `si64` | |
-| (I?) | `input_conversion` | function | |
-| (I?) | `body` | function | |
-| (I?) | `output_conversion` | function | |
-
-#### Outputs
-
-| Name | Type | Constraints |
-|-----------|----------------------------|-------------|
-| `results` | variadic number of tensors | |
-
#### Constraints
* (C?) `same(shape(inputs...))`.
* (C?) `element_type(inputs...) = element_type(init_values...)`.
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
-* (C?) `0 < size(inputs) = size(init_values) = size(results) = N`.
-* (C?) `0 <= dimensions < rank(inputs[0])`.
-* (C?) `is_unique(dimensions)`.
-* (C?) `input_conversion` has type `tensor, ..., tensor ->
- (tensor, ..., tensor)` where `Ei = element_type(inputs[i])`.
* (C?) `body` has type `tensor, ..., tensor, tensor, ...,`
`tensor) -> (tensor, ..., tensor)` where
- `Ei = element_type(output_types(input_conversion)[i])`.
-* (C?) `output_conversion` has type `tensor, ..., tensor ->
- (tensor, ..., tensor)` where
- `E'i = element_type(results[i])`.
-* (C?) `element_type(output_types(input_conversion)...) =
- element_type(input_types(output_conversion)...)`.
+ `is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or
+ `is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or
+ `is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or
+ `is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`.
* (C?) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
+`reduce_implicit_convert` is defined as
+
+```python
+def reduce_implicit_convert(x: Value, source_type: Type, destination_type:
+ Type):
+ if source_type == destination_type:
+ return x
+ if is_quantized(source_type) and is_quantized(destination_type):
+ return quantize(x, destination_type)
+ return convert(x, destination_type)
+```
+
The above specification of `reduce` op can be used to define the specification
-of other ops as shown below. For brevity, we are only presenting the relevant
+of other ops as shown below. As before, we are only presenting the relevant
portions of the spec which needs modification.
### Revised specification of reduce_window op
-#### Semantics
-
-Applies a reduction functions `input_conversion`, `body`, and
-`output_conversion` to windows of `inputs` and `init_values` and produces
-`results`.
-
-...
-
-More formally,
-`results...[result_index] = reduce(windows, init_values, axes(inputs...),
- input_conversion, body, output_conversion)`
-where:
-....
-
-#### Inputs
-
-| Label | Name | Type |
-|-------|---------------------|----------|
-| (I?) | `input_conversion` | function |
-| (I8) | `body` | function |
-| (I?) | `output_conversion` | function |
-
#### Constraints
* (C?) `element_type(inputs...) = element_type(init_values...)`.
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
-* (C?) `input_conversion` has type `tensor, ..., tensor ->
- (tensor, ..., tensor)` where `Ei = element_type(inputs[i])`.
* (C?) `body` has type `tensor, ..., tensor, tensor, ...,`
`tensor) -> (tensor, ..., tensor)` where
- `Ei = element_type(output_types(input_conversion)[i])`.
-* (C?) `output_conversion` has type `tensor, ..., tensor ->
- (tensor, ..., tensor)` where
- `E'i = element_type(results[i])`.
-* (C?) `element_type(output_types(input_conversion)...) =
- element_type(input_types(output_conversion)...)`.
+ `is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or
+ `is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or
+ `is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or
+ `is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`.
### Revised specification of select_and_scatter op
@@ -476,74 +183,171 @@ not need additional conversion functions associated with `select`. But the
`scatter` function needs be accompanied with `input_conversion` and
`output_conversion` functions.
-#### Semantics
-
-Scatters the values from the `source` tensor using `scatter` based on the
-outcome of `reduce_window` of the `input` tensor using `select` and produces
-a `result` tensor.
-
-More formally:
-...
-
-* `result[result_index] = reduce([source_values], [init_value], [0],
- input_conversion, scatter, output_conversion)`
- where:
- ...
-
-#### Inputs
-
-| Label | Name | Type |
-|-------|---------------------|----------|
-| (I8) | `input_conversion` | function |
-| (I8) | `scatter` | function |
-| (I8) | `output_conversion` | function |
-
#### Constraints
* (C1) `element_type(operand) = element_type(source)`.
* (C3) `element_type(init_value) = element_type(operand)`.
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
-* (C?) `input_conversion` has type `tensor -> (tensor)` where
- `Ei = element_type(operand)`.
* (C10) `scatter` has type `(tensor, tensor) -> tensor` where
- `E = element_type(output_types(input_conversion))`.
-* (C?) `output_conversion` has type `tensor -> (tensor)` where
- `E'i = element_type(result)`.
-* (C?) `element_type(output_types(input_conversion)) =
- element_type(input_types(output_conversion))`.
-* (C11) `shape(operand) = shape(result)`.
+ `is_integer(element_type(operand)) = is_integer(element_type(E]` or
+ `is_float(element_type(operand)) = is_float(element_type(E]` or
+ `is_complex(element_type(operand)) = is_complex(element_type(E]` or
+ `is_quantized(element_type(operand)) = is_quantized(element_type(E]`.
-## [11 Aug'23] Revised proposal
+### Action Plan
+
+I propose to follow the action plan (order matters):
+
+* Update the specification of ReduceOp, ReduceWindowOp, and SelectAndScatterOp
+ op, taking the accumulation type into account, via [open
+ pr](https://github.com/openxla/stablehlo/pull/1538).
+* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp,
+ BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp,
+ via [open ticket](https://github.com/openxla/stablehlo/issues/1666).
+* Spec the behavior of `precision_config` in DotGeneralOp. [open
+issue](https://github.com/openxla/stablehlo/issues/755)
+* Consider adding `precision_config` in reduction op. `precision_config`,
+currently used for `dot_general` and `convolution`, to override the precision
+specified by the input parameters, allowing the choice of low precision vs high
+precision computation. We should consider adding `precision_config` to all
+reduction based op as well. [need a ticket for this]
+* Consider adding `accumulation_type` to `dot_general`/`convolution op`. The
+attribute seems beneficial for ops like `dot_general` and `convolution` which
+does not have an explicit reduction function. [need a ticket for this item].
+
+## Summary of previous proposals
+
+For completeness of the presentation, let me provide the proposals which are
+evaluated previously and help shape the current proposal.
+
+### Re-scale input to accumulation type
+
+This option is the simplest from the POV for specification of quantized `reduce`
+op. This is adding `stablehlo.uniform_quantize`ops before and after reduce op
+which operates on the "accumulator" type.
+
+```mlir
+%widen = "stablehlo.uniform_quantize"(%input)
+ : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
+
+%reduce = "stablehlo.reduce"(%widen) {
+ ^reduce_computation(%lhs: !quant.uniform, %rhs: !qunat.uniform):
+ // reduce_computation_block
+ }
+ : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
+
+%narrowed = "stablehlo.uniform_quantize"(%reduce)
+ : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform>
+```
+
+#### Tradeoffs
+
+* (+) An advantage of this option is that we only need minor changes to the
+ specification (i.e. to allow quantized types).
+* (-) The compiler must pattern match 3 operations and map them into some
+ internal representation before their compilation or execution.
+* (-) The compiler must ensure that the `stablehlo.uniform_quantize` (or
+ `stablehlo.convert` in the case of `bf16` or `f16`) is not folded before the
+ backend matches the pattern.
+ [for more information](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906)
+
+This proposal should be avoided because it is hard to control the transformation
+which might disrupt the pattern to be matched.
+
+### Introduce on-the-fly type conversions
+
+Proposes addition two regions in reduce op to (1) convert the input type to the
+type of the `body` function argument and (2) convert the result type of the
+`body` function to the output type. Following is the code snippet with the
+proposed syntax of reduce op:
+
+```mlir
+%result = "stablehlo.reduce"(%input, %init_value) ({
+ ^input_conversion(
+ %input: tensor>):
+ %input_rescaled = "stablehlo.uniform_quantize"(%input)
+ : (tensor>)
+ -> tensor>
+ "stablehlo.return"(%input_rescaled)
+ : (tensor>) -> ()
-### Context
+ }, {
+ ^reduce_computation(
+ %lhs: tensor>,
+ %rhs: tensor>):
+ %add = "stablehlo.add"(%lhs, %rhs)
+ : (tensor>,
+ tensor>)
+ -> tensor>
+ "stablehlo.return"(%add)
+ : (tensor>) -> ()
+ }, {
+ ^output_conversion(
+ %intermediate_result: tensor>):
+ %output_rescaled = "stablehlo.uniform_quantize"(%intermediate_result)
+ : (tensor>)
+ -> tensor>
+ "stablehlo.return"(%output_rescaled)
+ : (tensor>) -> ()
+ }) {
+ dimensions = dense<...> : tensor<1xi64>
+ } : (tensor<... x !quant.uniform>,
+ tensor<... x !quant.uniform>)
+ -> tensor<... x !quant.uniform>
+```
-Option #2 should be avoided because it is hard to control the transformation
-which might disrupt the pattern to be matched. The option #1 sounds good except
-that the extra input/output conversion blocks are surplus information. The
-specification would benefit if the intent of the conversion blocks can be
-expressed precisely. The conversion blocks provides a way to capture the
-accumulation type needed to compute the accumulative operation on.
+Here we will informally propose the semantics of the additional functions
+`input_conversion` and `output_conversion` introduced.
-The revised proposal is:
+```python
++----------+ +--------+ +--------+ +----------+ +--------+ +--------+
+|init_value| |input[0]| |input[1]| |init_value| |input[2]| |input[3]|
++----------+ +--------+ +--------+ +----------+ +--------+ +--------+
+ | | | | | |
++----------+ +--------+ +--------+ +----------+ +--------+ +--------+
+|input | |input | |input | |input | |input | |input |
+|convert | |convert | |convert | |convert | |convert | |convert |
++----------+ +--------+ +--------+ +----------+ +--------+ +--------+
+ \ / / \ / /
+ +-------+ / +-------+ /
+ |compute| / |compute| /
+ +-------+ / +-------+ /
+ \ / \ /
+ +-------+ +-------+
+ |compute| |compute|
+ +-------+ +-------+
+ \___________ ___________/
+ \ /
+ +-------+
+ |compute|
+ +-------+
+ |
+ +-------+
+ |output |
+ |convert|
+ +-------+
+```
-* To capture the accumulation type via an additional StableHLO attribute like
- `accumulation_element_type`.
-* The attribute seems beneficial for other ops as well like `dot_general` and
- `convolution`.
-* `precision_config`, currently used for `dot_general` and `convolution`, is
- used to override the precision specified by the input parameters, allowing the
- choice of low precision vs high precision computation. We should consider
- adding `precision_config` to all reduction based op as well.
+### Tradeoffs
-### Few implementation details
+* (+) Enables programmers to program at (almost) baremetal. If the hardware
+ can support reduction computation in wider type (e.g. in the SIMD
+ instruction set, we typically do widening/compute/narrowing within the
+ kernel to save the memory bandwidth), the programmer can explicitly request
+ for that.
+* (-) The disadvantage of this representation is that the syntax is more
+ verbose and requires significant changes to the specification.
+* (-) The extra input/output conversion blocks are surplus information. The
+intent of conversion blocks is to capture the accumulation type needed to
+compute the accumulative operation on. The specification would benefit if the
+intent can be expressed succinctly.
-#### On StableHLO side
+### Introduce accumulation type attribute
-The reduce syntax to be augmented with a optional [type
-attribute](https://github.com/llvm/llvm-project/blob/51a57074bc63842970c4c160b05c1a7e42db7523/mlir/include/mlir/IR/OpBase.td#L1466)
-as follows:
+Instead of using additional input and output conversion blocks, use a type
+attribute `accumulation type` to capture the accumulation type. As an example,
```mlir
%0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] {
@@ -558,48 +362,11 @@ as follows:
// OptionalAttr>:$accumulation_type
```
-Note that the main difference between this option and the option #1 is that the
-input and output conversion blocks are no longer used as their intent is
-specified via the `accumulation_type` attribute. However, the reducer block
-still needs to express the computation in accumulation type only.
-
-**Why optional attribute?**
-
-* At times, it might be desirable not to hard-code the accumulation type. For
- example, when we would like to write a generic code and let the downstream
- compilation tools to decide the exact accumulation type based on the hardware
- of choice.
-* It allows the stablehlo, used in various existing pipelines, to remain
- largely unaffected by this change.
-
-Next, the StableHLO specification should be updated with the syntax and
-semantics aspects of this attribute.
-
-#### On StableHLO Consumers side
-
-The consumers can pattern match the op taking the accumulation type in account
-if the targeted hardware supports accumulation at higher type.
-There are still to explore things about maintaining StableHLO-HLO parity which
-needs to be addresses as well.
-
-### Action Plan
-
-I propose to follow the action plan (order matters):
+Note that the main difference between this option and the previous option is
+that the input and output conversion blocks are no longer used and their intent
+is specified via the `accumulation_type` attribute. However, the reducer block
+needs to express the computation in accumulation type only.
-* Update the specification of ReduceOp, ReduceWindowOp, and SelectAndScatterOp
- op, taking the accumulation type into account, via [open
- pr](https://github.com/openxla/stablehlo/pull/1538).
-* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp,
- BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp,
- via [open ticket](https://github.com/openxla/stablehlo/issues/1666).
-* Add implementation for additional attribute in the above ops. This includes
-updating the tablegen spec/verifiers/type inferencers. [Need a ticket for this].
-* Address the disparity between StableHLO and HLO because of the introduction of
-this new attribute in StableHLO: Should/How XLA should consume this additional
-attribute? [Need a ticket for this].
-* Spec the behavior of `precision_config` in DotGeneralOp. [open
-issue](https://github.com/openxla/stablehlo/issues/755)
-* Consider adding `precision_config` in reduction op. [need a ticket for this
-* Consider adding `accumulation_type` to `dot_general`/`convolution op`.
-[need a ticket for this item].
-item].
+This options is discarded because, for reduce op, the additional attribute seems
+redundant and can be inferred based on the differences in element type of
+operand and reduction block arguments (as described in the current proposal).