From dcb508360e3a7a5fcca66de20e6a24612277f05e Mon Sep 17 00:00:00 2001 From: Andra Bisca Date: Wed, 20 Sep 2023 12:11:08 +0200 Subject: [PATCH] Compute memref sizes by multiplying all shape sizes. (#641) Co-authored-by: abisca --- lib/Dialect/AIE/IR/AIEDialect.cpp | 18 ++++++++++++++---- .../AIEObjectFifoStatefulTransform.cpp | 4 ++-- .../link_test_distribute.mlir | 10 +++++----- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/AIE/IR/AIEDialect.cpp b/lib/Dialect/AIE/IR/AIEDialect.cpp index 3b66b8d2d1..f15c678720 100644 --- a/lib/Dialect/AIE/IR/AIEDialect.cpp +++ b/lib/Dialect/AIE/IR/AIEDialect.cpp @@ -452,13 +452,18 @@ LogicalResult xilinx::AIE::ObjectFifoLinkOp::verify() { AIEObjectFifoType fifoType = fifoOut.getElemType().cast(); MemRefType elemType = fifoType.getElementType().cast(); - int outputSize = (int)elemType.getShape()[0]; + int64_t outputSize = 1; + for (auto dim : elemType.getShape()) + outputSize *= dim; int inputSize = 0; for (auto fifoIn : getInputObjectFifos()) { AIEObjectFifoType fifo = fifoIn.getElemType().cast(); MemRefType elemType = fifo.getElementType().cast(); - inputSize += (int)elemType.getShape()[0]; + int64_t nextInputSize = 1; + for (auto dim : elemType.getShape()) + nextInputSize *= dim; + inputSize += nextInputSize; } if (inputSize != outputSize) return emitError("Total size of input objFifos in ObjectFifoLinkOp must " @@ -468,13 +473,18 @@ LogicalResult xilinx::AIE::ObjectFifoLinkOp::verify() { ObjectFifoCreateOp fifoIn = getInputObjectFifos()[0]; AIEObjectFifoType fifoType = fifoIn.getElemType().cast(); MemRefType elemType = fifoType.getElementType().cast(); - int inputSize = (int)elemType.getShape()[0]; + int64_t inputSize = 1; + for (auto dim : elemType.getShape()) + inputSize *= dim; int outputSize = 0; for (auto fifoOut : getOutputObjectFifos()) { AIEObjectFifoType fifo = fifoOut.getElemType().cast(); MemRefType elemType = fifo.getElementType().cast(); - outputSize += (int)elemType.getShape()[0]; + int64_t nextOutputSize = 1; + for (auto dim : elemType.getShape()) + nextOutputSize *= dim; + outputSize += nextOutputSize; } if (outputSize != inputSize) return emitError("Total size of output objFifos in ObjectFifoLinkOp must " diff --git a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp index b5586d2709..cafbc83763 100644 --- a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp +++ b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp @@ -647,7 +647,7 @@ struct AIEObjectFifoStatefulTransformPass if (fifoIn.name() == op.name()) break; else - extraOffset += (int)elemType.getShape()[0]; + extraOffset += (int)getMemrefTypeSize(elemType); } } } else if (linkOp->isDistribute()) { @@ -665,7 +665,7 @@ struct AIEObjectFifoStatefulTransformPass if (fifoOut.name() == op.name()) break; else - extraOffset += (int)elemType.getShape()[0]; + extraOffset += (int)getMemrefTypeSize(elemType); } } } else { diff --git a/test/objectFifo-stateful-transform/link_test_distribute.mlir b/test/objectFifo-stateful-transform/link_test_distribute.mlir index 4ab9d6905e..e6cdffde0e 100644 --- a/test/objectFifo-stateful-transform/link_test_distribute.mlir +++ b/test/objectFifo-stateful-transform/link_test_distribute.mlir @@ -27,8 +27,8 @@ // CHECK: %9 = AIE.lock(%1, 0) {init = 6 : i32, sym_name = "link1_cons_prod_lock"} // CHECK: %10 = AIE.lock(%1, 1) {init = 0 : i32, sym_name = "link1_cons_cons_lock"} // CHECK: AIE.flow(%1, DMA : 0, %2, DMA : 0) -// CHECK: %11 = AIE.buffer(%2) {sym_name = "link2_cons_buff_0"} : memref<16xi32> -// CHECK: %12 = AIE.buffer(%2) {sym_name = "link2_cons_buff_1"} : memref<16xi32> +// CHECK: %11 = AIE.buffer(%2) {sym_name = "link2_cons_buff_0"} : memref<4x4xi32> +// CHECK: %12 = AIE.buffer(%2) {sym_name = "link2_cons_buff_1"} : memref<4x4xi32> // CHECK: %13 = AIE.lock(%2, 0) {init = 2 : i32, sym_name = "link2_cons_prod_lock"} // CHECK: %14 = AIE.lock(%2, 1) {init = 0 : i32, sym_name = "link2_cons_cons_lock"} // CHECK: AIE.flow(%1, DMA : 1, %3, DMA : 0) @@ -108,12 +108,12 @@ // CHECK: %29 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb3) // CHECK: ^bb1: // 2 preds: ^bb0, ^bb2 // CHECK: AIE.useLock(%13, AcquireGreaterEqual, 1) -// CHECK: AIE.dmaBd(<%11 : memref<16xi32>, 0, 16>, 0) +// CHECK: AIE.dmaBd(<%11 : memref<4x4xi32>, 0, 16>, 0) // CHECK: AIE.useLock(%14, Release, 1) // CHECK: AIE.nextBd ^bb2 // CHECK: ^bb2: // pred: ^bb1 // CHECK: AIE.useLock(%13, AcquireGreaterEqual, 1) -// CHECK: AIE.dmaBd(<%12 : memref<16xi32>, 0, 16>, 0) +// CHECK: AIE.dmaBd(<%12 : memref<4x4xi32>, 0, 16>, 0) // CHECK: AIE.useLock(%14, Release, 1) // CHECK: AIE.nextBd ^bb1 // CHECK: ^bb3: // pred: ^bb0 @@ -161,7 +161,7 @@ module @link_distribute { %tile33 = AIE.tile(3, 3) AIE.objectFifo @link1 (%tile20, {%tile21}, 2 : i32) : !AIE.objectFifo> - AIE.objectFifo @link2 (%tile21, {%tile22}, 2 : i32) : !AIE.objectFifo> + AIE.objectFifo @link2 (%tile21, {%tile22}, 2 : i32) : !AIE.objectFifo> AIE.objectFifo @link3 (%tile21, {%tile23}, 2 : i32) : !AIE.objectFifo> AIE.objectFifo @link4 (%tile21, {%tile33}, 2 : i32) : !AIE.objectFifo>