diff --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h index e7158ee3852e18..1d651e394b897d 100644 --- a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h +++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h @@ -16,6 +16,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class TypeConverter; #define GEN_PASS_DECL_TOSATOARITH #include "mlir/Conversion/Passes.h.inc" @@ -25,7 +26,8 @@ namespace tosa { std::unique_ptr createTosaToArith(bool includeApplyRescale = false, bool use32BitApplyRescale = false); -void populateTosaToArithConversionPatterns(RewritePatternSet *patterns); +void populateTosaToArithConversionPatterns(TypeConverter &converter, + RewritePatternSet *patterns); void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit = false); diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index c864a8aeacbd5b..d3ea818da6c296 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -303,6 +303,29 @@ class CmpIOpConversion : public OpConversionPattern { } }; +class NegFOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto adaptedOp = adaptor.getOperand(); + auto adaptedOpType = adaptedOp.getType(); + + if (!isa(adaptedOpType)) { + return rewriter.notifyMatchFailure(op.getLoc(), + "negf currently only supported on " + "floats, not tensors/vectors thereof"); + } + + rewriter.replaceOpWithNewOp(op, adaptedOpType, + adaptedOp); + return success(); + } +}; + template class CastConversion : public OpConversionPattern { public: @@ -716,6 +739,7 @@ void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns, UnsignedShiftOpConversion, CmpFOpConversion, CmpIOpConversion, + NegFOpConversion, SelectOpConversion, // Truncation is guaranteed for unsigned types. UnsignedCastConversion, diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 50e57682a2dc8d..fd42d1d444420a 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -22,13 +23,23 @@ using namespace tosa; namespace { -class ConstOpConverter : public OpRewritePattern { +class ConstOpConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tosa::ConstOp op, - PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, op.getValue()); + LogicalResult matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + auto elements = dyn_cast(adaptor.getValue()); + if (!elements) { + return rewriter.notifyMatchFailure(op, "expected dense elements attr"); + } + + auto convertedElTy = getTypeConverter()->convertType(elements.getElementType()); + if (!convertedElTy) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); + } + rewriter.replaceOpWithNewOp(op, elements.bitcast(convertedElTy)); return success(); } }; @@ -238,9 +249,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { } // namespace -void mlir::tosa::populateTosaToArithConversionPatterns( +void mlir::tosa::populateTosaToArithConversionPatterns(TypeConverter &converter, RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add(converter, patterns->getContext()); } void mlir::tosa::populateTosaRescaleToArithConversionPatterns( diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp index de82c0335c985d..ff3f923b71fbfd 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp @@ -19,6 +19,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" namespace mlir { #define GEN_PASS_DEF_TOSATOARITH @@ -34,12 +35,15 @@ struct TosaToArith : public impl::TosaToArithBase { TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {} void runOnOperation() override { + TypeConverter converter; + mlir::tosa::populateTosaToLinalgTypeConversion(converter); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addIllegalOp(); target.addLegalDialect(); - mlir::tosa::populateTosaToArithConversionPatterns(&patterns); + mlir::tosa::populateTosaToArithConversionPatterns(converter, &patterns); if (this->includeApplyRescale) { mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 38abad1b229854..0a7855b4b888b9 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -81,6 +81,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { // ----- +func.func @arith_negf_tensor(%arg0: tensor<5xf32>) -> tensor<5xf32> { + // expected-error @+1 {{failed to legalize operation 'arith.negf'}} + %n = arith.negf %arg0 : tensor<5xf32> + return %n: tensor<5xf32> +} + +// ----- + +func.func @arith_negf_vector(%arg0: vector<5xf32>) -> vector<5xf32> { + // expected-error @+1 {{failed to legalize operation 'arith.negf'}} + %n = arith.negf %arg0 : vector<5xf32> + return %n: vector<5xf32> +} + +// ----- + func.func @arith_shli_i1(%arg0: i1, %arg1: i1) { // expected-error @+1 {{failed to legalize operation 'arith.shli'}} %shli = arith.shli %arg0, %arg1 : i1 diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 7a6dc7bfc5809c..dd24d4d81ca8fe 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -560,6 +560,16 @@ func.func @arith_cmpi_index(%arg0: i32, %arg1: i32) -> i1 { return %slt: i1 } +// ----- + +func.func @arith_negf(%arg0: f32) -> f32 { + // CHECK-LABEL: arith_negf + // CHECK-SAME: %[[Arg0:[^ ]*]]: f32 + // CHECK: %[[N:[^ ]*]] = emitc.unary_minus %[[Arg0]] : (f32) -> f32 + %n = arith.negf %arg0 : f32 + // CHECK: return %[[N]] + return %n: f32 +} // ----- diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir index c4f82d53af9822..63d1423ea3ad6d 100644 --- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir +++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir @@ -2,12 +2,16 @@ // RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s // CHECK-LABEL: func @const_test -func.func @const_test() -> (tensor) { - // CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor - %result = "tosa.const"() {value = dense<3> : tensor} : () -> tensor +func.func @const_test() -> (tensor, tensor) { + // CHECK: %[[CI32:.+]] = arith.constant dense<3> : tensor + %i32 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor - // CHECK: return [[C3]] - return %result : tensor + // CHECK: %[[CUI32:.+]] = arith.constant dense<3> : tensor + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CUI32]] : tensor to tensor + %ui32 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: return %[[CI32]], %[[CAST]] + return %i32, %ui32 : tensor, tensor } // -----