Skip to content

Commit

Permalink
Iter3: Address feedback comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Oct 13, 2023
1 parent a6aba14 commit 2d4f432
Showing 1 changed file with 50 additions and 45 deletions.
95 changes: 50 additions & 45 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -778,28 +778,29 @@ defined as follows:

Afterwards, within each `process_group`:

* `result@process[result_index] = convert_or_quantize(exec(schedule), type(result))`
for some binary tree `schedule` where:
* `result@process[result_index] = convert_or_quantize_or_dequantize(exec(
schedule), type(result))` for some binary tree `schedule` where:
* `exec(node)` = `computation(exec(node.left), exec(node.right))`.
* `exec(leaf)` = `leaf.value`.
* `schedule` is an implementation-defined binary tree whose in-order
traversal is `convert_or_quantize(operands@process_group...[result_index], type(func_inputs(computation)[0]))`.
traversal is `convert_or_quantize_or_dequantize(
operands@process_group...[result_index], type(func_inputs(computation)[0]))`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C5), (C6) |
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C5), (C6) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) |
| (I3) | `channel_id` | constant of type `si64` | (C4) |
| (I4) | `use_global_device_ids` | constant of type `i1` | (C4) |
| (I5) | `computation` | function | (C5) |

#### Outputs

| Name | Type | Constraints |
|----------|----------------------------|-------------|
| `result` | tensor or quantized tensor | (C6) |
| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C6) |

#### Constraints

Expand All @@ -811,7 +812,7 @@ Afterwards, within each `process_group`:
* (C3) `0 <= replica_groups < size(replica_groups)`.
* (C4) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C5) `computation` has type `(tensor<E>, tensor<E>) -> (tensor<E>)` where
`is_convertible_or_quantizable(E, element_type(operand))`.
`is_promotable(E, element_type(operand))`.
* (C6) `baseline_type(result) = baseline_type(operand)`.

#### Examples
Expand Down Expand Up @@ -4016,12 +4017,12 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = convert_or_quantize(reduce(
input_slices_converted), type(results...)))` where:
More formally, `results...[j0, ..., jR-1] = convert_or_quantize_or_dequantize(
reduce(input_slices_converted), type(results...)))` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
* `input_slices_converted = convert_or_quantize(input_slices...,
* `input_slices_converted = convert_or_quantize_or_dequantize(input_slices...,
type(func_inputs(body)[:len(func_inputs(body))//2])...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
Expand All @@ -4033,8 +4034,9 @@ input_slices_converted), type(results...)))` where:
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
* Interspersed with an implementation-defined amount of
`convert_or_quantize(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)`
at implementation-defined positions.
`convert_or_quantize_or_dequantize(init_values...,
type(func_inputs(body)[len(func_inputs(body))//2:])...)` at
implementation-defined positions.

#### Inputs

Expand All @@ -4058,9 +4060,9 @@ input_slices_converted), type(results...)))` where:
* (C3) `0 < size(inputs) = size(init_values) = size(results) = N`.
* (C4) `0 <= dimensions < rank(inputs[0])`.
* (C5) `is_unique(dimensions)`.
* (C6) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
* (C6) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`.
`is_promotable(element_type(inputs[i]), element_type(E[i]))`.
* (C7) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
* (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
Expand Down Expand Up @@ -4174,7 +4176,7 @@ Afterwards, within each `process_group`:

| Label | Name | Type | Constraints |
|-------|-------------------------|----------------------------------------------|------------------------|
| (I1) | `operand` | tensor or quantized tensor | (C1), (C2), (C7), (C8) |
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C2), (C7), (C8) |
| (I2) | `scatter_dimension` | constant of type `si64` | (C1), (C2), (C8) |
| (I3) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C3-C5) |
| (I4) | `channel_id` | constant of type `si64` | (C6) |
Expand All @@ -4183,9 +4185,9 @@ Afterwards, within each `process_group`:

#### Outputs

| Name | Type | Constraints |
|----------|----------------------------|-------------|
| `result` | tensor or quantized tensor | (C8) |
| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C8) |

#### Constraints

