Skip to content

Commit

Permalink
[canonicalization] Add AMDAIEDialectFoldInterface to keeps ops in amd…
Browse files Browse the repository at this point in the history
…aie.cores (#716)
  • Loading branch information
newling authored Aug 28, 2024
1 parent 9095210 commit a5fbf9d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 18 deletions.
5 changes: 4 additions & 1 deletion compiler/plugins/target/AMD-AIE/aie/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,16 @@ void AIEDialect::printType(Type type, DialectAsmPrinter &printer) const {

/// without this, canonicalize/cse/etc will lift eg constants out of core ops
/// causing eg lower-to-aie to fail to converge
///
/// There's no way to do this is tablegen, so unfortunately it must be hidden
/// away here
struct AIEDialectFoldInterface : DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;

/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final override {
bool shouldMaterializeInto(Region *region) const final {
// If this is an AIE::CoreOp region, then insert into it.
return isa<CoreOp>(region->getParentOp());
}
Expand Down
23 changes: 21 additions & 2 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

#include "iree-amd-aie/IR/AMDAIEAttrs.h"
#include "iree-amd-aie/IR/AMDAIEDialect.cpp.inc"
#include "iree-amd-aie/IR/AMDAIETypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"

namespace mlir::iree_compiler::AMDAIE {

Expand All @@ -24,11 +25,29 @@ struct AMDAIEDialectOpAsmInterface : public OpAsmDialectInterface {
}
};

/// without this, canonicalize/cse/etc will lift eg constants out of core ops
/// at every opportunity, causing problems when lowering to AIE.
///
/// There's no way to do this is tablegen, so unfortunately it must be hidden
/// away here
struct AMDAIEDialectFoldInterface : DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;

/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final {
// If this is an AMDAIE::CoreOp region, then insert into it.
return isa<AMDAIE::CoreOp>(region->getParentOp());
}
};

void AMDAIEDialect::initialize() {
initializeAMDAIEAttrs();
initializeAMDAIEOps();
initializeAMDAIETypes();
addInterfaces<AMDAIEDialectOpAsmInterface>();
addInterfaces<AMDAIEDialectFoldInterface>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,34 @@ func.func @logicalobjectfifo_from_memref(%arg0: memref<1x1x8x16xi32, 1>) {
%1 = amdaie.dma_cpy_nd(%0[][][], %0[][][]) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>)
return
}

// -----


// A test of AMDAIEDialectFoldInterface. Don't move ops out of cores.

// CHECK-LABEL: func @isolated_cores
// CHECK-NOT: arith.constant 3
// CHECK: amdaie.core
// CHECK: arith.constant 3
// CHECK: amdaie.core
// CHECK: arith.constant 3
func.func @isolated_cores() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%tile_0 = amdaie.tile(%c0, %c0)
%tile_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.core(%tile_0, in : [], out : []) {
%c3 = arith.constant 3 : index
%alloc = memref.alloc() : memref<2x2xindex>
linalg.fill ins(%c3 : index) outs(%alloc : memref<2x2xindex>)
amdaie.end
}
%1 = amdaie.core(%tile_1, in : [], out : []) {
%c3 = arith.constant 3 : index
%alloc = memref.alloc() : memref<2x2xindex>
linalg.fill ins(%c3 : index) outs(%alloc : memref<2x2xindex>)
amdaie.end
}
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK: amdaie.core
// CHECK-DAG: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-DAG: amdaie.core
// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-DAG: %[[REM:.+]] = arith.remsi %[[ARG0]], %[[C2]] : index
// CHECK-DAG: %[[DIV:.+]] = arith.divsi %[[ARG0]], %[[C2]] : index
// CHECK-DAG: func.call @callee(%[[DIV]], %[[REM]]) : (index, index) -> ()
module @test_single {
func.func private @callee(%i: index, %j: index)
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%tile_0_2 = amdaie.tile(%c0, %c2)
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) {
%c7 = arith.constant 7 : index
%c11 = arith.constant 11 : index
%tile_7_11 = amdaie.tile(%c7, %c11)
%core_7_11 = amdaie.core(%tile_7_11, in : [], out : []) {
scf.forall (%i, %j) in (2, 2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
Expand All @@ -31,8 +33,8 @@ module @test_single {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK: amdaie.core
// CHECK-DAG: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-DAG: amdaie.core
// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-DAG: %[[REM:.+]] = arith.remsi %[[ARG0]], %[[C2]] : index
// CHECK-DAG: %[[DIV:.+]] = arith.divsi %[[ARG0]], %[[C2]] : index
// CHECK-DAG: func.call @callee(%[[DIV]], %[[REM]]) : (index, index) -> ()
Expand All @@ -44,8 +46,10 @@ module @test_multi {
func.func private @callee(%i: index, %j: index)
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%tile_0_2 = amdaie.tile(%c0, %c2)
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) {
%c7 = arith.constant 7 : index
%c11 = arith.constant 11 : index
%tile_7_11 = amdaie.tile(%c7, %c11)
%core_7_11 = amdaie.core(%tile_7_11, in : [], out : []) {
scf.forall (%i, %j) in (2, 2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
Expand All @@ -64,8 +68,8 @@ module @test_multi {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK: amdaie.core
// CHECK-DAG: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C16]] step %[[C1]] {
// CHECK-DAG: amdaie.core
// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C16]] step %[[C1]] {
// CHECK-DAG: %[[REM0:.+]] = arith.remsi %[[ARG0]], %[[C4]] : index
// CHECK-DAG: %[[DIV0:.+]] = arith.divsi %[[ARG0]], %[[C4]] : index
// CHECK-DAG: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C4]] step %[[C1]] {
Expand All @@ -76,8 +80,10 @@ module @test_nested {
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%tile_0_2 = amdaie.tile(%c0, %c2)
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) {
%c7 = arith.constant 7 : index
%c11 = arith.constant 11 : index
%tile_7_11 = amdaie.tile(%c7, %c11)
%core_7_11 = amdaie.core(%tile_7_11, in : [], out : []) {
scf.forall (%i, %j) in (4, 4) {
scf.forall (%k, %l) in (2, 2) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK: amdaie.core
// CHECK-DAG: amdaie.core
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK: amdaie.logicalobjectfifo.acquire
// CHECK: amdaie.logicalobjectfifo.access
Expand Down Expand Up @@ -54,7 +54,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK: amdaie.core
// CHECK-DAG: amdaie.core
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C8]] step %[[C4]] {
// CHECK: amdaie.logicalobjectfifo.acquire
// CHECK: amdaie.logicalobjectfifo.access
Expand Down Expand Up @@ -111,7 +111,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C17:.+]] = arith.constant 17 : index
// CHECK-DAG: %[[C21:.+]] = arith.constant 21 : index
// CHECK: amdaie.core
// CHECK-DAG: amdaie.core
// CHECK: scf.for %[[ARG0:.+]] = %[[C1]] to %[[C17]] step %[[C8]] {
// CHECK: amdaie.logicalobjectfifo.acquire
// CHECK: amdaie.logicalobjectfifo.access
Expand Down

0 comments on commit a5fbf9d

Please sign in to comment.