diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h index 700dce50e025..889e0a7ce5fc 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h @@ -15,6 +15,12 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +// Check if encoding user is one of matmul encodings. +bool isMatmulEncodingUser(EncodingUser user); + +// Check if encoding user is one of batch matmul encodings. +bool isBatchMatmulEncodingUser(EncodingUser user); + struct MatmulTileParams { int64_t M = 1; int64_t K = 1; diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index 769541bc300d..342a48397b7c 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp @@ -248,7 +248,10 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, if (!lhsEncoding || !rhsEncoding || !resultEncoding) { return failure(); } - if (lhsEncoding.getRole().getValue() != + if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) || + !isMatmulEncodingUser(rhsEncoding.getUser().getValue()) || + !isMatmulEncodingUser(resultEncoding.getUser().getValue()) || + lhsEncoding.getRole().getValue() != mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS || rhsEncoding.getRole().getValue() != mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS || @@ -262,8 +265,46 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, return mmt4DOp; } -/// Utility method to convert from `linalg.fill` on `tensor` type with encoding -/// to fill of the materialized type +/// Utility method to convert from `linalg.batch_matmul` with +/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS +/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS +/// - result encoding with user=BATCH_MATMUL_*, role=RESULT +/// to linalg.batch_mmt4d op. +static FailureOr +lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, + ValueRange convertedInputOperands, + ValueRange convertedOutputOperands, MaterializeEncodingFn, + MaterializeEncodingValueFn) { + if (!batchMatmulOp.hasTensorSemantics()) + return failure(); + auto inputs = batchMatmulOp.getDpsInputOperands(); + auto outputs = batchMatmulOp.getDpsInitOperands(); + auto lhsEncoding = + getEncodingAttr(inputs[0]->get().getType().cast()); + auto rhsEncoding = + getEncodingAttr(inputs[1]->get().getType().cast()); + auto resultEncoding = + getEncodingAttr(outputs[0]->get().getType().cast()); + if (!lhsEncoding || !rhsEncoding || !resultEncoding) { + return failure(); + } + + if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) || + !isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) || + !isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) || + lhsEncoding.getRole().getValue() != EncodingRole::LHS || + rhsEncoding.getRole().getValue() != EncodingRole::RHS || + resultEncoding.getRole().getValue() != EncodingRole::RESULT) { + return failure(); + } + Operation *batchMmt4DOp = rewriter.create( + batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(), + convertedInputOperands, convertedOutputOperands); + return batchMmt4DOp; +} + +/// Utility method to convert from `linalg.fill` on `tensor` type with +/// encoding to fill of the materialized type static FailureOr lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp, ValueRange convertedInputOperands, @@ -515,9 +556,11 @@ void populateMaterializeEncodingPatterns( MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn) { - // Add all patterns for converting from encoded type to the materialized type + // Add all patterns for converting from encoded type to the materialized + // type patterns.insert, MaterializeDPSOperation, + MaterializeDPSOperation, MaterializeOperation, SetEncodingOpToPackOpConversion, UnsetEncodingOpToPackOpConversion>( diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp index cae238d65fdc..73c141d884d3 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp @@ -11,11 +11,21 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { -MaterializeEncodingInfo -chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, - MatmulTileParams tileParams) { - // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. - int64_t matmulDimBase = 0; +bool isMatmulEncodingUser(EncodingUser user) { + switch (user) { + case EncodingUser::MATMUL_F32F32F32: + case EncodingUser::MATMUL_F16F16F32: + case EncodingUser::MATMUL_F16F16F16: + case EncodingUser::MATMUL_BF16BF16F32: + case EncodingUser::MATMUL_BF16BF16BF16: + case EncodingUser::MATMUL_I8I8I32: + return true; + default: + return false; + } +} + +bool isBatchMatmulEncodingUser(EncodingUser user) { switch (user) { case EncodingUser::BATCH_MATMUL_F32F32F32: case EncodingUser::BATCH_MATMUL_F16F16F32: @@ -23,11 +33,17 @@ chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, case EncodingUser::BATCH_MATMUL_BF16BF16F32: case EncodingUser::BATCH_MATMUL_BF16BF16BF16: case EncodingUser::BATCH_MATMUL_I8I8I32: - matmulDimBase = 1; - break; + return true; default: - break; + return false; } +} + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams) { + // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. + int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0; MaterializeEncodingInfo encodingInfo; encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir index 7f757aa1e829..d60028735f6e 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir @@ -240,3 +240,101 @@ func.func @pack_unpack_batch_matmul_result(%arg0 : tensor) -> tensor< // CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]] // CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_batch_matmul(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32x320xf32>, %arg2 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32x320xf32> -> tensor<128x32x320xf32, #iree_linalg_ext.encoding> + %2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80x320xf32> -> tensor<128x80x320xf32, #iree_linalg_ext.encoding> + %3 = linalg.batch_matmul ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding>, tensor<128x32x320xf32, #iree_linalg_ext.encoding>) + outs(%2 : tensor<128x80x320xf32, #iree_linalg_ext.encoding>) -> tensor<128x80x320xf32, #iree_linalg_ext.encoding> + %4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80x320xf32, #iree_linalg_ext.encoding> -> tensor<128x80x320xf32> + return %4 : tensor<128x80x320xf32> +} +// CHECK: func @pack_batch_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32x320xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80x320xf32> +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG1]] +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[ARG2]] +// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_batch_matmul_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = iree_linalg_ext.set_encoding %arg2 : tensor -> tensor> + %3 = linalg.batch_matmul ins(%0, %1 : tensor>, tensor>) + outs(%2 : tensor>) -> tensor> + %4 = iree_linalg_ext.unset_encoding %3 : tensor> -> tensor + return %4 : tensor +} +// CHECK: func @pack_batch_matmul_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG1]] +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[ARG2]] +// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg1, %c2 : tensor + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = tensor.empty(%d0, %d1, %d2) : tensor> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor>) + -> tensor> + %4 = linalg.batch_matmul ins(%0, %1 : tensor>, tensor>) + outs(%3 : tensor>) -> tensor> + %5 = iree_linalg_ext.unset_encoding %4 : tensor> -> tensor + return %5 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_batch_matmul_fill_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUT_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]] +// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]] +// CHECK-DAG: %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[OUT_D1]], %[[OUT_D2]]) : tensor +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]] +// CHECK: return %[[UNPACK]]