Skip to content

Commit

Permalink
Address feedback: fix header indentation and code block language
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Aug 11, 2023
1 parent f319357 commit 543639a
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions rfcs/20230622-quantized-reduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

Status: Review<br/>
Initial version: 06/22/2023<br/>
Last updated: 07/02/2023<br/>
Last updated: 07/13/2023<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1664)

## Version log

* 06/22/2023: Initial version.
* 07/13/2023: Fixed typo in code blocks, header indentation.

## Introduction

Expand Down Expand Up @@ -128,7 +129,7 @@ be applied to the non-leaf nodes of the schedule tree.
The `output_conversion` block is applied just after the `result` for a particular
index is computed as shown in the above diagram.

Please refer to the [formal spec](#specification-of-reduce-op) of the proposed
Please refer to the [formal spec](#revised-specification-of-reduce-op) of the proposed
reduce op.

### Implementation details
Expand Down Expand Up @@ -206,7 +207,7 @@ generic printing,
**even if the explicitly provided conversion functions are identity function**:
To avoid identification of identity functions which could be tricky in general.

#### Tradeoffs
### Tradeoffs

* (+) Enables programmers to program at (almost) baremetal. If the hardware
can support reduction computation in wider type (e.g. in the SIMD
Expand Down Expand Up @@ -248,13 +249,10 @@ type.
backend matches the pattern.
[for more information](https://github.com/openxla/stablehlo/pull/1538#issuecomment-1599476906)

## Other options considered

There is another option considered which did not fly well because of limited
extensibility. Adding it just for completeness purposes.

### Option 3: allow accumulator type to be different from input type
## Option 3: allow accumulator type to be different from input type

This is another option we considered which does not fly well because of limited
expressibility. Adding it just for completeness purposes.
The idea here is to convey the accumulator type using the `init_value` operand
of `reduce` op. The code snippet for `reduce` looks like:

Expand Down Expand Up @@ -284,7 +282,7 @@ from the input type. The first argument of the compute block is fixed for the
traversed element and the second argument is fixed for the intermediate
(accumulation) result.

#### Tradeoffs
### Tradeoffs

* (+) Make the accumulation type explicit in the IR.
* (-) This representation imposes a limitation on the evaluation order.
Expand Down Expand Up @@ -355,9 +353,8 @@ types.
To provide an estimate of specification changes needed to implement option #1
I have attempted to provide the blueprint here.

### Specification of reduce op
### Revised specification of reduce op

```python
#### Semantics

Applies a reduction functions `input_conversion`, `body`, and
Expand Down Expand Up @@ -426,15 +423,13 @@ map(output_conversion, reduce(input_slices_converted))` where:
element_type(input_types(output_conversion)...)`.
* (C?) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
```

The above specification of `reduce` op can be used to define the specification
of other ops as shown below. For brevity, we are only presenting the relevant
portions of the spec which needs modification.

### reduce_window
### Revised specification of reduce_window op

```python
#### Semantics

Applies a reduction functions `input_conversion`, `body`, and
Expand Down Expand Up @@ -471,17 +466,15 @@ where:
`E'i = element_type(results[i])`.
* (C?) `element_type(output_types(input_conversion)...) =
element_type(input_types(output_conversion)...)`.
```

### select_and_scatter
### Revised specification of select_and_scatter op

This op originally takes two function arguments `select` and `scatter`. As the
`select` function is supposed to perform a non-accumulative operation, we may
not need additional conversion functions associated with `select`. But the
`scatter` function needs be accompanied with `input_conversion` and
`output_conversion` functions.

```python
#### Semantics

Scatters the values from the `source` tensor using `scatter` based on the
Expand All @@ -490,6 +483,7 @@ a `result` tensor.

More formally:
...

* `result[result_index] = reduce([source_values], [init_value], [0],
input_conversion, scatter, output_conversion)`
where:
Expand Down Expand Up @@ -519,4 +513,3 @@ More formally:
element_type(input_types(output_conversion))`.
* (C11) `shape(operand) = shape(result)`.
<!-- markdownlint-enable line-length -->
```

0 comments on commit 543639a

Please sign in to comment.