Skip to content

Commit

Permalink
review iteration: various typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Sep 11, 2023
1 parent 1d016b2 commit df49179
Showing 1 changed file with 48 additions and 48 deletions.
96 changes: 48 additions & 48 deletions rfcs/20230622-quantized-reduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ The RFC introduces the following proposal, emerged out of discussion in the
, along with their tradeoffs.

The proposal allows the reducer block to express the computation in a different
element type (preferably higher accumulation type) than the one used in reduce
element type (preferably wider accumulation type) than the one used in reduce
op's ops arguments and return type. For illustrative purposes, in the following
example, the operand element type `tensor<!quant.uniform<ui8:f32,
input_scale:input_zp>>` is different from the element type for
reduction region's block arguments. Similarly, the element type of the
reduce op's result `!quant.uniform<ui8:f32, output_scale:output_zp>>` is
different from that of block return (`tensor<!quant.uniform<i32:f32,
accum_scale:accum_zp>>`).
example, the operand element type
`tensor<!quant.uniform<ui8:f32, input_scale:input_zp>>` is different from the
element type for reduction region's block arguments. Similarly, the element
type of the reduce op's result
`!quant.uniform<ui8:f32, output_scale:output_zp>>` is different from that of
block return (`tensor<!quant.uniform<i32:f32, accum_scale:accum_zp>>`).

```mlir
%result = "stablehlo.reduce"(%input, %init_value) ({
Expand All @@ -71,32 +71,32 @@ example, the operand element type `tensor<!quant.uniform<ui8:f32,

### Semantics

Depending on (1) the input operand type is different from the reduction block
If (1) the input operand type is different from the reduction block
argument type or (2) the op result type is different from the reduction block
return type, there will be implicit type conversion defined by either
`stablehlo.convert`, `stablehlo.uniform_quantize`, or
`stablehlo.uniform_dequantize`. For example,

| Implicit type conversion op | element type of operand or result type | element type of block argument or block return type |
|-----------------------------------|----------------------------------------|-----------------------------------------------------|
| (A) `stablehlo.uniorm_quantize` | quantized tensor | quantized tensor |
| (B) `stablehlo.uniorm_quantize` | floating point | quantized tensor |
| (C) `stablehlo.uniorm_dequantize` | quantized tensor | floating point |
| (D) `stablehlo.convert` | floating-point | integer |
| (E) `stablehlo.convert` | integer | floating-point |
| (F) `stablehlo.convert` | floating-point | floating-point |
| (G) `stablehlo.convert` | integer | integer |
| (G) `stablehlo.convert` | complex | complex |
| Implicit type conversion op | element type of operand or block return | element type of block argument or op return |
|-----------------------------------|-----------------------------------------|---------------------------------------------|
| (A) `stablehlo.uniform_quantize` | quantized tensor | quantized tensor |
| (B) `stablehlo.uniform_quantize` | floating point | quantized tensor |
| (C) `stablehlo.uniorm_dequantize` | quantized tensor | floating point |
| (D) `stablehlo.convert` | floating-point | integer |
| (E) `stablehlo.convert` | integer | floating-point |
| (F) `stablehlo.convert` | floating-point | floating-point |
| (G) `stablehlo.convert` | integer | integer |
| (H) `stablehlo.convert` | complex | complex |

At this point there is no use for cases other than (A), (F), and (G). My
proposal here would be to address (A), (F), and (G) only. Note that the (F)
partially addresses [Decide on mixed
precision](https://github.com/openxla/stablehlo/issues/369) for reduce op in
that it allows the the input or init value to differ from the corresponding
block arguments w.r.t the precision of floating-point types. However, the
mixed precision implementation in HLO seems more detailed in the sense that
even allows `inputs` and `init_values` to differ in floating-point
precision. My proposal would be to treat the above ticket separately.
partially addresses
[Decide on mixed precision](https://github.com/openxla/stablehlo/issues/369)
for reduce op in that it allows the input or init value to differ from the
corresponding block arguments w.r.t the precision of floating-point types.
However, the mixed precision implementation in HLO seems more detailed in the
sense that even allows `inputs` and `init_values` to differ in floating-point
precision. My proposal would be to treat the above ticket separately.

## Appendix

Expand Down Expand Up @@ -139,10 +139,10 @@ reduce_implicit_convert(reduce(input_slices_converted),
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
* (C?) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or
`is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or
`is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`.
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i]))`.
* (C?) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.

Expand Down Expand Up @@ -170,10 +170,10 @@ portions of the spec which needs modification.
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
* (C?) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_integer(element_type(inputs[i])) = is_integer(element_type(Ei]` or
`is_float(element_type(inputs[i])) = is_float(element_type(Ei]` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(Ei]` or
`is_quantized(element_type(inputs[i])) = is_quantized(element_type(Ei]`.
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i]))`.

### Revised specification of select_and_scatter op

Expand All @@ -190,10 +190,10 @@ not need additional conversion functions associated with `select`. But the
* (C3) `element_type(init_value) = element_type(operand)`.
* (C?) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`is_integer(element_type(operand)) = is_integer(element_type(E]` or
`is_float(element_type(operand)) = is_float(element_type(E]` or
`is_complex(element_type(operand)) = is_complex(element_type(E]` or
`is_quantized(element_type(operand)) = is_quantized(element_type(E]`.
`is_integer(element_type(operand)) = is_integer(element_type(E))` or
`is_float(element_type(operand)) = is_float(element_type(E))` or
`is_complex(element_type(operand)) = is_complex(element_type(E))` or
`is_quantized(element_type(operand)) = is_quantized(element_type(E))`.
<!-- markdownlint-enable line-length -->

### Action Plan
Expand All @@ -204,18 +204,18 @@ I propose to follow the action plan (order matters):
op, taking the accumulation type into account, via [open
pr](https://github.com/openxla/stablehlo/pull/1538).
* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp,
BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp,
via [open ticket](https://github.com/openxla/stablehlo/issues/1666).
BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp,
via [open ticket](https://github.com/openxla/stablehlo/issues/1666).
* Spec the behavior of `precision_config` in DotGeneralOp. [open
issue](https://github.com/openxla/stablehlo/issues/755)
* Consider adding `precision_config` in reduction op. `precision_config`,
currently used for `dot_general` and `convolution`, to override the precision
specified by the input parameters, allowing the choice of low precision vs high
precision computation. We should consider adding `precision_config` to all
reduction based op as well. [need a ticket for this]
currently used for `dot_general` and `convolution`, to override the precision
specified by the input parameters, allowing the choice of low precision vs
high precision computation. We should consider adding `precision_config` to
all reduction based op as well. [need a ticket for this]
* Consider adding `accumulation_type` to `dot_general`/`convolution op`. The
attribute seems beneficial for ops like `dot_general` and `convolution` which
does not have an explicit reduction function. [need a ticket for this item].
attribute seems beneficial for ops like `dot_general` and `convolution` which
does not have an explicit reduction function. [need a ticket for this item].

## Summary of previous proposals

Expand Down Expand Up @@ -340,9 +340,9 @@ Here we will informally propose the semantics of the additional functions
* (-) The disadvantage of this representation is that the syntax is more
verbose and requires significant changes to the specification.
* (-) The extra input/output conversion blocks are surplus information. The
intent of conversion blocks is to capture the accumulation type needed to
compute the accumulative operation on. The specification would benefit if the
intent can be expressed succinctly.
intent of conversion blocks is to capture the accumulation type needed to
compute the accumulative operation on. The specification would benefit if the
intent can be expressed succinctly.

### Introduce accumulation type attribute

Expand Down

0 comments on commit df49179

Please sign in to comment.