From 56d1d6d9935019ae2d040baad3fabb7ca0bc11b4 Mon Sep 17 00:00:00 2001 From: Stephen Neuendorffer Date: Mon, 13 May 2024 16:20:23 -0700 Subject: [PATCH] Move conditionals to target model --- include/aie/Dialect/AIE/IR/AIETargetModel.h | 22 ++++++++++++++++++++- lib/Dialect/AIE/IR/AIEDialect.cpp | 13 ++++-------- lib/Targets/AIETargetCDODirect.cpp | 20 +++++++++++++------ lib/Targets/AIETargetXAIEV2.cpp | 12 ++--------- test/dialect/AIE/badcascade-vc1902.mlir | 2 +- test/dialect/AIE/badcascade-ve2802.mlir | 2 +- 6 files changed, 43 insertions(+), 28 deletions(-) diff --git a/include/aie/Dialect/AIE/IR/AIETargetModel.h b/include/aie/Dialect/AIE/IR/AIETargetModel.h index f8a2b8ba34..f994922925 100644 --- a/include/aie/Dialect/AIE/IR/AIETargetModel.h +++ b/include/aie/Dialect/AIE/IR/AIETargetModel.h @@ -165,6 +165,9 @@ class AIETargetModel { /// Return the size (in bytes) of the local data memory of a core. virtual uint32_t getLocalMemorySize() const = 0; + /// Return the size (in bits) of the accumulator/cascade. + virtual uint32_t getAccumulatorCascadeSize() const = 0; + /// Return the number of lock objects virtual uint32_t getNumLocks(int col, int row) const = 0; @@ -204,6 +207,15 @@ class AIETargetModel { // Return true if this is an NPU-based device // There are several special cases for handling the NPU at the moment. virtual bool isNPU() const { return false; } + + // Return the bit offset of the column within a tile address. + // This is used to compute the control address of a tile from it's column location. + virtual uint32_t getColumnShift() const = 0; + + // Return the bit offset of the row within a tile address. + // This is used to compute the control address of a tile from it's row location. + virtual uint32_t getRowShift() const = 0; + }; class AIE1TargetModel : public AIETargetModel { @@ -243,6 +255,7 @@ class AIE1TargetModel : public AIETargetModel { uint32_t getMemNorthBaseAddress() const override { return 0x00030000; } uint32_t getMemEastBaseAddress() const override { return 0x00038000; } uint32_t getLocalMemorySize() const override { return 0x00008000; } + uint32_t getAccumulatorCascadeSize() const override { return 384; } uint32_t getNumLocks(int col, int row) const override { return 16; } uint32_t getNumBDs(int col, int row) const override { return 16; } uint32_t getNumMemTileRows() const override { return 0; } @@ -268,6 +281,9 @@ class AIE1TargetModel : public AIETargetModel { return true; return false; } + + uint32_t getColumnShift() const override { return 23; } + uint32_t getRowShift() const override { return 18; } }; class AIE2TargetModel : public AIETargetModel { @@ -300,7 +316,8 @@ class AIE2TargetModel : public AIETargetModel { uint32_t getMemNorthBaseAddress() const override { return 0x00060000; } uint32_t getMemEastBaseAddress() const override { return 0x00070000; } uint32_t getLocalMemorySize() const override { return 0x00010000; } - + uint32_t getAccumulatorCascadeSize() const override { return 512; } + uint32_t getNumLocks(int col, int row) const override { return isMemTile(col, row) ? 64 : 16; } @@ -322,6 +339,9 @@ class AIE2TargetModel : public AIETargetModel { bool isLegalMemtileConnection(WireBundle srcBundle, int srcChan, WireBundle dstBundle, int dstChan) const override; + + uint32_t getColumnShift() const override { return 25; } + uint32_t getRowShift() const override { return 20; } }; class VC1902TargetModel : public AIE1TargetModel { diff --git a/lib/Dialect/AIE/IR/AIEDialect.cpp b/lib/Dialect/AIE/IR/AIEDialect.cpp index c5eb1fa62e..40fe08c72f 100644 --- a/lib/Dialect/AIE/IR/AIEDialect.cpp +++ b/lib/Dialect/AIE/IR/AIEDialect.cpp @@ -966,15 +966,10 @@ LogicalResult PutCascadeOp::verify() { Type type = getCascadeValue().getType(); DataLayout dataLayout = DataLayout::closest(*this); auto bits = dataLayout.getTypeSizeInBits(type); - if (targetModel.getTargetArch() == AIEArch::AIE1) { - if (bits != 384) - return emitOpError("must be a 384-bit type"); - } else if (targetModel.getTargetArch() == AIEArch::AIE2) { - if (bits != 512) - return emitOpError("must be a 512-bit type"); - } else - return emitOpError("cascade not supported in ") - << stringifyAIEArch(targetModel.getTargetArch()); + auto archbits = targetModel.getAccumulatorCascadeSize(); + if (bits != archbits) + return emitOpError("type must match architecture cascade width (") << + archbits << " bits in " << stringifyAIEArch(targetModel.getTargetArch()) << ")"; return success(); } diff --git a/lib/Targets/AIETargetCDODirect.cpp b/lib/Targets/AIETargetCDODirect.cpp index 6ac60e6473..ed60b70004 100644 --- a/lib/Targets/AIETargetCDODirect.cpp +++ b/lib/Targets/AIETargetCDODirect.cpp @@ -179,8 +179,6 @@ static_assert(XAIE_OK == 0); auto ps = std::filesystem::path::preferred_separator; #define XAIE_BASE_ADDR 0x40000000 -#define XAIE_COL_SHIFT 25 -#define XAIE_ROW_SHIFT 20 #define XAIE_SHIM_ROW 0 #define XAIE_MEM_TILE_ROW_START 1 #define XAIE_PARTITION_BASE_ADDR 0x0 @@ -398,11 +396,21 @@ struct AIEControl { size_t deviceRows = tm.rows(); size_t deviceCols = tm.columns() + partitionStartCol; - configPtr = XAie_Config{ - /*AieGen*/ XAIE_DEV_GEN_AIEML, + // Don't put this in the target model, because it's XAIE specific. + unsigned char devGen; + switch(tm.getTargetArch()) { + case AIEArch::AIE1: // probably unreachable. + devGen = XAIE_DEV_GEN_AIE; + break; + case AIEArch::AIE2: + devGen = XAIE_DEV_GEN_AIEML; + break; + } + configPtr = XAie_Config { + /*AieGen*/ devGen, /*BaseAddr*/ XAIE_BASE_ADDR, - /*ColShift*/ XAIE_COL_SHIFT, - /*RowShift*/ XAIE_ROW_SHIFT, + /*ColShift*/ (uint8_t)tm.getColumnShift(), + /*RowShift*/ (uint8_t)tm.getRowShift(), /*NumRows*/ static_cast(deviceRows), /*NumCols*/ static_cast(deviceCols), /*ShimRowNum*/ XAIE_SHIM_ROW, diff --git a/lib/Targets/AIETargetXAIEV2.cpp b/lib/Targets/AIETargetXAIEV2.cpp index a07dfc69d4..9c7150ac72 100644 --- a/lib/Targets/AIETargetXAIEV2.cpp +++ b/lib/Targets/AIETargetXAIEV2.cpp @@ -317,26 +317,18 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { std::string AIE1_device("XAIE_DEV_GEN_AIE"); std::string AIE2_device("XAIE_DEV_GEN_AIEML"); std::string device; - int col_shift = 0; - int row_shift = 0; switch (arch) { case AIEArch::AIE1: device = AIE1_device; - col_shift = 23; - row_shift = 18; break; case AIEArch::AIE2: device = AIE2_device; - col_shift = 25; - row_shift = 20; break; } - assert(col_shift); - assert(row_shift); output << " ctx->AieConfigPtr.AieGen = " << device << ";\n"; output << " ctx->AieConfigPtr.BaseAddr = 0x20000000000;\n"; - output << " ctx->AieConfigPtr.ColShift = " << col_shift << ";\n"; - output << " ctx->AieConfigPtr.RowShift = " << row_shift << ";\n"; + output << " ctx->AieConfigPtr.ColShift = " << targetModel.getColumnShift() << ";\n"; + output << " ctx->AieConfigPtr.RowShift = " << targetModel.getRowShift() << ";\n"; output << " ctx->AieConfigPtr.NumRows = " << targetModel.rows() << ";\n"; output << " ctx->AieConfigPtr.NumCols = " << targetModel.columns() << ";\n"; output << " ctx->AieConfigPtr.ShimRowNum = 0;\n"; diff --git a/test/dialect/AIE/badcascade-vc1902.mlir b/test/dialect/AIE/badcascade-vc1902.mlir index 0d10150a45..04eab38b17 100644 --- a/test/dialect/AIE/badcascade-vc1902.mlir +++ b/test/dialect/AIE/badcascade-vc1902.mlir @@ -9,7 +9,7 @@ //===----------------------------------------------------------------------===// // RUN: not aie-opt %s 2>&1 | FileCheck %s -// CHECK: error{{.*}}'aie.put_cascade' op must be a 384-bit type +// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (384 bits in AIE1) module @test { aie.device(xcvc1902) { diff --git a/test/dialect/AIE/badcascade-ve2802.mlir b/test/dialect/AIE/badcascade-ve2802.mlir index b7e9022104..23f16ad885 100644 --- a/test/dialect/AIE/badcascade-ve2802.mlir +++ b/test/dialect/AIE/badcascade-ve2802.mlir @@ -9,7 +9,7 @@ //===----------------------------------------------------------------------===// // RUN: not aie-opt %s 2>&1 | FileCheck %s -// CHECK: error{{.*}}'aie.put_cascade' op must be a 512-bit type +// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (512 bits in AIE2) module @test { aie.device(xcve2802) {