Skip to content

Commit

Permalink
[MLIR] Add initial convert-memref-to-emitc pass (#30)
Browse files Browse the repository at this point in the history
This translates memref types in func.func, func.call and func.return.

Reviewers: TinaAMD

Reviewed By: TinaAMD

Pull Request: #113
  • Loading branch information
mgehre-amd authored Feb 28, 2024
1 parent 995b14c commit d3c90a4
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

#include "mlir/Pass/Pass.h"

namespace mlir {
class RewritePatternSet;
class TypeConverter;

#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"

void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);

std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();

} // namespace mlir

#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
];
}

//===----------------------------------------------------------------------===//
// MemRefToEmitC
//===----------------------------------------------------------------------===//

def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
let summary = "Convert MemRef dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
}

//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(NVGPUToNVVM)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRMemRefToEmitC
MemRefToEmitC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRMemRefDialect
MLIRTransforms
)
113 changes: 113 additions & 0 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert memref ops into emitc ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// Disallow all memrefs even though we only have conversions
/// for memrefs with static shape right now to have good diagnostics.
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }

template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
!std::is_convertible<RangeT, Operation *>::value,
bool>
isLegal(RangeT &&range) {
return llvm::all_of(range, [](Type type) { return isLegal(type); });
}

bool isLegal(Operation *op) {
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
}

bool isSignatureLegal(FunctionType ty) {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}

struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
TypeConverter converter;
// Pass through for all other types.
converter.addConversion([](Type type) { return type; });

converter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
if (memRefType.hasStaticShape()) {
return emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
}
return {};
});

converter.addConversion(
[&converter](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
return std::nullopt;

SmallVector<Type> results;
if (failed(converter.convertTypes(ty.getResults(), results)))
return std::nullopt;

return FunctionType::get(ty.getContext(), inputs, results);
});

RewritePatternSet patterns(&getContext());
populateMemRefToEmitCConversionPatterns(patterns, converter);

ConversionTarget target(getContext());
target.addDynamicallyLegalOp<func::FuncOp>(
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
target.addDynamicallyLegalDialect<func::FuncDialect>(
[](Operation *op) { return isLegal(op); });
target.addIllegalDialect<memref::MemRefDialect>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
}

std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
return std::make_unique<ConvertMemRefToEmitCPass>();
}
24 changes: 24 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics

// Unranked memrefs are not converted
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
func.func @memref_unranked(%arg0 : memref<*xf32>) {
return
}

// -----

// Memrefs with dynamic shapes are not converted
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
return
}

// -----

// Memrefs with dynamic shapes are not converted
func.func @memref_op(%arg0 : memref<2x4xf32>) {
// expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}
24 changes: 24 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s

// CHECK-LABEL: memref_arg
// CHECK-SAME: !emitc.array<32xf32>)
func.func @memref_arg(%arg0 : memref<32xf32>) {
func.return
}

// -----

// CHECK-LABEL: memref_return
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
// CHECK: return %[[arg0]] : !emitc.array<32xf32>
func.return %arg0 : memref<32xf32>
}

// CHECK-LABEL: memref_call
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
func.func @memref_call(%arg0 : memref<32xf32>) {
// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
func.return
}

0 comments on commit d3c90a4

Please sign in to comment.