Skip to content

Commit

Permalink
Add AggregatedOpInterface to iree_linalg_ext.custom_op (#18700)
Browse files Browse the repository at this point in the history
This allows decomposing the op, i.e. inlining the region of the
`custom_op` into the parent region, and removal of the op.

Also add a TD op for testing.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Oct 10, 2024
1 parent 598a60e commit 5f3f863
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -726,4 +726,63 @@ FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>({loopNest.results[0]});
}

//===----------------------------------------------------------------------===//
// CustomOp
//===----------------------------------------------------------------------===//

FailureOr<SmallVector<Value>> CustomOp::decomposeOperation(OpBuilder &builder) {
CustomOp customOp = *this;

IRRewriter rewriter(builder);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(customOp);
// Inline the body of the operation using the ins/outs as the arguments.
SmallVector<Value> argReplacements;
Location loc = getLoc();
Block *body = customOp.getBody();
for (auto [operand, argument] :
llvm::zip_equal(customOp->getOperands(), body->getArguments())) {
if (operand.getType() != argument.getType()) {
assert(isa<RankedTensorType>(operand.getType()) &&
isa<RankedTensorType>(argument.getType()) &&
"expected operand and arguments to be `RankedTensorType`");
Value cast =
builder.create<tensor::CastOp>(loc, argument.getType(), operand);
argReplacements.push_back(cast);
} else {
argReplacements.push_back(operand);
}
}

Block *oldBlock = customOp->getBlock();
Block *newBlock = rewriter.splitBlock(oldBlock, Block::iterator(customOp));
rewriter.mergeBlocks(body, oldBlock, argReplacements);

// Get the operands of the `iree_linalg_ext.yield` which is the terminator of
// `oldBlock` right now.
auto yieldOp = cast<IREE::LinalgExt::YieldOp>(oldBlock->getTerminator());
rewriter.setInsertionPointToEnd(oldBlock);
SmallVector<Value> customOpReplacements;
for (auto [yieldedVal, result] :
llvm::zip_equal(yieldOp->getOperands(), customOp->getResults())) {
if (yieldedVal.getType() != result.getType()) {
assert(isa<RankedTensorType>(yieldedVal.getType()) &&
isa<RankedTensorType>(result.getType()) &&
"expected yielded value and result to be `RankedTensorType`");
Value cast =
builder.create<tensor::CastOp>(loc, result.getType(), yieldedVal);
customOpReplacements.push_back(cast);
} else {
customOpReplacements.push_back(yieldedVal);
}
}
// Erase the yield op.
rewriter.eraseOp(yieldOp);

// Merge the block back.
rewriter.mergeBlocks(newBlock, oldBlock);

return customOpReplacements;
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ iree_td_library(
iree_compiler_cc_library(
name = "IR",
srcs = [
"AggregatedOpInterfaceImpl.cpp",
"LinalgExtAttrs.cpp.inc",
"LinalgExtDialect.cpp",
"LinalgExtDialect.cpp.inc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_cc_library(
"LinalgExtOps.h.inc"
"LinalgExtTypes.h.inc"
SRCS
"AggregatedOpInterfaceImpl.cpp"
"LinalgExtAttrs.cpp.inc"
"LinalgExtDialect.cpp"
"LinalgExtDialect.cpp.inc"
Expand Down
12 changes: 7 additions & 5 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,8 @@ LogicalResult CustomOp::verify() {
return success();
}

/// Start `LinalgFusionInterface` implementation.

SmallVector<AffineMap> CustomOp::getIndexingMapsForOperands() {
return llvm::map_to_vector(
getIndexingMaps().getValue().take_front(getNumDpsInputs()),
Expand All @@ -1868,11 +1870,9 @@ SmallVector<AffineMap> CustomOp::getIndexingMapsForResults() {
[](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
}

SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
});
}
/// End `LinalgFusionInterface` implementation

/// Start `ReifyRankedShapedTypeOpInterface` implementation

LogicalResult
CustomOp::reifyResultShapes(OpBuilder &builder,
Expand All @@ -1885,6 +1885,8 @@ CustomOp::reifyResultShapes(OpBuilder &builder,
return success();
}

/// End `ReifyRankedShapedTypeOpInterface` implementation

//===---------------------------------------------------------------------===//
// IndexOp
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,8 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
//===---------------------------------------------------------------------===//

def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
DeclareOpInterfaceMethods<AggregatedOpInterface, [
"decomposeOperation"]>,
DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,12 @@ LogicalResult OnlineAttentionOp::getResultTilePosition(
/// These methods copied/modified from `TilingInterface` implementation of
/// `getIterationDomain` of `LinalgOp`s.

SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
});
}

/// Method similar to `LinalgOp`s that concatenates shapes of all operands.
static SmallVector<OpFoldResult>
createFlatListOfOperandDims(OpBuilder &builder, Location loc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure LinalgExt::DecomposeAggregateOp::applyToOne(
transform::TransformRewriter &rewriter,
linalg::AggregatedOpInterface aggregateOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
FailureOr<SmallVector<Value>> replacements =
aggregateOp.decomposeOperation(rewriter);
if (failed(replacements)) {
return emitDefiniteFailure() << "failed to decompose operation";
}
rewriter.replaceOp(aggregateOp, replacements.value());
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne(
transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
transform::ApplyToEachResultList &results,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,44 @@ def DecomposeTiledAttentionOp : Op<Transform_Dialect, "iree.decompose_tiled_atte
}];
}


def DecomposeAggregateOp : Op<Transform_Dialect, "iree.decompose_aggregate_op",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target iree_linalg_ext.attention ops and decompose them.
This transform consumes the target handle and produces a result handle.
}];

let arguments = (
ins TransformHandleTypeInterface:$target,
OptionalAttr<I64Attr>:$tile_size
);
let results = (outs Variadic<TransformHandleTypeInterface>:$result);

let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::AggregatedOpInterface target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

def ConvertToOnlineAttention : Op<Transform_Dialect, "iree.convert_to_online_attention",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "Transforms",
srcs = [
"AggregatedOpInterfaceImpl.cpp",
"ConvertConv2DToIm2ColOp.cpp",
"ConvertConv2DToWinograd.cpp",
"ConvertToLoops.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ iree_cc_library(
"Passes.h.inc"
"Transforms.h"
SRCS
"AggregatedOpInterfaceImpl.cpp"
"ConvertConv2DToIm2ColOp.cpp"
"ConvertConv2DToWinograd.cpp"
"ConvertToLoops.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"conv2d_to_winograd.mlir",
"convert_to_loops.mlir",
"convert_to_online_attention.mlir",
"decompose_aggregate_op.mlir",
"decompose_attention.mlir",
"decompose_im2col.mlir",
"decompose_online_attention.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
"conv2d_to_winograd.mlir"
"convert_to_loops.mlir"
"convert_to_online_attention.mlir"
"decompose_aggregate_op.mlir"
"decompose_attention.mlir"
"decompose_im2col.mlir"
"decompose_online_attention.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s

func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>,
%rhs1 : tensor<?x?xf32>, %rhs2 : tensor<?x?xf32>, %scalar : f32,
%outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>)
-> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
%0:2 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
affine_map<(d0, d1)[s0, s1] -> ()>,
affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
#iree_linalg_ext.iterator_type<parallel>]}
ins(%lhs1, %rhs1, %rhs2, %scalar
: tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, f32)
outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
%s : f32, %t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
%0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1, %s : tensor<?x?xf32>, f32) outs(%1 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 :f32):
%3 = arith.addf %b0, %b2 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
iree_linalg_ext.yield %0, %2 : tensor<?x?xf32>, tensor<?x?xf32>
} -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
}
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
transform.yield
}
}
// CHECK-LABEL: func @custom_op_decomposition(
// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
// CHECK: %[[MATMUL1:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] :
// CHECK-SAME: outs(%[[INIT1]] :
// CHECK: %[[MATMUL2:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] :
// CHECK-SAME: outs(%[[INIT2]] :
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] :
// CHECK-SAME: outs(%[[MATMUL2]] :
// CHECK: return %[[MATMUL1]], %[[GENERIC]]

0 comments on commit 5f3f863

Please sign in to comment.