Skip to content

Commit

Permalink
Make pass more restricted for correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 22, 2023
1 parent 0e075cf commit 663466b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,27 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
if (!isElementwise(linalgOp)) {
return failure();
}
if (!llvm::hasSingleElement(linalgOp.getResults())) {
return failure();
}

// Find a tensor.extract op in the linalgOp body.
auto extractOps = linalgOp.getBody()->getOps<tensor::ExtractOp>();
if (!llvm::hasSingleElement(extractOps)) {
return failure();
}
tensor::ExtractOp extractOp = *extractOps.begin();
auto resultType = dyn_cast<TensorType>(linalgOp.getResult(0).getType());
if (!resultType) {
return failure();
}

ArrayRef<int64_t> sourceShape = extractOp.getTensor().getType().getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();

// Raise the tensor.extract op to an input.
SmallVector<AffineExpr> exprs;
for (Value indexValue : extractOp.getIndices()) {
for (auto [idx, indexValue] : llvm::enumerate(extractOp.getIndices())) {
// For raising, the indexing value must be one of the following:
// 1. A constant value.
// 2. A linalg.index.
Expand All @@ -111,6 +121,18 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
}
// 2. The indexing value is a linalg.index.
if (auto indexOp = indexValue.getDefiningOp<linalg::IndexOp>()) {
// Make sure that for this index, the size of the input and output
// match and are not dynamic. We need this to maintain the op to be
// elementwise.
// TODO: This restriction can be relaxed by adding a extract_slice op
// on the `source` tensor. This is not same as raising the whole
// operation to an extract_slice, as there can be permutations and
// projections involved.
if (sourceShape[idx] == ShapedType::kDynamic ||
resultShape[indexOp.getDim()] == ShapedType::kDynamic ||
sourceShape[idx] != resultShape[indexOp.getDim()]) {
return failure();
}
exprs.push_back(
getAffineDimExpr(indexOp.getDim(), rewriter.getContext()));
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,6 @@ func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>,
// CHECK: %[[RESULT:.+]] = linalg.matmul_transpose_b
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]
// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @test(%arg0: tensor<1x1x?x?xf32>, %arg1: tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> {
%c0 = arith.constant 0 : index
// CHECK: linalg.generic
// CHECK: (%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x?x?xf32>) {
^bb0(%out: f32):
%1 = linalg.index 2 : index
%2 = linalg.index 3 : index
%extracted = tensor.extract %arg0[%c0, %c0, %1, %2] : tensor<1x1x?x?xf32>
linalg.yield %extracted : f32
} -> tensor<1x1x?x?xf32>
return %0 : tensor<1x1x?x?xf32>
}

// -----

Expand All @@ -218,3 +201,37 @@ func.func @test(%A : tensor<1x1x5120xf32>, %B : tensor<5120xf32>) -> tensor<5120
} -> tensor<5120xf32>
return %0 : tensor<5120xf32>
}

// -----

// This currently should not be raised as the operation does not remain
// elementwise after raising the tensor.extract to input.
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<128x128x128xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
// CHECK: linalg.generic
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<128x128x128xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<64x64x64xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
// CHECK: tensor.extract_slice
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<64x64x64xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

0 comments on commit 663466b

Please sign in to comment.