Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/fused-ops' into bump_to_…
Browse files Browse the repository at this point in the history
…5855237
  • Loading branch information
mgehre-amd committed Oct 1, 2024
2 parents ed14c0c + 69d08b3 commit 3c7d83a
Show file tree
Hide file tree
Showing 30 changed files with 905 additions and 34 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
Expand Down Expand Up @@ -72,6 +71,7 @@
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToEmitC/UBToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,18 @@ def TosaToTensor : Pass<"tosa-to-tensor"> {
let constructor = "tosa::createTosaToTensor()";
}

//===----------------------------------------------------------------------===//
// UBToEmitC
//===----------------------------------------------------------------------===//

def ConvertUBToEmitC : Pass<"convert-ub-to-emitc"> {
let summary = "Convert UB dialect to EmitC dialect";
let description = [{
This pass converts supported UB ops to EmitC dialect.
}];
let dependentDialects = ["emitc::EmitCDialect"];
}

//===----------------------------------------------------------------------===//
// UBToLLVM
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- UBToEmitC.h - UB to EmitC dialect conversion -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H
#define MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DECL_CONVERTUBTOEMITC
#include "mlir/Conversion/Passes.h.inc"

namespace ub {
void populateUBToEmitCConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns);
} // namespace ub
} // namespace mlir

#endif // MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

/// Give the name of the EmitC reference attribute.
StringRef getReferenceAttributeName();

} // namespace emitc
} // namespace mlir

Expand Down
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {

def EmitC_CastOp : EmitC_Op<"cast",
[CExpression,
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape]> {
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Cast operation";
let description = [{
The `emitc.cast` operation performs an explicit type conversion and is emitted
Expand All @@ -285,9 +284,11 @@ def EmitC_CastOp : EmitC_Op<"cast",
```
}];

let arguments = (ins EmitCType:$source);
let arguments = (ins EmitCType:$source,
UnitAttr:$reference);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest) (`ref` $reference^)?";
let hasVerifier = 1;
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
Expand Down Expand Up @@ -1092,14 +1093,16 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
UnitAttr:$extern_specifier,
UnitAttr:$static_specifier,
UnitAttr:$const_specifier);
UnitAttr:$const_specifier,
UnitAttr:$reference);

let assemblyFormat = [{
(`extern` $extern_specifier^)?
(`static` $static_specifier^)?
(`const` $const_specifier^)?
$sym_name
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
(`ref` $reference^)?
attr-dict
}];

Expand Down
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===---------- FunctionOpAssembly.h - Parser for `emitc.func` op ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H
#define MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H

#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LogicalResult.h"

#include "mlir/IR/Builders.h"

namespace mlir::emitc {

class FuncOp;

ParseResult
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments,
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs);

ParseResult
parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic,
StringAttr typeAttrName,
function_interface_impl::FuncTypeBuilder funcTypeBuilder,
StringAttr argAttrsName, StringAttr resAttrsName);

void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef<Type> argTypes,
bool isVariadic, ArrayRef<Type> resultTypes);

void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic,
StringRef typeAttrName, StringAttr argAttrsName,
StringAttr resAttrsName);

} // namespace mlir::emitc

#endif // MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToMLProgram)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToEmitC)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
add_subdirectory(VectorToArmSME)
Expand Down
67 changes: 66 additions & 1 deletion mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -166,6 +168,68 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};

struct ConvertCollapseShape final
: public OpConversionPattern<memref::CollapseShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");
auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

struct ConvertExpandShape final
: public OpConversionPattern<memref::ExpandShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");

auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

} // namespace

void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
Expand All @@ -187,5 +251,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
ConvertStore, ConvertCollapseShape, ConvertExpandShape>(
converter, patterns.getContext());
}
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/UBToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRUBToEmitC
UBToEmitC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/UBToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRLLVMCommonConversion
MLIREmitCDialect
MLIRUBDialect
)
79 changes: 79 additions & 0 deletions mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===- UBToEmitC.cpp - UB to EmitC dialect conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/UBToEmitC/UBToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTUBTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {
struct PoisonOpLowering : public OpConversionPattern<ub::PoisonOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const TypeConverter *converter = getTypeConverter();
Type convertedType = converter->convertType(op.getType());

if (!convertedType)
return rewriter.notifyMatchFailure(op.getLoc(), "type conversion failed");

if (!(emitc::isIntegerIndexOrOpaqueType(convertedType) ||
emitc::isSupportedFloatType(convertedType))) {
return rewriter.notifyMatchFailure(
op.getLoc(), "only scalar poison values can be lowered");
}

// Any constant will be fine to lower a poison op
rewriter.replaceOpWithNewOp<emitc::VariableOp>(
op, convertedType, emitc::OpaqueAttr::get(op->getContext(), ""));
return success();
}
};
} // namespace

void ub::populateUBToEmitCConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns.add<PoisonOpLowering>(converter, ctx);
}

struct ConvertUBToEmitC : public impl::ConvertUBToEmitCBase<ConvertUBToEmitC> {
using Base::Base;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
TypeConverter converter;
converter.addConversion([](Type t) { return t; });
populateEmitCSizeTTypeConversions(converter);

ConversionTarget target(getContext());
target.addLegalDialect<emitc::EmitCDialect>();
target.addIllegalDialect<ub::UBDialect>();

mlir::ub::populateUBToEmitCConversionPatterns(converter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
9 changes: 6 additions & 3 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,13 @@ ParseResult mlir::affine::parseDimAndSymbolList(
template <typename OpTy>
static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
unsigned numDims) {
unsigned numDims,
bool allowNonAffineDimOperands = false) {
unsigned opIt = 0;
for (auto operand : operands) {
if (opIt++ < numDims) {
if (!isValidDim(operand, getAffineScope(op)))
if (!isValidDim(operand, getAffineScope(op)) &&
!(allowNonAffineDimOperands && operand.getType().isIndex()))
return op.emitOpError("operand cannot be used as a dimension id");
} else if (!isValidSymbol(operand, getAffineScope(op))) {
return op.emitOpError("operand cannot be used as a symbol");
Expand Down Expand Up @@ -2804,7 +2806,8 @@ LogicalResult AffineIfOp::verify() {

// Verify that the operands are valid dimension/symbols.
if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
condition.getNumDims())))
condition.getNumDims(),
/*allowNonAffineDimOperands=*/true)))
return failure();

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIREmitCDialect
EmitC.cpp
FunctionOpAssembly.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC
Expand Down
Loading

0 comments on commit 3c7d83a

Please sign in to comment.