Skip to content

Commit

Permalink
Merge branch 'main' into avarma_truncf_pr_1_insert_cores
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma authored Oct 10, 2024
2 parents db518b0 + f9f78a9 commit 5d3e2a1
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 237 deletions.
328 changes: 134 additions & 194 deletions compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@

#include "AIEDialect.h"
#include "Passes.h"
#include "aievec/AIEVecDialect.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -24,148 +19,109 @@
#define DEBUG_TYPE "amdaie-standard-lowering"

using namespace mlir;
using namespace mlir::vector;
using namespace xilinx;
using namespace xilinx::AIE;

struct AMDAIEUseLockToStdLowering : OpConversionPattern<UseLockOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
UseLockOp useLock, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<DeviceOp>(useLock->getParentOp())) {
// Generate the intrinsic name
std::string funcName = "llvm.aie2.";
if (useLock.getAction() == LockAction::Acquire ||
useLock.getAction() == LockAction::AcquireGreaterEqual)
funcName += "acquire";
else if (useLock.getAction() == LockAction::Release)
funcName += "release";
// TODO(max): this can be simplified with
// SymbolTable::lookupNearestSymbolFrom if DeviceOp ceases to be a
// SymbolTable
func::FuncOp useLockFunc =
useLock->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
funcName);

SmallVector<Value, 2> args;
int lockValue = useLock.getValue().value_or(1);

// AIE2 acquire greater equal is encoded as a negative value.
if (useLock.getAction() == LockAction::AcquireGreaterEqual)
lockValue = -lockValue;
args.push_back(rewriter.create<arith::IndexCastOp>(
useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
useLock.getLock()));
args.push_back(rewriter.create<arith::ConstantOp>(
useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
rewriter.getI32IntegerAttr(lockValue)));
rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), useLockFunc,
args);
}
static void lockToStd(UseLockOp useLock, IRRewriter &rewriter) {
if (!isa<DeviceOp>(useLock->getParentOp())) {
std::string funcName = [&]() {
switch (useLock.getAction()) {
case LockAction::Acquire:
case LockAction::AcquireGreaterEqual:
return "llvm.aie2.acquire";
case LockAction::Release:
return "llvm.aie2.release";
default:
assert(false && "Unknown lock action");
}
}();

// TODO(max): this can be simplified with
// SymbolTable::lookupNearestSymbolFrom if DeviceOp ceases to be a
// SymbolTable
ModuleOp modOp = useLock->getParentOfType<ModuleOp>();
func::FuncOp func = modOp.lookupSymbol<func::FuncOp>(funcName);

int lockValue = useLock.getValue().value_or(1);

// AIE2 acquire greater equal is encoded as a negative value.
if (useLock.getAction() == LockAction::AcquireGreaterEqual)
lockValue = -lockValue;

rewriter.eraseOp(useLock);
return success();
rewriter.setInsertionPoint(useLock);
IntegerAttr lockAttr = rewriter.getI32IntegerAttr(lockValue);
IntegerType type = IntegerType::get(rewriter.getContext(), 32);
Location loc = useLock.getLoc();

SmallVector<Value, 2> args{
rewriter.create<arith::IndexCastOp>(loc, type, useLock.getLock()),
rewriter.create<arith::ConstantOp>(loc, type, lockAttr)};

rewriter.create<func::CallOp>(loc, func, args);
}
};

