Skip to content

Commit

Permalink
Check encoding user
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 25, 2023
1 parent 065435c commit 7acb230
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

// Check if encoding user is one of batch matmul encodings.
bool isBatchMatmulEncodingUser(EncodingUser user);

struct MatmulTileParams {
int64_t M = 1;
int64_t K = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,9 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
}

/// Utility method to convert from `linalg.batch_matmul` with
/// - lhs encoding with role=LHS
/// - rhs encoding with role=RHS
/// - result encoding with role=RESULT
/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS
/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
/// - result encoding with user=BATCH_MATMUL_*, role=RESULT
/// to linalg.batch_mmt4d op.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
Expand All @@ -285,12 +285,13 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
if (lhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
rhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
resultEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) {

if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
!isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
!isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
lhsEncoding.getRole().getValue() != EncodingRole::LHS ||
rhsEncoding.getRole().getValue() != EncodingRole::RHS ||
resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
return failure();
}
Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
Expand All @@ -299,8 +300,8 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
return batchMmt4DOp;
}

/// Utility method to convert from `linalg.fill` on `tensor` type with encoding
/// to fill of the materialized type
/// Utility method to convert from `linalg.fill` on `tensor` type with
/// encoding to fill of the materialized type
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp,
ValueRange convertedInputOperands,
Expand Down Expand Up @@ -552,7 +553,8 @@ void populateMaterializeEncodingPatterns(
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {

// Add all patterns for converting from encoded type to the materialized type
// Add all patterns for converting from encoded type to the materialized
// type
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::MatmulOp>,
MaterializeDPSOperation<linalg::BatchMatmulOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@ namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = 0;
bool isBatchMatmulEncodingUser(EncodingUser user) {
switch (user) {
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
matmulDimBase = 1;
break;
return true;
default:
break;
return false;
}
}

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;

MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
Expand Down

0 comments on commit 7acb230

Please sign in to comment.