Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for ConvolutionOp #1314

Closed
wants to merge 6 commits into from
Closed

Conversation

ghpvnist
Copy link
Member

@ghpvnist ghpvnist commented Mar 11, 2023

We have the following constraints in the spec (excluding quantization-related constraints C28-C33):

(I1) `lhs` tensor.
(I2) `rhs` tensor.
(I3) `window_strides` 1-dimensional tensor constant of type `si64`.
(I4) `padding` 2-dimensional tensor constant of type `si64`.
(I5) `lhs_dilation` 1-dimensional tensor constant of type `si64`.
(I6) `rhs_dilation` 1-dimensional tensor constant of type `si64`.
(I7) `window_reversal` 1-dimensional tensor constant of type `i1`.
(I8) `input_batch_dimension` constant of type `si64`.
(I9) `input_feature_dimension` constant of type `si64`.
(I10) `input_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I11) `kernel_input_feature_dimension` constant of type `si64`.
(I12) `kernel_output_feature_dimension` constant of type `si64`.
(I13) `kernel_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I14) `output_batch_dimension` constant of type `si64`.
(I15) `output_feature_dimension` constant of type `si64`.
(I16) `output_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I17) `feature_group_count` constant of type `si64`.
(I18) `batch_group_count` constant of type `si64`.
(I19) `precision_config` variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`.
(C1) `N = rank(lhs) = rank(rhs)`.
(C2) `size(window_strides) = N - 2`.
(C3) `0 < window_strides`.
(C4) `shape(padding) = [N - 2, 2]`.
(C5) `size(lhs_dilation) = N - 2`.
(C6) `0 < lhs_dilation`.
(C7) `size(rhs_dilation) = N - 2`.
(C8) `0 < rhs_dilation`.
(C9) `size(window_reversal) = N - 2`.
(C10) `dim(lhs, input_batch_dimension) % batch_group_count = 0`.
(C11) `dim(lhs, input_feature_dimension) % feature_group_count = 0`.
(C12) `size(input_spatial_dimensions) = N - 2`.
(C13) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
* `is_unique(input_dimensions)`.
* `0 <= input_dimensions < N`.
(C14) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`.
(C15) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`.
(C16) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`.
(C17) `size(kernel_spatial_dimensions) = N - 2`.
(C18) Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
* `is_unique(kernel_dimensions)`.
* `0 <= kernel_dimensions < N`.
(C19) `size(output_spatial_dimensions) = N - 2`.
(C20) Given `output_dimensions = [output_batch_dimension] +
      output_spatial_dimensions + [output_feature_dimension]`:
* `is_unique(output_dimensions)`.
* `0 <= output_dimensions < N`.
(C21) `0 < feature_group_count`.
(C22) `0 < batch_group_count`.
(C23) `feature_group_count = 1 or batch_group_count = 1`.
(C24) `size(precision_config) = 2`.
(C25) `dim(result, result_dim)` is defined as:
* `dim(lhs, input_batch_dimension) / batch_group_count` if `result_dim = output_batch_dimension`.
* `dim(rhs, kernel_output_feature_dimension)` if `result_dim = output_feature_dimension`.
* `num_windows` otherwise, where:
  * `output_spatial_dimensions[spatial_dim] = result_dim`.
  * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
  * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
  * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
  * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
  * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
  * `is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]`.
  * `num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
(C26) `rank(result) = N`.
(C27) `element_type(lhs) = element_type(rhs) = element_type(result)`.

These constraints will be comprehensively covered by the following tests:

I1: a) `lhs` tensor. (Covered by ODS).
I2: a) `rhs` tensor. (Covered by ODS).
I3: a) `window_strides` is not a 1-dimensional tensor.
    b) element_type(`window_strides`) != `si64`. (Covered by ODS).
I4: a) `padding` is not a 2-dimensional tensor.
    b) element_type(`padding`) != `si64`. (Covered by ODS).
I5: a) `lhs_dilation` is not a 1-dimensional tensor.
    b)  element_type(`lhs_dilation`) != `si64`. (Covered by ODS).
I6: a) `rhs_dilation` is not a 1-dimensional tensor.
    b) element_type(`rhs_dilation`) != `si64`. (Covered by ODS).
I7: a) `window_reversal` is not a 1-dimensional tensor.
    b) element_type(`window_reversal`) != `i1`. (Covered by ODS).
