From 37263b6c6741894ffbc0f61979c5c85db515ef2d Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Wed, 4 Sep 2024 10:24:17 +0800 Subject: [PATCH] [mlir][tosa] Add verifier for `tosa.pad` (#106351) This patch adds verifier to `tosa.pad` which fixes a crash. `tosa.pad` expect: - same input and output tensor rank. - 'padding' tensor rank equal to 2. Fix #106168. --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 14 +++++++++++ mlir/test/Dialect/Tosa/invalid.mlir | 25 ++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0be0f8ef2d7a0c5..1a132e73be86458 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1594,6 +1594,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 267a875710ed71d..d93db1b237f3164 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -817,6 +817,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( return success(); } +LogicalResult tosa::PadOp::verify() { + RankedTensorType inputType = getInput1().getType(); + RankedTensorType outputType = getOutput().getType(); + TensorType paddingType = getPadding().getType(); + + if (inputType.getRank() != outputType.getRank()) + return emitOpError() << "expect same input and output tensor rank."; + + if (paddingType.hasRank() && paddingType.getRank() != 2) + return emitOpError() << "expect 'padding' tensor rank equal to 2."; + + return success(); +} + static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { return dim == -1 ? ShapedType::kDynamic : dim; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e72e154f952771c..418f7687b3cce86 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -72,6 +72,31 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> t // ----- +func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) { + // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}} + %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32> + return +} + +// ----- + +func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) { + // expected-error@+1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}} + %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32> + return +} + +// ----- + +func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) { + %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}} + %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32> + return +} + +// ----- + func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> { // expected-error@+1 {{'tosa.transpose' op perms of transpose is not constant}} %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>