Skip to content

Commit

Permalink
Materialize batch_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 19, 2023
1 parent cccd6ee commit 66857bb
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,43 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
return mmt4DOp;
}

/// Utility method to convert from `linalg.batch_matmul` with
/// - lhs encoding with role=LHS
/// - rhs encoding with role=RHS
/// - result encoding with role=RESULT
/// to linalg.batch_mmt4d op.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands, MaterializeEncodingFn,
MaterializeEncodingValueFn) {
if (!batchMatmulOp.hasTensorSemantics())
return failure();
auto inputs = batchMatmulOp.getDpsInputOperands();
auto outputs = batchMatmulOp.getDpsInitOperands();
auto lhsEncoding =
getEncodingAttr(inputs[0]->get().getType().cast<RankedTensorType>());
auto rhsEncoding =
getEncodingAttr(inputs[1]->get().getType().cast<RankedTensorType>());
auto resultEncoding =
getEncodingAttr(outputs[0]->get().getType().cast<RankedTensorType>());
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
if (lhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
rhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
resultEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) {
return failure();
}
Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
convertedInputOperands, convertedOutputOperands);
return batchMmt4DOp;
}

/// Utility method to convert from `linalg.fill` on `tensor` type with encoding
/// to fill of the materialized type
static FailureOr<Operation *>
Expand Down Expand Up @@ -518,6 +555,7 @@ void populateMaterializeEncodingPatterns(
// Add all patterns for converting from encoded type to the materialized type
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::MatmulOp>,
MaterializeDPSOperation<linalg::BatchMatmulOp>,
MaterializeOperation<tensor::EmptyOp>,
SetEncodingOpToPackOpConversion,
UnsetEncodingOpToPackOpConversion>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,101 @@ func.func @pack_unpack_batch_matmul_result(%arg0 : tensor<?x?x?xf32>) -> tensor<
// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32x320xf32>, %arg2 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> {
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32x320xf32> -> tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80x320xf32> -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.batch_matmul ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%2 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<128x80x320xf32>
return %4 : tensor<128x80x320xf32>
}
// CHECK: func @pack_batch_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32x320xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80x320xf32>
// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]]
// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]]
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
// CHECK: func @pack_batch_matmul_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]]
// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]]
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%d2 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>)
-> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
return %5 : tensor<?x?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK: func @pack_batch_matmul_fill_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
// CHECK-DAG: %[[OUT_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]]
// CHECK-DAG: %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[OUT_D1]], %[[OUT_D2]]) : tensor<?x?x?x8x8xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x8x8xf32>)
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

0 comments on commit 66857bb

Please sign in to comment.