Skip to content

Commit

Permalink
revised plan
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Aug 14, 2023
1 parent 99f3d58 commit 91ad66f
Showing 1 changed file with 92 additions and 1 deletion.
93 changes: 92 additions & 1 deletion rfcs/20230622-quantized-reduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

Status: Review<br/>
Initial version: 06/22/2023<br/>
Last updated: 07/13/2023<br/>
updated: 07/13/2023<br/>: Minor refactoring of the examples.
Last updated: 08/11/2023<br/>: Revision of the proposal to introduce an
attribute to capture accumulation type.<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664)

## Version log
Expand Down Expand Up @@ -512,3 +514,92 @@ More formally:
element_type(input_types(output_conversion))`.
* (C11) `shape(operand) = shape(result)`.
<!-- markdownlint-enable line-length -->

## [11 Aug'23] Revised proposal

### Context

Option #2 should be avoided because it is hard to control the transformation
which might disrupt the pattern to be matched. The option #1 sounds good except
that the extra input/output conversion blocks are surplus information. The
specification would benefit if the intent of the conversion blocks can be
expressed precisely. The conversion blocks provides a way to capture the
accumulation type needed to compute the accumulative operation on.

The revised proposal is:

* To capture the accumulation type via an additional StableHLO attribute like
`accumulation_element_type`.
* The attribute seems beneficial for other ops as well like `dot_general` and
`convolution`.
* `precision_config`, currently used for `dot_general` and `convolution`, is
used to override the precision specified by the input parameters, allowing the
choice of low precision vs high precision computation. We should consider
adding `precision_config` to all reduction based op as well.

### Few implementation details

#### On StableHLO side

The reduce syntax to be augmented with a optional [type
attribute](https://github.com/llvm/llvm-project/blob/51a57074bc63842970c4c160b05c1a7e42db7523/mlir/include/mlir/IR/OpBase.td#L1466)
as follows:

```mlir
%0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] {
accumulation_type = tensor<!quant.uniform<i32:f32, 3.400000e+01:16>>
} : (tensor<16x!quant.uniform<i8:f32, 3.400000e+01:16>>, tensor<!quant.uniform<i8:f32, 3.400000e+01:16>>) -> tensor<!quant.uniform<i8:f32, 3.400000e+01:16>>
reducer(%arg2: tensor<!quant.uniform<i32:f32, 3.400000e+01:16>>, %arg3: tensor<!quant.uniform<i32:f32, 3.400000e+01:16>>) {
%1 = stablehlo.add %arg2, %arg3 : tensor<!quant.uniform<i32:f32, 3.400000e+01:16>>
stablehlo.return %1 : tensor<!quant.uniform<i32:f32, 3.400000e+01:16>>
}
// using tablegen specification like
// OptionalAttr<TypeAttrOf<HLO_Tensor>>:$accumulation_type
```

Note that the main difference between this option and the option #1 is that the
input and output conversion blocks are no longer used as their intent is
specified via the `accumulation_type` attribute. However, the reducer block
still needs to express the computation in accumulation type only.

**Why optional attribute?**

* At times, it might be desirable not to hard-code the accumulation type. For
example, when we would like to write a generic code and let the downstream
compilation tools to decide the exact accumulation type based on the hardware
of choice.
* It allows the stablehlo, used in various existing pipelines, to remain
largely unaffected by this change.

Next, the StableHLO specification should be updated with the syntax and
semantics aspects of this attribute.

#### On StableHLO Consumers side

The consumers can pattern match the op taking the accumulation type in account
if the targeted hardware supports accumulation at higher type.
There are still to explore things about maintaining StableHLO-HLO parity which
needs to be addresses as well.

### Action Plan

I propose to follow the action plan (order matters):

* Update the specification of ReduceOp, ReduceWindowOp, and SelectAndScatterOp
op, taking the accumulation type into account, via [open
pr](https://github.com/openxla/stablehlo/pull/1538).
* Finalize the quantized specification of AllReduceOp, BatchNormTrainingOp,
BatchNormGradOp and ReduceScatterOp, whose semantics depend on ReduceOp,
via [open ticket](https://github.com/openxla/stablehlo/issues/1666).
* Add implementation for additional attribute in the above ops. This includes
updating the tablegen spec/verifiers/type inferencers. [Need a ticket for this].
* Address the disparity between StableHLO and HLO because of the introduction of
this new attribute in StableHLO: Should/How XLA should consume this additional
attribute? [Need a ticket for this].
* Spec the behavior of `precision_config` in DotGeneralOp. [open
issue](https://github.com/openxla/stablehlo/issues/755)
* Consider adding `precision_config` in reduction op. [need a ticket for this
* Consider adding `accumulation_type` to `dot_general`/`convolution op`.
[need a ticket for this item].
item].

0 comments on commit 91ad66f

Please sign in to comment.