diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 47e2571e2bb13d..b5fb0bb0f75ea7 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1071,3 +1071,67 @@ func.func @mirrorpad_reflect(%arg0: tensor<13x21x3xf32>) -> tensor<14x22x4xf32> %1 = "tf.Identity"(%0) {device = ""} : (tensor<14x22x4xf32>) -> tensor<14x22x4xf32> return %0 : tensor<14x22x4xf32> } + +// ----- + +// CHECK-LABEL: test_broadcast_to_f32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> +func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf32>) { + %shape = "tf.Const"() {value = dense<[3, 3, 1, 7]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xf32>, tensor<4xi32>) -> tensor<3x3x13x7xf32> + return %1 : tensor<3x3x13x7xf32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> +func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi32>) { + %shape = "tf.Const"() {value = dense<[7, 7, 13, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi32>, tensor<4xi32>) -> tensor<3x3x13x3xi32> + return %1 : tensor<3x3x13x3xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_or"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> +func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) { + %shape = "tf.Const"() {value = dense<[7, 7, 13, 7]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi1>, tensor<4xi32>) -> tensor<7x7x13x7xi1> + return %1 : tensor<7x7x13x7xi1> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i16 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi16>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi16>, tensor<7x7x13x3xi16>) -> tensor<7x7x13x3xi16> +// CHECK: return %[[VAL_3]] : tensor<7x7x13x3xi16> +func.func @test_broadcast_to_i16(%arg0: tensor<13x1xi16>) -> (tensor<7x7x13x3xi16>) { + %shape = "tf.Const"() {value = dense<[7, 7, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi16>, tensor<4xi32>) -> tensor<7x7x13x3xi16> + return %1 : tensor<7x7x13x3xi16> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_smaller_rank +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[13, 7]> : tensor<2xi32>} +// CHECK: %[[VAL_1:.*]] = "tf.BroadcastTo"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32> +// CHECK: return %[[VAL_1]] : tensor<13x7xi32> +func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { + %s = "tf.Const"() {value = dense<[13, 7]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.BroadcastTo"(%arg0, %s) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32> + return %1 : tensor<13x7xi32> +} diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 68e68345ce7436..6105416b325063 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2804,3 +2804,94 @@ func.func @test_squared_difference_f32(%arg0: tensor<1x197x768xf32>, %arg1: tens %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768xf32>, tensor<1x197x1xf32>) -> tensor<1x197x768xf32> func.return %0 : tensor<1x197x768xf32> } + +// ----- + +// CHECK-LABEL: test_broadcast_to_f32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> +func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf32>) { + %shape = arith.constant dense<[3, 3, 1, 7]> : tensor<4xi32> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xf32>, tensor<4xi32>) -> tensor<3x3x13x7xf32> + return %1 : tensor<3x3x13x7xf32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_f16 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf16>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf16>, tensor<3x3x13x7xf16>) -> tensor<3x3x13x7xf16> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf16> +func.func @test_broadcast_to_f16(%arg0: tensor<13x1xf16>) -> (tensor<3x3x13x7xf16>) { + %shape = arith.constant dense<[3, 3, 1, 7]> : tensor<4xi32> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xf16>, tensor<4xi32>) -> tensor<3x3x13x7xf16> + return %1 : tensor<3x3x13x7xf16> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> +func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi32>) { + %shape = arith.constant dense<[7, 7, 13, 3]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xi32>, tensor<4xi64>) -> tensor<3x3x13x3xi32> + return %1 : tensor<3x3x13x3xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_or"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> +func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) { + %shape = arith.constant dense<[7, 7, 13, 7]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xi1>, tensor<4xi64>) -> tensor<7x7x13x7xi1> + return %1 : tensor<7x7x13x7xi1> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_qi8 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%1) : (tensor<1x1x13x1x!quant.uniform>) -> tensor<1x1x13x1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%3) : (tensor<7x7x13x3xi32>) -> tensor<7x7x13x3x!quant.uniform> +// CHECK: return %[[VAL_4]] : tensor<7x7x13x3x!quant.uniform> +func.func @test_broadcast_to_qi8(%arg0: tensor<13x1x!quant.uniform>) -> (tensor<7x7x13x3x!quant.uniform>) { + %shape = arith.constant dense<[7, 7, 1, 3]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1x!quant.uniform>, tensor<4xi64>) -> tensor<7x7x13x3x!quant.uniform> + return %1 : tensor<7x7x13x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_smaller_rank +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[13, 7]> : tensor<2xi48>} +// CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi48>) -> tensor<13x7xi32> +// CHECK: return %[[VAL_1]] : tensor<13x7xi32> +func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { + %shape = arith.constant dense<[13, 7]> : tensor<2xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<2x3x13x1xi32>, tensor<2xi64>) -> tensor<13x7xi32> + return %1 : tensor<13x7xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i48 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[7, 7, 1, 7]> : tensor<4xi48>} +// CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<1x1x13x1xi48>, tensor<4xi48>) -> tensor<7x7x13x7xi48> +// CHECK: return %[[VAL_1]] : tensor<7x7x13x7xi48> +func.func @test_broadcast_to_i48(%arg0: tensor<1x1x13x1xi48>) -> (tensor<7x7x13x7xi48>) { + %shape = arith.constant dense<[7, 7, 1, 7]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<1x1x13x1xi48>, tensor<4xi64>) -> tensor<7x7x13x7xi48> + return %1 : tensor<7x7x13x7xi48> +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index d8785910105ea2..848fd4bd40de0a 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -39,6 +39,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project @@ -4610,5 +4611,158 @@ std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +// Lowers BroadcastTo operator to a sequence of TOSA ops. +std::optional convertBroadcastToOp(PatternRewriter& rewriter, + Operation* op, Value input, + Value shape) { + RankedTensorType input_type = dyn_cast(input.getType()); + if (!input_type) { + (void)rewriter.notifyMatchFailure(op, "input type not ranked tensor"); + return std::nullopt; + } + + Type element_type = input_type.getElementType(); + if (element_type.isa()) { + (void)rewriter.notifyMatchFailure(op, "input element type is complex"); + return std::nullopt; + } + + if (isa(element_type)) { + auto bitwidth = element_type.getIntOrFloatBitWidth(); + if (bitwidth > 32) { + (void)rewriter.notifyMatchFailure( + op, "input element type has greater than 32 bits"); + return std::nullopt; + } + } + + ElementsAttr shape_elems; + if (!matchPattern(shape, m_Constant(&shape_elems))) { + (void)rewriter.notifyMatchFailure(op, "shape is not constant"); + return std::nullopt; + } + int input_rank = input_type.getRank(); + int shape_rank = shape_elems.getNumElements(); + + if (auto shape_type = dyn_cast(shape.getType())) { + if (shape_type.hasStaticShape()) { + assert(shape_type.getRank() == 1); + if (!shape_type.isDynamicDim(0) && + shape_rank != shape_type.getDimSize(0)) { + // shape_elems and shape's type's 'are different + // this is not supported for now + (void)rewriter.notifyMatchFailure( + op, + "shape's constant value has different elements than its static " + "dimension"); + return std::nullopt; + } + } + } + + if (input_rank > shape_rank) { + // not clear what to do in this case, bail for now + (void)rewriter.notifyMatchFailure(op, "shape has less rank than input"); + return std::nullopt; + } + + // equalize new_rank and input_rank + if (input_rank < shape_rank) { + // reshape input to shape_rank + SmallVector reshaped_shape((shape_rank - input_rank), 1); + for (auto dim : input_type.getShape()) { + reshaped_shape.push_back(dim); + } + input_type = + tensorflow::GetTypeFromTFTensorShape(reshaped_shape, element_type); + input = CreateOpAndInfer( + rewriter, op->getLoc(), input_type, input, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshaped_shape))); + } + + auto input_shape = input_type.getShape(); + assert(input_shape.size() == shape_rank); // should be equal ranks by now + + // construct new_shape as broadcasted shape of input_shape and shape_elems + int32_t num_elements = 1; + SmallVector new_shape; + for (int i = 0; i < shape_rank; i++) { + auto shape_dim = shape_elems.getValues()[i].getInt(); + auto input_dim = input_shape[i]; + if (shape_dim != input_dim && std::min(shape_dim, input_dim) != 1) { + // shape_dim and input_dim are different, but the lower value is not 1 + // this is not broadcastable + (void)rewriter.notifyMatchFailure( + op, "input and shape are not broadcastable"); + return std::nullopt; + } + auto dim = std::max(shape_dim, input_dim); + new_shape.push_back(dim); + num_elements *= dim; + } + + RankedTensorType output_type = + tensorflow::GetTypeFromTFTensorShape(new_shape, element_type); + + if (element_type.isa()) { + // F32: legalize to broadcastable Add with (0.f) + auto const_attr = + DenseElementsAttr::get(output_type, rewriter.getZeroAttr(element_type)); + Value f32_const_zero = + rewriter.create(op->getLoc(), output_type, const_attr); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + input, f32_const_zero) + .getResult(); + } + + if (element_type.isInteger(1)) { + // I1: legalize to broadcastable LogicalOr with false + auto const_attr = + DenseElementsAttr::get(output_type, rewriter.getZeroAttr(element_type)); + Value i1_const_zero = + rewriter.create(op->getLoc(), output_type, const_attr); + return CreateOpAndInfer( + rewriter, op->getLoc(), output_type, input, i1_const_zero) + .getResult(); + } + + if (isa(element_type)) { + RankedTensorType cast_shaped_type = output_type.clone(element_type); + auto const_attr = DenseElementsAttr::get( + cast_shaped_type, rewriter.getZeroAttr(element_type)); + Value const_zero = rewriter.create( + op->getLoc(), cast_shaped_type, const_attr); + // I32: legalize to broadcastable Add with 0 + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + input, const_zero) + .getResult(); + } + + if (auto quant_ty = dyn_cast(element_type)) { + auto cast_type = rewriter.getI32Type(); + RankedTensorType cast_shaped_type = output_type.clone(cast_type); + auto const_attr = DenseElementsAttr::get(cast_shaped_type, + rewriter.getZeroAttr(cast_type)); + Value const_zero = rewriter.create( + op->getLoc(), cast_shaped_type, const_attr); + + // for any other non-float element type: + // cast input to the storage type, perform an add 0, then cast back. + Value input_cast = CreateOpAndInfer( + rewriter, op->getLoc(), + /* I32 input type */ input_type.clone(cast_type), input); + Value add_const = CreateOpAndInfer( + rewriter, op->getLoc(), output_type.clone(cast_type), input_cast, + const_zero); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + add_const) + .getResult(); + } + + (void)rewriter.notifyMatchFailure(op, "Unsupported element type"); + return std::nullopt; +} + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 3dc87952753583..52073400b079f2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -306,6 +306,11 @@ std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type); +// Lowers BroadcastTo operator to a sequence of TOSA ops. +std::optional convertBroadcastToOp(PatternRewriter& rewriter, + Operation* op, Value input, + Value shape); + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 8b20350e5f5bef..313f06961140ce 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -151,6 +151,7 @@ DECL_CONVERT_OP(LeftShift); DECL_CONVERT_OP(RightShift); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(BatchMatMulV2); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP LogicalResult ConvertTFReluOp::matchAndRewrite( @@ -2421,6 +2422,21 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( return success(); } +LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tf_broadcast_to_op = cast(op); + + std::optional result = + convertBroadcastToOp(rewriter, op, tf_broadcast_to_op.getInput(), + tf_broadcast_to_op.getShape()); + + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + void LegalizeTF::runOnOperation() { auto* ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -2523,6 +2539,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); } // Creates an instance of the TensorFlow dialect LegalizeTF pass. diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 8863a3d0df4031..bd53b5a3acd974 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -193,6 +193,7 @@ DECL_CONVERT_OP(While); DECL_CONVERT_OP(Real); DECL_CONVERT_OP(Imag); DECL_CONVERT_OP(RFFT2d); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -4478,6 +4479,21 @@ LogicalResult ConvertTFLRFFT2dOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_broadcast_to_op = cast(op); + + std::optional result = + convertBroadcastToOp(rewriter, op, tfl_broadcast_to_op.getInput(), + tfl_broadcast_to_op.getShape()); + + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); @@ -4615,6 +4631,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLReal); DEF_PATTERN_INSERT(TFLImag); DEF_PATTERN_INSERT(TFLRFFT2d); + DEF_PATTERN_INSERT(TFLBroadcastTo); } // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.