diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h index c293497814c6..889e0a7ce5fc 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h @@ -15,6 +15,9 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +// Check if encoding user is one of matmul encodings. +bool isMatmulEncodingUser(EncodingUser user); + // Check if encoding user is one of batch matmul encodings. bool isBatchMatmulEncodingUser(EncodingUser user); diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index 5baa8213db13..342a48397b7c 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp @@ -248,7 +248,10 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, if (!lhsEncoding || !rhsEncoding || !resultEncoding) { return failure(); } - if (lhsEncoding.getRole().getValue() != + if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) || + !isMatmulEncodingUser(rhsEncoding.getUser().getValue()) || + !isMatmulEncodingUser(resultEncoding.getUser().getValue()) || + lhsEncoding.getRole().getValue() != mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS || rhsEncoding.getRole().getValue() != mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS || diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp index 1163aa3c2912..73c141d884d3 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp @@ -11,6 +11,20 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +bool isMatmulEncodingUser(EncodingUser user) { + switch (user) { + case EncodingUser::MATMUL_F32F32F32: + case EncodingUser::MATMUL_F16F16F32: + case EncodingUser::MATMUL_F16F16F16: + case EncodingUser::MATMUL_BF16BF16F32: + case EncodingUser::MATMUL_BF16BF16BF16: + case EncodingUser::MATMUL_I8I8I32: + return true; + default: + return false; + } +} + bool isBatchMatmulEncodingUser(EncodingUser user) { switch (user) { case EncodingUser::BATCH_MATMUL_F32F32F32: