Skip to content

Commit

Permalink
[LinalgExt] Moving encoding utils to EncodingAttr builtin or LinalgEx…
Browse files Browse the repository at this point in the history
…t/IR (#17053)

We can not move the utils to LinalgExt/Utils because of circular
dependency issue. The LinalgExt/IR depends on LinalgExt/Utils -- which
needs to be fixed. For now, we put it to LinalgExt/IR which is similar
to tensor dialect. E.g., upstream puts `getOrCreateRanges` to
`Tensor/IR/Tensor.h`.

It also switch the file to not use using-directives (e.g., `using
namespace foo`) because it encounters a linking issue. There are many
reasons to avoid using-directives, see
https://google.github.io/styleguide/cppguide.html#Namespaces for more
details.
  • Loading branch information
hanhanW authored Apr 17, 2024
1 parent c83f9ba commit fdfe344
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 67 deletions.
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ iree_cc_library(
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::LinalgExt::Utils
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ getFlagForUserAndOperandTypes(IREE::LinalgExt::EncodingAttr encoding,
ArrayRef<Attribute> operandTypes) {
// There are currently no batch_mmt4d ukernels, so check for no batch
// dimension.
auto cDims = getEncodingContractionDims(encoding);
auto cDims = IREE::LinalgExt::getEncodingContractionDims(encoding);
if (failed(cDims) || !cDims->batch.empty() || operandTypes.size() != 3) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_NONE;
}
Expand Down
53 changes: 10 additions & 43 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand All @@ -13,6 +14,8 @@ namespace mlir::iree_compiler {

using IREE::LinalgExt::EncodingAttr;
using IREE::LinalgExt::EncodingRole;
using IREE::LinalgExt::getEncodingAttr;
using IREE::LinalgExt::getEncodingContractionDims;

/// For a given tensor type with an encoding, return the materialized
/// type to use for it. If no encoding is set, then return the tensor type
Expand Down Expand Up @@ -67,42 +70,6 @@ MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
});
}

EncodingAttr getEncodingAttr(RankedTensorType type) {
return type.getEncoding().dyn_cast_or_null<EncodingAttr>();
}

static AffineMap getMapForRole(EncodingAttr encoding) {
EncodingRole role = encoding.getRole().getValue();
if (role == EncodingRole::LHS)
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[0])
.getAffineMap();
else if (role == EncodingRole::RHS)
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[1])
.getAffineMap();
else
return cast<AffineMapAttr>(encoding.getUserIndexingMaps()[2])
.getAffineMap();
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

/// Given the dim position of the encoding `user_indexing_maps`, return the
/// matching index of the given encoding's tensor
static unsigned mapDimToRoleIndex(int64_t dimPos, EncodingAttr encoding) {
AffineMap map = getMapForRole(encoding);
auto idx = map.getResultPosition(getAffineDimExpr(dimPos, map.getContext()));
assert(idx.has_value());
return idx.value();
}

