forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR] Add initial convert-memref-to-emitc pass (#30)
This translates memref types in func.func, func.call and func.return. Reviewers: TinaAMD Reviewed By: TinaAMD Pull Request: #113
- Loading branch information
1 parent
995b14c
commit d3c90a4
Showing
8 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H | ||
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
class RewritePatternSet; | ||
class TypeConverter; | ||
|
||
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC | ||
#include "mlir/Conversion/Passes.h.inc" | ||
|
||
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, | ||
TypeConverter &typeConverter); | ||
|
||
std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass(); | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
add_mlir_conversion_library(MLIRMemRefToEmitC | ||
MemRefToEmitC.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIREmitCDialect | ||
MLIRMemRefDialect | ||
MLIRTransforms | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements a pass to convert memref ops into emitc ops. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/Passes.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
/// Disallow all memrefs even though we only have conversions | ||
/// for memrefs with static shape right now to have good diagnostics. | ||
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); } | ||
|
||
template <typename RangeT> | ||
std::enable_if_t<!std::is_convertible<RangeT, Type>::value && | ||
!std::is_convertible<RangeT, Operation *>::value, | ||
bool> | ||
isLegal(RangeT &&range) { | ||
return llvm::all_of(range, [](Type type) { return isLegal(type); }); | ||
} | ||
|
||
bool isLegal(Operation *op) { | ||
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); | ||
} | ||
|
||
bool isSignatureLegal(FunctionType ty) { | ||
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); | ||
} | ||
|
||
struct ConvertMemRefToEmitCPass | ||
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { | ||
void runOnOperation() override { | ||
TypeConverter converter; | ||
// Pass through for all other types. | ||
converter.addConversion([](Type type) { return type; }); | ||
|
||
converter.addConversion([](MemRefType memRefType) -> std::optional<Type> { | ||
if (memRefType.hasStaticShape()) { | ||
return emitc::ArrayType::get(memRefType.getShape(), | ||
memRefType.getElementType()); | ||
} | ||
return {}; | ||
}); | ||
|
||
converter.addConversion( | ||
[&converter](FunctionType ty) -> std::optional<Type> { | ||
SmallVector<Type> inputs; | ||
if (failed(converter.convertTypes(ty.getInputs(), inputs))) | ||
return std::nullopt; | ||
|
||
SmallVector<Type> results; | ||
if (failed(converter.convertTypes(ty.getResults(), results))) | ||
return std::nullopt; | ||
|
||
return FunctionType::get(ty.getContext(), inputs, results); | ||
}); | ||
|
||
RewritePatternSet patterns(&getContext()); | ||
populateMemRefToEmitCConversionPatterns(patterns, converter); | ||
|
||
ConversionTarget target(getContext()); | ||
target.addDynamicallyLegalOp<func::FuncOp>( | ||
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); }); | ||
target.addDynamicallyLegalDialect<func::FuncDialect>( | ||
[](Operation *op) { return isLegal(op); }); | ||
target.addIllegalDialect<memref::MemRefDialect>(); | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, | ||
TypeConverter &converter) { | ||
|
||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, | ||
converter); | ||
populateCallOpTypeConversionPattern(patterns, converter); | ||
populateReturnOpTypeConversionPattern(patterns, converter); | ||
} | ||
|
||
std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() { | ||
return std::make_unique<ConvertMemRefToEmitCPass>(); | ||
} |
24 changes: 24 additions & 0 deletions
24
mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics | ||
|
||
// Unranked memrefs are not converted | ||
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}} | ||
func.func @memref_unranked(%arg0 : memref<*xf32>) { | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// Memrefs with dynamic shapes are not converted | ||
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}} | ||
func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) { | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// Memrefs with dynamic shapes are not converted | ||
func.func @memref_op(%arg0 : memref<2x4xf32>) { | ||
// expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}} | ||
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s | ||
|
||
// CHECK-LABEL: memref_arg | ||
// CHECK-SAME: !emitc.array<32xf32>) | ||
func.func @memref_arg(%arg0 : memref<32xf32>) { | ||
func.return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: memref_return | ||
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32> | ||
func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> { | ||
// CHECK: return %[[arg0]] : !emitc.array<32xf32> | ||
func.return %arg0 : memref<32xf32> | ||
} | ||
|
||
// CHECK-LABEL: memref_call | ||
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) | ||
func.func @memref_call(%arg0 : memref<32xf32>) { | ||
// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32> | ||
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32> | ||
func.return | ||
} |