Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move conditionals to target model #1537

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/aie/Dialect/AIE/IR/AIETargetModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class AIETargetModel {
/// Return the size (in bytes) of the local data memory of a core.
virtual uint32_t getLocalMemorySize() const = 0;

/// Return the size (in bits) of the accumulator/cascade.
virtual uint32_t getAccumulatorCascadeSize() const = 0;

/// Return the number of lock objects
virtual uint32_t getNumLocks(int col, int row) const = 0;

Expand Down Expand Up @@ -204,6 +207,16 @@ class AIETargetModel {
// Return true if this is an NPU-based device
// There are several special cases for handling the NPU at the moment.
virtual bool isNPU() const { return false; }

// Return the bit offset of the column within a tile address.
// This is used to compute the control address of a tile from it's column
// location.
virtual uint32_t getColumnShift() const = 0;

// Return the bit offset of the row within a tile address.
// This is used to compute the control address of a tile from it's row
// location.
virtual uint32_t getRowShift() const = 0;
};

class AIE1TargetModel : public AIETargetModel {
Expand Down Expand Up @@ -243,6 +256,7 @@ class AIE1TargetModel : public AIETargetModel {
uint32_t getMemNorthBaseAddress() const override { return 0x00030000; }
uint32_t getMemEastBaseAddress() const override { return 0x00038000; }
uint32_t getLocalMemorySize() const override { return 0x00008000; }
uint32_t getAccumulatorCascadeSize() const override { return 384; }
uint32_t getNumLocks(int col, int row) const override { return 16; }
uint32_t getNumBDs(int col, int row) const override { return 16; }
uint32_t getNumMemTileRows() const override { return 0; }
Expand All @@ -268,6 +282,9 @@ class AIE1TargetModel : public AIETargetModel {
return true;
return false;
}

uint32_t getColumnShift() const override { return 23; }
uint32_t getRowShift() const override { return 18; }
};

class AIE2TargetModel : public AIETargetModel {
Expand Down Expand Up @@ -300,6 +317,7 @@ class AIE2TargetModel : public AIETargetModel {
uint32_t getMemNorthBaseAddress() const override { return 0x00060000; }
uint32_t getMemEastBaseAddress() const override { return 0x00070000; }
uint32_t getLocalMemorySize() const override { return 0x00010000; }
uint32_t getAccumulatorCascadeSize() const override { return 512; }

uint32_t getNumLocks(int col, int row) const override {
return isMemTile(col, row) ? 64 : 16;
Expand All @@ -322,6 +340,9 @@ class AIE2TargetModel : public AIETargetModel {
bool isLegalMemtileConnection(WireBundle srcBundle, int srcChan,
WireBundle dstBundle,
int dstChan) const override;

uint32_t getColumnShift() const override { return 25; }
uint32_t getRowShift() const override { return 20; }
};

class VC1902TargetModel : public AIE1TargetModel {
Expand Down
14 changes: 5 additions & 9 deletions lib/Dialect/AIE/IR/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,15 +966,11 @@ LogicalResult PutCascadeOp::verify() {
Type type = getCascadeValue().getType();
DataLayout dataLayout = DataLayout::closest(*this);
auto bits = dataLayout.getTypeSizeInBits(type);
if (targetModel.getTargetArch() == AIEArch::AIE1) {
if (bits != 384)
return emitOpError("must be a 384-bit type");
} else if (targetModel.getTargetArch() == AIEArch::AIE2) {
if (bits != 512)
return emitOpError("must be a 512-bit type");
} else
return emitOpError("cascade not supported in ")
<< stringifyAIEArch(targetModel.getTargetArch());
auto archbits = targetModel.getAccumulatorCascadeSize();
if (bits != archbits)
return emitOpError("type must match architecture cascade width (")
<< archbits << " bits in "
<< stringifyAIEArch(targetModel.getTargetArch()) << ")";
return success();
}

Expand Down
18 changes: 13 additions & 5 deletions lib/Targets/AIETargetCDODirect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ static_assert(XAIE_OK == 0);
auto ps = std::filesystem::path::preferred_separator;

#define XAIE_BASE_ADDR 0x40000000
#define XAIE_COL_SHIFT 25
#define XAIE_ROW_SHIFT 20
#define XAIE_SHIM_ROW 0
#define XAIE_MEM_TILE_ROW_START 1
#define XAIE_PARTITION_BASE_ADDR 0x0
Expand Down Expand Up @@ -398,11 +396,21 @@ struct AIEControl {
size_t deviceRows = tm.rows();
size_t deviceCols = tm.columns() + partitionStartCol;

// Don't put this in the target model, because it's XAIE specific.
unsigned char devGen;
switch (tm.getTargetArch()) {
case AIEArch::AIE1: // probably unreachable.
devGen = XAIE_DEV_GEN_AIE;
break;
case AIEArch::AIE2:
devGen = XAIE_DEV_GEN_AIEML;
break;
}
configPtr = XAie_Config{
/*AieGen*/ XAIE_DEV_GEN_AIEML,
/*AieGen*/ devGen,
/*BaseAddr*/ XAIE_BASE_ADDR,
/*ColShift*/ XAIE_COL_SHIFT,
/*RowShift*/ XAIE_ROW_SHIFT,
/*ColShift*/ (uint8_t)tm.getColumnShift(),
/*RowShift*/ (uint8_t)tm.getRowShift(),
/*NumRows*/ static_cast<uint8_t>(deviceRows),
/*NumCols*/ static_cast<uint8_t>(deviceCols),
/*ShimRowNum*/ XAIE_SHIM_ROW,
Expand Down
14 changes: 4 additions & 10 deletions lib/Targets/AIETargetXAIEV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,20 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) {
std::string AIE1_device("XAIE_DEV_GEN_AIE");
std::string AIE2_device("XAIE_DEV_GEN_AIEML");
std::string device;
int col_shift = 0;
int row_shift = 0;
switch (arch) {
case AIEArch::AIE1:
device = AIE1_device;
col_shift = 23;
row_shift = 18;
break;
case AIEArch::AIE2:
device = AIE2_device;
col_shift = 25;
row_shift = 20;
break;
}
assert(col_shift);
assert(row_shift);
output << " ctx->AieConfigPtr.AieGen = " << device << ";\n";
output << " ctx->AieConfigPtr.BaseAddr = 0x20000000000;\n";
output << " ctx->AieConfigPtr.ColShift = " << col_shift << ";\n";
output << " ctx->AieConfigPtr.RowShift = " << row_shift << ";\n";
output << " ctx->AieConfigPtr.ColShift = " << targetModel.getColumnShift()
<< ";\n";
output << " ctx->AieConfigPtr.RowShift = " << targetModel.getRowShift()
<< ";\n";
output << " ctx->AieConfigPtr.NumRows = " << targetModel.rows() << ";\n";
output << " ctx->AieConfigPtr.NumCols = " << targetModel.columns() << ";\n";
output << " ctx->AieConfigPtr.ShimRowNum = 0;\n";
Expand Down
2 changes: 1 addition & 1 deletion test/dialect/AIE/badcascade-vc1902.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//===----------------------------------------------------------------------===//

// RUN: not aie-opt %s 2>&1 | FileCheck %s
// CHECK: error{{.*}}'aie.put_cascade' op must be a 384-bit type
// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (384 bits in AIE1)

module @test {
aie.device(xcvc1902) {
Expand Down
2 changes: 1 addition & 1 deletion test/dialect/AIE/badcascade-ve2802.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//===----------------------------------------------------------------------===//

// RUN: not aie-opt %s 2>&1 | FileCheck %s
// CHECK: error{{.*}}'aie.put_cascade' op must be a 512-bit type
// CHECK: error{{.*}}'aie.put_cascade' op type must match architecture cascade width (512 bits in AIE2)

module @test {
aie.device(xcve2802) {
Expand Down
Loading