From 18c517325bdd9dc6f4d2c4b6e375820173e2ed14 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 14:01:58 -0700 Subject: [PATCH] Set default attribute values in `Convert1DConvOp` if not set PiperOrigin-RevId: 550665807 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 29 +++++++++++++++++++ .../tensorflow/transforms/legalize_hlo.cc | 18 +++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 979445dc3a8cce..7c5a8261d5fd12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1719,6 +1719,35 @@ func.func @convert_conv1d(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256 func.return %0 : tensor<16x32x256xbf16> } +// CHECK-LABEL: func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16> +// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16> +// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16> +// CHECK: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16> +// CHECK: return %[[VAL_14]] : tensor<16x32x256xbf16> +// CHECK: } +func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + padding = dense<0> : tensor<1x2xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> + func.return %0 : tensor<16x32x256xbf16> +} + // CHECK-LABEL: func.func @convert_conv1d_non_canonical_dimension_numbers( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x256xbf16>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 9a15a022947184..9cee8917a8b60f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -311,6 +311,12 @@ class Convert1DConvOp : public OpConversionPattern, RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array); // LHS dilation + // Set LHS dilation defaults if not set (1 for each input spatial dimension) + if (!conv_op.getLhsDilation().has_value()) { + conv_op.setLhsDilationAttr(rewriter.getI64TensorAttr( + std::vector(dnums.getInputSpatialDimensions().size(), 1))); + } + SmallVector lhs_dilation_array_2d; for (const auto v : conv_op.getLhsDilation().value().getValues()) { lhs_dilation_array_2d.emplace_back(v); @@ -321,6 +327,13 @@ class Convert1DConvOp : public OpConversionPattern, lhs_dilation_array_2d); // RHS dilation + // Set RHS dilation defaults if not set (1 for each kernel spatial + // dimension) + if (!conv_op.getRhsDilation().has_value()) { + conv_op.setRhsDilationAttr(rewriter.getI64TensorAttr( + std::vector(dnums.getKernelSpatialDimensions().size(), 1))); + } + SmallVector rhs_dilation_array_2d; for (const auto v : conv_op.getRhsDilation().value().getValues()) { rhs_dilation_array_2d.emplace_back(v); @@ -338,9 +351,6 @@ class Convert1DConvOp : public OpConversionPattern, RankedTensorType::get({2}, rewriter.getI64Type()), SmallVector({0, 0})); - // Precision config - if (!conv_op.getPrecisionConfig().has_value()) return failure(); - // Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC auto dnums_2d = mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(), @@ -380,7 +390,7 @@ class Convert1DConvOp : public OpConversionPattern, transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(), window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d, window_reversal_2d, dnums_2d, conv_op.getFeatureGroupCount(), - conv_op.getBatchGroupCount(), *conv_op.getPrecisionConfig()); + conv_op.getBatchGroupCount(), conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); auto conv2d_output_type = conv2d_output.getType().cast();