Skip to content

Commit

Permalink
Support default configs for BATCH_MATMUL_* in MaterializeEncoding pass (
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu authored Aug 24, 2023
1 parent b0c77fa commit 8314b5c
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 76 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtEncodingUtils",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_cc_library(
::PassHeaders
::PassesIncGen
IREELinalgExtDialect
IREELinalgExtEncodingUtils
IREELinalgExtTransforms
IREELinalgExtUtils
LLVMSupport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Codegen/Common/CPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
Expand Down Expand Up @@ -171,7 +172,8 @@ materializeEncodingForTarget(RankedTensorType tensorType,
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(user, role, tileParams);
auto encodingInfo =
IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams);
auto originalTypeAttr = encoding.getOriginalType();
auto originalType = originalTypeAttr
? originalTypeAttr.getValue().cast<RankedTensorType>()
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,10 @@
namespace mlir {
namespace iree_compiler {

struct MatmulTileParams {
int64_t M = 1;
int64_t K = 1;
int64_t N = 1;
};

void adjustTileSizesToNarrowStaticShape(
IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape);

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingUser user,
IREE::LinalgExt::EncodingRole role,
MatmulTileParams tileParams);

IREE::LinalgExt::MaterializeEncodingValueFn
getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,52 +269,6 @@ struct MaterializeFlowDispatchTensorStoreOp

} // namespace

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = 0;
switch (user) {
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
matmulDimBase = 1;
break;
default:
break;
}

MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
break;
}
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
encodingInfo.outerDimsPerm =
llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
encodingInfo.outerDimsPerm.push_back(matmulDimBase);
break;
}
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
break;
}
default: {
assert(false);
return {};
}
}
return encodingInfo;
}

void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape) {
for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) {
Expand Down
29 changes: 23 additions & 6 deletions llvm-external-projects/iree-dialects/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,13 @@ gentbl_cc_library(

cc_library(
name = "IREELinalgExtUtils",
srcs = glob([
"lib/Dialect/LinalgExt/Utils/*.cpp",
]),
hdrs = glob([
"include/iree-dialects/Dialect/LinalgExt/Utils/*.h",
]),
srcs = [
"lib/Dialect/LinalgExt/Utils/Utils.cpp",
],
hdrs = [
"include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h",
"include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h",
],
includes = ["include"],
deps = [
"@llvm-project//llvm:Support",
Expand All @@ -331,6 +332,21 @@ cc_library(
],
)

cc_library(
name = "IREELinalgExtEncodingUtils",
srcs = [
"lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp",
],
hdrs = [
"include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h",
],
includes = ["include"],
deps = [
":IREELinalgExtDialect",
":IREELinalgExtUtils",
],
)

cc_library(
name = "IREELinalgExtDialect",
srcs = glob([
Expand Down Expand Up @@ -437,6 +453,7 @@ cc_library(
deps = [
":IREEInputDialect",
":IREELinalgExtDialect",
":IREELinalgExtEncodingUtils",
":IREELinalgExtPassIncGen",
":IREELinalgExtUtils",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

struct MatmulTileParams {
int64_t M = 1;
int64_t K = 1;
int64_t N = 1;
};

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams);

} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_library(IREELinalgExtPasses
LINK_LIBS PUBLIC
IREEInputDialect
IREELinalgExtDialect
IREELinalgExtEncodingUtils
IREELinalgExtUtils
MLIRAffineDialect
MLIRIR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -89,19 +90,23 @@ chooseEncodingInfo(RankedTensorType tensorType) {
auto encoding = getEncodingAttr(tensorType);
if (!encoding)
return failure();

auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
switch (role) {
case EncodingRole::LHS:
return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
break;
case EncodingRole::RHS:
return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
break;
case EncodingRole::RESULT:
return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
break;
default:
return failure();
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
case EncodingUser::MATMUL_I8I8I32:
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@ add_mlir_library(IREELinalgExtUtils
MLIRTensorDialect
MLIRMemRefDialect
)

add_mlir_library(IREELinalgExtEncodingUtils
EncodingUtils.cpp

LINK_LIBS PUBLIC
IREELinalgExtDialect
IREELinalgExtUtils
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = 0;
switch (user) {
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
matmulDimBase = 1;
break;
default:
break;
}

MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
break;
}
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
encodingInfo.outerDimsPerm =
llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
encodingInfo.outerDimsPerm.push_back(matmulDimBase);
break;
}
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
break;
}
default: {
assert(false);
return {};
}
}
return encodingInfo;
}

} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
Loading

0 comments on commit 8314b5c

Please sign in to comment.