Skip to content

Commit

Permalink
[ROCm] Add pieces from the CUDA codgen lowering path (#14769)
Browse files Browse the repository at this point in the history
Enable iree-llvmgpu-cast-address-space-function.
Add DropSharedMemoryDeallocOp rewrite pattern in ConvertToROCDLPass.
Remove dynamic legalization for FuncOp.
Add ConvertToDynamicSharedMemory ConvertToROCDLPass.

I am not at all confident that these need to be the same on the ROCm
path as are on the CUDA path. If there is no one to express confidence I
will dig deeper to make sure it makes sense.

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
  • Loading branch information
sogartar and sogartar authored Sep 8, 2023
1 parent 26528b9 commit 25c2ab3
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
20 changes: 20 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ static LogicalResult contractOpFilter(Operation *op) {
linalgOp.getNumParallelLoops() <= 3);
}

// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DropSharedMemoryDeallocOp : public OpRewritePattern<memref::DeallocOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::DeallocOp op,
PatternRewriter &rewriter) const override {
if (!hasSharedMemoryAddressSpace(
llvm::cast<MemRefType>(op.getMemref().getType())))
return failure();
rewriter.eraseOp(op);
return success();
}
};

} // namespace

void populateVectorTransferToGPUMMAPreparationPatterns(
Expand Down Expand Up @@ -232,5 +248,9 @@ void populateContractPromotionPatterns(RewritePatternSet &patterns,
.addFilter(contractOpFilter));
}

void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns) {
patterns.add<DropSharedMemoryDeallocOp>(patterns.getContext());
}

} // namespace iree_compiler
} // namespace mlir
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void populateCombineVectorTransferReadBroadcastPatterns(
void populateContractPromotionPatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> operandsToPromote);

void populateDropSharedMemoryDeallocOpPatterns(RewritePatternSet &patterns);

} // namespace iree_compiler
} // namespace mlir

Expand Down
20 changes: 2 additions & 18 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h"
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
Expand Down Expand Up @@ -38,22 +38,6 @@ namespace iree_compiler {

namespace {

// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DropSharedMemoryDeallocOp : public OpRewritePattern<memref::DeallocOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::DeallocOp op,
PatternRewriter &rewriter) const override {
if (!hasSharedMemoryAddressSpace(
llvm::cast<MemRefType>(op.getMemref().getType())))
return failure();
rewriter.eraseOp(op);
return success();
}
};

/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
Expand Down Expand Up @@ -99,7 +83,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
RewritePatternSet patterns(&getContext());
patterns.insert<DropSharedMemoryDeallocOp>(&getContext());
populateDropSharedMemoryDeallocOpPatterns(patterns);
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
Expand Down
8 changes: 3 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h"
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
RewritePatternSet patterns(&getContext());
populateDropSharedMemoryDeallocOpPatterns(patterns);
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
Expand Down Expand Up @@ -112,14 +114,10 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
LLVMConversionTarget target(getContext());
populateFuncToLLVMFuncOpConversionPattern(converter, llvmPatterns);
configureGpuToROCDLConversionLegality(target);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (isEntryPoint(funcOp))
return false;
return true;
});
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
}
ConvertToDynamicSharedMemory(m);
}
};

Expand Down
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) {
// debug info well.
pm.addPass(createStripDebugInfoPass());
// Cast address spaces of all function arguments to generic
if (!useROCM)
pm.addPass(createLLVMGPUCastAddressSpaceFunction());
pm.addPass(createLLVMGPUCastAddressSpaceFunction());
if (useROCM) {
// convert to ROCDL.
pm.addPass(createConvertToROCDLPass());
Expand Down

0 comments on commit 25c2ab3

Please sign in to comment.