I8: a) element_type(`input_batch_dimension`) != `si64`. (Covered by ODS).
I9: a) element_type(`input_feature_dimension`) != `si64`. (Covered by ODS).
I10: a) `input_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`input_spatial_dimensions`) != `si64`. (Covered by ODS).
I11: a) element_type(`kernel_input_feature_dimension`) != `si64`. (Covered by ODS).
I12: a) element_type(`kernel_output_feature_dimension`) != `si64`. (Covered by ODS).
I13: a) `kernel_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`kernel_spatial_dimensions`) != `si64`. (Covered by ODS).
I14: a) element_type(`output_batch_dimension`) != `si64`. (Covered by ODS).
I15: a) element_type(`output_feature_dimension`) != `si64`. (Covered by ODS).
I16: a) `output_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`output_spatial_dimensions`) != `si64`. (Covered by ODS).
I17: a) element_type(`feature_group_count`) != `si64`. (Covered by ODS).
I18: a) element_type(`batch_group_count`) != `si64`. (Covered by ODS).
I19: a) `precision_config` does not have variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS).
C1: a) N = rank(`lhs`) != rank(`rhs`).
C2: a) size(`window_strides`) != N - 2.
C3: a) `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)).
C4: a) dim(`padding`, 0) != N - 2.
    b) dim(`padding`, 1) != 2.
C5: a) size(`lhs_dilation`) != N - 2.
C6: a) `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)).
C7: a) size(`rhs_dilation`) != N - 2.
C8: a) `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)).
C9: a) size(`window_reversal`) != N - 2.
C10: a) `dim(lhs, input_batch_dimension) % batch_group_count != 0`.
C11: a) `dim(lhs, input_feature_dimension) % feature_group_count != 0`.
C12: a) size(`input_spatial_dimensions`) != N - 2.
C13: a) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * Any dimensions in `input_dimensions` are not unique.
     b) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * For any i in `input_dimensions`, i < 0.
     c) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * For any i in `input_dimensions`, i >= N.
C14: a) `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`.
C15: a) `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`.
C16: a) `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`.
C17: a) size(`kernel_spatial_dimensions`) != N - 2.
C18: a) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * Any dimensions in `kernel_dimensions` are not unique.
     b) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * For any i in$ `kernel_dimensions`, i < 0.
     c) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * For any i in `kernel_dimensions`, i >= N.
C19: a) size(`output_spatial_dimensions`) != N - 2.
C20: a) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * Any dimensions in `output_dimensions` are not unique.
     b) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i < 0.
     c) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i >= N.
C21: a) `feature_group_count <= 0`.
C22: a) `batch_group_count <= 0`.
C23: a) `feature_group_count` != 1 and `batch_group_count` != 1.
C24: a) size(`precision_config`) != 2.
C25: a) For result_dim in [0, N):
        `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
     b) For result_dim in [0, N):
        `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
     c) For result_dim in [0, N):
        `dim(result, result_dim)` != `num_windows` otherwise, where:
       * `output_spatial_dimensions[spatial_dim] = result_dim`.
       * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
       * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
       * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
       * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
       * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
       * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
C26: a) rank(result) != N.
C27: a) element_type(`lhs`) != element_type(`rhs`).

If we drop the "Covered by ODS" pieces, this will leave us with the following test cases:

I3a: `window_strides` is not a 1-dimensional tensor.
I4a: `padding` is not a 2-dimensional tensor.
I5a: `lhs_dilation` is not a 1-dimensional tensor.
I6a: `rhs_dilation` is not a 1-dimensional tensor.
I7a: `window_reversal` is not a 1-dimensional tensor.
C1a: rank(`lhs`) != rank(`rhs`) != N.
C2a: size(`window_strides`) != N - 2.
C3a: `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)).
C4a: dim(`padding`, 0) != N - 2.
C4b: dim(`padding`, 1) != 2.
C5a: size(`lhs_dilation`) != N - 2.
C6a: `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)).
C7a: size(`rhs_dilation`) != N - 2.
C8a: `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)).
C9a: size(`window_reversal`) != N - 2.
C10a: `dim(lhs, input_batch_dimension) % batch_group_count != 0`.
C11a: `dim(lhs, input_feature_dimension) % feature_group_count != 0`.
C12a: size(`input_spatial_dimensions`) != N - 2.
C13a: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * Any dimensions in `input_dimensions` are not unique.
C13b: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * For any i in `input_dimensions`, i < 0.
C13c: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * For any i in `input_dimensions`, i >= N.
C14a: `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`.
C15a: `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`.
C16a: `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`.
C17a: size(`kernel_spatial_dimensions`) != N - 2.
C18a: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * Any dimensions in `kernel_dimensions` are not unique.
C18b: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * For any i in$ `kernel_dimensions`, i < 0.
C18c: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * For any i in `kernel_dimensions`, i >= N.
C19a: size(`output_spatial_dimensions`) != N - 2.
C20a: Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * Any dimensions in `output_dimensions` are not unique.
     b) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i < 0.
     c) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i >= N.
C21a: `feature_group_count <= 0`.
C22a: `batch_group_count <= 0`.
C23a: `feature_group_count` != 1 and `batch_group_count` != 1.
C24a: size(`precision_config`) != 2.
C25a: For result_dim in [0, N):
      `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