RankedTensorType getOriginalTypeWithEncoding(RankedTensorType type) {
auto encoding = getEncodingAttr(type);
if (!encoding) {
Expand Down Expand Up @@ -136,27 +103,27 @@ MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
"Expected at most one M, N, K, and Batch dimension");
if (!cDims->batch.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->batch[0], encoding));
encoding.mapDimToRoleIndex(cDims->batch[0]));
}
if (role != EncodingRole::RHS && !cDims->m.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->m[0], encoding));
encoding.mapDimToRoleIndex(cDims->m[0]));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->m[0], encoding));
encoding.mapDimToRoleIndex(cDims->m[0]));
encodingInfo.innerTileSizes.push_back(tileMxNxK.M);
}
if (role != EncodingRole::LHS && !cDims->n.empty()) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->n[0], encoding));
encoding.mapDimToRoleIndex(cDims->n[0]));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->n[0], encoding));
encoding.mapDimToRoleIndex(cDims->n[0]));
encodingInfo.innerTileSizes.push_back(tileMxNxK.N);
}
if (role != EncodingRole::RESULT) {
encodingInfo.outerDimsPerm.push_back(
mapDimToRoleIndex(cDims->k[0], encoding));
encoding.mapDimToRoleIndex(cDims->k[0]));
encodingInfo.innerDimsPos.push_back(
mapDimToRoleIndex(cDims->k[0], encoding));
encoding.mapDimToRoleIndex(cDims->k[0]));
encodingInfo.innerTileSizes.push_back(tileMxNxK.K);
}
return encodingInfo;
Expand Down
9 changes: 0 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::iree_compiler {
Expand Down Expand Up @@ -73,14 +72,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
// Utility methods about Encoding.
//===---------------------------------------------------------------------===//

/// Returns the encoding attribute from the type if there is an encoding.
/// Otherwise, returns null.
IREE::LinalgExt::EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(IREE::LinalgExt::EncodingAttr encoding);

/// Returns the original type that carried by encoding.
RankedTensorType getOriginalTypeWithEncoding(RankedTensorType type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace mlir::iree_compiler {

using namespace IREE::LinalgExt;
using IREE::HAL::ExecutableTargetAttr;
using IREE::LinalgExt::getEncodingAttr;

//===---------------------------------------------------------------------===//
// Utility methods
Expand Down Expand Up @@ -350,7 +351,7 @@ lowerContractionOpWithEncoding(RewriterBase &rewriter,

Type newResultType = newResult.getType();

auto cDims = getEncodingContractionDims(lhsEncoding);
auto cDims = IREE::LinalgExt::getEncodingContractionDims(lhsEncoding);
if (cDims->batch.empty()) {
result = rewriter.create<linalg::Mmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def EncodingAttr :
CArg<"ArrayRef<AffineMap>", "{}">:$maps)>
];

let extraClassDeclaration = [{
/// Returns the indexing map used by the role in the encoding.
AffineMap getMapForRole();

/// Given the dim position of the encoding `user_indexing_maps`, returns the
/// matching index of the given encoding's tensor.
unsigned mapDimToRoleIndex(int64_t dimPos);
}];

let genVerifyDecl = 0;
}

Expand Down
61 changes: 50 additions & 11 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"

using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
namespace IREE = mlir::iree_compiler::IREE;
namespace mlir::iree_compiler::IREE::LinalgExt {

//===----------------------------------------------------------------------===//
// Utils.
Expand Down Expand Up @@ -1743,11 +1741,11 @@ LogicalResult PackOp::verify() {
}

SmallVector<OpFoldResult> PackOp::getMixedTiles() {
return ::getMixedTiles(*this);
return LinalgExt::getMixedTiles(*this);
}

SmallVector<int64_t> PackOp::getStaticTiles() {
return ::getStaticTiles(*this);
return LinalgExt::getStaticTiles(*this);
}

// Helper for PackOp::{getResultShape,getPackedType}. Returns the shape of the
Expand Down Expand Up @@ -1841,11 +1839,11 @@ ShapedType PackOp::getPackedType(ShapedType sourceType,
}

DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
return ::getDimAndTileMapping(*this);
return LinalgExt::getDimAndTileMapping(*this);
}

SmallVector<Range> PackOp::getIterationDomain(OpBuilder &builder) {
return ::getIterationDomain(*this, builder);
return LinalgExt::getIterationDomain(*this, builder);
}

/// Generate the body of the innermost loop of the scalar implementation
Expand Down Expand Up @@ -2011,15 +2009,15 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
}

SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
return ::getMixedTiles(*this);
return LinalgExt::getMixedTiles(*this);
}

SmallVector<int64_t> UnPackOp::getStaticTiles() {
return ::getStaticTiles(*this);
return LinalgExt::getStaticTiles(*this);
}

DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
return ::getDimAndTileMapping(*this);
return LinalgExt::getDimAndTileMapping(*this);
}

LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder,
Expand Down Expand Up @@ -2089,7 +2087,7 @@ UnPackOp::reifyResultShapes(OpBuilder &builder,
}

SmallVector<Range> UnPackOp::getIterationDomain(OpBuilder &builder) {
return ::getIterationDomain(*this, builder);
return LinalgExt::getIterationDomain(*this, builder);
}

LogicalResult UnPackOp::verify() {
Expand Down Expand Up @@ -2797,6 +2795,47 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, EncodingRole role,
b.getAffineMapArrayAttr(maps));
}

AffineMap EncodingAttr::getMapForRole() {
EncodingRole role = getRole().getValue();
switch (role) {
case EncodingRole::LHS:
return getUserIndexingMaps()[0].cast<AffineMapAttr>().getAffineMap();
case EncodingRole::RHS:
return getUserIndexingMaps()[1].cast<AffineMapAttr>().getAffineMap();
case EncodingRole::RESULT:
return getUserIndexingMaps()[2].cast<AffineMapAttr>().getAffineMap();
default:
return AffineMap();
}
}

unsigned EncodingAttr::mapDimToRoleIndex(int64_t dimPos) {
AffineMap map = getMapForRole();
auto idx = map.getResultPosition(getAffineDimExpr(dimPos, getContext()));
assert(idx.has_value());
return idx.value();
}

//===---------------------------------------------------------------------===//
// LinalgExt Dialect Helpers
//===---------------------------------------------------------------------===//

EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt

// clang-format off
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" // IWYU pragma: keep
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand All @@ -29,4 +30,19 @@

// clang-format on

//===---------------------------------------------------------------------===//
// LinalgExt Dialect Helpers
//===---------------------------------------------------------------------===//

namespace mlir::iree_compiler::IREE::LinalgExt {

/// Returns the encoding attribute from the type if there is an encoding.
/// Otherwise, returns null.
EncodingAttr getEncodingAttr(RankedTensorType type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding);
} // namespace mlir::iree_compiler::IREE::LinalgExt

#endif // IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_

0 comments on commit fdfe344

Please sign in to comment.