From b24a6e2f8d114aae22eb1202ea4cfaa96b6c07e8 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Thu, 24 Aug 2023 17:54:48 -0700 Subject: [PATCH] [Flow] Raise special `linalg.generic` ops to `linalg.fill` ops (#14773) --- .../Flow/Transforms/RaiseSpecialOps.cpp | 42 +++++++++++++++++++ .../Transforms/test/raise_special_ops.mlir | 26 ++++++++++++ 2 files changed, 68 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp index 6fcddd660c31..8d4c4d9cf681 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp @@ -76,6 +76,34 @@ std::optional matchATransposeBMatmul(linalg::LinalgOp matmulOp) { return std::nullopt; } +// Method to match a linalg.generic op representing a linalg.fill op. Returns +// the fill value (input operand to linalg.fill) on success. +std::optional matchGenericFill(linalg::LinalgOp linalgOp) { + if (isa(linalgOp.getOperation()) && + linalgOp.getNumDpsInputs() == 0 && linalgOp.getNumDpsInits() == 1 && + linalgOp.getNumParallelLoops() == linalgOp.getNumLoops() && + linalgOp.getIndexingMapsArray()[0].isIdentity()) { + // Check that the op body is only a linalg.yield op. + Value yieldOperand; + for (Operation &bodyOp : linalgOp.getBlock()->getOperations()) { + if (isa(bodyOp)) { + yieldOperand = bodyOp.getOperand(0); + } else { + return std::nullopt; + } + } + // Check that the operand of the linalg.yield op is not an argument of the + // linalg.generic basic block + for (Value blockArg : linalgOp.getBlock()->getArguments()) { + if (yieldOperand == blockArg) { + return std::nullopt; + } + } + return yieldOperand; + } + return std::nullopt; +} + /// Matches a linalg.generic operation reading data from a tensor `source` using /// tensor.extract, and raises the `source` tensor to an input of the linalg /// operation. @@ -333,6 +361,7 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { SmallVector> softmaxRoots; SmallVector> transposeMatmulRoots; + SmallVector> genericFills; getOperation()->walk([&](linalg::LinalgOp op) { { transform_ext::MatcherContext matcherContext; @@ -347,6 +376,10 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { transposeMatmulRoots.push_back(std::make_pair( cast(op.getOperation()), newRhs.value())); } + if (std::optional fillInput = matchGenericFill(op)) { + genericFills.push_back( + std::make_pair(cast(op), fillInput.value())); + } } }); @@ -369,6 +402,15 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { rewriter.replaceOpWithNewOp( matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs); } + for (std::pair genericFill : genericFills) { + auto genericOp = genericFill.first; + Value fillInput = genericFill.second; + Value init = genericOp.getDpsInitOperand(0)->get(); + rewriter.setInsertionPoint(genericOp); + SmallVector attrs = getPrunedAttributeList(genericOp); + rewriter.replaceOpWithNewOp( + genericOp, ValueRange{fillInput}, ValueRange{init}, attrs); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir index 49389b61258a..54835d52f745 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir @@ -187,6 +187,32 @@ func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>, // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK: return %[[RESULT]] +func.func @generic_fill(%arg0: tensor) -> tensor<1x1x?x?xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c0 : tensor + %dim_0 = tensor.dim %arg0, %c1 : tensor + %0 = tensor.empty(%dim, %dim_0) : tensor<1x1x?x?xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + outs(%0 : tensor<1x1x?x?xf32>) { + ^bb0(%out: f32): + linalg.yield %cst : f32 + } -> tensor<1x1x?x?xf32> + return %1 : tensor<1x1x?x?xf32> +} +// CHECK-LABEL: func @generic_fill +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK-SAME: : tensor<1x1x?x?xf32> +// CHECK: %[[RESULT:.+]] = linalg.fill +// CHECK-SAME: ins(%[[CST]] : f32) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x?x?xf32>) +// CHECK: return %[[RESULT]] + // ----- #map = affine_map<(d0) -> (d0)>