Expand All @@ -4199,7 +4201,7 @@ Afterwards, within each `process_group`:
* (C5) `0 <= replica_groups < size(replica_groups)`.
* (C6) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C7) `computation` has type `(tensor<E>, tensor<E>) -> (tensor<E>)` where
`is_convertible_or_quantizable(E, element_type(operand))`.
`is_promotable(E, element_type(operand))`.
* (C8) `baseline_type(result) = baseline_type(operand)` except:
* `dim(result, scatter_dimension) = dim(operand, scatter_dimension) /
dim(process_groups, 1)`.
Expand Down Expand Up @@ -4290,7 +4292,7 @@ where:
* (C12) `shape(padding) = [rank(inputs[0]), 2]`.
* (C13) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`.
`is_promotable(element_type(inputs[i]), element_type(E[i]))`.
* (C14) `same(shape(results...))`.
* (C15) `shape(results[0]) = num_windows` where:
* `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
Expand Down Expand Up @@ -4758,11 +4760,13 @@ Given that, `results = exec(schedule, inputs)`, where:
* `exec([update_index, ...], results) = exec([...], updated_results)` where:
* If `result_index` is in bounds for `shape(results...)`
* `updated_values = update_computation(
convert_or_quantize(results...[result_index], type(func_inputs(
update_computation)[:len(func_inputs(update_computation))//2])... ),
convert_or_quantize(updates...[update_index], type(func_inputs(
update_computation)[len(func_inputs(update_computation))//2:])... ))`
* `updated_values_converted = convert_or_quantize(
convert_or_quantize_or_dequantize(results...[result_index], type(
func_inputs(update_computation)[:len(func_inputs(update_computation))//2])
... ),
convert_or_quantize_or_dequantize(updates...[update_index], type(
func_inputs(update_computation)[len(func_inputs(update_computation))//2:])
... ))`
* `updated_values_converted = convert_or_quantize_or_dequantize(
updated_values, type(results...))`
* `updated_results` is a copy of `results` with `results...[result_index]`
set to `updated_values_converted...`.
Expand Down Expand Up @@ -4832,7 +4836,7 @@ undefined.
* (C14) `0 <= index_vector_dim <= rank(scatter_indices)`.
* (C15) `update_computation` has type `(tensor<E0>, ..., tensor<EN-1>,
tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)`,
where `is_convertible_or_quantizable(Ei, element_type(inputs[i]))`.
where `is_promotable(Ei, element_type(inputs[i]))`.
* (C16) `baseline_type(inputs...) = baseline_type(results...)`.

#### Examples
Expand Down Expand Up @@ -4989,7 +4993,7 @@ More formally:
* (C9) `select` has type `(tensor<E>, tensor<E>) -> tensor<i1>` where
`E = element_type(operand)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`is_convertible_or_quantizable(element_type(operand), element_type(E))`.
`is_promotable(element_type(operand), element_type(E))`.
* (C11) `baseline_type(operand) = baseline_type(result)`.
<!-- markdownlint-enable line-length -->

Expand Down Expand Up @@ -6267,6 +6271,20 @@ def element_type(x: Value | Placeholder | Type):
return element_type(type(x))
```

* `is_promotable(x: Type, y: Type) -> bool` checks if type `x` can be promoted
to type `y`. When `x` and `y` are `QuantizedTensorElementType`s, the promotion
is applied only to the `storage_type`. This specific version of promotion is
currently used in context of reduction computation (refer to
[RFC](https://github.com/openxla/stablehlo/pull/1664) for more details).

```python
def is_promotable(x: Type, y: Type) -> Value:
return (is_integer(x) and is_integer(y)) or
(is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
```

* `is_per_axis_quantized(x: Value | Placeholder | Type) -> Value` is a shortcut
for `is_quantized(x) and quantization_dimension(x) is not None`.

Expand Down Expand Up @@ -6294,19 +6312,6 @@ If `x` is a value or placeholder, this function is a shortcut for
`member_name(type(x))`. If `x` is not a type that has an appropriate member, or
a value or a placeholder of such a type, returns `None`.

* `is_convertible_or_quantizable(x: Type, y: Type) -> bool` checks for the
equality of `x` and `y`, ignoring the bitwidth, when they are of type
`TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s,
the function checks for the equality of `QuantizationExpressedType` component.

```python
def is_convertible_or_quantizable(x: Type, y: Type) -> Value:
return is_integer(x) = is_integer(y) or
is_float(x) = is_float(y) or
is_complex(x) = is_complex(y) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
```

#### Construction of values

* `operation_name(*xs: Value | Type) -> Value`. Available for all operations.
Expand All @@ -6324,12 +6329,12 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings)
notations from Python are available to index into tensors, quantized tensors
and tuples.

* `convert_or_quantize(x: Value, destination_type: Type) -> Value` is defined on
tensors and returns the converted value of `x` based on the `type(x)` and
`destination_type` as follows:
* `convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value`
is defined on tensors and returns the converted value of `x` based on the
`type(x)` and `destination_type` as follows:

```python
def convert_or_quantize(x: Value, destination_type: Type) -> Value:
def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x

Expand Down

0 comments on commit 2d4f432

Please sign in to comment.