diff --git a/rfcs/20230622-quantized-reduction.md b/rfcs/20230622-quantized-reduction.md index 2152a5cc9ea..248d97c097a 100644 --- a/rfcs/20230622-quantized-reduction.md +++ b/rfcs/20230622-quantized-reduction.md @@ -2,8 +2,8 @@ Status: Review
Initial version: 06/22/2023
-updated: 07/13/2023
: Minor refactoring of the examples. -Last updated: 08/11/2023
: Revision of the proposal to introduce an +updated: 07/13/2023: Minor refactoring of the examples.
+Last updated: 08/11/2023: Revision of the proposal to introduce an attribute to capture accumulation type.
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664) @@ -11,6 +11,9 @@ Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664) * 06/22/2023: Initial version. * 07/13/2023: Fixed typo in code blocks, header indentation. +* 08/11/2023: Revision of the proposal to introduce an attribute to capture + accumulation type. +* 08/25/2023: The additional attribute is redundant. ## Introduction @@ -24,44 +27,32 @@ op, for non-quantized types, has constraints like which constrained the signature of reduce op and its associated reducer function `body` to have the same element types for `inputs`, `results` and arguments and -return for `body`. For reducer function performing an accumulative operation like -add, this means that the the result of accumulation can overflow in which case -the result will be implementation defined (e.g., -[saturated](https://en.wikipedia.org/wiki/Saturation_arithmetic) or -[wrap around](https://en.wikipedia.org/wiki/Integer_overflow)). -From the conversation with customers it seems a reasonable behavior for non -quantized data types. However, with quantized data types, such loss in precision -is not acceptable and hence the motivation is to perform the accumulation in -some higher data type. - -The RFC highlights some of the options emerged out of discussion in the +return for `body`. For reducer function performing an accumulative operation +like add, this means that the the result of accumulation can overflow in which +case the result will be implementation defined (e.g., + [saturated](https://en.wikipedia.org/wiki/Saturation_arithmetic) or + [wrap around](https://en.wikipedia.org/wiki/Integer_overflow)). From +the conversation with customers it seems a reasonable behavior for non quantized +data types. However, with quantized data types, such loss in precision is not +acceptable and hence the motivation is to perform the accumulation in some +higher data type. + +The RFC introduces the following proposal, emerged out of discussion in the [thread](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906) -along with their tradeoffs. The proposal option #1 looks promising at this -point, but we are open to further discussion on this. - -## Option 1: Introduce additional conversion functions - -[The thread](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906) -discuses an option, proposed by @loganchien, on how to achieve the structural -changes as mentioned above. We note that some of the examples/diagrams presented -here are borrowed from an internal doc @loganchien authored. - -The proposed options introduces on-the-fly type conversions, which (1) convert -the input type to the type of the `body` function argument and (2) convert the -result type of the `body` function to the output type. Following is the code -snippet with the proposed syntax of reduce op: +, along with their tradeoffs. + +The proposal allows the reducer block to express the computation in a different +element type (preferably higher accumulation type) than the one used in reduce +op's ops arguments and return type. For illustrative purposes, in the following +example, the operand element type `tensor>` is different from the element type for + reduction region's block arguments. Similarly, the element type of the + reduce op's result `!quant.uniform>` is + different from that of block return (`tensor>`). ```mlir %result = "stablehlo.reduce"(%input, %init_value) ({ - ^input_conversion( - %input: tensor>): - %input_rescaled = "stablehlo.uniform_quantize"(%input) - : (tensor>) - -> tensor> - "stablehlo.return"(%input_rescaled) - : (tensor>) -> () - - }, { ^reduce_computation( %lhs: tensor>, %rhs: tensor>): @@ -71,310 +62,63 @@ snippet with the proposed syntax of reduce op: -> tensor> "stablehlo.return"(%add) : (tensor>) -> () - }, { - ^output_conversion( - %intermediate_result: tensor>): - %output_rescaled = "stablehlo.uniform_quantize"(%intermediate_result) - : (tensor>) - -> tensor> - "stablehlo.return"(%output_rescaled) - : (tensor>) -> () }) { - dimensions = dense<...> : tensor<1xi64> - } : (tensor<... x !quant.uniform>, - tensor<... x !quant.uniform>) - -> tensor<... x !quant.uniform> + dimensions = dense<1> : tensor + } : (tensor<5 x 1 x !quant.uniform>, + tensor>) + -> tensor<5 x !quant.uniform> ``` ### Semantics -Here we will informally propose the semantics of the additional functions -`input_conversion` and `output_conversion` introduced. - -```python -+----------+ +--------+ +--------+ +----------+ +--------+ +--------+ -|init_value| |input[0]| |input[1]| |init_value| |input[2]| |input[3]| -+----------+ +--------+ +--------+ +----------+ +--------+ +--------+ - | | | | | | -+----------+ +--------+ +--------+ +----------+ +--------+ +--------+ -|input | |input | |input | |input | |input | |input | -|convert | |convert | |convert | |convert | |convert | |convert | -+----------+ +--------+ +--------+ +----------+ +--------+ +--------+ - \ / / \ / / - +-------+ / +-------+ / - |compute| / |compute| / - +-------+ / +-------+ / - \ / \ / - +-------+ +-------+ - |compute| |compute| - +-------+ +-------+ - \___________ ___________/ - \ / - +-------+ - |compute| - +-------+ - | - +-------+ - |output | - |convert| - +-------+ -``` - -### Semantics of `input_conversion` block - -The `input_conversion` block is applied selectively to the leaf nodes of a -schedule tree as shown in above diagram. Note that the `input_conversion` cannot -be applied to the non-leaf nodes of the schedule tree. - -### Semantics of `output_conversion` block - -The `output_conversion` block is applied just after the `result` for a particular -index is computed as shown in the above diagram. - -Please refer to the [formal spec](#revised-specification-of-reduce-op) of the proposed -reduce op. - -### Implementation details - -From the implementation POV of the proposed spec, we note that -`input_conversion` or `output_conversion` can very well be optional with -default values as identity functions. For example, the following code snippet - -```mlir -%result = "stablehlo.reduce"(%input, %init_value) ({ - ^reduce_computation( - %lhs: tensor>, - %rhs: tensor>): - %add = "stablehlo.add"(%lhs, %rhs) - : (tensor>, - tensor>) - -> tensor> - "stablehlo.return"(%add) - : (tensor>) -> () - }) { - dimensions = dense<...> : tensor<1xi64> - } : (tensor<... x !quant.uniform>, - tensor<... x !quant.uniform>) - -> tensor<... x !quant.uniform> -``` - -should be interpreted as - -```mlir -%result = "stablehlo.reduce"(%input, %init_value) ({ - ^input_conversion( - %input: tensor>): - "stablehlo.return"(%input) - : (tensor>) -> () - - }, { - ^reduce_computation( - %lhs: tensor>, - %rhs: tensor>): - %add = "stablehlo.add"(%lhs, %rhs) - : (tensor>, - tensor>) - -> tensor> - "stablehlo.return"(%add) - : (tensor>) -> () - }, { - ^output_conversion( - %intermediate_result: tensor>): - "stablehlo.return"(%intermediate_result) - : (tensor>) -> () - }) { - dimensions = dense<...> : tensor<1xi64> - } : (tensor<... x !quant.uniform>, - tensor<... x !quant.uniform>) - -> tensor<... x !quant.uniform> -``` - -Note that with default values, the input/result type of `reduce` op matches -with the argument or the result type of the `reduce_computation`, including the -quantization parameters. - -Also, note that the relative order of `input_conversion` or `output_conversion` -w.r.t the `reduce_computation` can be used to identify the appropriate -conversion function when any one of `input_conversion` or `output_conversion` is -missing. - -The existing pretty printing is currently producing the following output -`stablehlo.reduce(%input init: %init_value) applies stablehlo.add across -dimensions = [1] : (tensor<1x6xi64>, tensor) -> tensor<1xi64>`. IMO, -modifying the above format, with the default conversion function, will create -clutter. My proposal here is to follow the existing pretty printing when the -conversion functions are "not provided". In the event, the conversion functions -are explicitly provided, then the pretty printers will fall back to default -generic printing, -**even if the explicitly provided conversion functions are identity function**: -To avoid identification of identity functions which could be tricky in general. - -### Tradeoffs - -* (+) Enables programmers to program at (almost) baremetal. If the hardware - can support reduction computation in wider type (e.g. in the SIMD - instruction set, we typically do widening/compute/narrowing within the - kernel to save the memory bandwidth), the programmer can explicitly request - for that. -* (-) The disadvantage of this representation is that the syntax is more - verbose and requires significant changes to the specification. - -## Option 2: re-scale input to accumulation type - -This option is the simplest from the POV for specification of quantized `reduce` -op. This is adding `stablehlo.uniform_quantize`ops before and after reduce op -which operates on the "accumulator" type. - -```mlir -%widen = "stablehlo.uniform_quantize"(%input) - : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> - -%reduce = "stablehlo.reduce"(%widen) { - ^reduce_computation(%lhs: !quant.uniform, %rhs: !qunat.uniform): - // reduce_computation_block - } - : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> - -%narrowed = "stablehlo.uniform_quantize"(%reduce) - : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> -``` - -### Tradeoffs - -* (+) An advantage of this option is that we only need minor changes to the - specification (i.e. to allow quantized types). -* (-) The compiler must pattern match 3 operations and map them into some - internal representation before their compilation or execution. -* (-) The compiler must ensure that the `stablehlo.uniform_quantize` (or - `stablehlo.convert` in the case of `bf16` or `f16`) is not folded before the - backend matches the pattern. - [for more information](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906) - -## Option 3: allow accumulator type to be different from input type - -This is another option we considered which does not fly well because of limited -expressibility. Adding it just for completeness purposes. -The idea here is to convey the accumulator type using the `init_value` operand -of `reduce` op. The code snippet for `reduce` looks like: - -```mlir -%result = "stablehlo.reduce"(%input, %init_value) ({ - ^reduce_computation( - %elem: tensor>, - %acc: tensor>): - %elem_rescaled = "stablehlo.uniform_quantize"(%elem) - : (tensor>) - -> tensor> - %add = "stablehlo.add"(%elem_rescaled, %acc) - : (tensor>, - tensor>) - -> tensor> - "stablehlo.return"(%0) - : (tensor>) -> () - }) { - dimensions = dense<1> : tensor<1xi64> - } : (tensor<... x !quant.uniform>, - tensor<... x !quant.uniform>) - -> tensor<... x !quant.uniform> -``` - -In this option, the `init_value` type and the `result` type can be different -from the input type. The first argument of the compute block is fixed for the -traversed element and the second argument is fixed for the intermediate -(accumulation) result. - -### Tradeoffs - -* (+) Make the accumulation type explicit in the IR. -* (-) This representation imposes a limitation on the evaluation order. - Since we can’t express the computation between two intermediate (accumulation) - results, we can not arbitrarily insert `init_value` and start the - computation at an arbitrary location. The following shows the restricted - evaluation order with the method. - -```python -+----------+ +--------+ +--------+ +--------+ +--------+ -|init_value| |input[0]| |input[1]| |input[2]| |input[3]| -+----------+ +--------+ +--------+ +--------+ +--------+ - \ / / / / - +-------+ / / / - |compute| / / / - +-------+ / / / - \ / / / - +-------+ / / - |compute| / / - +-------+ / / - \ / / - +-------+ / - |compute| / - +-------+ / - \ / - +-------+ - |compute| - +-------+ -``` - -## Open Question - -### Should we restrict the proposal #1 to quantized types only? - -The above proposal #1 of introducing the additional functions is theoretically -not limited to quantized `reduce` op, but also can be applied to `reduce` op with -non-quantized types. For example, - -```mlir -%result = "stablehlo.reduce"(%input, %init_value) ({ - ^input_conversion(%arg0: tensor): - %0 = "stablehlo.convert"(%arg0): (tensor) -> (tensor) - "stablehlo.return"(%0) : (tensor) -> (tensor) - }, { - ^bb0(%arg0: tensor, %arg1: tensor): - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> - tensor - "stablehlo.return"(%0) : (tensor) -> () - }, { - ^output_conversion(%arg0: tensor): - %0 = "stablehlo.convert"(%arg0): (tensor) -> (tensor) - "stablehlo.return"(%0) : (tensor) -> (tensor) - }) { - dimensions = dense<1> : tensor<1xbf16> -} : (tensor<1x6xbf16>, tensor) -> tensor<1xbf16> -``` - -However, it is not clear how such operations will be lowered to other IR -representations, like HLO, which does not support such additional computation -blocks. IMO there is no additional benefit to support such conversion -functions for regular type given that there already exists infrastructure -(backend support, lowering passes) to support regular types w/o conversion -functions. My proposal here would be to restrict the support to only quantized -types. +Depending on (1) the input operand type is different from the reduction block +argument type or (2) the op result type is different from the reduction block +return type, there will be implicit type conversion defined by either +`stablehlo.convert`, `stablehlo.uniform_quantize`, or +`stablehlo.uniform_dequantize`. For example, + + | Implicit type conversion op | element type of operand or result type | element type of block argument or block return type | + |-----------------------------------|----------------------------------------|-----------------------------------------------------| + | (A) `stablehlo.uniorm_quantize` | quantized tensor | quantized tensor | + | (B) `stablehlo.uniorm_quantize` | floating point | quantized tensor | + | (C) `stablehlo.uniorm_dequantize` | quantized tensor | floating point | + | (D) `stablehlo.convert` | floating-point | integer | + | (E) `stablehlo.convert` | integer | floating-point | + | (F) `stablehlo.convert` | floating-point | floating-point | + | (G) `stablehlo.convert` | integer | integer | + | (G) `stablehlo.convert` | complex | complex | + +At this point there is no use for cases other than (A), (F), and (G). My +proposal here would be to address (A), (F), and (G) only. Note that the (F) + partially addresses [Decide on mixed + precision](https://github.com/openxla/stablehlo/issues/369) for reduce op in + that it allows the the input or init value to differ from the corresponding + block arguments w.r.t the precision of floating-point types. However, the + mixed precision implementation in HLO seems more detailed in the sense that + even allows `inputs` and `init_values` to differ in floating-point + precision. My proposal would be to treat the above ticket separately. ## Appendix -To provide an estimate of specification changes needed to implement option #1 -I have attempted to provide the blueprint here. +To provide an estimate of specification changes needed to implement the +proposal, I have attempted to provide the blueprint here. ### Revised specification of reduce op -#### Semantics +Here we include only the relevant portions of the spec with the proposed update. -Applies a reduction functions `input_conversion`, `body`, and -`output_conversion` to `inputs` and `init_values` along the `dimensions` and -produces `results` tensors. +#### Semantics -The order of reductions is implementation-defined, which means that `body` and -`init_values` must form a monoid to guarantee that the operation produces the -same results for all inputs on all implementations. However, this condition -doesn't hold for many popular reductions. E.g. floating-point addition for -`body` and zero for `init_values` don't actually form a monoid because -floating-point addition is not associative. +... More formally, `results...[j0, ..., jR-1] = -map(output_conversion, reduce(input_slices_converted))` where: +reduce_implicit_convert(reduce(input_slices_converted), + type(func_outputs(body)...), type(results...)))` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. -* `input_slices_converted = map(input_conversion, input_slices...)`. +* `input_slices_converted = reduce_implicit_convert(input_slices..., + type(inputs...), type(func_inputs(body)...)`. * `reduce(input_slices_converted) = exec(schedule)` for some binary tree `schedule` where: * `exec(node) = body(exec(node.left), exec(node.right))`. @@ -384,89 +128,52 @@ map(output_conversion, reduce(input_slices_converted))` where: * `input_slices_converted...[index]` values, for all `index` in `index_space(input_slices_converted)` in the ascending lexicographic order of `index`. - * Interspersed with an implementation-defined amount of `init_values` + * Interspersed with an implementation-defined amount of + `reduce_implicit_convert(init_values..., type(init_values...), type(func_inputs(body)[:len(func_inputs(body)//2)])...)` at implementation-defined positions. -#### Inputs - -| Label | Name | Type | Constraints | -|-------|---------------------|----------------------------------------------|-------------| -| (I?) | `inputs` | variadic number of tensors | | -| (I?) | `init_values` | variadic number of 0-dimensional tensors | | -| (I?) | `dimensions` | 1-dimensional tensor constant of type `si64` | | -| (I?) | `input_conversion` | function | | -| (I?) | `body` | function | | -| (I?) | `output_conversion` | function | | - -#### Outputs - -| Name | Type | Constraints | -|-----------|----------------------------|-------------| -| `results` | variadic number of tensors | | - #### Constraints * (C?) `same(shape(inputs...))`. * (C?) `element_type(inputs...) = element_type(init_values...)`. * (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`. -* (C?) `0 < size(inputs) = size(init_values) = size(results) = N`. -* (C?) `0 <= dimensions < rank(inputs[0])`. -* (C?) `is_unique(dimensions)`. -* (C?) `input_conversion` has type `tensor, ..., tensor -> - (tensor, ..., tensor)` where `Ei = element_type(inputs[i])`. * (C?) `body` has type `tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `Ei = element_type(output_types(input_conversion)[i])`. -* (C?) `output_conversion` has type `tensor, ..., tensor -> - (tensor, ..., tensor)` where - `E'i = element_type(results[i])`. -* (C?) `element_type(output_types(input_conversion)...) = - element_type(input_types(output_conversion)...)`. + `is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or + `is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or + `is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or + `is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`. * (C?) `shape(results...) = shape(inputs...)` except that the dimension sizes of `inputs...` corresponding to `dimensions` are not included. +`reduce_implicit_convert` is defined as + +```python +def reduce_implicit_convert(x: Value, source_type: Type, destination_type: + Type): + if source_type == destination_type: + return x + if is_quantized(source_type) and is_quantized(destination_type): + return quantize(x, destination_type) + return convert(x, destination_type) +``` + The above specification of `reduce` op can be used to define the specification -of other ops as shown below. For brevity, we are only presenting the relevant +of other ops as shown below. As before, we are only presenting the relevant portions of the spec which needs modification. ### Revised specification of reduce_window op -#### Semantics - -Applies a reduction functions `input_conversion`, `body`, and -`output_conversion` to windows of `inputs` and `init_values` and produces -`results`. - -... - -More formally, -`results...[result_index] = reduce(windows, init_values, axes(inputs...), - input_conversion, body, output_conversion)` -where: -.... - -#### Inputs - -| Label | Name | Type | -|-------|---------------------|----------| -| (I?) | `input_conversion` | function | -| (I8) | `body` | function | -| (I?) | `output_conversion` | function | - #### Constraints * (C?) `element_type(inputs...) = element_type(init_values...)`. * (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`. -* (C?) `input_conversion` has type `tensor, ..., tensor -> - (tensor, ..., tensor)` where `Ei = element_type(inputs[i])`. * (C?) `body` has type `tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `Ei = element_type(output_types(input_conversion)[i])`. -* (C?) `output_conversion` has type `tensor, ..., tensor -> - (tensor, ..., tensor)` where - `E'i = element_type(results[i])`. -* (C?) `element_type(output_types(input_conversion)...) = - element_type(input_types(output_conversion)...)`. + `is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or + `is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or + `is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or + `is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`. ### Revised specification of select_and_scatter op @@ -476,74 +183,171 @@ not need additional conversion functions associated with `select`. But the `scatter` function needs be accompanied with `input_conversion` and `output_conversion` functions. -#### Semantics - -Scatters the values from the `source` tensor using `scatter` based on the -outcome of `reduce_window` of the `input` tensor using `select` and produces -a `result` tensor. - -More formally: -... - -* `result[result_index] = reduce([source_values], [init_value], [0], - input_conversion, scatter, output_conversion)` - where: - ... - -#### Inputs - -| Label | Name | Type | -|-------|---------------------|----------| -| (I8) | `input_conversion` | function | -| (I8) | `scatter` | function | -| (I8) | `output_conversion` | function | - #### Constraints * (C1) `element_type(operand) = element_type(source)`. * (C3) `element_type(init_value) = element_type(operand)`. * (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`. -* (C?) `input_conversion` has type `tensor -> (tensor)` where - `Ei = element_type(operand)`. * (C10) `scatter` has type `(tensor, tensor) -> tensor` where - `E = element_type(output_types(input_conversion))`. -* (C?) `output_conversion` has type `tensor -> (tensor)` where - `E'i = element_type(result)`. -* (C?) `element_type(output_types(input_conversion)) = - element_type(input_types(output_conversion))`. -* (C11) `shape(operand) = shape(result)`. + `is_integer(element_type(operand)) = is_integer(element_type(E]` or + `is_float(element_type(operand)) = is_float(element_type(E]` or + `is_complex(element_type(operand)) = is_complex(element_type(E]` or + `is_quantized(element_type(operand)) = is_quantized(element_type(E]`. -## [11 Aug'23] Revised proposal +### Action Plan + +I propose to follow the action plan (order matters): + +* Update the specification of ReduceOp, ReduceWindowOp, and SelectAndScatterOp + op, taking the accumulation type into account, via [open + pr](https://github.com/openxla/stablehlo/pull/1538). +* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp, + BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp, + via [open ticket](https://github.com/openxla/stablehlo/issues/1666). +* Spec the behavior of `precision_config` in DotGeneralOp. [open +issue](https://github.com/openxla/stablehlo/issues/755) +* Consider adding `precision_config` in reduction op. `precision_config`, +currently used for `dot_general` and `convolution`, to override the precision +specified by the input parameters, allowing the choice of low precision vs high +precision computation. We should consider adding `precision_config` to all +reduction based op as well. [need a ticket for this] +* Consider adding `accumulation_type` to `dot_general`/`convolution op`. The +attribute seems beneficial for ops like `dot_general` and `convolution` which +does not have an explicit reduction function. [need a ticket for this item]. + +## Summary of previous proposals + +For completeness of the presentation, let me provide the proposals which are +evaluated previously and help shape the current proposal. + +### Re-scale input to accumulation type + +This option is the simplest from the POV for specification of quantized `reduce` +op. This is adding `stablehlo.uniform_quantize`ops before and after reduce op +which operates on the "accumulator" type. + +```mlir +%widen = "stablehlo.uniform_quantize"(%input) + : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> + +%reduce = "stablehlo.reduce"(%widen) { + ^reduce_computation(%lhs: !quant.uniform, %rhs: !qunat.uniform): + // reduce_computation_block + } + : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> + +%narrowed = "stablehlo.uniform_quantize"(%reduce) + : (tensor<... x !quant.uniform>) -> tensor<... x !quant.uniform> +``` + +#### Tradeoffs + +* (+) An advantage of this option is that we only need minor changes to the + specification (i.e. to allow quantized types). +* (-) The compiler must pattern match 3 operations and map them into some + internal representation before their compilation or execution. +* (-) The compiler must ensure that the `stablehlo.uniform_quantize` (or + `stablehlo.convert` in the case of `bf16` or `f16`) is not folded before the + backend matches the pattern. + [for more information](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906) + +This proposal should be avoided because it is hard to control the transformation +which might disrupt the pattern to be matched. + +### Introduce on-the-fly type conversions + +Proposes addition two regions in reduce op to (1) convert the input type to the +type of the `body` function argument and (2) convert the result type of the +`body` function to the output type. Following is the code snippet with the +proposed syntax of reduce op: + +```mlir +%result = "stablehlo.reduce"(%input, %init_value) ({ + ^input_conversion( + %input: tensor>): + %input_rescaled = "stablehlo.uniform_quantize"(%input) + : (tensor>) + -> tensor> + "stablehlo.return"(%input_rescaled) + : (tensor>) -> () -### Context + }, { + ^reduce_computation( + %lhs: tensor>, + %rhs: tensor>): + %add = "stablehlo.add"(%lhs, %rhs) + : (tensor>, + tensor>) + -> tensor> + "stablehlo.return"(%add) + : (tensor>) -> () + }, { + ^output_conversion( + %intermediate_result: tensor>): + %output_rescaled = "stablehlo.uniform_quantize"(%intermediate_result) + : (tensor>) + -> tensor> + "stablehlo.return"(%output_rescaled) + : (tensor>) -> () + }) { + dimensions = dense<...> : tensor<1xi64> + } : (tensor<... x !quant.uniform>, + tensor<... x !quant.uniform>) + -> tensor<... x !quant.uniform> +``` -Option #2 should be avoided because it is hard to control the transformation -which might disrupt the pattern to be matched. The option #1 sounds good except -that the extra input/output conversion blocks are surplus information. The -specification would benefit if the intent of the conversion blocks can be -expressed precisely. The conversion blocks provides a way to capture the -accumulation type needed to compute the accumulative operation on. +Here we will informally propose the semantics of the additional functions +`input_conversion` and `output_conversion` introduced. -The revised proposal is: +```python ++----------+ +--------+ +--------+ +----------+ +--------+ +--------+ +|init_value| |input[0]| |input[1]| |init_value| |input[2]| |input[3]| ++----------+ +--------+ +--------+ +----------+ +--------+ +--------+ + | | | | | | ++----------+ +--------+ +--------+ +----------+ +--------+ +--------+ +|input | |input | |input | |input | |input | |input | +|convert | |convert | |convert | |convert | |convert | |convert | ++----------+ +--------+ +--------+ +----------+ +--------+ +--------+ + \ / / \ / / + +-------+ / +-------+ / + |compute| / |compute| / + +-------+ / +-------+ / + \ / \ / + +-------+ +-------+ + |compute| |compute| + +-------+ +-------+ + \___________ ___________/ + \ / + +-------+ + |compute| + +-------+ + | + +-------+ + |output | + |convert| + +-------+ +``` -* To capture the accumulation type via an additional StableHLO attribute like - `accumulation_element_type`. -* The attribute seems beneficial for other ops as well like `dot_general` and - `convolution`. -* `precision_config`, currently used for `dot_general` and `convolution`, is - used to override the precision specified by the input parameters, allowing the - choice of low precision vs high precision computation. We should consider - adding `precision_config` to all reduction based op as well. +### Tradeoffs -### Few implementation details +* (+) Enables programmers to program at (almost) baremetal. If the hardware + can support reduction computation in wider type (e.g. in the SIMD + instruction set, we typically do widening/compute/narrowing within the + kernel to save the memory bandwidth), the programmer can explicitly request + for that. +* (-) The disadvantage of this representation is that the syntax is more + verbose and requires significant changes to the specification. +* (-) The extra input/output conversion blocks are surplus information. The +intent of conversion blocks is to capture the accumulation type needed to +compute the accumulative operation on. The specification would benefit if the +intent can be expressed succinctly. -#### On StableHLO side +### Introduce accumulation type attribute -The reduce syntax to be augmented with a optional [type -attribute](https://github.com/llvm/llvm-project/blob/51a57074bc63842970c4c160b05c1a7e42db7523/mlir/include/mlir/IR/OpBase.td#L1466) -as follows: +Instead of using additional input and output conversion blocks, use a type +attribute `accumulation type` to capture the accumulation type. As an example, ```mlir %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] { @@ -558,48 +362,11 @@ as follows: // OptionalAttr>:$accumulation_type ``` -Note that the main difference between this option and the option #1 is that the -input and output conversion blocks are no longer used as their intent is -specified via the `accumulation_type` attribute. However, the reducer block -still needs to express the computation in accumulation type only. - -**Why optional attribute?** - -* At times, it might be desirable not to hard-code the accumulation type. For - example, when we would like to write a generic code and let the downstream - compilation tools to decide the exact accumulation type based on the hardware - of choice. -* It allows the stablehlo, used in various existing pipelines, to remain - largely unaffected by this change. - -Next, the StableHLO specification should be updated with the syntax and -semantics aspects of this attribute. - -#### On StableHLO Consumers side - -The consumers can pattern match the op taking the accumulation type in account -if the targeted hardware supports accumulation at higher type. -There are still to explore things about maintaining StableHLO-HLO parity which -needs to be addresses as well. - -### Action Plan - -I propose to follow the action plan (order matters): +Note that the main difference between this option and the previous option is +that the input and output conversion blocks are no longer used and their intent +is specified via the `accumulation_type` attribute. However, the reducer block +needs to express the computation in accumulation type only. -* Update the specification of ReduceOp, ReduceWindowOp, and SelectAndScatterOp - op, taking the accumulation type into account, via [open - pr](https://github.com/openxla/stablehlo/pull/1538). -* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp, - BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp, - via [open ticket](https://github.com/openxla/stablehlo/issues/1666). -* Add implementation for additional attribute in the above ops. This includes -updating the tablegen spec/verifiers/type inferencers. [Need a ticket for this]. -* Address the disparity between StableHLO and HLO because of the introduction of -this new attribute in StableHLO: Should/How XLA should consume this additional -attribute? [Need a ticket for this]. -* Spec the behavior of `precision_config` in DotGeneralOp. [open -issue](https://github.com/openxla/stablehlo/issues/755) -* Consider adding `precision_config` in reduction op. [need a ticket for this -* Consider adding `accumulation_type` to `dot_general`/`convolution op`. -[need a ticket for this item]. -item]. +This options is discarded because, for reduce op, the additional attribute seems +redundant and can be inferred based on the differences in element type of +operand and reduction block arguments (as described in the current proposal).