Skip to content

Commit

Permalink
Run aievec lowering passes in 'main' pipeline (#833)
Browse files Browse the repository at this point in the history
This PR puts the aievec passes in the main 'IREE' pipeline, rather than
running them in aie2xclbin.

I found the trickiest was figuring out why it didn't "just work" without
adding Affine and AIEVec as legal dialects in the pattern target in
AMDAIECoreToStandard, So I'd like to refactor that pass soon -- there
are too many dialects listed there. I'd also like to throw the final
lowering to LLVM over the wall from aie2xclbin too, but that was causing
other issues which I'd rather address in another PR.

---------

Co-authored-by: Jorn Tuyls <jtuyls@users.noreply.github.com>
  • Loading branch information
newling and jtuyls authored Oct 8, 2024
1 parent 941596b commit be42d19
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
8 changes: 4 additions & 4 deletions compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

#include "AIEDialect.h"
#include "Passes.h"
#include "aievec/AIEVecDialect.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -147,8 +149,6 @@ struct AMDAIECoreToStandardFunc : OpConversionPattern<CoreOp> {
rewriter.getUnknownLoc(), coreName,
FunctionType::get(rewriter.getContext(), {}, {}));



rewriter.cloneRegionBefore(coreOp.getBody(), coreFunc.getBody(),
coreFunc.getBody().begin(), mapper);

Expand Down Expand Up @@ -214,15 +214,13 @@ struct AMDAIECoreToStandardPass : mlir::OperationPass<ModuleOp> {
void runOnOperation() override {
ModuleOp m = getOperation();


if (m.getOps<DeviceOp>().empty()) {
m.emitOpError("expected AIE.device operation at toplevel");
return signalPassFailure();
}

OpBuilder builder = OpBuilder::atBlockEnd(m.getBody());


// Ensure that we don't have an incorrect target triple. This may override
// some bogus target triple in the original mlir.
m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
Expand All @@ -233,7 +231,9 @@ struct AMDAIECoreToStandardPass : mlir::OperationPass<ModuleOp> {
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<cf::ControlFlowDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<affine::AffineDialect>();
target.addLegalDialect<VectorDialect>();
target.addLegalDialect<mlir::iree_compiler::aievec::AIEVecDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalOp<func::FuncOp, ModuleOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@
#include "XCLBinGen.h"

#include <filesystem>
#include <fstream>
#include <functional>
#include <random>
#include <regex>
#include <sstream>
// ReSharper disable once CppUnusedIncludeDirective
#include <fstream>
#include <unordered_map>

#include "AMDAIETargets.h"
#include "aievec/Passes.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree/compiler/Utils/ToolUtils.h"
#include "llvm/ADT/SmallString.h"
Expand All @@ -28,9 +24,12 @@
#include "llvm/Support/Path.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/LLVMIR/Export.h"
Expand Down Expand Up @@ -982,6 +981,7 @@ struct RemoveAlignment2FromLLVMLoadPass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
RemoveAlignment2FromLLVMLoadPass);
};

} // namespace

static LogicalResult generateUnifiedObject(
Expand All @@ -996,10 +996,8 @@ static LogicalResult generateUnifiedObject(
PassManager pm(context, ModuleOp::getOperationName());
applyConfigToPassManager(pm, printIRBeforeAll, printIRAfterAll,
printIRModuleScope, timing);

pm.addPass(mlir::iree_compiler::AMDAIE::createAMDAIECoreToStandardPass());
// Convert specific vector dialect ops (like vector.contract) to the AIEVec
// dialect
mlir::iree_compiler::aievec::buildConvertVectorToAIEVec(pm);
mlir::iree_compiler::AMDAIE::addLowerToLLVMPasses(pm);
pm.addPass(std::make_unique<RemoveAlignment2FromLLVMLoadPass>());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ func.func @matmul_i8_i32(%lhs: tensor<32x16xi8>, %rhs: tensor<16x32xi8>) -> tens

// -----

func.func @matmul_bf16(%lhs: tensor<16x32xbf16>, %rhs: tensor<32x16xbf16>) -> tensor<16x16xbf16>
func.func @matmul_bf16(%lhs: tensor<16x32xbf16>, %rhs: tensor<32x16xbf16>) -> tensor<16x16xf32>
{
%cst = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<16x16xbf16>
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<16x16xbf16>) -> tensor<16x16xbf16>
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<16x16xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x16xf32>) -> tensor<16x16xf32>
%res = linalg.matmul ins(%lhs, %rhs: tensor<16x32xbf16>, tensor<32x16xbf16>)
outs(%1: tensor<16x16xbf16>) -> tensor<16x16xbf16>
return %res : tensor<16x16xbf16>
outs(%1: tensor<16x16xf32>) -> tensor<16x16xf32>
return %res : tensor<16x16xf32>
}

// CHECK-LABEL: hal.executable.export public @matmul_bf16_dispatch_0_matmul_16x16x32_bf16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ void addMLIRAIELoweringPasses(OpPassManager &passManager) {
devicePM.addPass(createAMDAIENormalizeAddressSpacesPass());
devicePM.addPass(createCanonicalizerPass());
}

mlir::iree_compiler::aievec::buildConvertVectorToAIEVec(passManager);
}

// NOTE: this runs on the top-level program module containing all hal.executable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ func.func @mm_in_bf16_out_f32(%lhs: tensor<64x64xbf16>,
return %res : tensor<64x64xf32>
}

// CHECK-DISABLED-NOT: vector.contract
// CHECK-ENABLED: vector.contract
// CHECK-DEFAULT: vector.contract
// CHECK-DISABLED-NOT: aievec.matmul
// CHECK-ENABLED: aievec.matmul
// CHECK-DEFAULT: aievec.matmul

0 comments on commit be42d19

Please sign in to comment.