Skip to content

Commit

Permalink
Adding CUFCommon.{h,cpp} for CUF utilities (llvm#113740)
Browse files Browse the repository at this point in the history
  • Loading branch information
Renaud-K authored Oct 25, 2024
1 parent 242ccd2 commit 3acf856
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
25 changes: 25 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CUFCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===-- CUFCommon.h -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
#define FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/BuiltinOps.h"

static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod";

namespace cuf {

/// Retrieve or create the CUDA Fortran GPU module in the given \p mod.
mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
mlir::SymbolTable &symTab);

} // namespace cuf

#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
CUFCommon.cpp
CUFAddConstructor.cpp
CUFDeviceGlobal.cpp
CUFOpConversion.cpp
Expand Down
7 changes: 3 additions & 4 deletions flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand All @@ -24,8 +25,6 @@ namespace fir {

namespace {

static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};

static constexpr llvm::StringRef cudaFortranCtorName{
"__cudaFortranConstructor"};

Expand Down Expand Up @@ -60,15 +59,15 @@ struct CUFAddConstructor
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);

// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
builder.getStringAttr(cudaDeviceModuleName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
Expand Down
31 changes: 31 additions & 0 deletions flang/lib/Optimizer/Transforms/CUFCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- CUFCommon.cpp - Shared functions between passes ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"

/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
mlir::SymbolTable &symTab) {
if (auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName))
return gpuMod;

auto *ctx = mod.getContext();
mod->setAttr(mlir::gpu::GPUDialect::getContainerModuleAttrName(),
mlir::UnitAttr::get(ctx));

mlir::OpBuilder builder(ctx);
auto gpuMod = builder.create<mlir::gpu::GPUModuleOp>(mod.getLoc(),
cudaDeviceModuleName);
llvm::SmallVector<mlir::Attribute> targets;
targets.push_back(mlir::NVVM::NVVMTargetAttr::get(ctx));
gpuMod.setTargetsAttr(builder.getArrayAttr(targets));
mlir::Block::iterator insertPt(mod.getBodyRegion().front().end());
symTab.insert(gpuMod, insertPt);
return gpuMod;
}

0 comments on commit 3acf856

Please sign in to comment.