Skip to content

Commit

Permalink
Iter6: Address feedback comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Nov 16, 2023
1 parent 1715790 commit a8b8a99
Showing 1 changed file with 31 additions and 36 deletions.
67 changes: 31 additions & 36 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,13 @@ defined as follows:

Afterwards, within each `process_group`:

* `result@process[result_index] = convert_or_quantize_or_dequantize(exec(
schedule), type(result))` for some binary tree `schedule` where:
* `result@process[result_index] = to_destination_type(exec(schedule),
type(result))` for some binary tree `schedule` where:
* `exec(node)` = `computation(exec(node.left), exec(node.right))`.
* `exec(leaf)` = `leaf.value`.
* `schedule` is an implementation-defined binary tree whose in-order
traversal is `convert_or_quantize_or_dequantize(
operands@process_group...[result_index], type(func_inputs(computation)[0]))`.
traversal is `to_destination_type(operands@process_group...[result_index],
type(func_inputs(computation)[0]))`.

#### Inputs

Expand Down Expand Up @@ -1092,11 +1092,8 @@ grad_output, type(grad_operand), type(grad_scale), type(feature_index))`.
#### Constraints

* (C1) `0 <= feature_index < rank(operand)`.
* (C2) `baseline_element_type(operand) = baseline_element_type(scale) =
baseline_element_type(mean) = baseline_element_type(variance) =
baseline_element_type(grad_output) = baseline_element_type(grad_operand)
= baseline_element_type(grad_scale) =
baseline_element_type(grad_offset)`.
* (C2) `operand`, `scale`, `mean`, `variance`, `grad_output`, `grad_operand`,
`grad_scale` and `grad_offset` have the same `baseline_element_type`.
* (C3) `operand`, `grad_output` and `grad_operand` have the same shape.
* (C4) `scale`, `mean`, `variance`, `grad_scale` and `grad_offset` have the
same shape.
Expand Down Expand Up @@ -1183,9 +1180,8 @@ feature_index), operand, scale, offset, mean, variance, type(result))`.
#### Constraints

* (C1) `0 <= feature_index < rank(operand)`.
* (C2) `baseline_element_type(operand) = baseline_element_type(scale) =
baseline_element_type(offset) = baseline_element_type(mean) =
baseline_element_type(variance) = baseline_element_type(result)`.
* (C2) `operand`, `scale`, `offset`, `mean`, `variance` and `result` have the
same `baseline_element_type`.
* (C3) `size(scale) = dim(operand, feature_index)`.
* (C4) `size(offset) = dim(operand, feature_index)`.
* (C5) `size(mean) = dim(operand, feature_index)`.
Expand Down Expand Up @@ -1275,9 +1271,8 @@ scale, offset, type(output), type(batch_mean), type(batch_var))`.
#### Constraints

* (C1) `0 <= feature_index < rank(operand)`.
* (C2) `baseline_element_type(operand) = baseline_element_type(scale) =
baseline_element_type(offset) = baseline_element_type(batch_mean) =
baseline_element_type(batch_var) = baseline_element_type(output)`.
* (C2) `operand`, `scale`, `offset`, `batch_mean`, `batch_var` and `output` have
the same `baseline_element_type`.
* (C3) `size(scale) = dim(operand, feature_index)`.
* (C4) `size(offset) = dim(operand, feature_index)`.
* (C5) `size(batch_mean) = dim(operand, feature_index)`.
Expand Down Expand Up @@ -4017,13 +4012,15 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = convert_or_quantize_or_dequantize(
More formally, `results...[j0, ..., jR-1] = to_destination_type(
reduce(input_slices_converted), type(results...)))` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
* `input_slices_converted = convert_or_quantize_or_dequantize(input_slices...,
* `input_slices_converted = to_destination_type(input_slices...,
type(func_inputs(body)[:len(func_inputs(body))//2])...)`.
* `init_values_converted = to_destination_type(init_values...,
type(func_inputs(body)[len(func_inputs(body))//2:])...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node) = body(exec(node.left), exec(node.right))`.
Expand All @@ -4034,9 +4031,7 @@ reduce(input_slices_converted), type(results...)))` where:
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
* Interspersed with an implementation-defined amount of
`convert_or_quantize_or_dequantize(init_values...,
type(func_inputs(body)[len(func_inputs(body))//2:])...)` at
implementation-defined positions.
`init_values_converted` at implementation-defined positions.

#### Inputs

Expand Down Expand Up @@ -4759,14 +4754,14 @@ Given that, `results = exec(schedule, inputs)`, where:
`index_space(updates[0])`.
* `exec([update_index, ...], results) = exec([...], updated_results)` where:
* If `result_index` is in bounds for `shape(results...)`
* `updated_values = update_computation(
convert_or_quantize_or_dequantize(results...[result_index], type(
func_inputs(update_computation)[:len(func_inputs(update_computation))//2])
... ),
convert_or_quantize_or_dequantize(updates...[update_index], type(
func_inputs(update_computation)[len(func_inputs(update_computation))//2:])
... ))`
* `updated_values_converted = convert_or_quantize_or_dequantize(
* `results_converted = to_destination_type(
results...[result_index], type(func_inputs(update_computation)
[:len(func_inputs(update_computation))//2])... )`
* `updates_converted = to_destination_type(
updates...[update_index], type(func_inputs(update_computation)
[len(func_inputs(update_computation))//2:])... )`
* `updated_values = update_computation(result_converted, updates_converted)`
* `updated_values_converted = to_destination_type(
updated_values, type(results...))`
* `updated_results` is a copy of `results` with `results...[result_index]`
set to `updated_values_converted...`.
Expand Down Expand Up @@ -6336,12 +6331,12 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings)
notations from Python are available to index into tensors, quantized tensors
and tuples.

* `convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value`
is defined on tensors and returns the converted value of `x` based on the
`type(x)` and `destination_type` as follows:
* `to_destination_type(x: Value, destination_type: Type) -> Value` is defined on
tensors and returns the converted value of `x` based on the `type(x)` and
`destination_type` as follows:

```python
def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x

Expand All @@ -6358,10 +6353,10 @@ def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value
return convert(x, destination_type)
```

There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize`
operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After
the merge we do not need the above function and can use the operation name for
`convert` instead.
There is early discussion on merging `convert`, `uniform_quantize` and
`uniform_dequantize` operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)).
After the merge we do not need the above function and can use the operation name
for `convert` instead.

* `is_nan(x: Value) -> Value` is defined on tensors and returns `true` if
all elements of `x` are `NaN` or `false` otherwise. If `x` is not a tensor,
Expand Down

0 comments on commit a8b8a99

Please sign in to comment.