Skip to content

Commit

Permalink
[SnitchDMA] Split transfer legalization out of DMAToLLVM
Browse files Browse the repository at this point in the history
It is up to the target, in this case snitch, to dictate which DMA transfers are directly legal or not. This logic previously lived in `DMAToLLVM` and due to being part of the "one-shot-to-llvm" pass, difficult to test and implement.

This PR therefore splits the logic into a dedicated legalization pass with the lowering to LLVM greatly simplified.
  • Loading branch information
zero9178 committed Sep 2, 2024
1 parent ec8499a commit cbb17ca
Show file tree
Hide file tree
Showing 11 changed files with 500 additions and 349 deletions.
214 changes: 42 additions & 172 deletions codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,24 @@ static bool isContiguous(MemRefType memRefType) {
}

namespace {
struct StartTransferOp1DLowering : ConvertOpToLLVMPattern<StartTransferOp> {
struct StartTransferOpLowering : ConvertOpToLLVMPattern<StartTransferOp> {

LLVM::LLVMFuncOp dmaStart1DFunc;
LLVM::LLVMFuncOp dmaStart2DFunc;

StartTransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, /*benefit=*/2),
dmaStart1DFunc(dmaStart1DFunc) {}
StartTransferOpLowering(LLVM::LLVMFuncOp dmaStart1DFunc,
LLVM::LLVMFuncOp dmaStart2DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), dmaStart1DFunc(dmaStart1DFunc),
dmaStart2DFunc(dmaStart2DFunc) {}

LogicalResult match(StartTransferOp op) const override {
return success(isContiguous(op.getSource().getType()) &&
isContiguous(op.getDest().getType()));
}
LogicalResult
matchAndRewrite(StartTransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned rank = op.getSource().getType().getRank();
if (rank != 1 && rank != 2)
return failure();

void rewrite(StartTransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Expand All @@ -87,170 +89,38 @@ struct StartTransferOp1DLowering : ConvertOpToLLVMPattern<StartTransferOp> {
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getDest().getType());

MemRefType sourceMemRef = op.getSource().getType();
SmallVector<Value> dynamicSizes;
for (auto [index, dim] : llvm::enumerate(sourceMemRef.getShape()))
if (ShapedType::isDynamic(dim))
dynamicSizes.push_back(
sourceDescriptor.size(rewriter, op->getLoc(), index));

SmallVector<Value> sizes;
SmallVector<Value> strides;
Value totalSize;
getMemRefDescriptorSizes(
op->getLoc(),
// Offsets are not considered an identity layout.
// Get rid of the layout entirely for the size calculation.
MemRefType::get(sourceMemRef.getShape(), sourceMemRef.getElementType(),
nullptr, sourceMemRef.getMemorySpace()),
dynamicSizes, rewriter, sizes, strides, totalSize);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, dmaStart1DFunc,
ValueRange{
dest,
source,
totalSize,
});
}
};

struct StartTransferOp2DLowering : ConvertOpToLLVMPattern<StartTransferOp> {

LLVM::LLVMFuncOp dmaStart2DFunc;

StartTransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {}

LogicalResult
matchAndRewrite(StartTransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType sourceMemRef = op.getSource().getType();
MemRefType destMemRef = op.getDest().getType();

// Compute the size of the contiguous inner loop common to both MemRefs and
// "shave" it off the ends of the shapes and strides. The remaining shapes
// and strides are considered our outer dimensions.
FailureOr<size_t> sourceNonContiguous =
getNumNonContiguousOuterDims(sourceMemRef);
FailureOr<size_t> destNonContiguous =
getNumNonContiguousOuterDims(destMemRef);
if (failed(sourceNonContiguous) || failed(destNonContiguous))
return failure();
size_t sharedNonContiguous =
std::max(*sourceNonContiguous, *destNonContiguous);
if (sharedNonContiguous == 0)
return failure();

Value elementSize = rewriter.create<LLVM::ConstantOp>(
op->getLoc(),
rewriter.getI32IntegerAttr(llvm::divideCeil(
op.getSource().getType().getElementTypeBitWidth(), 8)));
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, op->getLoc(), op.getSource());

// Build a loop nest iterating over all outer dimensions - 1 and adjusts the
// source and destination pointers accordingly. The inner-most outer
// dimension is used in the DMA call for the repetition count and strides.
SmallVector<Value> lowerBounds;
SmallVector<Value> upperBounds;
SmallVector<Value> steps;
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
for (size_t index : llvm::seq(sharedNonContiguous - 1)) {
lowerBounds.push_back(zeroIndex);
steps.push_back(oneIndex);
upperBounds.push_back(getValueOrCreateConstantIndexOp(
rewriter, op->getLoc(), sizes[index]));
}

Value contiguousSize;
for (auto index :
llvm::seq<size_t>(sharedNonContiguous, sourceMemRef.getRank())) {
Value dim =
getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), sizes[index]);
if (!contiguousSize) {
contiguousSize = dim;
continue;
}
contiguousSize =
rewriter.create<arith::MulIOp>(op->getLoc(), contiguousSize, dim);
Value innerSize = rewriter.create<LLVM::MulOp>(
op->getLoc(), sourceDescriptor.size(rewriter, op->getLoc(), rank - 1),
elementSize);
if (rank == 1) {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, dmaStart1DFunc,
ValueRange{
dest,
source,
innerSize,
});
} else {
Value sourceStride = rewriter.create<LLVM::MulOp>(
op->getLoc(), sourceDescriptor.stride(rewriter, op->getLoc(), 0),
elementSize);
Value destStride = rewriter.create<LLVM::MulOp>(
op->getLoc(), destDescriptor.stride(rewriter, op->getLoc(), 0),
elementSize);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, dmaStart2DFunc,
ValueRange{
dest,
source,
innerSize,
destStride,
sourceStride,
sourceDescriptor.size(rewriter, op->getLoc(), 0),
});
}
contiguousSize = typeConverter->materializeTargetConversion(
rewriter, op->getLoc(), getIndexType(), contiguousSize);
contiguousSize =
rewriter.create<LLVM::MulOp>(op->getLoc(), contiguousSize, elementSize);

