Skip to content

Commit

Permalink
Avoid folding unit extent dims within pre-formed dispatches. (#14394)
Browse files Browse the repository at this point in the history
This makes the fold unit dims pattern ignore pre-formed dispatches.

Addresses #14337
  • Loading branch information
MaheshRavishankar authored Aug 3, 2023
1 parent e450c61 commit ac0b103
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ iree_compiler_cc_library(
"EraseUnusedLinalgOperands.cpp",
"ExpandTensorShapes.cpp",
"ExportBenchmarkFuncs.cpp",
"FoldUnitExtentDims.cpp",
"FormDispatchRegions.cpp",
"FormDispatchWorkgroups.cpp",
"FormScalarDispatches.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_cc_library(
"EraseUnusedLinalgOperands.cpp"
"ExpandTensorShapes.cpp"
"ExportBenchmarkFuncs.cpp"
"FoldUnitExtentDims.cpp"
"FormDispatchRegions.cpp"
"FormDispatchWorkgroups.cpp"
"FormScalarDispatches.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- FoldUnitExtentDims.cpp - Pass to fold unit extent dims of tensors -===//
//
// Light weight wrapper to call the patterns to fold unit extent dims with
// IREE control.
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {

namespace {
struct FoldUnitExtentDimsPass
: public FoldUnitExtentDimsBase<FoldUnitExtentDimsPass> {

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
linalg::LinalgDialect, tensor::TensorDialect>();
}

void runOnOperation() override;
};
} // namespace

void FoldUnitExtentDimsPass::runOnOperation() {
Operation *funcOp = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet foldUnitDimsPatterns(context);
linalg::ControlDropUnitDims options;
auto defaultFn = options.controlFn;
options.controlFn = [&](Operation *op) {
// Ignore operations already in dispatches.
if (!isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
return defaultFn(op);
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(foldUnitDimsPatterns)))) {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFoldUnitExtentDimsPass() {
return std::make_unique<FoldUnitExtentDimsPass>();
}

} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
// - Remove unit-extent dimensions.
.addPass(mlir::createConvertElementwiseToLinalgPass)
.addPass(createGeneralizeLinalgNamedOpsPass)
.addPass(mlir::createLinalgFoldUnitExtentDimsPass)
.addPass(createFoldUnitExtentDimsPass)
.addPass(createRaiseSpecialOps)
.addPass(createInterchangeGenericOpsPass)
.addPass(createCollapseDimsPass)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ createTensorPadToTensorInsertSlicePass(bool skipSingleLinalgOpUses = false);
// Create a pass to detach elementwise ops from named Linalg ops.
std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass();

// Create a pass that imports upstream patterns to fold unit extent dims
// but with IREE control.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFoldUnitExtentDimsPass();

// Creates a pass to fuse Linalg operations on tensors.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFusionOfTensorOpsPass(bool fuseMultiUse = false,
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def DetachElementwiseFromNamedOps :
let constructor = "mlir::iree_compiler::IREE::Flow::createDetachElementwiseFromNamedOpsPass()";
}

def FoldUnitExtentDims :
InterfacePass<"iree-flow-fold-unit-extent-dims", "mlir::FunctionOpInterface"> {
let summary = "Fold unit extent dimension of operations";
let constructor = "mlir::iree_compiler::IREE::Flow::createFoldUnitExtentDimsPass()";
}

def FormDispatchRegions :
InterfacePass<"iree-flow-form-dispatch-regions", "mlir::FunctionOpInterface"> {
let summary = "Form Dispatch Region Ops from Linalg operations on tensors to form dispatch.regions";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_lit_test_suite(
"dispatch_linalg_transform_dialect.mlir",
"expand_tensor_shapes.mlir",
"export_benchmark_funcs.mlir",
"fold_unit_dims.mlir",
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
"form_scalar_dispatches.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"dispatch_linalg_transform_dialect.mlir"
"expand_tensor_shapes.mlir"
"export_benchmark_funcs.mlir"
"fold_unit_dims.mlir"
"form_dispatch_regions.mlir"
"form_dispatch_workgroups.mlir"
"form_scalar_dispatches.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-fold-unit-extent-dims))" %s | FileCheck %s

func.func @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> {
%0 = tensor.empty() : tensor<1x1x10xf32>
%1 = flow.dispatch.region[] -> (tensor<1x1x10xf32>) {
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<1x1x10xf32>) outs(%0 : tensor<1x1x10xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%3 = arith.addf %b0, %b0 : f32
linalg.yield %3 : f32
} -> tensor<1x1x10xf32>
flow.return %2 : tensor<1x1x10xf32>
}
return %1 : tensor<1x1x10xf32>
}
// CHECK: func @no_fold_unit_dims_in_dispatches(%[[ARG0:.+]]: tensor<1x1x10xf32>)
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>)
// CHECK: flow.return %[[GENERIC]]
// CHECK: return %[[DISPATCH]]

0 comments on commit ac0b103

Please sign in to comment.