From 5f3f863ca26fef5331420ba17df21730b14e4dbe Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 9 Oct 2024 21:40:25 -0700 Subject: [PATCH] Add `AggregatedOpInterface` to `iree_linalg_ext.custom_op` (#18700) This allows decomposing the op, i.e. inlining the region of the `custom_op` into the parent region, and removal of the op. Also add a TD op for testing. Signed-off-by: MaheshRavishankar --- .../AggregatedOpInterfaceImpl.cpp | 59 ++++++++++++++++++ .../compiler/Dialect/LinalgExt/IR/BUILD.bazel | 1 + .../Dialect/LinalgExt/IR/CMakeLists.txt | 1 + .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 12 ++-- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 2 + .../LinalgExt/IR/TilingInterfaceImpl.cpp | 6 ++ .../LinalgExtExtensionsOps.cpp | 14 +++++ .../LinalgExtExtensionsOps.td | 38 ++++++++++++ .../Dialect/LinalgExt/Transforms/BUILD.bazel | 1 - .../LinalgExt/Transforms/CMakeLists.txt | 1 - .../LinalgExt/Transforms/test/BUILD.bazel | 1 + .../LinalgExt/Transforms/test/CMakeLists.txt | 1 + .../test/decompose_aggregate_op.mlir | 62 +++++++++++++++++++ 13 files changed, 192 insertions(+), 7 deletions(-) rename compiler/src/iree/compiler/Dialect/LinalgExt/{Transforms => IR}/AggregatedOpInterfaceImpl.cpp (93%) create mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp similarity index 93% rename from compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp rename to compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 5f886c71882d..389eeadb75a9 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -726,4 +726,63 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { return SmallVector({loopNest.results[0]}); } +//===----------------------------------------------------------------------===// +// CustomOp +//===----------------------------------------------------------------------===// + +FailureOr> CustomOp::decomposeOperation(OpBuilder &builder) { + CustomOp customOp = *this; + + IRRewriter rewriter(builder); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(customOp); + // Inline the body of the operation using the ins/outs as the arguments. + SmallVector argReplacements; + Location loc = getLoc(); + Block *body = customOp.getBody(); + for (auto [operand, argument] : + llvm::zip_equal(customOp->getOperands(), body->getArguments())) { + if (operand.getType() != argument.getType()) { + assert(isa(operand.getType()) && + isa(argument.getType()) && + "expected operand and arguments to be `RankedTensorType`"); + Value cast = + builder.create(loc, argument.getType(), operand); + argReplacements.push_back(cast); + } else { + argReplacements.push_back(operand); + } + } + + Block *oldBlock = customOp->getBlock(); + Block *newBlock = rewriter.splitBlock(oldBlock, Block::iterator(customOp)); + rewriter.mergeBlocks(body, oldBlock, argReplacements); + + // Get the operands of the `iree_linalg_ext.yield` which is the terminator of + // `oldBlock` right now. + auto yieldOp = cast(oldBlock->getTerminator()); + rewriter.setInsertionPointToEnd(oldBlock); + SmallVector customOpReplacements; + for (auto [yieldedVal, result] : + llvm::zip_equal(yieldOp->getOperands(), customOp->getResults())) { + if (yieldedVal.getType() != result.getType()) { + assert(isa(yieldedVal.getType()) && + isa(result.getType()) && + "expected yielded value and result to be `RankedTensorType`"); + Value cast = + builder.create(loc, result.getType(), yieldedVal); + customOpReplacements.push_back(cast); + } else { + customOpReplacements.push_back(yieldedVal); + } + } + // Erase the yield op. + rewriter.eraseOp(yieldOp); + + // Merge the block back. + rewriter.mergeBlocks(newBlock, oldBlock); + + return customOpReplacements; +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel index e0ebdb2313d8..f9703a46d26d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel @@ -43,6 +43,7 @@ iree_td_library( iree_compiler_cc_library( name = "IR", srcs = [ + "AggregatedOpInterfaceImpl.cpp", "LinalgExtAttrs.cpp.inc", "LinalgExtDialect.cpp", "LinalgExtDialect.cpp.inc", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt index 4f1d5231742a..68e90eb2a3fe 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( "LinalgExtOps.h.inc" "LinalgExtTypes.h.inc" SRCS + "AggregatedOpInterfaceImpl.cpp" "LinalgExtAttrs.cpp.inc" "LinalgExtDialect.cpp" "LinalgExtDialect.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index dafb17fbca45..bbbc32a3c5e0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1856,6 +1856,8 @@ LogicalResult CustomOp::verify() { return success(); } +/// Start `LinalgFusionInterface` implementation. + SmallVector CustomOp::getIndexingMapsForOperands() { return llvm::map_to_vector( getIndexingMaps().getValue().take_front(getNumDpsInputs()), @@ -1868,11 +1870,9 @@ SmallVector CustomOp::getIndexingMapsForResults() { [](Attribute attr) { return cast(attr).getValue(); }); } -SmallVector CustomOp::getLoopIteratorTypes() { - return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) { - return cast(attr).getValue(); - }); -} +/// End `LinalgFusionInterface` implementation + +/// Start `ReifyRankedShapedTypeOpInterface` implementation LogicalResult CustomOp::reifyResultShapes(OpBuilder &builder, @@ -1885,6 +1885,8 @@ CustomOp::reifyResultShapes(OpBuilder &builder, return success(); } +/// End `ReifyRankedShapedTypeOpInterface` implementation + //===---------------------------------------------------------------------===// // IndexOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index e6aab962d10d..ef806e5f6b0d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1623,6 +1623,8 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_ //===---------------------------------------------------------------------===// def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods CustomOp::getLoopIteratorTypes() { + return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) { + return cast(attr).getValue(); + }); +} + /// Method similar to `LinalgOp`s that concatenates shapes of all operands. static SmallVector createFlatListOfOperandDims(OpBuilder &builder, Location loc, diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp index 923b30a8da73..b350053b394e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp @@ -51,6 +51,20 @@ DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure LinalgExt::DecomposeAggregateOp::applyToOne( + transform::TransformRewriter &rewriter, + linalg::AggregatedOpInterface aggregateOp, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + FailureOr> replacements = + aggregateOp.decomposeOperation(rewriter); + if (failed(replacements)) { + return emitDefiniteFailure() << "failed to decompose operation"; + } + rewriter.replaceOp(aggregateOp, replacements.value()); + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne( transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp, transform::ApplyToEachResultList &results, diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td index 84e588af9621..9e761c0843f5 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td @@ -86,6 +86,44 @@ def DecomposeTiledAttentionOp : Op { + let description = [{ + Target iree_linalg_ext.attention ops and decompose them. + This transform consumes the target handle and produces a result handle. + }]; + + let arguments = ( + ins TransformHandleTypeInterface:$target, + OptionalAttr:$tile_size + ); + let results = (outs Variadic:$result); + + let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)"; + let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::AggregatedOpInterface target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def ConvertToOnlineAttention : Op, + %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, + %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) + -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + %0:2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, + affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, + affine_map<(d0, d1)[s0, s1] -> ()>, + affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], + iterator_types = [#iree_linalg_ext.iterator_type, + #iree_linalg_ext.iterator_type]} + ins(%lhs1, %rhs1, %rhs2, %scalar + : tensor<1000000x?xf32>, tensor, tensor, f32) + outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, + %s : f32, %t3 : tensor, %t4 : tensor) : + %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) + outs(%t3 : tensor) -> tensor + %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) + outs(%t4 : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1, %s : tensor, f32) outs(%1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): + %3 = arith.addf %b0, %b2 : f32 + linalg.yield %3 : f32 + } -> tensor + iree_linalg_ext.yield %0, %2 : tensor, tensor + } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> + return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} +// CHECK-LABEL: func @custom_op_decomposition( +// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK: %[[MATMUL1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : +// CHECK-SAME: outs(%[[INIT1]] : +// CHECK: %[[MATMUL2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : +// CHECK-SAME: outs(%[[INIT2]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : +// CHECK-SAME: outs(%[[MATMUL2]] : +// CHECK: return %[[MATMUL1]], %[[GENERIC]]