C25b: For result_dim in [0, N):
      `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
C25c: For result_dim in [0, N):
      `dim(result, result_dim)` != `num_windows` otherwise, where:
        * `output_spatial_dimensions[spatial_dim] = result_dim`.
        * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
        * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
        * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
        * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
        * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
        * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
C26a: rank(result) != N.
C27a: element_type(`lhs`) != element_type(`rhs`).

Notes:

closes #970

@ghpvnist ghpvnist added Spec Interpreter Migrate to MHLO PR that needs to be migrated to MLIR-HLO labels Mar 11, 2023
@ghpvnist ghpvnist requested a review from sdasgup3 March 11, 2023 00:47
@ghpvnist
Copy link
Member Author

ghpvnist commented Mar 11, 2023

  1. Currently there is no constraint to check whether rhs dimension size is zero, but we do have a check for it in TypeInference L337. Should we add one? In particular, this test in verify_conv.mlir:
func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
    %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
  // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}}
  %0 = stablehlo.convolution(%arg0, %arg1)
         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
         {
           batch_group_count = 1 : i64,
           feature_group_count = 1 : i64,
           precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
       (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32>
  func.return %0 : tensor<1x8x8x16xf32>
}
  1. You may ignore the DotGeneralOp implementation while reviewing this PR.
  2. Also, tests related to preferred element type and some floating-point tests are still disabled.

stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
docs/spec.md Outdated Show resolved Hide resolved
output.mlir Outdated Show resolved Hide resolved
stablehlo/dialect/StablehloOps.td Show resolved Hide resolved
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
SmallVector<Tensor> results;
for (auto [left, right] : llvm::zip(lhses, rhses)) {
SmallVector<ShapedTypeComponents> inferredConvolutionType;
auto convolutionStatus = hlo::inferConvolutionOp(
Copy link
Member

@sdasgup3 sdasgup3 Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the use of inferConvolutionOp and it's interface

LogicalResult inferConvolutionOp(
, I am thinking about using the same interface for evalConvolutionOp

The main motivation is: We are using a lot of boilerplate code to convert attributes from one type to other just to call infer*Ops. All these boilerplate can be removed if we have evalConvolutionOp share the same interface with inferConv*Op. This will improve the readability.

Benefit:

  1. We can get rid of all these wrap/unwrap code for various attributes. No flattenPad.
  2. The callsite of evalConvolutionOp in eval, where we have the loop over the ops, would be simplified as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your overall idea makes sense. But the counterargument would be that it comes at the cost of adding boiler plate code to create default parameters under eval, and it would be an exception from a related issue #1031 of unwrapping MLIR based classes out of inferFooOps. Let's hear from @burmako before I continue with this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be an exception from a related issue #1031 of unwrapping MLIR based classes out of inferFooOps

All we need here is a common interface for inferOp* and evalOp*, so as to remove the conversion code. I am perfectly fine with modifying the inferOp* interfaces as we did earlier as well.

Sure, let's hear from Eugene first.

stablehlo/reference/Ops.cpp Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
@ghpvnist ghpvnist mentioned this pull request Mar 14, 2023
@ghpvnist ghpvnist requested a review from sdasgup3 March 15, 2023 01:15
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
@ghpvnist ghpvnist force-pushed the convolution branch 3 times, most recently from 49537c5 to e1fcf8d Compare March 16, 2023 00:11
@ghpvnist ghpvnist requested a review from sdasgup3 March 16, 2023 00:16
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
}
return evalConcatenateOp(results, outputFeatureDimension, result.getType());
}
auto lhsWindowDimensions = concatAndPermute(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the spec

* `lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])`.
* `result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])`.

All instances of concateAndPermute are using only permutation orders. We can defined them once like the following as pass onto concateAndPermute

auto lhsPermutation = inputBatchDimension + inputSpatialDimensions + inputFeatureDimension;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto lhsPermutation = inputBatchDimension + inputSpatialDimensions + inputFeatureDimension;

Since we can't quite use the syntax sugar mentioned above, I moved the computation for permutation out of the helper function. This saves some compute and is also closer to the spec.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for that we need to added overload operator+ in Axes? Let us weigh in @burmako opinion on this. Keeping this unresolved for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. The only problem is that the way we overload operator+ would be different in Axes.h and Sizes.h. One option is to write concat function, but this does not necessarily make the code as simple as using operator+.

stablehlo/reference/Ops.cpp Outdated Show resolved Hide resolved
@sdasgup3
Copy link
Member

Finish adding my remaining set of comments! Sorry for taking a bit longer for the review. It took me some time to get the overall perceptive. Also, I can see how much attention and hard-work you have put in to get the implementation and make it working. Thanks @ghpvnist !

@sdasgup3 sdasgup3 assigned ghpvnist and unassigned sdasgup3 Mar 25, 2023
@ghpvnist ghpvnist requested a review from burmako March 27, 2023 17:59
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Mar 27, 2023
@ghpvnist
Copy link
Member Author

Given that this change is very large to review in one go, I've split the implementation from the rest of this PR in #1964. The remaining changes from this PR will be addressed once the implementation is merged with #2092 tracking this.

@ghpvnist ghpvnist closed this Mar 14, 2024
ghpvnist added a commit that referenced this pull request Apr 4, 2024
Since the interpreter implementation is already a mouthful, I've split
#1314 to separate the implementation from the remaining checklist items.
Once this is merged, I'll work on the remaining to update the rest of
the docs.

closes #970
ghpvnist added a commit that referenced this pull request Apr 18, 2024
This is part 2 of #1964 to implement parts of #1314.

Also fix a bug on generating invalid padding shape (used to generate
`N-2` number of dim size 2 instead of shape `[N-2, 2]`).
ghpvnist added a commit that referenced this pull request Apr 23, 2024
This is part 3 of #1964 to implement the remaining parts of #1314.

One notable change in TypeInference.cpp is (C27), whose verification
differs whether element type is quantized.

We have the following constraints in the spec (excluding
quantization-related constraints C28-C33):

```
(I1) `lhs` tensor.
(I2) `rhs` tensor.
(I3) `window_strides` 1-dimensional tensor constant of type `si64`.
(I4) `padding` 2-dimensional tensor constant of type `si64`.
(I5) `lhs_dilation` 1-dimensional tensor constant of type `si64`.
(I6) `rhs_dilation` 1-dimensional tensor constant of type `si64`.
(I7) `window_reversal` 1-dimensional tensor constant of type `i1`.
(I8) `input_batch_dimension` constant of type `si64`.
(I9) `input_feature_dimension` constant of type `si64`.
(I10) `input_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I11) `kernel_input_feature_dimension` constant of type `si64`.
(I12) `kernel_output_feature_dimension` constant of type `si64`.
(I13) `kernel_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I14) `output_batch_dimension` constant of type `si64`.
(I15) `output_feature_dimension` constant of type `si64`.
(I16) `output_spatial_dimensions` 1-dimensional tensor constant of type `si64`.
(I17) `feature_group_count` constant of type `si64`.
(I18) `batch_group_count` constant of type `si64`.
(I19) `precision_config` variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`.
(C1) `N = rank(lhs) = rank(rhs)`.
(C2) `size(window_strides) = N - 2`.
(C3) `0 < window_strides`.
(C4) `shape(padding) = [N - 2, 2]`.
(C5) `size(lhs_dilation) = N - 2`.
(C6) `0 < lhs_dilation`.
(C7) `size(rhs_dilation) = N - 2`.
(C8) `0 < rhs_dilation`.
(C9) `size(window_reversal) = N - 2`.
(C10) `dim(lhs, input_batch_dimension) % batch_group_count = 0`.
(C11) `dim(lhs, input_feature_dimension) % feature_group_count = 0`.
(C12) `size(input_spatial_dimensions) = N - 2`.
(C13) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
* `is_unique(input_dimensions)`.
* `0 <= input_dimensions < N`.
(C14) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`.
(C15) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`.
(C16) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`.
(C17) `size(kernel_spatial_dimensions) = N - 2`.
(C18) Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
* `is_unique(kernel_dimensions)`.
* `0 <= kernel_dimensions < N`.
(C19) `size(output_spatial_dimensions) = N - 2`.
(C20) Given `output_dimensions = [output_batch_dimension] +
      output_spatial_dimensions + [output_feature_dimension]`:
