From bb5f2f5c32fac1c8b95c00f119f24e8866421096 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Sun, 6 Oct 2024 11:55:59 +0530 Subject: [PATCH] [CPU] Enable tileDispatchUsingForall for mmt4d and convolution pipelines. (#18618) `tileDispatchUsingForall` is enabled for mmt4d and conv2d pipeline. --------- Signed-off-by: MaheshRavishankar Co-authored-by: MaheshRavishankar --- .../Common/ReconcileTranslationInfo.cpp | 14 +++-- .../test/reconcile_translation_info.mlir | 13 ---- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 61 +++++++++++++------ .../Codegen/LLVMCPU/test/pipeline_tests.mlir | 3 +- 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp index f09d3a693b24..0fc04f80baff 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp @@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" + namespace mlir::iree_compiler { #define GEN_PASS_DEF_RECONCILETRANSLATIONINFOPASS @@ -263,10 +264,6 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter, return success(); } - if (!llvm::hasSingleElement(body)) { - return funcOp.emitOpError("unhandled function with multiple blocks"); - } - auto forAllOps = body.getOps(); SmallVector workgroupForAllOps = llvm::to_vector( llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) { @@ -295,6 +292,10 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter, "scf.forall ops withing the function"); } + if (!llvm::hasSingleElement(body)) { + return funcOp.emitOpError("unhandled function with multiple blocks"); + } + scf::ForallOp forallOp = *forAllOps.begin(); if (failed(resolveWorkgroupCount(rewriter, funcOp, forallOp))) { return failure(); @@ -359,9 +360,10 @@ void ReconcileTranslationInfoPass::runOnOperation() { auto innerModuleOp = variantOp.getInnerModule(); auto exportOps = variantOp.getOps(); + + // reconciliation for multiple export ops is unsupported. if (!llvm::hasSingleElement(exportOps)) { - variantOp.emitOpError("reconciliation for multiple export ops unsupported"); - return signalPassFailure(); + return; } auto exportOp = *exportOps.begin(); IRRewriter rewriter(&getContext()); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir index fa56c6d55c94..c7e95db23cd2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir @@ -1,18 +1,5 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info, canonicalize)))" %s --verify-diagnostics --allow-unregistered-dialect | FileCheck %s -#pipeline_layout = #hal.pipeline.layout -]> -hal.executable private @err_multiple_entry_point { - // expected-error @+1 {{reconciliation for multiple export ops unsupported}} - hal.executable.variant public @reconcile_workgroup_size target(#hal.executable.target<"", "", {}>) { - hal.executable.export public @entry_point1 layout(#pipeline_layout) - hal.executable.export public @entry_point2 layout(#pipeline_layout) - } -} - -// ----- - #pipeline_layout = #hal.pipeline.layout ]> diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 91f81569863d..cac496f92805 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -91,6 +91,11 @@ static llvm::cl::opt clEnableVectorContractCustomKernels( "LLVMCPUMmt4dVectorLowering pass."), llvm::cl::init(false)); +static llvm::cl::opt clTileDispatchUsingForall( + "iree-llvmcpu-tile-dispatch-using-forall", + llvm::cl::desc("Enable tile and distribute to workgroups using scf.forall"), + llvm::cl::init(false)); + // By default, IREE does not enable the Armv9-A streaming SVE mode in the // presence of scalable vectors (even when using `+sme`), as currently there's // no cost model of when it could be beneficial. This flag will effectively make @@ -104,11 +109,18 @@ static llvm::cl::opt clForceArmStreaming( "than SVE). Requires the +sme feature flag."), llvm::cl::init(false)); -static void addTileAndDistributePasses(OpPassManager &funcPassManager) { - funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass()); - funcPassManager.addPass(createCSEPass()); - funcPassManager.addPass(createConvertToDestinationPassingStylePass()); - funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); +// TODO: Enable `TileDispatchUsingForall` for every pipeline. +static void addTileAndDistributePasses(OpPassManager &funcPassManager, + bool enableTileDispatchUsingForall) { + if (enableTileDispatchUsingForall || clTileDispatchUsingForall) { + funcPassManager.addPass( + createTileAndDistributeToWorkgroupsUsingForallOpPass()); + } else { + funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass()); + funcPassManager.addPass(createCSEPass()); + funcPassManager.addPass(createConvertToDestinationPassingStylePass()); + funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass()); + } funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createFuseTensorPadWithConsumerPass()); @@ -333,7 +345,8 @@ void buildLLVMCPUVectorLoweringPipeline( void addCPUBufferOpsTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/true); // Skip tiling reduction loops because this is expected to apply on copy ops // only. @@ -370,7 +383,8 @@ void addCPUBufferOpsTileAndVectorizePipeline( void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/false); SmallVector allFusableLevels(tilingConfig.getFusableLevels()); // Apply tile and fuse to all the non-distribution fusable levels. Skip @@ -449,7 +463,8 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, void addConvTileAndDecomposeExpertPassPipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/true); // Run LLVMTileAndFuse firstly in case that we have fill + conv + generic // ops. At this stage, we do not apply vectorization. The reduction dim won't @@ -512,7 +527,8 @@ void addConvTileAndDecomposeExpertPassPipeline( void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/true); funcPassManager.addPass(createLLVMCPUTileAndFusePass( static_cast(tilingConfig.getVectorCommonParallelLevel()))); @@ -560,7 +576,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, void addCPUDataTilingPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/true); // The below two passes are nop if pack/unpack is not specified in ukernels // attribute. By default, they are disabled. @@ -603,7 +620,8 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager, void addCPULinalgExtTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/false); funcPassManager.addPass( createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel())); // TODO: Remove the pass once we have PartialReductionOpInterface implemented @@ -642,7 +660,8 @@ void addCPULinalgExtTileAndVectorizePipeline( } void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) { - addTileAndDistributePasses(funcPassManager); + addTileAndDistributePasses(funcPassManager, + /*enableTileDispatchUsingForall=*/false); addCPUBufferizePasses(funcPassManager); } @@ -790,13 +809,21 @@ void buildLLVMCPUCodegenConfigurationPassPipeline( void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager, bool enableAArch64SME) { - OpPassManager &modulePassManager = variantPassManager.nest(); - modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass()); - FunctionLikeNest(modulePassManager) - .addPass(createLLVMCPULowerExecutableTargetPass); + + { + OpPassManager &modulePassManager = variantPassManager.nest(); + modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass()); + FunctionLikeNest(modulePassManager) + .addPass(createLLVMCPULowerExecutableTargetPass); + } + + variantPassManager.addPass(createReconcileTranslationInfoPass()); // Run conversion to LLVM at `ModuleOp` granularity. - addLowerToLLVMPasses(modulePassManager, enableAArch64SME); + { + OpPassManager &modulePassManager = variantPassManager.nest(); + addLowerToLLVMPasses(modulePassManager, enableAArch64SME); + } LLVM_DEBUG({ llvm::dbgs() << "LLVMCPU codegen pass pipeline:\n"; variantPassManager.printAsTextualPipeline(llvm::dbgs()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir index 4d91d677d45f..fde534292a3d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir @@ -276,8 +276,7 @@ func.func @ukernel_dispatch() attributes {hal.executable.target = #executable_ta } // CHECK-LABEL: func @ukernel_dispatch() // Checks scf.for for distribution loops. -// CHECK: scf.for -// CHECK: scf.for +// CHECK: scf.forall // Checks scf.for for outer and inner parallel loops. // CHECK: scf.for // CHECK: scf.for