Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev #209

Draft
wants to merge 9 commits into
base: feature/fused-ops
Choose a base branch
from
Draft

dev #209

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {

def EmitC_CastOp : EmitC_Op<"cast",
[CExpression,
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape]> {
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Cast operation";
let description = [{
The `cast` operation performs an explicit type conversion and is emitted
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ class CallOpConversion final : public OpConversionPattern<func::CallOp> {
callOp, "only functions with zero or one result can be converted");

// Convert the original function results.
Type resultTy = nullptr;
SmallVector<Type> types;
if (callOp.getNumResults()) {
resultTy = typeConverter->convertType(callOp.getResult(0).getType());
auto resultTy = typeConverter->convertType(callOp.getResult(0).getType());
if (!resultTy)
return rewriter.notifyMatchFailure(
callOp, "function return type conversion failed");
types.push_back(resultTy);
}

rewriter.replaceOpWithNewOp<emitc::CallOp>(
callOp, resultTy, adaptor.getOperands(), callOp->getAttrs());
callOp, types, adaptor.getOperands(), callOp->getAttrs());

return success();
}
Expand Down
50 changes: 48 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,51 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};

struct ConvertCollapseShape final
: public OpConversionPattern<memref::CollapseShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, resultTy, operands.getSrc());
return success();
}
};

struct ConvertExpandShape final
: public OpConversionPattern<memref::ExpandShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, resultTy, operands.getSrc());
return success();
}
};

} // namespace

void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
Expand All @@ -186,6 +231,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {

void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
patterns.add<ConvertAlloca, ConvertCollapseShape, ConvertExpandShape,
ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
converter, patterns.getContext());
}
13 changes: 6 additions & 7 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,12 @@ LogicalResult emitc::AssignOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

return (
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType, emitc::SignedSizeTType, emitc::SizeTType>(
input)) &&
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType, emitc::SignedSizeTType, emitc::SizeTType>(
output)));
return ((llvm::isa<ArrayType, IntegerType, FloatType, IndexType,
emitc::OpaqueType, emitc::PointerType,
emitc::SignedSizeTType, emitc::SizeTType>(input)) &&
(llvm::isa<ArrayType, IntegerType, FloatType, IndexType,
emitc::OpaqueType, emitc::PointerType,
emitc::SignedSizeTType, emitc::SizeTType>(output)));
}

OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2327,11 +2327,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();

auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
insertOp.isSameAs(extractOp, isSame))
return insertOp.getSource();

while (insertOp) {
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (insertOp.getSource().getType() == extractOp.getType() &&
insertOp.isSameAs(extractOp, isSame))
return insertOp.getSource();
// TODO: Need to stop at the first insert_slice that has some overlap with
// the extracted range to avoid returning an early insert_slice that was
// (partially) overwritten by later ones.
insertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
}
return {};
}

Expand Down
73 changes: 55 additions & 18 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ struct CppEmitter {
LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);

/// Emits an assignment for a variable which has been declared previously.
LogicalResult emitVariableAssignment(OpResult result);
LogicalResult emitVariableAssignment(OpResult result, StringRef prefix = "v");

/// Emits a variable declaration for a result of an operation.
LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon);
LogicalResult emitVariableDeclaration(OpResult result, bool trailingSemicolon,
StringRef prefix = "v");

/// Emits a declaration of a variable with the given type and name.
LogicalResult emitVariableDeclaration(Location loc, Type type,
Expand All @@ -152,7 +152,7 @@ struct CppEmitter {
/// - emits nothing if no value produced by op;
/// Emits final '=' operator where a type is produced. Returns failure if
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);
LogicalResult emitAssignPrefix(Operation &op, StringRef prefix = "v");

/// Emits a global variable declaration or definition.
LogicalResult emitGlobalVariable(GlobalOp op);
Expand All @@ -175,7 +175,7 @@ struct CppEmitter {
LogicalResult emitExpression(ExpressionOp expressionOp);

/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);
StringRef getOrCreateName(Value val, StringRef prefix = "v");

// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);
Expand Down Expand Up @@ -303,6 +303,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Attribute value) {
OpResult result = operation->getResult(0);

std::string prefix = "v";
if (auto c = dyn_cast<emitc::ConstantOp>(operation)) {
Attribute val = c.getValue();
if (auto ia = dyn_cast<IntegerAttr>(val)) {
if (ia.getInt() > 0)
prefix = "c_" + std::to_string(ia.getInt()) + "_";
else
prefix = "c_n" + std::to_string(-ia.getInt()) + "_";
}
}

// Only emit an assignment as the variable was already declared when printing
// the FuncOp.
if (emitter.shouldDeclareVariablesAtTop()) {
Expand All @@ -312,7 +323,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
return success();
}

if (failed(emitter.emitVariableAssignment(result)))
if (failed(emitter.emitVariableAssignment(result, prefix)))
return failure();
return emitter.emitAttribute(operation->getLoc(), value);
}
Expand All @@ -326,7 +337,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
}