* `is_unique(output_dimensions)`.
* `0 <= output_dimensions < N`.
(C21) `0 < feature_group_count`.
(C22) `0 < batch_group_count`.
(C23) `feature_group_count = 1 or batch_group_count = 1`.
(C24) `size(precision_config) = 2`.
(C25) `dim(result, result_dim)` is defined as:
* `dim(lhs, input_batch_dimension) / batch_group_count` if `result_dim = output_batch_dimension`.
* `dim(rhs, kernel_output_feature_dimension)` if `result_dim = output_feature_dimension`.
* `num_windows` otherwise, where:
  * `output_spatial_dimensions[spatial_dim] = result_dim`.
  * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
  * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
  * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
  * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
  * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
  * `is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]`.
  * `num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
(C26) `rank(result) = N`.
(C27) `element_type(lhs) = element_type(rhs) = element_type(result)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `lhs` tensor. (Covered by ODS).
I2: a) `rhs` tensor. (Covered by ODS).
I3: a) `window_strides` is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(`window_strides`) != `si64`. (Covered by ODS).
I4: a) `padding` is not a 2-dimensional tensor.
    b) element_type(`padding`) != `si64`. (Covered by ODS).
I5: a) `lhs_dilation` is not a 1-dimensional tensor. (Covered by ODS).
    b)  element_type(`lhs_dilation`) != `si64`. (Covered by ODS).
