diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h index 700dce50e025..c293497814c6 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h @@ -15,6 +15,9 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +// Check if encoding user is one of batch matmul encodings. +bool isBatchMatmulEncodingUser(EncodingUser user); + struct MatmulTileParams { int64_t M = 1; int64_t K = 1; diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index dc10837ff1cd..5baa8213db13 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp @@ -263,9 +263,9 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, } /// Utility method to convert from `linalg.batch_matmul` with -/// - lhs encoding with role=LHS -/// - rhs encoding with role=RHS -/// - result encoding with role=RESULT +/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS +/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS +/// - result encoding with user=BATCH_MATMUL_*, role=RESULT /// to linalg.batch_mmt4d op. static FailureOr lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, @@ -285,12 +285,13 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, if (!lhsEncoding || !rhsEncoding || !resultEncoding) { return failure(); } - if (lhsEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS || - rhsEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS || - resultEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) { + + if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) || + !isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) || + !isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) || + lhsEncoding.getRole().getValue() != EncodingRole::LHS || + rhsEncoding.getRole().getValue() != EncodingRole::RHS || + resultEncoding.getRole().getValue() != EncodingRole::RESULT) { return failure(); } Operation *batchMmt4DOp = rewriter.create( @@ -299,8 +300,8 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, return batchMmt4DOp; } -/// Utility method to convert from `linalg.fill` on `tensor` type with encoding -/// to fill of the materialized type +/// Utility method to convert from `linalg.fill` on `tensor` type with +/// encoding to fill of the materialized type static FailureOr lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp, ValueRange convertedInputOperands, @@ -552,7 +553,8 @@ void populateMaterializeEncodingPatterns( MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn) { - // Add all patterns for converting from encoded type to the materialized type + // Add all patterns for converting from encoded type to the materialized + // type patterns.insert, MaterializeDPSOperation, MaterializeDPSOperation, diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp index cae238d65fdc..1163aa3c2912 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp @@ -11,11 +11,7 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { -MaterializeEncodingInfo -chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, - MatmulTileParams tileParams) { - // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. - int64_t matmulDimBase = 0; +bool isBatchMatmulEncodingUser(EncodingUser user) { switch (user) { case EncodingUser::BATCH_MATMUL_F32F32F32: case EncodingUser::BATCH_MATMUL_F16F16F32: @@ -23,11 +19,17 @@ chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, case EncodingUser::BATCH_MATMUL_BF16BF16F32: case EncodingUser::BATCH_MATMUL_BF16BF16BF16: case EncodingUser::BATCH_MATMUL_I8I8I32: - matmulDimBase = 1; - break; + return true; default: - break; + return false; } +} + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams) { + // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. + int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0; MaterializeEncodingInfo encodingInfo; encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};