// Emit a variable declaration.
if (failed(emitter.emitAssignPrefix(*operation)))
if (failed(emitter.emitAssignPrefix(*operation, prefix)))
return failure();
return emitter.emitAttribute(operation->getLoc(), value);
}
Expand Down Expand Up @@ -716,6 +727,21 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *castOp.getOperation();

if (auto arrType = dyn_cast<emitc::ArrayType>(castOp.getType())) {
std::string shapeStr = "";
for (auto i : arrType.getShape()) {
shapeStr += "[";
shapeStr += std::to_string(i);
shapeStr += "]";
}
os << "float (&" << emitter.getOrCreateName(castOp.getResult()) << ")"
<< shapeStr << " = *reinterpret_cast<float (*)" << shapeStr << ">(";
if (failed(emitter.emitOperand(castOp.getOperand())))
return failure();
os << ")";
return success();
}

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "(";
Expand Down Expand Up @@ -1128,7 +1154,7 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
}

/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
StringRef CppEmitter::getOrCreateName(Value val, StringRef prefix) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
if (!valueMapper.count(val)) {
Expand All @@ -1139,7 +1165,8 @@ StringRef CppEmitter::getOrCreateName(Value val) {
val.getDefiningOp())) {
valueMapper.insert(val, getGlobal.getName().str());
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
valueMapper.insert(val,
formatv("{0}{1}", prefix, ++valueInScopeCount.top()));
}
}
return *valueMapper.begin(val);
Expand Down Expand Up @@ -1377,17 +1404,19 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
}

LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
LogicalResult CppEmitter::emitVariableAssignment(OpResult result,
StringRef prefix) {
if (!hasValueInScope(result)) {
return result.getDefiningOp()->emitOpError(
"result variable for the operation has not been declared");
}
os << getOrCreateName(result) << " = ";
os << getOrCreateName(result, prefix) << " = ";
return success();
}

LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
bool trailingSemicolon,
StringRef prefix) {
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
return success();
if (hasValueInScope(result)) {
Expand All @@ -1396,7 +1425,7 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
}
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
result.getType(),
getOrCreateName(result))))
getOrCreateName(result, prefix))))
return failure();
if (trailingSemicolon)
os << ";\n";
Expand Down Expand Up @@ -1427,7 +1456,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
return success();
}

LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
LogicalResult CppEmitter::emitAssignPrefix(Operation &op, StringRef prefix) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
return success();
Expand All @@ -1438,10 +1467,11 @@ LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
case 1: {
OpResult result = op.getResult(0);
if (shouldDeclareVariablesAtTop()) {
if (failed(emitVariableAssignment(result)))
if (failed(emitVariableAssignment(result, prefix)))
return failure();
} else {
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false,
prefix)))
return failure();
os << " = ";
}
Expand All @@ -1450,7 +1480,8 @@ LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
default:
if (!shouldDeclareVariablesAtTop()) {
for (OpResult result : op.getResults()) {
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true,
prefix)))
return failure();
}
}
Expand Down Expand Up @@ -1512,7 +1543,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return success();

os << (trailingSemicolon ? ";\n" : "\n");
os << (trailingSemicolon ? ";" : "");

if (!isa<UnknownLoc>(op.getLoc())) {
os << " // ";
op.getLoc().print(os);
}
os << "\n";

return success();
}
Expand Down
8 changes: 8 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ declare_mlir_dialect_python_bindings(
dialects/func.py
DIALECT_NAME func)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/EmitCOps.td
SOURCES
dialects/emitc.py
DIALECT_NAME emitc)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
19 changes: 19 additions & 0 deletions mlir/python/mlir/dialects/EmitCOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- EmitCOps.td - Entry point for Func bind -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the main file from which the Python bindings for the Func dialect
// are generated.
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_FUNC
#define PYTHON_BINDINGS_FUNC

include "mlir/Dialect/EmitC/IR/EmitC.td"

#endif
Loading
Loading