Skip to content

Commit

Permalink
[GlobalOpt] Prevent fusing transposed extend in RaiseSpecialOps (#18901)
Browse files Browse the repository at this point in the history
NamedImplicitCastOpConversion pattern is incorrectly fusing transposed
element-wise extend into Linalg op.

---------

Signed-off-by: Cullen Rhodes <cullen.rhodes@arm.com>
  • Loading branch information
c-rhodes authored Oct 30, 2024
1 parent 5fc340d commit 15ea0dc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern<OpTy> {
return false;
}

if (!llvm::all_of(producer.getIndexingMapsArray(),
[](AffineMap map) { return map.isIdentity(); }))
return false;

std::optional<CastOpInterface> castOp =
getDefiningNonI1ExtendingCastOp(operand.get());
if (!castOp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,33 @@ util.func public @matmul_extsi(%arg0 : tensor<10x20xi32>,
// CHECK: util.return %[[RESULT]]
// -----

// Regression test. extsi is transposed, dont't fuse into matmul.
util.func public @matmul_extsi_transposed(%arg0 : tensor<10x20xi32>,
%arg1 : tensor<40x20xi16>) -> tensor<10x40xi32> {
%0 = tensor.empty() : tensor<20x40xi32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : tensor<40x20xi16>) outs(%0 : tensor<20x40xi32>) {
^bb0(%b0 : i16, %b1 : i32):
%e = arith.extsi %b0 : i16 to i32
linalg.yield %e : i32
} -> tensor<20x40xi32>
%2 = tensor.empty() : tensor<10x40xi32>
%3 = arith.constant 0 : i32
%4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
%5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>)
outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32>
util.return %5 : tensor<10x40xi32>
}
// CHECK-LABEL: util.func public @matmul_extsi_transposed
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xi16>
// CHECK: %[[GEN:.+]] = linalg.generic
// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[GEN]]
// CHECK: util.return %[[RESULT]]
// -----

util.func public @matmul_extsi_a(%arg0 : tensor<10x20xi16>,
%arg1 : tensor<20x40xi32>) -> tensor<10x40xi32> {
%0 = tensor.empty() : tensor<10x20xi32>
Expand Down

0 comments on commit 15ea0dc

Please sign in to comment.