Skip to content

Commit

Permalink
aiex.configure_task: Allow referencing shim_dma_allocation (Xilin…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrej authored Aug 16, 2024
1 parent 59153bf commit a9525db
Show file tree
Hide file tree
Showing 18 changed files with 533 additions and 9 deletions.
6 changes: 6 additions & 0 deletions include/aie/Dialect/AIE/IR/AIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def AIE_TileOp: AIE_Op<"tile", [
return coreOp;
return nullptr;
}

static AIE::TileOp getOrCreate(mlir::OpBuilder builder, AIE::DeviceOp device, int col, int row);
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1570,6 +1572,10 @@ def AIE_ShimDMAAllocationOp : AIE_Op<"shim_dma_allocation", [HasParent<"DeviceOp
let assemblyFormat = [{
$sym_name `(` $channel_dir `,` $channel_index `,` $col `)` attr-dict
}];

let extraClassDeclaration = [{
static ::xilinx::AIE::ShimDMAAllocationOp getForSymbol(::xilinx::AIE::DeviceOp device, ::llvm::StringRef symbol);
}];
}

def AIE_ObjectFifoCreateOp: AIE_Op<"objectfifo", [HasParent<"DeviceOp">, Symbol]> {
Expand Down
33 changes: 32 additions & 1 deletion include/aie/Dialect/AIEX/IR/AIEX.td
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,20 @@ def AIE_DMAConfigureTaskOp : AIEX_Op<"dma_configure_task", [HasParent<"RuntimeSe
let hasCanonicalizeMethod = 1;
}

def AIE_DMAConfigureTaskForOp : AIEX_Op<"dma_configure_task_for", [HasParent<"RuntimeSequenceOp">]>, Results<(outs Index:$result)> {
let summary = "As dma_configure_task, but specify tile, direction and channel by reference to a Shim DMA allocation op";

let arguments = (
ins FlatSymbolRefAttr:$alloc,
DefaultValuedOptionalAttr<BoolAttr, "false">:$issue_token,
DefaultValuedOptionalAttr<I32Attr, "0">:$repeat_count
);

let regions = (region AnyRegion:$body);

let assemblyFormat = [{ $alloc regions attr-dict }];
}

def AIE_DMAFreeTaskOp : AIEX_Op<"dma_free_task", [HasParent<"RuntimeSequenceOp">]> {
let summary = "Free all Buffer Descriptor IDs Associated with the Given Task";
let description = [{
Expand Down Expand Up @@ -944,7 +958,6 @@ def AIE_DMAAwaitTaskOp : AIEX_Op<"dma_await_task", [HasParent<"RuntimeSequenceOp
}];
}


def AIE_DMAStartBdChainOp: AIEX_Op<"dma_start_bd_chain", [HasParent<"RuntimeSequenceOp">, TileElement]>,
Results<(outs Index:$result)>
{
Expand Down Expand Up @@ -984,4 +997,22 @@ def AIE_DMAStartBdChainOp: AIEX_Op<"dma_start_bd_chain", [HasParent<"RuntimeSequ

}

def AIE_DMAStartBdChainForOp: AIEX_Op<"dma_start_bd_chain_for", [HasParent<"RuntimeSequenceOp">]>,
Results<(outs Index:$result)>
{
let summary = "As dma_start_bd_chain, but specify tile, direction and channel by reference to a Shim DMA allocation op";

let arguments = (
ins FlatSymbolRefAttr:$symbol,
Variadic<AnyType>:$args,
FlatSymbolRefAttr:$alloc,
DefaultValuedOptionalAttr<BoolAttr, "false">:$issue_token,
DefaultValuedOptionalAttr<I32Attr, "0">:$repeat_count
);

let assemblyFormat = [{
$symbol `(` $args `)` `:` `(` type($args) `)` ` ` `for` ` ` $alloc attr-dict
}];
}

#endif // AIEX_OPS
2 changes: 2 additions & 0 deletions include/aie/Dialect/AIEX/Transforms/AIEXPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ std::unique_ptr<mlir::OperationPass<AIE::DeviceOp>>
createAIEAssignRuntimeSequenceBDIDsPass();
std::unique_ptr<mlir::OperationPass<AIE::DeviceOp>>
createAIEDMATasksToNPUPass();
std::unique_ptr<mlir::OperationPass<AIE::DeviceOp>>
createAIESubstituteShimDMAAllocationsPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
10 changes: 10 additions & 0 deletions include/aie/Dialect/AIEX/Transforms/AIEXPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,14 @@ def AIEDMATasksToNPU : Pass<"aie-dma-tasks-to-npu", "AIE::DeviceOp"> {
];
}

def AIESubstituteShimDMAAllocations : Pass<"aie-substitute-shim-dma-allocations", "AIE::DeviceOp"> {
let summary = "Replace symbolic references to `aie.shim_dma_allocation` ops with their `(tile, direction, channel)` triple";

let constructor = "xilinx::AIEX::createAIESubstituteShimDMAAllocationsPass()";
let dependentDialects = [
"xilinx::AIE::AIEDialect",
"xilinx::AIEX::AIEXDialect",
];
}

#endif
37 changes: 37 additions & 0 deletions lib/Dialect/AIE/IR/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,27 @@ bool isLegalTileConnection(TileOp tile, const AIETargetModel &targetModel,
tile.colIndex(), tile.rowIndex(), srcBundle, srcChan, dstBundle, dstChan);
}

TileOp TileOp::getOrCreate(mlir::OpBuilder builder, DeviceOp device, int col,
int row) {
TileOp tile = nullptr;
// Find matching predefined tile at device top level, ...
for (auto t : device.getOps<AIE::TileOp>()) {
if (t.getRow() == row && t.getCol() == col) {
tile = t;
break;
}
}
// ... or if undefined, create a new tile op
if (!tile) {
OpBuilder::InsertionGuard guard(builder);
mlir::Block &device_start_block = *device.getBodyRegion().begin();
builder.setInsertionPointToStart(&device_start_block);
tile = builder.create<TileOp>(builder.getUnknownLoc(),
builder.getIndexType(), col, 0);
}
return tile;
}

//===----------------------------------------------------------------------===//
// ShimSwitchboxOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2249,6 +2270,22 @@ LogicalResult BDChainOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ShimDMAAllocationOp
//===----------------------------------------------------------------------===//

ShimDMAAllocationOp ShimDMAAllocationOp::getForSymbol(DeviceOp device,
llvm::StringRef symbol) {
auto alloc_ops = device.getOps<ShimDMAAllocationOp>();
for (auto it = alloc_ops.begin(); it != alloc_ops.end(); ++it) {
AIE::ShimDMAAllocationOp a = *it;
if (a.getSymName() == symbol) {
return a;
}
}
return nullptr;
}

// Include implementations for custom attributes
#define GET_ATTRDEF_CLASSES
#include "aie/Dialect/AIE/IR/AIEAttrs.cpp.inc"
5 changes: 5 additions & 0 deletions lib/Dialect/AIEX/Transforms/AIEAssignRuntimeSequenceBDIDs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ struct AIEAssignRuntimeSequenceBDIDsPass
<< "Lower this operation first using the "
"--aie-materialize-bd-chains pass.";
}
if (llvm::isa<DMAConfigureTaskForOp>(task_op)) {
err.attachNote(task_op->getLoc())
<< "Lower this operation first using the "
"--aie-substitute-shim-dma-allocations pass.";
}
return err;
}

Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/AIEX/Transforms/AIEDMATasksToNPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand All @@ -30,6 +31,11 @@ struct DMAStartTaskOpPattern : OpConversionPattern<DMAStartTaskOp> {
matchAndRewrite(DMAStartTaskOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
DMAConfigureTaskOp task_op = op.getTaskOp();
if (!task_op) {
// Cannot rewrite this; probably points to a DMAStartTaskForOp,
// which we will lower once it has been rewritten into a DMAStartTaskOp.
return failure();
}
AIE::TileOp tile = task_op.getTileOp();
std::optional<uint32_t> first_bd_id = task_op.getFirstBdId();
if (!first_bd_id) {
Expand All @@ -54,6 +60,9 @@ struct DMAAwaitTaskOpPattern : OpConversionPattern<DMAAwaitTaskOp> {
matchAndRewrite(DMAAwaitTaskOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
DMAConfigureTaskOp task_op = op.getTaskOp();
if (!task_op) {
return failure();
}
if (!task_op.getIssueToken()) {
auto err = op.emitOpError(
"Cannot wait on a BD that is not configured to issue a token.");
Expand Down
56 changes: 48 additions & 8 deletions lib/Dialect/AIEX/Transforms/AIEMaterializeBDChains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,39 @@ using namespace mlir;
using namespace xilinx;
using namespace xilinx::AIEX;

struct DMAStartBdChainForOpPattern : RewritePattern {

DMAStartBdChainForOpPattern(MLIRContext *ctx)
: RewritePattern(DMAStartBdChainForOp::getOperationName(),
PatternBenefit(1), ctx) {}

LogicalResult matchAndRewrite(Operation *op_any,
PatternRewriter &rewriter) const override {
DMAStartBdChainForOp op = llvm::dyn_cast<DMAStartBdChainForOp>(op_any);
if (!op) {
return failure();
}
AIE::DeviceOp device = op->getParentOfType<AIE::DeviceOp>();

AIE::ShimDMAAllocationOp alloc_op =
AIE::ShimDMAAllocationOp::getForSymbol(device, op.getAlloc());
if (!alloc_op) {
return op.emitOpError("no shim DMA allocation found for symbol");
}

const int col = alloc_op.getCol();
AIE::TileOp tile = AIE::TileOp::getOrCreate(rewriter, device, col, 0);
DMAStartBdChainOp new_op = rewriter.create<DMAStartBdChainOp>(
op.getLoc(), rewriter.getIndexType(), op.getSymbol(), op.getArgs(),
tile.getResult(), alloc_op.getChannelDir(),
(int32_t)alloc_op.getChannelIndex(), op.getIssueToken(),
op.getRepeatCount());
rewriter.replaceAllUsesWith(op.getResult(), new_op.getResult());
rewriter.eraseOp(op);
return success();
}
};

struct DMAInlineBDChainPattern : RewritePattern {

DMAInlineBDChainPattern(MLIRContext *ctx)
Expand Down Expand Up @@ -85,17 +118,24 @@ struct AIEMaterializeBDChainsPass
void runOnOperation() override {
MLIRContext *ctx = &getContext();
AIE::DeviceOp device = getOperation();

ConversionTarget target(getContext());
target.addLegalDialect<AIEXDialect>();
target.addIllegalOp<DMAStartBdChainOp>();
RewritePatternSet patterns(ctx);
patterns.insert<DMAInlineBDChainPattern>(ctx);
GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
rewriter_config.enableRegionSimplification =
GreedySimplifyRegionLevel::Disabled;
DMAConfigureTaskOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(device, std::move(patterns),

RewritePatternSet patterns_0(ctx);
patterns_0.insert<DMAStartBdChainForOpPattern>(ctx);
DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_0, ctx);
if (failed(applyPatternsAndFoldGreedily(device, std::move(patterns_0),
rewriter_config))) {
signalPassFailure();
}

RewritePatternSet patterns_1(ctx);
patterns_1.insert<DMAInlineBDChainPattern>(ctx);
rewriter_config.enableRegionSimplification =
GreedySimplifyRegionLevel::Disabled;
DMAConfigureTaskOp::getCanonicalizationPatterns(patterns_1, ctx);
if (failed(applyPatternsAndFoldGreedily(device, std::move(patterns_1),
rewriter_config))) {
signalPassFailure();
}
Expand Down
87 changes: 87 additions & 0 deletions lib/Dialect/AIEX/Transforms/AIESubstituteShimDMAAllocations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===- AIESubstituteShimDMAAllocations.cpp -----------------------*- C++
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//

#include <algorithm>
#include <iterator>

#include "aie/Dialect/AIE/IR/AIEDialect.h"
#include "aie/Dialect/AIEX/IR/AIEXDialect.h"
#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace xilinx;
using namespace xilinx::AIEX;

struct DMAConfigureTaskForOpPattern : RewritePattern {

DMAConfigureTaskForOpPattern(MLIRContext *ctx)
: RewritePattern(DMAConfigureTaskForOp::getOperationName(),
PatternBenefit(1), ctx) {}

LogicalResult matchAndRewrite(Operation *op_any,
PatternRewriter &rewriter) const override {
DMAConfigureTaskForOp op = llvm::dyn_cast<DMAConfigureTaskForOp>(op_any);
if (!op) {
return failure();
}
AIE::DeviceOp device = op->getParentOfType<AIE::DeviceOp>();

AIE::ShimDMAAllocationOp alloc_op =
AIE::ShimDMAAllocationOp::getForSymbol(device, op.getAlloc());
if (!alloc_op) {
return op.emitOpError("no shim DMA allocation found for symbol");
}

const int col = alloc_op.getCol();
AIE::TileOp tile = AIE::TileOp::getOrCreate(rewriter, device, col, 0);
DMAConfigureTaskOp new_op = rewriter.create<DMAConfigureTaskOp>(
op.getLoc(), rewriter.getIndexType(), tile.getResult(),
alloc_op.getChannelDir(), (int32_t)alloc_op.getChannelIndex(),
op.getIssueToken(), op.getRepeatCount());
rewriter.replaceAllUsesWith(op.getResult(), new_op.getResult());
rewriter.inlineRegionBefore(op.getBody(), new_op.getBody(),
new_op.getBody().begin());
rewriter.eraseOp(op);
return success();
}
};

struct AIESubstituteShimDMAAllocationsPass
: AIESubstituteShimDMAAllocationsBase<AIESubstituteShimDMAAllocationsPass> {

void runOnOperation() override {
AIE::DeviceOp device = getOperation();

// Convert DMAConfigureTaskForOps that reference shim DMA allocations
// to regular DMAConfigureTaskOps
ConversionTarget target(getContext());
target.addLegalDialect<AIEXDialect>();
target.addIllegalOp<DMAConfigureTaskForOp>();
RewritePatternSet patterns(&getContext());
patterns.insert<DMAConfigureTaskForOpPattern>(&getContext());

GreedyRewriteConfig rewriter_config = GreedyRewriteConfig();
if (failed(applyPatternsAndFoldGreedily(device, std::move(patterns),
rewriter_config))) {
signalPassFailure();
}
}
};

std::unique_ptr<OperationPass<AIE::DeviceOp>>
AIEX::createAIESubstituteShimDMAAllocationsPass() {
return std::make_unique<AIESubstituteShimDMAAllocationsPass>();
}
1 change: 1 addition & 0 deletions lib/Dialect/AIEX/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_dialect_library(AIEXTransforms
AIEMaterializeBDChains.cpp
AIEAssignRuntimeSequenceBDIDs.cpp
AIEDMATasksToNPU.cpp
AIESubstituteShimDMAAllocations.cpp
ADDITIONAL_HEADER_DIRS
${AIE_BINARY_DIR}/include

Expand Down
1 change: 1 addition & 0 deletions python/compiler/aiecc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"aie.device",
Pipeline()
.add_pass("aie-materialize-bd-chains")
.add_pass("aie-substitute-shim-dma-allocations")
.add_pass("aie-assign-runtime-sequence-bd-ids")
.add_pass("aie-dma-tasks-to-npu")
.add_pass("aie-dma-to-npu"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2024 AMD Inc.

// REQUIRES: ryzen_ai
//
// RUN: aie-opt --aie-substitute-shim-dma-allocations --aie-assign-runtime-sequence-bd-ids %s | FileCheck %s

// This test ensures that all available 16 buffer descriptors are used.

module {
aie.device(npu1_4col) {
%tile_0_0 = aie.tile(0, 0)
%tile_0_2 = aie.tile(0, 2)

aie.shim_dma_allocation @alloc0 (MM2S, 0, 0)

aiex.runtime_sequence(%arg0: memref<8xi16>) {
// Allocate all available BD IDs
%t1 = aiex.dma_configure_task_for @alloc0 {
// CHECK: aie.dma_bd(%arg0 : memref<8xi16>, 0, 8) {bd_id = 0 : i32}
aie.dma_bd(%arg0 : memref<8xi16>, 0, 8)
aie.end
}
aiex.dma_start_task(%t1)
aiex.dma_await_task(%t1)
}
}
}

Loading

0 comments on commit a9525db

Please sign in to comment.