Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 10, 2023
1 parent 08a90e6 commit 94ddb87
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ struct ConvertBatchMmt4DtoMmt4DPattern

struct DecomposeBatchMmt4DOpsPass
: public DecomposeBatchMmt4DOpsBase<DecomposeBatchMmt4DOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect,
scf::SCFDialect, tensor::TensorDialect>();
}

void runOnOperation() override;
};

Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
"decompose_batch_mmt4d_ops.mlir",
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
"eliminate_empty_tensors.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"convert_to_destination_passing_style.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_batch_mmt4d_ops.mlir"
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
"eliminate_empty_tensors.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// RUN: iree-opt --iree-codegen-decompose-batch-mmt4d-ops --split-input-file %s | FileCheck %s

func.func @batch_mmt4d_with_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%0 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
return %1 : tensor<128x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32>
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]])
// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32>
// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_no_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
return %1 : tensor<128x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_no_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32>
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]])
// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32>
// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
return %1 : tensor<1x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_unit_batch
// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32>
// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32>
// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32>

0 comments on commit 94ddb87

Please sign in to comment.