struct AMDAIEBufferToStandard : OpConversionPattern<BufferOp> {
using OpConversionPattern::OpConversionPattern;
ModuleOp &module;
// TODO(max): these should be optionals instead of checking against -1
// but the pass itself needs to be updated.
int tileCol = 0;
int tileRow = 0;
AMDAIEBufferToStandard(MLIRContext *context, ModuleOp &m, int tileCol = -1,
int tileRow = -1)
: OpConversionPattern(context),
module(m),
tileCol(tileCol),
tileRow(tileRow) {}
LogicalResult matchAndRewrite(
BufferOp buffer, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.setInsertionPointToStart(module.getBody());
auto t = llvm::cast<MemRefType>(buffer.getType());
StringRef symName = name(buffer).getValue();
// Don't emit initialization for cores that don't "own" the buffer (to
// prevent duplication in the data section of the elf/object file)
rewriter.create<memref::GlobalOp>(
rewriter.getUnknownLoc(), symName, rewriter.getStringAttr("public"),
buffer.getType(), nullptr, /*constant*/ false,
/*alignment*/ nullptr);

for (OpOperand &use : make_early_inc_range(buffer.getResult().getUses())) {
Operation *user = use.getOwner();
rewriter.setInsertionPoint(user);
auto allocated = rewriter.create<memref::GetGlobalOp>(
rewriter.getUnknownLoc(), t, symName);
// Assume that buffers are aligned so they can be vectorized.
rewriter.create<memref::AssumeAlignmentOp>(rewriter.getUnknownLoc(),
allocated, 32);
use.set(allocated.getResult());
}
rewriter.eraseOp(useLock);
}

rewriter.eraseOp(buffer);
return success();
static void bufferToStd(ModuleOp module, BufferOp buffer,
IRRewriter &rewriter) {
Location loc = buffer.getLoc();
rewriter.setInsertionPointToStart(module.getBody());
StringRef symName = name(buffer).getValue();
MemRefType type = llvm::cast<MemRefType>(buffer.getType());
// Don't emit initialization for cores that don't "own" the buffer (to
// prevent duplication in the data section of the elf/object file)
rewriter.create<memref::GlobalOp>(
loc, symName, rewriter.getStringAttr("public"), type, nullptr,
/*constant*/ false,
/*alignment*/ nullptr);

for (OpOperand &use : make_early_inc_range(buffer.getResult().getUses())) {
Operation *user = use.getOwner();
rewriter.setInsertionPoint(user);

auto allocated = rewriter.create<memref::GetGlobalOp>(loc, type, symName);
// Assume that buffers are aligned so they can be vectorized.
rewriter.create<memref::AssumeAlignmentOp>(loc, allocated, 32);
use.set(allocated.getResult());
}
};

struct AMDAIECoreToStandardFunc : OpConversionPattern<CoreOp> {
using OpConversionPattern::OpConversionPattern;
IRMapping &mapper;
// TODO(max): these should be optionals instead of checking against -1
// but the pass itself needs to be updated.
int tileCol = 0;
int tileRow = 0;

AMDAIECoreToStandardFunc(MLIRContext *context, IRMapping &mapper,
int tileCol = 1, int tileRow = 1)
: OpConversionPattern(context),
mapper(mapper),
tileCol(tileCol),
tileRow(tileRow) {}

LogicalResult matchAndRewrite(
CoreOp coreOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TileOp t = getTileOp(*coreOp);
int col = t.getCol();
int row = t.getRow();

// Only pull code for the indicated function
if ((tileRow != row && tileRow != -1) ||
(tileCol != col && tileCol != -1)) {
rewriter.eraseOp(coreOp);
return success();
}
rewriter.eraseOp(buffer);
}

// The parent should be an AIE.device op.
rewriter.setInsertionPointAfter(coreOp->getParentOp());

std::string coreName("core_" + std::to_string(col) + "_" +
std::to_string(row));
auto coreFunc = rewriter.create<func::FuncOp>(
rewriter.getUnknownLoc(), coreName,
FunctionType::get(rewriter.getContext(), {}, {}));

rewriter.cloneRegionBefore(coreOp.getBody(), coreFunc.getBody(),
coreFunc.getBody().begin(), mapper);

// Rewrite the AIE.end() op
coreFunc.getBody().walk([&](Operation *childOp) {
rewriter.setInsertionPointAfter(childOp);
if (isa<EndOp>(childOp)) {
rewriter.create<func::ReturnOp>(rewriter.getUnknownLoc(),
ValueRange({}));
rewriter.eraseOp(childOp);
}
});
static void coreToStd(CoreOp coreOp, IRRewriter &rewriter, int tileCol,
int tileRow) {
TileOp t = getTileOp(*coreOp);
int col = t.getCol();
int row = t.getRow();

// Only pull code for the indicated function
if ((tileRow != row && tileRow != -1) || (tileCol != col && tileCol != -1)) {
rewriter.eraseOp(coreOp);
return success();
return;
}
};

// The parent should be an AIE.device op.
rewriter.setInsertionPointAfter(coreOp->getParentOp());

// LLVM-style of the above (creating a string attribute):
std::string fName = "core_" + std::to_string(col) + "_" + std::to_string(row);
auto coreFunc = rewriter.create<func::FuncOp>(
rewriter.getUnknownLoc(), fName,
FunctionType::get(rewriter.getContext(), {}, {}));

IRMapping mapper;
rewriter.cloneRegionBefore(coreOp.getBody(), coreFunc.getBody(),
coreFunc.getBody().begin(), mapper);

// Rewrite the AIE.end op
coreFunc.getBody().walk([&](EndOp endOp) {
rewriter.setInsertionPointAfter(endOp);
rewriter.create<func::ReturnOp>(endOp->getLoc(), ValueRange({}));
rewriter.eraseOp(endOp);
});

rewriter.eraseOp(coreOp);
}

// Move all the ops with OpTy inside device, to just before the device.
template <typename OpTy>
Expand Down Expand Up @@ -211,6 +167,30 @@ struct AMDAIECoreToStandardPass : mlir::OperationPass<ModuleOp> {
llvm::cl::desc("Y coordinate of tile to generate code for"),
llvm::cl::init(-1)};

// Assert that cores are isolated
static bool coresAreIsolated(ModuleOp m) {
SmallVector<CoreOp> coreOps;
m->walk([&](CoreOp coreOp) { coreOps.push_back(coreOp); });
for (CoreOp coreOp : coreOps) {
auto walkResult = coreOp->walk([&](Operation *childOp) {
if (childOp == coreOp) return WalkResult::advance();
for (Value operand : childOp->getOperands()) {
if (Operation *operandOp = operand.getDefiningOp()) {
if (!coreOp->isAncestor(operandOp)) {
operandOp->emitOpError(
"is not in the core in which it is used. Cores must be "
"`isolated` before this point.");
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) return false;
}
return true;
}

void runOnOperation() override {
ModuleOp m = getOperation();

Expand All @@ -219,79 +199,39 @@ struct AMDAIECoreToStandardPass : mlir::OperationPass<ModuleOp> {
return signalPassFailure();
}

OpBuilder builder = OpBuilder::atBlockEnd(m.getBody());
MLIRContext *ctx = &getContext();
IRRewriter rewriter(ctx);
rewriter.setInsertionPointToEnd(m.getBody());

// Ensure that we don't have an incorrect target triple. This may override
// Ensure that we don't have an incorrect target triple. This may override
// some bogus target triple in the original mlir.
m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
builder.getStringAttr("aie2"));

IRMapping mapper;
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<cf::ControlFlowDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<affine::AffineDialect>();
target.addLegalDialect<VectorDialect>();
target.addLegalDialect<mlir::iree_compiler::aievec::AIEVecDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalOp<func::FuncOp, ModuleOp>();

RewritePatternSet patterns(&getContext());

StringAttr privateSym = StringAttr::get(&getContext(), "private");
rewriter.getStringAttr("aie2"));

StringAttr privateSym = StringAttr::get(ctx, "private");
auto buildDecl = [&](const std::string &funcName) {
builder.create<func::FuncOp>(
builder.getUnknownLoc(), funcName,
FunctionType::get(builder.getContext(),
{builder.getI32Type(), builder.getI32Type()}, {}),
rewriter.create<func::FuncOp>(
rewriter.getUnknownLoc(), funcName,
FunctionType::get(ctx, {rewriter.getI32Type(), rewriter.getI32Type()},
{}),
privateSym, ArrayAttr{}, ArrayAttr{});
};
buildDecl("llvm.aie2.acquire");
buildDecl("llvm.aie2.release");

patterns.add<AMDAIEUseLockToStdLowering>(m.getContext());
patterns.add<AMDAIEBufferToStandard>(m.getContext(), m, tileCol, tileRow);
if (failed(applyPartialConversion(m, target, std::move(patterns))))
return signalPassFailure();
m.walk([&](UseLockOp useLock) { lockToStd(useLock, rewriter); });

// Assert that cores are isolated
{
SmallVector<CoreOp> coreOps;
m->walk([&](CoreOp coreOp) { coreOps.push_back(coreOp); });
for (CoreOp coreOp : coreOps) {
auto walkResult = coreOp->walk([&](Operation *childOp) {
if (childOp == coreOp) return WalkResult::advance();
for (Value operand : childOp->getOperands()) {
if (Operation *operandOp = operand.getDefiningOp()) {
if (!coreOp->isAncestor(operandOp)) {
operandOp->emitOpError(
"is not in the core in which it is used. Cores must be "
"`isolated` before this point.");
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) return signalPassFailure();
}
}
RewritePatternSet outlinePatterns(&getContext());
outlinePatterns.add<AMDAIECoreToStandardFunc>(m.getContext(), mapper,
tileCol, tileRow);
if (failed(applyPartialConversion(m, target, std::move(outlinePatterns))))
return signalPassFailure();
m.walk([&](BufferOp buffer) { bufferToStd(m, buffer, rewriter); });

if (!coresAreIsolated(m)) return signalPassFailure();

// Move all the func.func ops and memref.globals from the device to the
// module
m.walk(
[&](CoreOp coreOp) { coreToStd(coreOp, rewriter, tileCol, tileRow); });

// Move all the func.func ops and memref.globals from device to module.
DeviceOp device = *m.getOps<DeviceOp>().begin();
outlineOps<memref::GlobalOp>(device);
outlineOps<func::FuncOp>(device);

MLIRContext &context = getContext();
IRRewriter rewriter(&context);
rewriter.eraseOp(device);
}
};
Expand Down
Loading

0 comments on commit 5d3e2a1

Please sign in to comment.