Skip to content

Commit

Permalink
Set default attribute values in Convert1DConvOp if not set
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550665807
  • Loading branch information
tensorflower-gardener committed Jul 24, 2023
1 parent c58073a commit 18c5173
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
29 changes: 29 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,35 @@ func.func @convert_conv1d(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256
func.return %0 : tensor<16x32x256xbf16>
}

// CHECK-LABEL: func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
// CHECK: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK: %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
// CHECK: return %[[VAL_14]] : tensor<16x32x256xbf16>
// CHECK: }
func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
feature_group_count = 1 : i64,
padding = dense<0> : tensor<1x2xi64>,
window_strides = dense<1> : tensor<1xi64>
} : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16>
func.return %0 : tensor<16x32x256xbf16>
}

// CHECK-LABEL: func.func @convert_conv1d_non_canonical_dimension_numbers(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x256xbf16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> {
Expand Down
18 changes: 14 additions & 4 deletions tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array);

// LHS dilation
// Set LHS dilation defaults if not set (1 for each input spatial dimension)
if (!conv_op.getLhsDilation().has_value()) {
conv_op.setLhsDilationAttr(rewriter.getI64TensorAttr(
std::vector<int64_t>(dnums.getInputSpatialDimensions().size(), 1)));
}

SmallVector<int64_t, 4> lhs_dilation_array_2d;
for (const auto v : conv_op.getLhsDilation().value().getValues<int64_t>()) {
lhs_dilation_array_2d.emplace_back(v);
Expand All @@ -321,6 +327,13 @@ class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
lhs_dilation_array_2d);

// RHS dilation
// Set RHS dilation defaults if not set (1 for each kernel spatial
// dimension)
if (!conv_op.getRhsDilation().has_value()) {
conv_op.setRhsDilationAttr(rewriter.getI64TensorAttr(
std::vector<int64_t>(dnums.getKernelSpatialDimensions().size(), 1)));
}

SmallVector<int64_t, 4> rhs_dilation_array_2d;
for (const auto v : conv_op.getRhsDilation().value().getValues<int64_t>()) {
rhs_dilation_array_2d.emplace_back(v);
Expand All @@ -338,9 +351,6 @@ class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
RankedTensorType::get({2}, rewriter.getI64Type()),
SmallVector<int64_t>({0, 0}));

// Precision config
if (!conv_op.getPrecisionConfig().has_value()) return failure();

// Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC
auto dnums_2d =
mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(),
Expand Down Expand Up @@ -380,7 +390,7 @@ class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(),
window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d,
window_reversal_2d, dnums_2d, conv_op.getFeatureGroupCount(),
conv_op.getBatchGroupCount(), *conv_op.getPrecisionConfig());
conv_op.getBatchGroupCount(), conv_op.getPrecisionConfigAttr());

OpResult conv2d_output = conv2d_op->getResult(0);
auto conv2d_output_type = conv2d_output.getType().cast<ShapedType>();
Expand Down

0 comments on commit 18c5173

Please sign in to comment.