diff --git a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp index 46741dab04..0d84bc31cb 100644 --- a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp +++ b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp @@ -848,6 +848,34 @@ struct AIEObjectFifoStatefulTransformPass std::vector> &dependencies, Value base, int64_t step, bool inLoop) { std::vector duplicatedOperations; // operations in current + // Recursive function to replace operands, uses recursion to handle nested + // loop structures. + std::function replaceOpsNested = + [&](Operation *op, unsigned &opIndex, + unsigned numDuplications) -> void { + if (auto loopOp = dyn_cast(op)) { + Block *body = loopOp.getBody(); + auto withoutTerminator = --body->end(); + // NOTE(jornt): This only handles the cases where the nested scf::for is + // located at the start of the body. This should be the most common + // case, but is not fully generic. + if (auto nestedLoop = dyn_cast(body->begin())) { + opIndex++; + replaceOperands(builder, nestedLoop, opIndex, base, step, inLoop, + numDuplications, dependencies, duplicatedOperations); + replaceOpsNested(nestedLoop, opIndex, numDuplications); + } else { + for (auto loopBodyOp = body->begin(); loopBodyOp != withoutTerminator; + ++loopBodyOp) { + opIndex++; + replaceOperands(builder, &*loopBodyOp, opIndex, base, step, inLoop, + numDuplications, dependencies, + duplicatedOperations); + } + } + } + }; + // duplication iteration for (int i = 0; i < numDuplications; i++) { duplicatedOperations.clear(); @@ -858,17 +886,7 @@ struct AIEObjectFifoStatefulTransformPass replaceOperands(builder, clone, opIndex, base, step, inLoop, i, dependencies, duplicatedOperations); builder.insert(clone); - - if (auto nestedLoop = dyn_cast(clone)) { - Block *body = nestedLoop.getBody(); - auto withoutTerminator = --body->end(); - for (auto loopOp = body->begin(); loopOp != withoutTerminator; - ++loopOp) { - opIndex++; - replaceOperands(builder, &*loopOp, opIndex, base, step, inLoop, i, - dependencies, duplicatedOperations); - } - } + replaceOpsNested(clone, opIndex, i); } } } diff --git a/test/objectFifo-stateful-transform/nested_loop_test.mlir b/test/objectFifo-stateful-transform/nested_loop_test.mlir new file mode 100644 index 0000000000..12d35fce7e --- /dev/null +++ b/test/objectFifo-stateful-transform/nested_loop_test.mlir @@ -0,0 +1,126 @@ +//===- nested_loop_test.mlir -----------------------------------------*- MLIR -*-===// +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Date: April 3rd 2024 +// +//===----------------------------------------------------------------------===// + +// RUN: aie-opt --aie-objectFifo-stateful-transform %s | FileCheck %s + +// CHECK-LABEL: aie.device(ipu) +// CHECK: scf.for +// CHECK: { +// CHECK: aie.use_lock +// CHECK: memref.reinterpret_cast +// CHECK: aie.use_lock +// CHECK: memref.reinterpret_cast +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: memref.load +// CHECK: memref.load +// CHECK: memref.load +// CHECK: arith.muli +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: aie.use_lock +// CHECK: aie.use_lock +// CHECK: aie.use_lock +// CHECK: memref.reinterpret_cast +// CHECK: aie.use_lock +// CHECK: memref.reinterpret_cast +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: scf.for +// CHECK: { +// CHECK: memref.load +// CHECK: memref.load +// CHECK: memref.load +// CHECK: arith.muli +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: aie.use_lock +// CHECK: aie.use_lock +// CHECK: } + +aie.device(ipu) { + %tile_0_1 = aie.tile(0, 1) + %tile_1_2 = aie.tile(1, 2) + %tile_0_2 = aie.tile(0, 2) + aie.objectfifo @in2(%tile_0_1, {%tile_0_2, %tile_1_2}, 4 : i32) : !aie.objectfifo> + aie.objectfifo @in7(%tile_0_1, {%tile_1_2}, 4 : i32) : !aie.objectfifo> + aie.objectfifo @in8(%tile_1_2, {%tile_0_1}, 4 : i32) : !aie.objectfifo> + %core_1_2 = aie.core(%tile_1_2) { + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c960 = arith.constant 960 : index + %0 = aie.objectfifo.acquire @in8(Produce, 1) : !aie.objectfifosubview> + %1 = aie.objectfifo.subview.access %0[0] : !aie.objectfifosubview> -> memref<32x32xi32, 1> + %reinterpret_cast = memref.reinterpret_cast %1 to offset: [0], sizes: [4, 8, 4, 8], strides: [256, 32, 8, 1] : memref<32x32xi32, 1> to memref<4x8x4x8xi32, 1> + aie.objectfifo.release @in2(Consume, 1) + aie.objectfifo.release @in7(Consume, 1) + scf.for %arg0 = %c64 to %c960 step %c64 { + %10 = aie.objectfifo.acquire @in2(Consume, 1) : !aie.objectfifosubview> + %11 = aie.objectfifo.subview.access %10[0] : !aie.objectfifosubview> -> memref<32x64xi32, 1> + %reinterpret_cast_4 = memref.reinterpret_cast %11 to offset: [0], sizes: [8, 8, 4, 8], strides: [256, 32, 8, 1] : memref<32x64xi32, 1> to memref<8x8x4x8xi32, 1> + %12 = aie.objectfifo.acquire @in7(Consume, 1) : !aie.objectfifosubview> + %13 = aie.objectfifo.subview.access %12[0] : !aie.objectfifosubview> -> memref<64x32xi32, 1> + %reinterpret_cast_5 = memref.reinterpret_cast %13 to offset: [0], sizes: [4, 8, 8, 8], strides: [512, 64, 8, 1] : memref<64x32xi32, 1> to memref<4x8x8x8xi32, 1> + scf.for %arg1 = %c0 to %c8 step %c1 { + scf.for %arg2 = %c0 to %c4 step %c1 { + scf.for %arg3 = %c0 to %c8 step %c1 { + scf.for %arg4 = %c0 to %c4 step %c1 { + scf.for %arg5 = %c0 to %c8 step %c1 { + scf.for %arg6 = %c0 to %c8 step %c1 { + %14 = memref.load %reinterpret_cast_4[%arg3, %arg1, %arg4, %arg6] : memref<8x8x4x8xi32, 1> + %15 = memref.load %reinterpret_cast_5[%arg2, %arg3, %arg6, %arg5] : memref<4x8x8x8xi32, 1> + %16 = memref.load %reinterpret_cast[%arg2, %arg1, %arg4, %arg5] : memref<4x8x4x8xi32, 1> + %17 = arith.muli %14, %15 : i32 + %18 = arith.addi %16, %17 : i32 + memref.store %18, %reinterpret_cast[%arg2, %arg1, %arg4, %arg5] : memref<4x8x4x8xi32, 1> + } + } + } + } + } + } + aie.objectfifo.release @in2(Consume, 1) + aie.objectfifo.release @in7(Consume, 1) + } + aie.end + } +} \ No newline at end of file