-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Avoid folding unit extent dims within pre-formed dispatches. (#14394)
This makes the fold unit dims pattern ignore pre-formed dispatches. Addresses #14337
- Loading branch information
1 parent
e450c61
commit ac0b103
Showing
9 changed files
with
111 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
//===- FoldUnitExtentDims.cpp - Pass to fold unit extent dims of tensors -===// | ||
// | ||
// Light weight wrapper to call the patterns to fold unit extent dims with | ||
// IREE control. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir { | ||
namespace iree_compiler { | ||
namespace IREE { | ||
namespace Flow { | ||
|
||
namespace { | ||
struct FoldUnitExtentDimsPass | ||
: public FoldUnitExtentDimsBase<FoldUnitExtentDimsPass> { | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<affine::AffineDialect, arith::ArithDialect, | ||
linalg::LinalgDialect, tensor::TensorDialect>(); | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
void FoldUnitExtentDimsPass::runOnOperation() { | ||
Operation *funcOp = getOperation(); | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet foldUnitDimsPatterns(context); | ||
linalg::ControlDropUnitDims options; | ||
auto defaultFn = options.controlFn; | ||
options.controlFn = [&](Operation *op) { | ||
// Ignore operations already in dispatches. | ||
if (!isNonNullAndOutsideDispatch(op)) { | ||
return SmallVector<unsigned>{}; | ||
} | ||
return defaultFn(op); | ||
}; | ||
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); | ||
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, | ||
std::move(foldUnitDimsPatterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> | ||
createFoldUnitExtentDimsPass() { | ||
return std::make_unique<FoldUnitExtentDimsPass>(); | ||
} | ||
|
||
} // namespace Flow | ||
} // namespace IREE | ||
} // namespace iree_compiler | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-fold-unit-extent-dims))" %s | FileCheck %s | ||
|
||
func.func @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> { | ||
%0 = tensor.empty() : tensor<1x1x10xf32> | ||
%1 = flow.dispatch.region[] -> (tensor<1x1x10xf32>) { | ||
%2 = linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], | ||
iterator_types = ["parallel", "parallel", "parallel"]} | ||
ins(%arg0 : tensor<1x1x10xf32>) outs(%0 : tensor<1x1x10xf32>) { | ||
^bb0(%b0 : f32, %b1 : f32): | ||
%3 = arith.addf %b0, %b0 : f32 | ||
linalg.yield %3 : f32 | ||
} -> tensor<1x1x10xf32> | ||
flow.return %2 : tensor<1x1x10xf32> | ||
} | ||
return %1 : tensor<1x1x10xf32> | ||
} | ||
// CHECK: func @no_fold_unit_dims_in_dispatches(%[[ARG0:.+]]: tensor<1x1x10xf32>) | ||
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region | ||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>) | ||
// CHECK: flow.return %[[GENERIC]] | ||
// CHECK: return %[[DISPATCH]] |