From d3c90a4f7b9d928acb71690df3a9bc2f51762770 Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:49:44 +0100 Subject: [PATCH] [MLIR] Add initial convert-memref-to-emitc pass (#30) This translates memref types in func.func, func.call and func.return. Reviewers: TinaAMD Reviewed By: TinaAMD Pull Request: https://github.com/xilinx/llvm-project/pull/113 --- .../Conversion/MemRefToEmitC/MemRefToEmitC.h | 27 +++++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 9 ++ mlir/lib/Conversion/CMakeLists.txt | 1 + .../Conversion/MemRefToEmitC/CMakeLists.txt | 17 +++ .../MemRefToEmitC/MemRefToEmitC.cpp | 113 ++++++++++++++++++ .../MemRefToEmitC/memref-to-emit-failed.mlir | 24 ++++ .../MemRefToEmitC/memref-to-emitc.mlir | 24 ++++ 8 files changed, 216 insertions(+) create mode 100644 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h create mode 100644 mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt create mode 100644 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h new file mode 100644 index 00000000000000..3d16e986ba44e9 --- /dev/null +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -0,0 +1,27 @@ +//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H +#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class RewritePatternSet; +class TypeConverter; + +#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC +#include "mlir/Conversion/Passes.h.inc" + +void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); + +std::unique_ptr> createConvertMemRefToEmitCPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index a25fd17ea923fb..985d6b8bfcb0a1 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -42,6 +42,7 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bb1341faafcf5a..5f81415e705549 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -724,6 +724,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// MemRefToEmitC +//===----------------------------------------------------------------------===// + +def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { + let summary = "Convert MemRef dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // MemRefToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index c3a2481975040c..e465ecaf37f92c 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -32,6 +32,7 @@ add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) +add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) add_subdirectory(NVGPUToNVVM) diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt new file mode 100644 index 00000000000000..ee2552d1821f2f --- /dev/null +++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRMemRefToEmitC + MemRefToEmitC.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRMemRefDialect + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp new file mode 100644 index 00000000000000..9d7cfa7f840b29 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -0,0 +1,113 @@ +//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert memref ops into emitc ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// Disallow all memrefs even though we only have conversions +/// for memrefs with static shape right now to have good diagnostics. +bool isLegal(Type t) { return !isa(t); } + +template +std::enable_if_t::value && + !std::is_convertible::value, + bool> +isLegal(RangeT &&range) { + return llvm::all_of(range, [](Type type) { return isLegal(type); }); +} + +bool isLegal(Operation *op) { + return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); +} + +bool isSignatureLegal(FunctionType ty) { + return isLegal(llvm::concat(ty.getInputs(), ty.getResults())); +} + +struct ConvertMemRefToEmitCPass + : public impl::ConvertMemRefToEmitCBase { + void runOnOperation() override { + TypeConverter converter; + // Pass through for all other types. + converter.addConversion([](Type type) { return type; }); + + converter.addConversion([](MemRefType memRefType) -> std::optional { + if (memRefType.hasStaticShape()) { + return emitc::ArrayType::get(memRefType.getShape(), + memRefType.getElementType()); + } + return {}; + }); + + converter.addConversion( + [&converter](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(converter.convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(converter.convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); + + RewritePatternSet patterns(&getContext()); + populateMemRefToEmitCConversionPatterns(patterns, converter); + + ConversionTarget target(getContext()); + target.addDynamicallyLegalOp( + [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); }); + target.addDynamicallyLegalDialect( + [](Operation *op) { return isLegal(op); }); + target.addIllegalDialect(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &converter) { + + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); +} + +std::unique_ptr> mlir::createConvertMemRefToEmitCPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir new file mode 100644 index 00000000000000..56d48bef3e1b34 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics + +// Unranked memrefs are not converted +// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}} +func.func @memref_unranked(%arg0 : memref<*xf32>) { + return +} + +// ----- + +// Memrefs with dynamic shapes are not converted +// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}} +func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) { + return +} + +// ----- + +// Memrefs with dynamic shapes are not converted +func.func @memref_op(%arg0 : memref<2x4xf32>) { + // expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}} + memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> + return +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir new file mode 100644 index 00000000000000..5114d6d91e5046 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s + +// CHECK-LABEL: memref_arg +// CHECK-SAME: !emitc.array<32xf32>) +func.func @memref_arg(%arg0 : memref<32xf32>) { + func.return +} + +// ----- + +// CHECK-LABEL: memref_return +// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32> +func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> { +// CHECK: return %[[arg0]] : !emitc.array<32xf32> + func.return %arg0 : memref<32xf32> +} + +// CHECK-LABEL: memref_call +// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) +func.func @memref_call(%arg0 : memref<32xf32>) { +// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32> + func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32> + func.return +}