I6: a) `rhs_dilation` is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(`rhs_dilation`) != `si64`. (Covered by ODS).
I7: a) `window_reversal` is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(`window_reversal`) != `i1`. (Covered by ODS).
I8: a) element_type(`input_batch_dimension`) != `si64`. (Covered by ODS).
I9: a) element_type(`input_feature_dimension`) != `si64`. (Covered by ODS).
I10: a) `input_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`input_spatial_dimensions`) != `si64`. (Covered by ODS).
I11: a) element_type(`kernel_input_feature_dimension`) != `si64`. (Covered by ODS).
I12: a) element_type(`kernel_output_feature_dimension`) != `si64`. (Covered by ODS).
I13: a) `kernel_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`kernel_spatial_dimensions`) != `si64`. (Covered by ODS).
I14: a) element_type(`output_batch_dimension`) != `si64`. (Covered by ODS).
I15: a) element_type(`output_feature_dimension`) != `si64`. (Covered by ODS).
I16: a) `output_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS).
     b) element_type(`output_spatial_dimensions`) != `si64`. (Covered by ODS).
I17: a) element_type(`feature_group_count`) != `si64`. (Covered by ODS).
I18: a) element_type(`batch_group_count`) != `si64`. (Covered by ODS).
I19: a) `precision_config` does not have variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS).
C1: a) N = rank(`lhs`) != rank(`rhs`).
C2: a) size(`window_strides`) != N - 2.
C3: a) `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)).
C4: a) dim(`padding`, 0) != N - 2.
    b) dim(`padding`, 1) != 2.
C5: a) size(`lhs_dilation`) != N - 2.
C6: a) `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)).
C7: a) size(`rhs_dilation`) != N - 2.
C8: a) `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)).
C9: a) size(`window_reversal`) != N - 2.
C10: a) `dim(lhs, input_batch_dimension) % batch_group_count != 0`.
C11: a) `dim(lhs, input_feature_dimension) % feature_group_count != 0`.
C12: a) size(`input_spatial_dimensions`) != N - 2.
C13: a) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * Any dimensions in `input_dimensions` are not unique.
     b) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * For any i in `input_dimensions`, i < 0.
     c) Given `input_dimensions = [input_batch_dimension] +
     input_spatial_dimensions + [input_feature_dimension]`:
     * For any i in `input_dimensions`, i >= N.
