Skip to content

Commit

Permalink
Move conditionals to target model
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenneuendorffer committed Jun 5, 2024
1 parent 1fe2cd3 commit 56d1d6d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 28 deletions.
22 changes: 21 additions & 1 deletion include/aie/Dialect/AIE/IR/AIETargetModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class AIETargetModel {
/// Return the size (in bytes) of the local data memory of a core.
virtual uint32_t getLocalMemorySize() const = 0;

/// Return the size (in bits) of the accumulator/cascade.
virtual uint32_t getAccumulatorCascadeSize() const = 0;

/// Return the number of lock objects
virtual uint32_t getNumLocks(int col, int row) const = 0;

Expand Down Expand Up @@ -204,6 +207,15 @@ class AIETargetModel {
// Return true if this is an NPU-based device
// There are several special cases for handling the NPU at the moment.
virtual bool isNPU() const { return false; }

// Return the bit offset of the column within a tile address.
// This is used to compute the control address of a tile from it's column location.
virtual uint32_t getColumnShift() const = 0;

// Return the bit offset of the row within a tile address.
// This is used to compute the control address of a tile from it's row location.
virtual uint32_t getRowShift() const = 0;

};

class AIE1TargetModel : public AIETargetModel {
Expand Down Expand Up @@ -243,6 +255,7 @@ class AIE1TargetModel : public AIETargetModel {
uint32_t getMemNorthBaseAddress() const override { return 0x00030000; }
uint32_t getMemEastBaseAddress() const override { return 0x00038000; }
uint32_t getLocalMemorySize() const override { return 0x00008000; }
uint32_t getAccumulatorCascadeSize() const override { return 384; }
uint32_t getNumLocks(int col, int row) const override { return 16; }
uint32_t getNumBDs(int col, int row) const override { return 16; }
uint32_t getNumMemTileRows() const override { return 0; }
Expand All @@ -268,6 +281,9 @@ class AIE1TargetModel : public AIETargetModel {
return true;
return false;
}

uint32_t getColumnShift() const override { return 23; }
uint32_t getRowShift() const override { return 18; }
};

class AIE2TargetModel : public AIETargetModel {
Expand Down Expand Up @@ -300,7 +316,8 @@ class AIE2TargetModel : public AIETargetModel {
uint32_t getMemNorthBaseAddress() const override { return 0x00060000; }
uint32_t getMemEastBaseAddress() const override { return 0x00070000; }
uint32_t getLocalMemorySize() const override { return 0x00010000; }

uint32_t getAccumulatorCascadeSize() const override { return 512; }

uint32_t getNumLocks(int col, int row) const override {
return isMemTile(col, row) ? 64 : 16;
}
Expand All @@ -322,6 +339,9 @@ class AIE2TargetModel : public AIETargetModel {
bool isLegalMemtileConnection(WireBundle srcBundle, int srcChan,
WireBundle dstBundle,
int dstChan) const override;

uint32_t getColumnShift() const override { return 25; }
uint32_t getRowShift() const override { return 20; }
};

class VC1902TargetModel : public AIE1TargetModel {
Expand Down
13 changes: 4 additions & 9 deletions lib/Dialect/AIE/IR/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,15 +966,10 @@ LogicalResult PutCascadeOp::verify() {
Type type = getCascadeValue().getType();
DataLayout dataLayout = DataLayout::closest(*this);
auto bits = dataLayout.getTypeSizeInBits(type);
if (targetModel.getTargetArch() == AIEArch::AIE1) {
if (bits != 384)
return emitOpError("must be a 384-bit type");
} else if (targetModel.getTargetArch() == AIEArch::AIE2) {
if (bits != 512)
return emitOpError("must be a 512-bit type");
} else
return emitOpError("cascade not supported in ")
<< stringifyAIEArch(targetModel.getTargetArch());
auto archbits = targetModel.getAccumulatorCascadeSize();
if (bits != archbits)
return emitOpError("type must match architecture cascade width (") <<
archbits << " bits in " << stringifyAIEArch(targetModel.getTargetArch()) << ")";
return success();
}

Expand Down
20 changes: 14 additions & 6 deletions lib/Targets/AIETargetCDODirect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ static_assert(XAIE_OK == 0);
auto ps = std::filesystem::path::preferred_separator;

#define XAIE_BASE_ADDR 0x40000000
#define XAIE_COL_SHIFT 25
#define XAIE_ROW_SHIFT 20
#define XAIE_SHIM_ROW 0
#define XAIE_MEM_TILE_ROW_START 1
#define XAIE_PARTITION_BASE_ADDR 0x0
Expand Down Expand Up @@ -398,11 +396,21 @@ struct AIEControl {
size_t deviceRows = tm.rows();
size_t deviceCols = tm.columns() + partitionStartCol;

configPtr = XAie_Config{
/*AieGen*/ XAIE_DEV_GEN_AIEML,
// Don't put this in the target model, because it's XAIE specific.
unsigned char devGen;
switch(tm.getTargetArch()) {
case AIEArch::AIE1: // probably unreachable.
devGen = XAIE_DEV_GEN_AIE;
break;
case AIEArch::AIE2:
devGen = XAIE_DEV_GEN_AIEML;
break;
}
configPtr = XAie_Config {
/*AieGen*/ devGen,
/*BaseAddr*/ XAIE_BASE_ADDR,
/*ColShift*/ XAIE_COL_SHIFT,
/*RowShift*/ XAIE_ROW_SHIFT,
/*ColShift*/ (uint8_t)tm.getColumnShift(),
/*RowShift*/ (uint8_t)tm.getRowShift(),
/*NumRows*/ static_cast<uint8_t>(deviceRows),
/*NumCols*/ static_cast<uint8_t>(deviceCols),
/*ShimRowNum*/ XAIE_SHIM_ROW,
Expand Down
12 changes: 2 additions & 10 deletions lib/Targets/AIETargetXAIEV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,18 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) {
std::string AIE1_device("XAIE_DEV_GEN_AIE");
std::string AIE2_device("XAIE_DEV_GEN_AIEML");
std::string device;
int col_shift = 0;
int row_shift = 0;
switch (arch) {
case AIEArch::AIE1:
device = AIE1_device;
col_shift = 23;
row_shift = 18;
break;
case AIEArch::AIE2:
device = AIE2_device;
col_shift = 25;
row_shift = 20;
break;
}
assert(col_shift);
assert(row_shift);
output << " ctx->AieConfigPtr.AieGen = " << device << ";\n";
output << " ctx->AieConfigPtr.BaseAddr = 0x20000000000;\n";
output << " ctx->AieConfigPtr.ColShift = " << col_shift << ";\n";
output << " ctx->AieConfigPtr.RowShift = " << row_shift << ";\n";
output << " ctx->AieConfigPtr.ColShift = " << targetModel.getColumnShift() << ";\n";
output << " ctx->AieConfigPtr.RowShift = " << targetModel.getRowShift() << ";\n";
output << " ctx->AieConfigPtr.NumRows = " << targetModel.rows() << ";\n";
output << " ctx->AieConfigPtr.NumCols = " << targetModel.columns() << ";\n";
output << " ctx->AieConfigPtr.ShimRowNum = 0;\n";
Expand Down
2 changes: 1 addition & 1 deletion test/dialect/AIE/badcascade-vc1902.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//===----------------------------------------------------------------------===//

// RUN: not aie-opt %s 2>&1 | FileCheck %s
// CHECK: error{{.*}}'aie.put_cascade' op must be a 384-bit type
// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (384 bits in AIE1)

module @test {
aie.device(xcvc1902) {
Expand Down
2 changes: 1 addition & 1 deletion test/dialect/AIE/badcascade-ve2802.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//===----------------------------------------------------------------------===//

// RUN: not aie-opt %s 2>&1 | FileCheck %s
// CHECK: error{{.*}}'aie.put_cascade' op must be a 512-bit type
// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (512 bits in AIE2)

module @test {
aie.device(xcve2802) {
Expand Down

0 comments on commit 56d1d6d

Please sign in to comment.