Skip to content

Commit

Permalink
Merge remote-tracking branch 'xlnx/feature/fused-ops' into bump_to_52…
Browse files Browse the repository at this point in the history
…050f3f
  • Loading branch information
mgehre-amd committed Sep 11, 2024
2 parents 7f73835 + 18808c7 commit f320c79
Show file tree
Hide file tree
Showing 20 changed files with 727 additions and 28 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

/// Give the name of the EmitC reference attribute.
StringRef getReferenceAttributeName();

} // namespace emitc
} // namespace mlir

Expand Down
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {

def EmitC_CastOp : EmitC_Op<"cast",
[CExpression,
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape]> {
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Cast operation";
let description = [{
The `cast` operation performs an explicit type conversion and is emitted
Expand All @@ -285,9 +284,11 @@ def EmitC_CastOp : EmitC_Op<"cast",
```
}];

let arguments = (ins EmitCType:$source);
let arguments = (ins EmitCType:$source,
UnitAttr:$reference);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest) (`ref` $reference^)?";
let hasVerifier = 1;
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
Expand Down Expand Up @@ -1050,14 +1051,16 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
UnitAttr:$extern_specifier,
UnitAttr:$static_specifier,
UnitAttr:$const_specifier);
UnitAttr:$const_specifier,
UnitAttr:$reference);

let assemblyFormat = [{
(`extern` $extern_specifier^)?
(`static` $static_specifier^)?
(`const` $const_specifier^)?
$sym_name
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
(`ref` $reference^)?
attr-dict
}];

Expand Down
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===---------- FunctionOpAssembly.h - Parser for `emitc.func` op ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H
#define MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H

#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LogicalResult.h"

#include "mlir/IR/Builders.h"

namespace mlir::emitc {

class FuncOp;

ParseResult
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments,
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs);

ParseResult
parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic,
StringAttr typeAttrName,
function_interface_impl::FuncTypeBuilder funcTypeBuilder,
StringAttr argAttrsName, StringAttr resAttrsName);

void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef<Type> argTypes,
bool isVariadic, ArrayRef<Type> resultTypes);

void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic,
StringRef typeAttrName, StringAttr argAttrsName,
StringAttr resAttrsName);

} // namespace mlir::emitc

#endif // MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H
67 changes: 66 additions & 1 deletion mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -166,6 +168,68 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};

struct ConvertCollapseShape final
: public OpConversionPattern<memref::CollapseShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");
auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

struct ConvertExpandShape final
: public OpConversionPattern<memref::ExpandShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");

auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

} // namespace

void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
Expand All @@ -187,5 +251,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
ConvertStore, ConvertCollapseShape, ConvertExpandShape>(
converter, patterns.getContext());
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIREmitCDialect
EmitC.cpp
FunctionOpAssembly.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC
Expand Down
50 changes: 43 additions & 7 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/Dialect/EmitC/IR/FunctionOpAssembly.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -122,6 +123,8 @@ bool mlir::emitc::isPointerWideType(Type type) {
type);
}

StringRef mlir::emitc::getReferenceAttributeName() { return "emitc.reference"; }

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand Down Expand Up @@ -225,13 +228,37 @@ LogicalResult emitc::AssignOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

// Cast to array is only possible from an array
if (isa<emitc::ArrayType>(input) != isa<emitc::ArrayType>(output))
return false;

// Arrays can be casted to arrays by reference.
if (isa<emitc::ArrayType>(input) && isa<emitc::ArrayType>(output))
return true;

// Scalars
return (
(emitc::isIntegerIndexOrOpaqueType(input) ||
emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
(emitc::isIntegerIndexOrOpaqueType(output) ||
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
}

LogicalResult CastOp::verify() {
bool isReference = getReference();

if (isa<emitc::ArrayType>(getDest().getType())) {
if (!isReference)
return emitOpError("cast of array must bear a reference");
return success();
}

if (isReference)
return emitOpError("cast of value type must not bear a reference");

return success();
}

//===----------------------------------------------------------------------===//
// CallOpaqueOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -518,16 +545,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };

return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
return parseFunctionOp(parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
}

void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
printFunctionOp(p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}

LogicalResult FuncOp::verify() {
Expand Down Expand Up @@ -945,6 +971,8 @@ LogicalResult emitc::ArrayType::verify(
for (int64_t dim : shape) {
if (dim < 0)
return emitError() << "dimensions must have non-negative size";
if (dim == ShapedType::kDynamic)
return emitError() << "dimensions must have static size";
}

if (!elementType)
Expand Down Expand Up @@ -1029,6 +1057,12 @@ LogicalResult GlobalOp::verify() {
}
if (getInitialValue().has_value()) {
Attribute initValue = getInitialValue().value();
if (getReference() && !isa<emitc::OpaqueAttr>(initValue)) {
return emitOpError("global reference initial value must be an opaque "
"attribute, got ")
<< initValue;
}

// Check that the type of the initial value is compatible with the type of
// the global variable.
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
Expand Down Expand Up @@ -1057,6 +1091,8 @@ LogicalResult GlobalOp::verify() {
"or opaque attribute, but got ")
<< initValue;
}
} else if (getReference()) {
return emitOpError("global reference must be initialized");
}
if (getStaticSpecifier() && getExternSpecifier()) {
return emitOpError("cannot have both static and extern specifiers");
Expand Down
Loading

0 comments on commit f320c79

Please sign in to comment.