Skip to content

Commit

Permalink
[CPU] Enable tileDispatchUsingForall for mmt4d and convolution pipeli…
Browse files Browse the repository at this point in the history
…nes. (#18618)

`tileDispatchUsingForall` is enabled for mmt4d and conv2d pipeline.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Co-authored-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
pashu123 and MaheshRavishankar authored Oct 6, 2024
1 parent 126e334 commit bb5f2f5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_RECONCILETRANSLATIONINFOPASS
Expand Down Expand Up @@ -263,10 +264,6 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
return success();
}

if (!llvm::hasSingleElement(body)) {
return funcOp.emitOpError("unhandled function with multiple blocks");
}

auto forAllOps = body.getOps<scf::ForallOp>();
SmallVector<scf::ForallOp> workgroupForAllOps = llvm::to_vector(
llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) {
Expand Down Expand Up @@ -295,6 +292,10 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
"scf.forall ops withing the function");
}

if (!llvm::hasSingleElement(body)) {
return funcOp.emitOpError("unhandled function with multiple blocks");
}

scf::ForallOp forallOp = *forAllOps.begin();
if (failed(resolveWorkgroupCount(rewriter, funcOp, forallOp))) {
return failure();
Expand Down Expand Up @@ -359,9 +360,10 @@ void ReconcileTranslationInfoPass::runOnOperation() {
auto innerModuleOp = variantOp.getInnerModule();

auto exportOps = variantOp.getOps<IREE::HAL::ExecutableExportOp>();

// reconciliation for multiple export ops is unsupported.
if (!llvm::hasSingleElement(exportOps)) {
variantOp.emitOpError("reconciliation for multiple export ops unsupported");
return signalPassFailure();
return;
}
auto exportOp = *exportOps.begin();
IRRewriter rewriter(&getContext());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info, canonicalize)))" %s --verify-diagnostics --allow-unregistered-dialect | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
]>
hal.executable private @err_multiple_entry_point {
// expected-error @+1 {{reconciliation for multiple export ops unsupported}}
hal.executable.variant public @reconcile_workgroup_size target(#hal.executable.target<"", "", {}>) {
hal.executable.export public @entry_point1 layout(#pipeline_layout)
hal.executable.export public @entry_point2 layout(#pipeline_layout)
}
}

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
]>
Expand Down
61 changes: 44 additions & 17 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ static llvm::cl::opt<bool> clEnableVectorContractCustomKernels(
"LLVMCPUMmt4dVectorLowering pass."),
llvm::cl::init(false));

static llvm::cl::opt<bool> clTileDispatchUsingForall(
"iree-llvmcpu-tile-dispatch-using-forall",
llvm::cl::desc("Enable tile and distribute to workgroups using scf.forall"),
llvm::cl::init(false));

// By default, IREE does not enable the Armv9-A streaming SVE mode in the
// presence of scalable vectors (even when using `+sme`), as currently there's
// no cost model of when it could be beneficial. This flag will effectively make
Expand All @@ -104,11 +109,18 @@ static llvm::cl::opt<bool> clForceArmStreaming(
"than SVE). Requires the +sme feature flag."),
llvm::cl::init(false));

static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
// TODO: Enable `TileDispatchUsingForall` for every pipeline.
static void addTileAndDistributePasses(OpPassManager &funcPassManager,
bool enableTileDispatchUsingForall) {
if (enableTileDispatchUsingForall || clTileDispatchUsingForall) {
funcPassManager.addPass(
createTileAndDistributeToWorkgroupsUsingForallOpPass());
} else {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
}
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
Expand Down Expand Up @@ -333,7 +345,8 @@ void buildLLVMCPUVectorLoweringPipeline(
void addCPUBufferOpsTileAndVectorizePipeline(
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/true);

// Skip tiling reduction loops because this is expected to apply on copy ops
// only.
Expand Down Expand Up @@ -370,7 +383,8 @@ void addCPUBufferOpsTileAndVectorizePipeline(
void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/false);

SmallVector<int64_t> allFusableLevels(tilingConfig.getFusableLevels());
// Apply tile and fuse to all the non-distribution fusable levels. Skip
Expand Down Expand Up @@ -449,7 +463,8 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
void addConvTileAndDecomposeExpertPassPipeline(
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/true);

// Run LLVMTileAndFuse firstly in case that we have fill + conv + generic
// ops. At this stage, we do not apply vectorization. The reduction dim won't
Expand Down Expand Up @@ -512,7 +527,8 @@ void addConvTileAndDecomposeExpertPassPipeline(
void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/true);

funcPassManager.addPass(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
Expand Down Expand Up @@ -560,7 +576,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/true);

// The below two passes are nop if pack/unpack is not specified in ukernels
// attribute. By default, they are disabled.
Expand Down Expand Up @@ -603,7 +620,8 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
void addCPULinalgExtTileAndVectorizePipeline(
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
LLVMCPUPipelineOptions &pipelineOpt) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/false);
funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
Expand Down Expand Up @@ -642,7 +660,8 @@ void addCPULinalgExtTileAndVectorizePipeline(
}

void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) {
addTileAndDistributePasses(funcPassManager);
addTileAndDistributePasses(funcPassManager,
/*enableTileDispatchUsingForall=*/false);
addCPUBufferizePasses(funcPassManager);
}

Expand Down Expand Up @@ -790,13 +809,21 @@ void buildLLVMCPUCodegenConfigurationPassPipeline(

void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager,
bool enableAArch64SME) {
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createLLVMCPULowerExecutableTargetPass);

{
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createLLVMCPULowerExecutableTargetPass);
}

variantPassManager.addPass(createReconcileTranslationInfoPass());

// Run conversion to LLVM at `ModuleOp` granularity.
addLowerToLLVMPasses(modulePassManager, enableAArch64SME);
{
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
addLowerToLLVMPasses(modulePassManager, enableAArch64SME);
}
LLVM_DEBUG({
llvm::dbgs() << "LLVMCPU codegen pass pipeline:\n";
variantPassManager.printAsTextualPipeline(llvm::dbgs());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ func.func @ukernel_dispatch() attributes {hal.executable.target = #executable_ta
}
// CHECK-LABEL: func @ukernel_dispatch()
// Checks scf.for for distribution loops.
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.forall
// Checks scf.for for outer and inner parallel loops.
// CHECK: scf.for
// CHECK: scf.for
Expand Down

0 comments on commit bb5f2f5

Please sign in to comment.