C14: a) `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`.
C15: a) `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`.
C16: a) `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`.
C17: a) size(`kernel_spatial_dimensions`) != N - 2.
C18: a) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * Any dimensions in `kernel_dimensions` are not unique.
     b) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * For any i in$ `kernel_dimensions`, i < 0.
     c) Given `kernel_dimensions = kernel_spatial_dimensions +
     [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
     * For any i in `kernel_dimensions`, i >= N.
C19: a) size(`output_spatial_dimensions`) != N - 2.
C20: a) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * Any dimensions in `output_dimensions` are not unique.
     b) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i < 0.
     c) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i >= N.
C21: a) `feature_group_count <= 0`.
C22: a) `batch_group_count <= 0`.
C23: a) `feature_group_count` != 1 and `batch_group_count` != 1.
C24: a) size(`precision_config`) != 2.
C25: a) For result_dim in [0, N):
        `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
     b) For result_dim in [0, N):
        `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
     c) For result_dim in [0, N):
        `dim(result, result_dim)` != `num_windows` otherwise, where:
       * `output_spatial_dimensions[spatial_dim] = result_dim`.
       * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
       * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
       * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
       * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
       * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
       * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
C26: a) rank(result) != N.
C27: a) element_type(`lhs`) != element_type(`rhs`).
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
I4a: `padding` is not a 2-dimensional tensor.
C1a: rank(`lhs`) != rank(`rhs`) != N.
C2a: size(`window_strides`) != N - 2.
C3a: `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)).
C4a: dim(`padding`, 0) != N - 2.
C4b: dim(`padding`, 1) != 2.
C5a: size(`lhs_dilation`) != N - 2.
C6a: `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)).
C7a: size(`rhs_dilation`) != N - 2.
C8a: `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)).
C9a: size(`window_reversal`) != N - 2.
C10a: `dim(lhs, input_batch_dimension) % batch_group_count != 0`.
C11a: `dim(lhs, input_feature_dimension) % feature_group_count != 0`.
C12a: size(`input_spatial_dimensions`) != N - 2.
C13a: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * Any dimensions in `input_dimensions` are not unique.
C13b: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * For any i in `input_dimensions`, i < 0.
C13c: Given `input_dimensions = [input_batch_dimension] +
      input_spatial_dimensions + [input_feature_dimension]`:
      * For any i in `input_dimensions`, i >= N.
C14a: `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`.
C15a: `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`.
C16a: `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`.
C17a: size(`kernel_spatial_dimensions`) != N - 2.
C18a: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * Any dimensions in `kernel_dimensions` are not unique.
C18b: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * For any i in$ `kernel_dimensions`, i < 0.
C18c: Given `kernel_dimensions = kernel_spatial_dimensions +
      [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`:
      * For any i in `kernel_dimensions`, i >= N.
C19a: size(`output_spatial_dimensions`) != N - 2.
C20a: Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * Any dimensions in `output_dimensions` are not unique.
     b) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i < 0.
     c) Given `output_dimensions = [output_batch_dimension] +
     output_spatial_dimensions + [output_feature_dimension]`:
     * For any i in `output_dimensions`, i >= N.
C21a: `feature_group_count <= 0`.
C22a: `batch_group_count <= 0`.
C23a: `feature_group_count` != 1 and `batch_group_count` != 1.
C24a: size(`precision_config`) != 2.
C25a: For result_dim in [0, N):
      `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
C25b: For result_dim in [0, N):
      `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
C25c: For result_dim in [0, N):
      `dim(result, result_dim)` != `num_windows` otherwise, where:
        * `output_spatial_dimensions[spatial_dim] = result_dim`.
        * `lhs_dim = input_spatial_dimensions[spatial_dim]`.
        * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
        * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
        * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
        * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
        * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
C26a: rank(result) != N.
C27a: element_type(`lhs`) != element_type(`rhs`).
```

Notes:
* (new C24) is left untouched as there are still pending action item
regarding the number of precision config values allowed in #879.

closes #2092
@ghpvnist ghpvnist deleted the convolution branch May 14, 2024 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Interpreter Migrate to MHLO PR that needs to be migrated to MLIR-HLO Spec
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Add interpreter for convolution
4 participants