diff --git a/docs/spec.md b/docs/spec.md index 0173673f76c..de009fd7fd0 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -778,13 +778,13 @@ defined as follows: Afterwards, within each `process_group`: -* `result@process[result_index] = convert_or_quantize_or_dequantize(exec( - schedule), type(result))` for some binary tree `schedule` where: +* `result@process[result_index] = to_destination_type(exec(schedule), + type(result))` for some binary tree `schedule` where: * `exec(node)` = `computation(exec(node.left), exec(node.right))`. * `exec(leaf)` = `leaf.value`. * `schedule` is an implementation-defined binary tree whose in-order - traversal is `convert_or_quantize_or_dequantize( - operands@process_group...[result_index], type(func_inputs(computation)[0]))`. + traversal is `to_destination_type(operands@process_group...[result_index], + type(func_inputs(computation)[0]))`. #### Inputs @@ -1092,11 +1092,8 @@ grad_output, type(grad_operand), type(grad_scale), type(feature_index))`. #### Constraints * (C1) `0 <= feature_index < rank(operand)`. -* (C2) `baseline_element_type(operand) = baseline_element_type(scale) = - baseline_element_type(mean) = baseline_element_type(variance) = - baseline_element_type(grad_output) = baseline_element_type(grad_operand) - = baseline_element_type(grad_scale) = - baseline_element_type(grad_offset)`. +* (C2) `operand`, `scale`, `mean`, `variance`, `grad_output`, `grad_operand`, + `grad_scale` and `grad_offset` have the same `baseline_element_type`. * (C3) `operand`, `grad_output` and `grad_operand` have the same shape. * (C4) `scale`, `mean`, `variance`, `grad_scale` and `grad_offset` have the same shape. @@ -1183,9 +1180,8 @@ feature_index), operand, scale, offset, mean, variance, type(result))`. #### Constraints * (C1) `0 <= feature_index < rank(operand)`. -* (C2) `baseline_element_type(operand) = baseline_element_type(scale) = - baseline_element_type(offset) = baseline_element_type(mean) = - baseline_element_type(variance) = baseline_element_type(result)`. +* (C2) `operand`, `scale`, `offset`, `mean`, `variance` and `result` have the + same `baseline_element_type`. * (C3) `size(scale) = dim(operand, feature_index)`. * (C4) `size(offset) = dim(operand, feature_index)`. * (C5) `size(mean) = dim(operand, feature_index)`. @@ -1275,9 +1271,8 @@ scale, offset, type(output), type(batch_mean), type(batch_var))`. #### Constraints * (C1) `0 <= feature_index < rank(operand)`. -* (C2) `baseline_element_type(operand) = baseline_element_type(scale) = - baseline_element_type(offset) = baseline_element_type(batch_mean) = - baseline_element_type(batch_var) = baseline_element_type(output)`. +* (C2) `operand`, `scale`, `offset`, `batch_mean`, `batch_var` and `output` have + the same `baseline_element_type`. * (C3) `size(scale) = dim(operand, feature_index)`. * (C4) `size(offset) = dim(operand, feature_index)`. * (C5) `size(batch_mean) = dim(operand, feature_index)`. @@ -4017,13 +4012,15 @@ doesn't hold for many popular reductions. E.g. floating-point addition for `body` and zero for `init_values` don't actually form a monoid because floating-point addition is not associative. -More formally, `results...[j0, ..., jR-1] = convert_or_quantize_or_dequantize( +More formally, `results...[j0, ..., jR-1] = to_destination_type( reduce(input_slices_converted), type(results...)))` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. -* `input_slices_converted = convert_or_quantize_or_dequantize(input_slices..., +* `input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)`. +* `init_values_converted = to_destination_type(init_values..., + type(func_inputs(body)[len(func_inputs(body))//2:])...)`. * `reduce(input_slices_converted) = exec(schedule)` for some binary tree `schedule` where: * `exec(node) = body(exec(node.left), exec(node.right))`. @@ -4034,9 +4031,7 @@ reduce(input_slices_converted), type(results...)))` where: `index_space(input_slices_converted)` in the ascending lexicographic order of `index`. * Interspersed with an implementation-defined amount of - `convert_or_quantize_or_dequantize(init_values..., - type(func_inputs(body)[len(func_inputs(body))//2:])...)` at - implementation-defined positions. + `init_values_converted` at implementation-defined positions. #### Inputs @@ -4759,14 +4754,14 @@ Given that, `results = exec(schedule, inputs)`, where: `index_space(updates[0])`. * `exec([update_index, ...], results) = exec([...], updated_results)` where: * If `result_index` is in bounds for `shape(results...)` - * `updated_values = update_computation( - convert_or_quantize_or_dequantize(results...[result_index], type( - func_inputs(update_computation)[:len(func_inputs(update_computation))//2]) - ... ), - convert_or_quantize_or_dequantize(updates...[update_index], type( - func_inputs(update_computation)[len(func_inputs(update_computation))//2:]) - ... ))` - * `updated_values_converted = convert_or_quantize_or_dequantize( + * `results_converted = to_destination_type( + results...[result_index], type(func_inputs(update_computation) + [:len(func_inputs(update_computation))//2])... )` + * `updates_converted = to_destination_type( + updates...[update_index], type(func_inputs(update_computation) + [len(func_inputs(update_computation))//2:])... )` + * `updated_values = update_computation(result_converted, updates_converted)` + * `updated_values_converted = to_destination_type( updated_values, type(results...))` * `updated_results` is a copy of `results` with `results...[result_index]` set to `updated_values_converted...`. @@ -6336,12 +6331,12 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings) notations from Python are available to index into tensors, quantized tensors and tuples. -* `convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value` -is defined on tensors and returns the converted value of `x` based on the -`type(x)` and `destination_type` as follows: +* `to_destination_type(x: Value, destination_type: Type) -> Value` is defined on +tensors and returns the converted value of `x` based on the `type(x)` and +`destination_type` as follows: ```python -def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value: +def to_destination_type(x: Value, destination_type: Type) -> Value: if type(x) == destination_type: return x @@ -6358,10 +6353,10 @@ def convert_or_quantize_or_dequantize(x: Value, destination_type: Type) -> Value return convert(x, destination_type) ``` -There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize` -operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After -the merge we do not need the above function and can use the operation name for -`convert` instead. +There is early discussion on merging `convert`, `uniform_quantize` and +`uniform_dequantize` operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). +After the merge we do not need the above function and can use the operation name +for `convert` instead. * `is_nan(x: Value) -> Value` is defined on tensors and returns `true` if all elements of `x` are `NaN` or `false` otherwise. If `x` is not a tensor,