From ff218c149b1066915cbdb9ac1a867c4618c4c27c Mon Sep 17 00:00:00 2001 From: Sai Kiran Yeddlapalli Ganesh Date: Wed, 10 Jan 2024 11:26:59 +0530 Subject: [PATCH] feat: add verifiers for concat, pad and sigmoid --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 15 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 139 ++++++++++++++----- mlir/test/Dialect/Tosa/invalid.mlir | 120 +++++++++++++++- 3 files changed, 236 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index df067397558ee7..7aaad3b663ed1a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -422,6 +422,8 @@ def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> { let results = (outs Tosa_Tensor:$output ); + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1423,15 +1425,14 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// // Operator: pad //===----------------------------------------------------------------------===// def Tosa_PadOp : Tosa_Op<"pad", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Pads a tensor with value specified."; let description = [{ @@ -1470,6 +1471,14 @@ def Tosa_PadOp : Tosa_Op<"pad", [ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; + } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 48dc95b3bed496..30aff016c98136 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -100,7 +100,8 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template static LogicalResult verifyConvOp(T op) { +template +static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); @@ -503,6 +504,41 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( return success(); } +LogicalResult ConcatOp::verify() { + OperandRange inputs = getInput1(); + + auto inputRank = ShapedType::kDynamic; + bool hasRankedInputs; + for (auto input : inputs) { + auto inputType = llvm::cast(input.getType()); + if (inputType.hasRank()) { + hasRankedInputs = true; + inputRank = inputType.getRank(); + break; + } + } + + if (hasRankedInputs) { + int64_t axis = getAxis(); + if (axis < 0 || axis >= std::max((int64_t)1, inputRank)) { + return emitOpError() << "axis must be in range 0 to " << inputRank - 1; + } + + for (auto input : inputs) { + auto inputType = llvm::cast(input.getType()); + if (!inputType.hasRank()) { + continue; + } + if (inputRank != inputType.getRank()) { + return emitOpError() + << "rank of input " << inputType + << " does not match other input rank(s) (" << inputRank << ")"; + } + } + } + return success(); +} + LogicalResult tosa::EqualOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, @@ -590,6 +626,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { + Type inputType = getElementTypeOrSelf(operands[0]); ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor paddingShape = operands.getShape(1); SmallVector outputShape; @@ -610,7 +647,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( } outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -618,7 +656,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( // If the paddings value is not a constant, all dimensions must be dynamic. if (!matchPattern(operands[1], m_Constant(&paddings))) { outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -638,7 +677,35 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( paddingValues[i * 2 + 1]); } - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); + return success(); +} + +LogicalResult PadOp::verify() { + ShapedType inputType = llvm::cast(getInput1().getType()); + ShapedType paddingType = llvm::cast(getPadding().getType()); + if (paddingType.hasRank()) { + if (paddingType.getRank() != 2) { + return emitOpError() << "paddings must be a tensor of rank 2"; + } + if (inputType.hasRank() && !paddingType.isDynamicDim(0) && + inputType.getRank() != paddingType.getDimSize(0)) { + return emitOpError() << "paddings must be a tensor of shape [" + << inputType.getRank() << ", 2]"; + } + if (!paddingType.isDynamicDim(1) && paddingType.getDimSize(1) != 2) { + return emitOpError() << "paddings must be a tensor of shape [" + << inputType.getRank() << ", 2]"; + } + + DenseIntElementsAttr paddings; + if (matchPattern(getPadding(), m_Constant(&paddings))) { + if (llvm::any_of(paddings, + [](auto val) { return val.getSExtValue() < 0; })) { + return emitOpError() << "number of pad elements must be positive"; + } + } + } return success(); } @@ -767,18 +834,18 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { } if ((int64_t)getNewShape().size() != outputType.getRank()) { - return emitOpError() << "rank of newShape (" << getNewShape().size() - << ") and output (" - << outputType.getRank() + return emitOpError() << "rank of newShape (" << getNewShape().size() + << ") and output (" << outputType.getRank() << ") must match"; } - for (int64_t dim=0; dim < outputType.getRank(); ++dim) { - if (getNewShape()[dim] != -1 && getNewShape()[dim] != outputType.getShape()[dim]) { - return emitOpError() << "newShape attribute (" << getNewShape()[dim] - << ") does not match output type (" - << outputType.getShape()[dim] - << ") in dimension " << dim; + for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { + if (getNewShape()[dim] != -1 && + getNewShape()[dim] != outputType.getShape()[dim]) { + return emitOpError() + << "newShape attribute (" << getNewShape()[dim] + << ") does not match output type (" << outputType.getShape()[dim] + << ") in dimension " << dim; } } } @@ -792,38 +859,34 @@ mlir::LogicalResult tosa::SliceOp::verify() { if (inputType.getRank() != outputType.getRank()) { return emitOpError() << "rank of input (" << inputType.getRank() - << ") and output (" - << outputType.getRank() - << ") must match"; + << ") and output (" << outputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != outputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and output (" - << outputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and output (" << outputType.getRank() + << ") must match"; } - for (int64_t dim=0; dim < outputType.getRank(); ++dim) { - if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && - getSize()[dim] != outputType.getShape()[dim]) { + for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { + if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && + getSize()[dim] != outputType.getShape()[dim]) { return emitOpError() << "size attribute (" << getSize()[dim] << ") does not match output type (" << outputType.getShape()[dim] << ") in dimension " << dim; - } + } } if ((int64_t)getStart().size() != inputType.getRank()) { - return emitOpError() << "rank of start (" << getStart().size() - << ") and input (" - << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of start (" << getStart().size() + << ") and input (" << inputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != inputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and input (" - << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and input (" << inputType.getRank() + << ") must match"; } for (int i = 0; i < outputType.getRank(); ++i) { @@ -1069,6 +1132,7 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) +COMPATIBLE_RETURN_TYPES(tosa::PadOp) #undef COMPATIBLE_RETURN_TYPES static LogicalResult NAryInferReturnTypes( @@ -1561,6 +1625,17 @@ LogicalResult WhileOp::inferReturnTypeComponents( return success(); } +LogicalResult SigmoidOp::verify() { + auto inputType = llvm::cast(getInput().getType()); + auto outputType = llvm::cast(getOutput().getType()); + auto result = verifyCompatibleShapes(inputType, outputType); + if (result.failed()) { + return emitOpError() << "input type " << inputType << " and output type " + << outputType << " are not compatible"; + } + return success(); +} + std::optional> ApplyScaleOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e285a9de1d66d3..72aebae31679c3 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -56,6 +56,48 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te // ----- +func.func @test_concat_output_shape_mismatch(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{inferred type(s) 'tensor<2x3xf32>' are incompatible with return type(s) of operation 'tensor<2x2xf32>}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @test_concat_output_rank_mismatch(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{inferred type(s) 'tensor<2x3xf32>' are incompatible with return type(s) of operation 'tensor}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_concat_input_rank_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2x2xf32>) -> tensor { + // expected-error@+1 {{'tosa.concat' op rank of input 'tensor<2x2x2xf32>' does not match other input rank(s) (2)}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_concat_axis_out_of_range(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { + // expected-error@+1 {{'tosa.concat' op axis must be in range 0 to 1}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = -1 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_concat_axis_out_of_range(%arg0 : tensor<10x11x12xf32>, %arg1 : tensor<10x11x21xf32>) -> tensor { + // expected-error@+1 {{'tosa.concat' op axis must be in range 0 to 2}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<10x11x12xf32>, tensor<10x11x21xf32>) -> tensor + return %0 : tensor +} + +// ----- + func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { // expected-error@+1 {{'tosa.pad' op padding of pad is not constant}} %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> @@ -64,11 +106,83 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32> // ----- -func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> tensor<13x21x3xi8> { +func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> tensor { %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> // expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}} - %1 = "tosa.pad"(%arg0, %0, %arg1) : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor) -> tensor<13x21x3xi8> - return %1 : tensor<13x21x3xi8> + %1 = "tosa.pad"(%arg0, %0, %arg1) : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor) -> tensor + return %1 : tensor +} + +// ----- + +func.func @test_pad_output_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.const"() {value = dense<[[1, 1], [1, 1], [1, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // expected-error@+2 {{'tosa.pad' op failed to infer returned types}} + // expected-error@+1 {{'tosa.pad' op inferred type(s) 'tensor<15x23x5xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3xf32>}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + return %1 : tensor<13x21x3xf32> +} + +// ----- + +func.func @test_pad_type_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xi32> { + %0 = "tosa.const"() {value = dense<[[1, 1], [1, 1], [1, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // expected-error@+2 {{'tosa.pad' op failed to infer returned types}} + // expected-error@+1 {{'tosa.pad' op inferred type(s) 'tensor<15x23x5xf32>' are incompatible with return type(s) of operation 'tensor<15x23x5xi32>}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<15x23x5xi32> + return %1 : tensor<15x23x5xi32> +} + +// ----- + +func.func @test_pad_incorret_padding_rank(%arg0: tensor<13x21xf32>) -> tensor<13x21xf32> { + %0 = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.pad' op paddings must be a tensor of rank 2}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32> + return %1 : tensor<13x21xf32> +} + +// ----- + +func.func @test_pad_incorret_padding_shape(%arg0: tensor<13x21xf32>) -> tensor<13x21xf32> { + %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1], [1, 1]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + // expected-error@+1 {{'tosa.pad' op paddings must be a tensor of shape [2, 2]}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21xf32>, tensor<4x2xi32>) -> tensor<13x21xf32> + return %1 : tensor<13x21xf32> +} + +// ----- + +func.func @test_pad_incorret_padding_shape(%arg0: tensor<13x21xf32>) -> tensor<13x21xf32> { + %0 = "tosa.const"() {value = dense<[[0, 0, 0, 1], [0, 1, 1, 1]]> : tensor<2x4xi32>} : () -> tensor<2x4xi32> + // expected-error@+1 {{'tosa.pad' op paddings must be a tensor of shape [2, 2]}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21xf32>, tensor<2x4xi32>) -> tensor<13x21xf32> + return %1 : tensor<13x21xf32> +} + +// ----- + +func.func @test_pad_negative_padding(%arg0: tensor<13x21xf32>) -> tensor { + %0 = "tosa.const"() {value = dense<[[0, 0], [0, -1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + // expected-error@+1 {{'tosa.pad' op number of pad elements must be positive}} + %1 = "tosa.pad"(%arg0, %0) : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor + return %1 : tensor +} + +// ----- + +func.func @test_sigmoid_type_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi8> { + // expected-error@+1 {{'tosa.sigmoid' op requires the same element type for all operands and results}} + %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x4xi8> + return %0 : tensor<13x21x4xi8> +} + +// ----- + +func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x4xf32> { + // expected-error@+1 {{'tosa.sigmoid' op input type 'tensor<13x21x3xf32>' and output type 'tensor<13x21x4xf32>' are not compatible}} + %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x4xf32> + return %0 : tensor<13x21x4xf32> } // -----