diff --git a/docs/spec.md b/docs/spec.md index c18409cfeb4..07bb315af27 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -778,18 +778,19 @@ defined as follows: Afterwards, within each `process_group`: -* `result@process[result_index] = convert_or_quantize(exec(schedule), type(result))` - for some binary tree `schedule` where: +* `result@process[result_index] = convert_or_quantize_or_dequantize(exec( + schedule), type(result))` for some binary tree `schedule` where: * `exec(node)` = `computation(exec(node.left), exec(node.right))`. * `exec(leaf)` = `leaf.value`. * `schedule` is an implementation-defined binary tree whose in-order - traversal is `convert_or_quantize(operands@process_group...[result_index], type(func_inputs(computation)[0]))`. + traversal is `convert_or_quantize_or_dequantize( + operands@process_group...[result_index], type(func_inputs(computation)[0]))`. #### Inputs | Label | Name | Type | Constraints | |-------|-------------------------|------------------------------------------------------------------|-------------| -| (I1) | `operand` | tensor or quantized tensor | (C5), (C6) | +| (I1) | `operand` | tensor or per-tensor quantized tensor | (C5), (C6) | | (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) | | (I3) | `channel_id` | constant of type `si64` | (C4) | | (I4) | `use_global_device_ids` | constant of type `i1` | (C4) | @@ -797,9 +798,9 @@ Afterwards, within each `process_group`: #### Outputs -| Name | Type | Constraints | -|----------|----------------------------|-------------| -| `result` | tensor or quantized tensor | (C6) | +| Name | Type | Constraints | +|----------|---------------------------------------|-------------| +| `result` | tensor or per-tensor quantized tensor | (C6) | #### Constraints @@ -811,7 +812,7 @@ Afterwards, within each `process_group`: * (C3) `0 <= replica_groups < size(replica_groups)`. * (C4) If `use_global_device_ids = true`, then `channel_id > 0`. * (C5) `computation` has type `(tensor, tensor) -> (tensor)` where - `is_convertible_or_quantizable(E, element_type(operand))`. + `is_promotable(E, element_type(operand))`. * (C6) `baseline_type(result) = baseline_type(operand)`. #### Examples @@ -4016,12 +4017,12 @@ doesn't hold for many popular reductions. E.g. floating-point addition for `body` and zero for `init_values` don't actually form a monoid because floating-point addition is not associative. -More formally, `results...[j0, ..., jR-1] = convert_or_quantize(reduce( -input_slices_converted), type(results...)))` where: +More formally, `results...[j0, ..., jR-1] = convert_or_quantize_or_dequantize( +reduce(input_slices_converted), type(results...)))` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. -* `input_slices_converted = convert_or_quantize(input_slices..., +* `input_slices_converted = convert_or_quantize_or_dequantize(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)`. * `reduce(input_slices_converted) = exec(schedule)` for some binary tree `schedule` where: @@ -4033,8 +4034,9 @@ input_slices_converted), type(results...)))` where: `index_space(input_slices_converted)` in the ascending lexicographic order of `index`. * Interspersed with an implementation-defined amount of - `convert_or_quantize(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)` - at implementation-defined positions. + `convert_or_quantize_or_dequantize(init_values..., + type(func_inputs(body)[len(func_inputs(body))//2:])...)` at + implementation-defined positions. #### Inputs @@ -4058,9 +4060,9 @@ input_slices_converted), type(results...)))` where: * (C3) `0 < size(inputs) = size(init_values) = size(results) = N`. * (C4) `0 <= dimensions < rank(inputs[0])`. * (C5) `is_unique(dimensions)`. -* (C6) `body` has type `tensor, ..., tensor, tensor, ...,` +* (C6) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`. + `is_promotable(element_type(inputs[i]), element_type(E[i]))`. * (C7) `shape(results...) = shape(inputs...)` except that the dimension sizes of `inputs...` corresponding to `dimensions` are not included. * (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`. @@ -4174,7 +4176,7 @@ Afterwards, within each `process_group`: | Label | Name | Type | Constraints | |-------|-------------------------|----------------------------------------------|------------------------| -| (I1) | `operand` | tensor or quantized tensor | (C1), (C2), (C7), (C8) | +| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C2), (C7), (C8) | | (I2) | `scatter_dimension` | constant of type `si64` | (C1), (C2), (C8) | | (I3) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C3-C5) | | (I4) | `channel_id` | constant of type `si64` | (C6) | @@ -4183,9 +4185,9 @@ Afterwards, within each `process_group`: #### Outputs -| Name | Type | Constraints | -|----------|----------------------------|-------------| -| `result` | tensor or quantized tensor | (C8) | +| Name | Type | Constraints | +|----------|---------------------------------------|-------------| +| `result` | tensor or per-tensor quantized tensor | (C8) | #### Constraints @@ -4199,7 +4201,7 @@ Afterwards, within each `process_group`: * (C5) `0 <= replica_groups < size(replica_groups)`. * (C6) If `use_global_device_ids = true`, then `channel_id > 0`. * (C7) `computation` has type `(tensor, tensor) -> (tensor)` where - `is_convertible_or_quantizable(E, element_type(operand))`. + `is_promotable(E, element_type(operand))`. * (C8) `baseline_type(result) = baseline_type(operand)` except: * `dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)`. @@ -4290,7 +4292,7 @@ where: * (C12) `shape(padding) = [rank(inputs[0]), 2]`. * (C13) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`. + `is_promotable(element_type(inputs[i]), element_type(E[i]))`. * (C14) `same(shape(results...))`. * (C15) `shape(results[0]) = num_windows` where: * `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`. @@ -4758,11 +4760,13 @@ Given that, `results = exec(schedule, inputs)`, where: * `exec([update_index, ...], results) = exec([...], updated_results)` where: * If `result_index` is in bounds for `shape(results...)` * `updated_values = update_computation( - convert_or_quantize(results...[result_index], type(func_inputs( - update_computation)[:len(func_inputs(update_computation))//2])... ), - convert_or_quantize(updates...[update_index], type(func_inputs( - update_computation)[len(func_inputs(update_computation))//2:])... ))` - * `updated_values_converted = convert_or_quantize( + convert_or_quantize_or_dequantize(results...[result_index], type( + func_inputs(update_computation)[:len(func_inputs(update_computation))//2]) + ... ), + convert_or_quantize_or_dequantize(updates...[update_index], type( + func_inputs(update_computation)[len(func_inputs(update_computation))//2:]) + ... ))` + * `updated_values_converted = convert_or_quantize_or_dequantize( updated_values, type(results...))` * `updated_results` is a copy of `results` with `results...[result_index]` set to `updated_values_converted...`. @@ -4832,7 +4836,7 @@ undefined. * (C14) `0 <= index_vector_dim <= rank(scatter_indices)`. * (C15) `update_computation` has type `(tensor, ..., tensor, tensor, ..., tensor) -> (tensor, ..., tensor)`, - where `is_convertible_or_quantizable(Ei, element_type(inputs[i]))`. + where `is_promotable(Ei, element_type(inputs[i]))`. * (C16) `baseline_type(inputs...) = baseline_type(results...)`. #### Examples @@ -4989,7 +4993,7 @@ More formally: * (C9) `select` has type `(tensor, tensor) -> tensor` where `E = element_type(operand)`. * (C10) `scatter` has type `(tensor, tensor) -> tensor` where - `is_convertible_or_quantizable(element_type(operand), element_type(E))`. + `is_promotable(element_type(operand), element_type(E))`. * (C11) `baseline_type(operand) = baseline_type(result)`. @@ -6267,6 +6271,20 @@ def element_type(x: Value | Placeholder | Type): return element_type(type(x)) ``` +* `is_promotable(x: Type, y: Type) -> bool` checks if type `x` can be promoted +to type `y`. When `x` and `y` are `QuantizedTensorElementType`s, the promotion +is applied only to the `storage_type`. This specific version of promotion is +currently used in context of reduction computation (refer to +[RFC](https://github.com/openxla/stablehlo/pull/1664) for more details). + +```python +def is_promotable(x: Type, y: Type) -> Value: + return (is_integer(x) and is_integer(y)) or + (is_float(x) and is_float(y)) or + (is_complex(x) and is_complex(y)) or + (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y)) +``` + * `is_per_axis_quantized(x: Value | Placeholder | Type) -> Value` is a shortcut for `is_quantized(x) and quantization_dimension(x) is not None`. @@ -6294,19 +6312,6 @@ If `x` is a value or placeholder, this function is a shortcut for `member_name(type(x))`. If `x` is not a type that has an appropriate member, or a value or a placeholder of such a type, returns `None`. -* `is_convertible_or_quantizable(x: Type, y: Type) -> bool` checks for the -equality of `x` and `y`, ignoring the bitwidth, when they are of type -`TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s, -the function checks for the equality of `QuantizationExpressedType` component. - -```python -def is_convertible_or_quantizable(x: Type, y: Type) -> Value: - return is_integer(x) = is_integer(y) or - is_float(x) = is_float(y) or - is_complex(x) = is_complex(y) or - (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y)) -``` - #### Construction of values * `operation_name(*xs: Value | Type) -> Value`. Available for all operations. @@ -6324,12 +6329,12 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings) notations from Python are available to index into tensors, quantized tensors and tuples. -* `convert_or_quantize(x: Value, destination_type: Type) -> Value` is defined on -tensors and returns the converted value of `x` based on the `type(x)` and -`destination_type` as follows: +* `convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value` +is defined on tensors and returns the converted value of `x` based on the +`type(x)` and `destination_type` as follows: ```python -def convert_or_quantize(x: Value, destination_type: Type) -> Value: +def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value: if type(x) == destination_type: return x