Value completedToken = rewriter.create<CompletedTokenOp>(op->getLoc());

scf::LoopNest loopNest = scf::buildLoopNest(
rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken,
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector {
SmallVector<OpFoldResult> offsets = ivs;
SmallVector<OpFoldResult> subSizes(sharedNonContiguous - 1,
rewriter.getIndexAttr(1));
for (unsigned i : llvm::seq<unsigned>(sharedNonContiguous - 1,
sourceMemRef.getRank())) {
offsets.push_back(rewriter.getIndexAttr(0));
subSizes.push_back(sizes[i]);
}
SmallVector<OpFoldResult> strides(sourceMemRef.getRank(),
rewriter.getIndexAttr(1));

TypedValue<MemRefType> sourceMemRefSlice =
rewriter.create<memref::SubViewOp>(loc, op.getSource(), offsets,
subSizes, strides);
TypedValue<MemRefType> destMemRefSlice =
rewriter.create<memref::SubViewOp>(loc, op.getDest(), offsets,
subSizes, strides);

auto sourceDescriptor =
MemRefDescriptor(typeConverter->materializeTargetConversion(
rewriter, op->getLoc(),
typeConverter->convertType(sourceMemRefSlice.getType()),
sourceMemRefSlice));
auto destDescriptor =
MemRefDescriptor(typeConverter->materializeTargetConversion(
rewriter, op->getLoc(),
typeConverter->convertType(destMemRefSlice.getType()),
destMemRefSlice));

Value sourceAdjusted = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(),
sourceMemRefSlice.getType());
Value destAdjusted = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(),
destMemRefSlice.getType());

Value sourceStride =
sourceDescriptor.stride(builder, loc, sharedNonContiguous - 1);
sourceStride = rewriter.create<LLVM::MulOp>(
op->getLoc(), sourceStride, elementSize);
Value destStride =
destDescriptor.stride(builder, loc, sharedNonContiguous - 1);
destStride = rewriter.create<LLVM::MulOp>(op->getLoc(), destStride,
elementSize);

Value outerLoopSize =
sourceDescriptor.size(builder, loc, sharedNonContiguous - 1);
return {builder
.create<LLVM::CallOp>(loc, dmaStart2DFunc,
ValueRange{
destAdjusted,
sourceAdjusted,
contiguousSize,
destStride,
sourceStride,
outerLoopSize,
})
.getResult()};
});

Type tokenType = typeConverter->convertType(op.getType());
rewriter.replaceOp(
op, typeConverter->materializeTargetConversion(
rewriter, op->getLoc(), tokenType, loopNest.results.front()));
return success();
}
};
Expand Down Expand Up @@ -502,8 +372,8 @@ void quidditch::populateDMAToLLVMConversionPatterns(
patterns.insert<CompletedTokenOpLowering, WaitForTransferOpLowering,
StartZeroMemTransferOpOpLowering, StatOpLowering,
CombineTokensOpLowering>(typeConverter);
patterns.insert<StartTransferOp1DLowering>(dmaStart1D, typeConverter);
patterns.insert<StartTransferOp2DLowering>(dmaStart2D, typeConverter);
patterns.insert<StartTransferOpLowering>(dmaStart1D, dmaStart2D,
typeConverter);
patterns.insert<StartContiguousZeroMemTransferOpOpLowering>(
dmaStart1D, dmaStart2D, typeConverter);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
-name=Transforms --gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
Passes
HDRS
"Passes.h"
"Passes.h.inc"
SRCS
"LegalizeDMAOperations.cpp"
DEPS
::PassesIncGen
Quidditch::Dialect::DMA::IR::DMADialect
Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect
MLIRIR
MLIRAffineDialect
MLIRArithDialect
MLIRSCFDialect
)
Loading

0 comments on commit cbb17ca

Please sign in to comment.