diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp index a3934d1988e0..59ff74b61a36 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp @@ -200,6 +200,22 @@ static LogicalResult contractOpFilter(Operation *op) { linalgOp.getNumParallelLoops() <= 3); } +// A `dealloc` is converted into a call to `free` on the underlying data buffer. +// The memref descriptor being an SSA value, there is no need to clean it up +// in any way. +struct DropSharedMemoryDeallocOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DeallocOp op, + PatternRewriter &rewriter) const override { + if (!hasSharedMemoryAddressSpace( + llvm::cast(op.getMemref().getType()))) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void populateVectorTransferToGPUMMAPreparationPatterns( @@ -232,5 +248,9 @@ void populateContractPromotionPatterns(RewritePatternSet &patterns, .addFilter(contractOpFilter)); } +void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h index 1875c5e5cfa4..a85a348d7f76 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h @@ -27,6 +27,8 @@ void populateCombineVectorTransferReadBroadcastPatterns( void populateContractPromotionPatterns(RewritePatternSet &patterns, ArrayRef operandsToPromote); +void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns); + } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp index 575608ddc444..10c4310feb6a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp @@ -4,10 +4,10 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h" #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" -#include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -38,22 +38,6 @@ namespace iree_compiler { namespace { -// A `dealloc` is converted into a call to `free` on the underlying data buffer. -// The memref descriptor being an SSA value, there is no need to clean it up -// in any way. -struct DropSharedMemoryDeallocOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::DeallocOp op, - PatternRewriter &rewriter) const override { - if (!hasSharedMemoryAddressSpace( - llvm::cast(op.getMemref().getType()))) - return failure(); - rewriter.eraseOp(op); - return success(); - } -}; - /// A pass that replaces all occurrences of GPU device operations with their /// corresponding NVVM equivalent. /// @@ -99,7 +83,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase { // Run Vector -> Vector transformations ahead of conversion to LLVM. { RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + populateDropSharedMemoryDeallocOpPatterns(patterns); populateScalarizeMathOps(patterns); populateConvertSharedMemoryAllocOps(patterns); vector::populateVectorToVectorCanonicalizationPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 617523616982..7b996fdb44af 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h" #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -69,6 +70,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase { // Run Vector -> Vector transformations ahead of conversion to LLVM. { RewritePatternSet patterns(&getContext()); + populateDropSharedMemoryDeallocOpPatterns(patterns); populateScalarizeMathOps(patterns); populateConvertSharedMemoryAllocOps(patterns); vector::populateVectorToVectorCanonicalizationPatterns(patterns); @@ -112,14 +114,10 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase { LLVMConversionTarget target(getContext()); populateFuncToLLVMFuncOpConversionPattern(converter, llvmPatterns); configureGpuToROCDLConversionLegality(target); - target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { - if (isEntryPoint(funcOp)) - return false; - return true; - }); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } + ConvertToDynamicSharedMemory(m); } }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 6afc6091aece..5f18556dbe59 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -556,8 +556,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) { // debug info well. pm.addPass(createStripDebugInfoPass()); // Cast address spaces of all function arguments to generic - if (!useROCM) - pm.addPass(createLLVMGPUCastAddressSpaceFunction()); + pm.addPass(createLLVMGPUCastAddressSpaceFunction()); if (useROCM) { // convert to ROCDL. pm.addPass(createConvertToROCDLPass());