From 9e091155b8b4147a523352e954e161b655399d80 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 27 Sep 2024 21:14:49 -0400 Subject: [PATCH] Simplifications around narrow dimensions in encodings. (#18607) * Drop the `kNarrowThreshold` constant, relying instead on the default padding value. * When reading an `encoding` attribute to tell if a `round_dims_to` entry should be considered narrow, rely on the fact that we only ever need one narrowest dimension in a given matmul to be considered narrow, so the smallest `round_dims_to` entry is the narrow one; if all `round_dims_to` entries are equal, the matmul is not narrow. * Introduce a `MatmulNarrowDim` struct to unify helpers and group them in `EncodingOps.{h,cpp}`. * This enforces in the type system that at most one of the M or N dimensions may be narrow, not both. Previously, we had different structs/tuples, none of which enforced that, so we felt compiled to write comments about the unenforced contract, and the concerned code was scattered across different files. * Remove the `getMatmulNarrow{M,N}` getters on `EncodingAttr`. * Generally we are over-relying on TableGen class methods, which only obfuscates things compared to functions declared manually in C++ files, and the new `MatmulNarrowDim` struct allows replacing both these methods by a single `getMatmulNarrowDim`, which also simplifies callers. Signed-off-by: Benoit Jacob --- .../Common/CPU/CPUMaterializeEncodings.cpp | 26 +++++---- .../compiler/Codegen/Common/EncodingUtils.cpp | 4 +- .../Dialect/Encoding/IR/EncodingBase.td | 8 --- .../Dialect/Encoding/IR/EncodingOps.cpp | 54 +++++++++++++------ .../Dialect/Encoding/IR/EncodingOps.h | 42 +++++++++++++-- .../compiler/DispatchCreation/SetEncoding.cpp | 53 ++---------------- 6 files changed, 95 insertions(+), 92 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index 963fb9c0c5c0..96fdfb903bc5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp @@ -47,8 +47,7 @@ enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims, // codegen.query_tile_sizes op, so we disable dynamic tile shapes for // batch_matmul. Also, they are not set up for narrow M/N matmul, so it is // disabled when it is the case. - if (!cDims.batch.empty() || encoding.getMatmulNarrowM() || - encoding.getMatmulNarrowN()) { + if (!cDims.batch.empty() || getMatmulNarrowDim(encoding)) { hasUkernelSupport = false; } if (hasUkernelSupport) { @@ -294,19 +293,20 @@ enumerateMatmulTileX86_64(TypeRange elementTypes, /// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such /// information to host. For now, they are defined by host. static TileMxNxK -chooseMatmulTile(ArrayRef enumeratedTiles, int64_t matmulNarrowM, - int64_t matmulNarrowN, +chooseMatmulTile(ArrayRef enumeratedTiles, + IREE::Encoding::MatmulNarrowDim narrowDim, ArrayRef hostDefinedUpperBound = {}) { assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) && "expected hostDefinedUpperBound is empty or has upper bound for {M, " "N, K}"); // Handle narrow-N by transposing to reduce to narrow-M. Note: the // enumeratedTiles currently only enumerate narrow-M cases. - if (matmulNarrowN && (!matmulNarrowM || matmulNarrowN < matmulNarrowM)) { + if (narrowDim.isN()) { SmallVector newHostDefinedUpperBound(hostDefinedUpperBound); std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]); - TileMxNxK tile = chooseMatmulTile(enumeratedTiles, matmulNarrowN, 0, - newHostDefinedUpperBound); + narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M; + TileMxNxK tile = + chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound); std::swap(tile.M, tile.N); return tile; } @@ -367,9 +367,9 @@ chooseMatmulTile(ArrayRef enumeratedTiles, int64_t matmulNarrowM, // are OK with the tile that has M==8 even though it requires some padding. // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would // end up selecting the vecmat tile (M==1) for that case! - if (matmulNarrowM) { + if (narrowDim) { ratedTile.paddingPenalty = - std::max(tile.M - llvm::PowerOf2Ceil(matmulNarrowM), 0); + std::max(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0); } ratedTile.productMxNxK = tile.M * tile.N * tile.K; ratedTiles.push_back(ratedTile); @@ -438,13 +438,11 @@ materializeEncodingForTarget(RankedTensorType tensorType, if (enumeratedTileMxNxK.empty()) { return failure(); } - int64_t matmulNarrowM = encoding.getMatmulNarrowM(); - int64_t matmulNarrowN = encoding.getMatmulNarrowN(); + auto narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding); // Choose a final matmul TileMxNxK from the above-enumarated tile shapes, // taking narrow dimensions into account. - TileMxNxK chosenTileMxNxK = - chooseMatmulTile(enumeratedTileMxNxK, matmulNarrowM, matmulNarrowN, - encoding.getRoundDimsToArray()); + TileMxNxK chosenTileMxNxK = chooseMatmulTile(enumeratedTileMxNxK, narrowDim, + encoding.getRoundDimsToArray()); // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp index d5e6b5561a91..f3f6cbb2b47b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp @@ -199,9 +199,7 @@ bool isNarrowNResult(EncodingAttr encoding) { return false; } - int64_t narrowM = encoding.getMatmulNarrowM(); - int64_t narrowN = encoding.getMatmulNarrowN(); - return narrowN && (!narrowM || narrowM > narrowN); + return IREE::Encoding::getMatmulNarrowDim(encoding).isN(); } SmallVector diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td index 8ea2bb499422..ebc936a082c0 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingBase.td @@ -110,14 +110,6 @@ def EncodingAttr : /// Clones an encoding with a new bcast_map EncodingAttr clone(AffineMap bcastMap); - - /// Returns the M size from `round_dims_to` if the value is less than - /// kNarrowThreshold. Otherwise, returns zero. - int64_t getMatmulNarrowM(); - - /// Returns the N size from `round_dims_to` if the value is less than - /// kNarrowThreshold. Otherwise, returns zero. - int64_t getMatmulNarrowN(); }]; let genVerifyDecl = 0; diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp index 954eec376357..19c65533cf6b 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp @@ -137,6 +137,33 @@ std::optional EncodingAttr::mapDimToOperandIndex(int64_t dimPos) { getAffineDimExpr(dimPos, getContext())); } +MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, + int narrowThreshold) { + linalg::ContractionDimensions cDims = + linalg::inferContractionDims(linalgOp).value(); + auto map = linalgOp.getIndexingMapsArray().back(); + auto outType = llvm::cast(linalgOp.getDpsInits()[0].getType()); + auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t { + return outType.getDimSize( + map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext())) + .value()); + }; + // M or N can be empty instead of having an explicit dim size of 1 for matvec + // and vecmat, so set to 1 if empty. + int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]); + int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]); + + MatmulNarrowDim narrowM, narrowN; + if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) { + narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize}; + } + if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) { + narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize}; + } + + return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN; +} + ArrayRef EncodingAttr::getRoundDimsToArray() { auto roundDimsTo = getRoundDimsTo(); if (!roundDimsTo) { @@ -151,26 +178,23 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { AffineMapAttr::get(bcastMap), getRoundDimsTo()); } -int64_t EncodingAttr::getMatmulNarrowM() { - if (getOpType().getValue() != EncodingOpType::matmul) { - return 0; +MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { + if (encoding.getOpType().getValue() != EncodingOpType::matmul) { + return {}; } - ArrayRef roundDimsTo = getRoundDimsToArray(); + ArrayRef roundDimsTo = encoding.getRoundDimsToArray(); if (roundDimsTo.empty()) { - return 0; + return {}; } - return roundDimsTo[0] < kNarrowThreshold ? roundDimsTo[0] : 0; -} - -int64_t EncodingAttr::getMatmulNarrowN() { - if (getOpType().getValue() != EncodingOpType::matmul) { - return 0; + int m = roundDimsTo[0]; + int n = roundDimsTo[1]; + if (m < n) { + return {MatmulNarrowDim::Dim::M, m}; } - ArrayRef roundDimsTo = getRoundDimsToArray(); - if (roundDimsTo.empty()) { - return 0; + if (n < m) { + return {MatmulNarrowDim::Dim::N, n}; } - return roundDimsTo[1] < kNarrowThreshold ? roundDimsTo[1] : 0; + return {}; } //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h index c9f694a83fea..9a0810ed78fe 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h @@ -35,9 +35,6 @@ namespace mlir::iree_compiler::IREE::Encoding { -/// Threadshold that determines if a dimension is considered "narrow" or not. -constexpr int64_t kNarrowThreshold = 32; - /// Returns the encoding attribute from the type if there is an encoding. /// Otherwise, returns null. EncodingAttr getEncodingAttr(RankedTensorType type); @@ -46,13 +43,50 @@ EncodingAttr getEncodingAttr(RankedTensorType type); FailureOr getEncodingContractionDims(EncodingAttr encoding); -// Assign a name to operand indices for clarity +/// Assign a name to operand indices for clarity const int64_t MATMUL_LHS = 0; const int64_t MATMUL_RHS = 1; const int64_t MATMUL_RESULT = 2; + /// Convert operand index to strings for printing std::string stringifyOperandIndex(IntegerAttr); +/// Designates a dimension in a matmul (either the M or the N dimension) as +/// being "narrow", i.e. small enough that we bother lowering the amount of +/// padding along that dimension compared to how padding we apply to +/// sufficiently large dimensions. +struct MatmulNarrowDim { + // Enumerates dimensions of a matmul that may be labelled as narrow. + enum class Dim { + None, + M, + N, + }; + Dim dim = Dim::None; // Which dimension is designated by *this. + int64_t size = 0; // Size of the designated dimension, or kDynamic. + + explicit operator bool() const { return dim != Dim::None; } + bool isM() const { return dim == Dim::M; } + bool isN() const { return dim == Dim::N; } +}; + +/// Returns the narrow dim in a given `linalgOp`, with respect to the given +/// `narrowThreshold` below which a dimension is eligible to be considered +/// narrow. If both M and N are narrow, M is returned. If neither M nor N are +/// narrow, this returns a default-constructed falsish value. +MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, + int narrowThreshold); + +/// Returns the narrow dim in a given `encoding`. This works by inspecting +/// the `round_dims_to` array attribute in the `encoding`. If the +/// `round_dims_to` of one dimension (M or N) is smaller than the other, then +/// that's the narrow dimension, because the only way it would have been set +/// to be smaller in the first place, is if we previously flagged that dimension +/// as narrow. If the `round_dims_to` of the M and N dimensions agree, then +/// neither is a narrow dimension and this returns a default-constructed falsish +/// value. +MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding); + } // namespace mlir::iree_compiler::IREE::Encoding #endif // IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGOPS_H_ diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp index b19165fb4557..0f16986a89d2 100644 --- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp @@ -45,47 +45,6 @@ Value setEncoding(OpBuilder &builder, Location loc, Value source, return builder.create(loc, resultType, source); }; -struct MatmulNarrowSizes { - std::optional M, N; -}; - -// Returns the minimum of static sizes of the M/N-dimensions in the types of the -// Ouput. -static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType, - linalg::LinalgOp linalgOp) { - linalg::ContractionDimensions cDims = - linalg::inferContractionDims(linalgOp).value(); - auto map = linalgOp.getIndexingMapsArray().back(); - auto getOutputSizeAtDimPos = [&](unsigned dimPos) -> int64_t { - return outType.getDimSize( - map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext())) - .value()); - }; - // M or N can be empty instead of having an explicit dim size of 1 for matvec - // and vecmat, so set to 1 if empty. - int64_t M = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]); - int64_t N = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]); - - MatmulNarrowSizes narrow; - if (!ShapedType::isDynamic(M) && M < IREE::Encoding::kNarrowThreshold) { - narrow.M = M; - } - if (!ShapedType::isDynamic(N) && N < IREE::Encoding::kNarrowThreshold) { - narrow.N = N; - } - - // Only pick 1 if both are present - if (narrow.M && narrow.N) { - if (*narrow.M <= *narrow.N) { - narrow.N.reset(); - } else { - narrow.M.reset(); - } - } - - return narrow; -} - static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc, Value source, SmallVector sizes) { @@ -247,8 +206,7 @@ class setContractionOpEncoding } SmallVector elemTypes = {lhsElemType, rhsElemType, outElemType}; - MatmulNarrowSizes narrowSizes = - getMatmulNarrowSizes(cast(out.getType()), linalgOp); + auto narrowDim = IREE::Encoding::getMatmulNarrowDim(linalgOp, padFactor); Location loc = linalgOp.getLoc(); SmallVector maps = linalgOp.getIndexingMapsArray(); @@ -256,13 +214,12 @@ class setContractionOpEncoding auto opType = IREE::Encoding::EncodingOpType::matmul; auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value { SmallVector roundDimsTo(3, padFactor); - if (narrowSizes.M) { - roundDimsTo[0] = llvm::PowerOf2Ceil(narrowSizes.M.value()); + if (narrowDim.isM()) { + roundDimsTo[0] = llvm::PowerOf2Ceil(narrowDim.size); } - if (narrowSizes.N) { - roundDimsTo[1] = llvm::PowerOf2Ceil(narrowSizes.N.value()); + if (narrowDim.isN()) { + roundDimsTo[1] = llvm::PowerOf2Ceil(narrowDim.size); } - auto encoding = EncodingAttr::get(linalgOp.getContext(), operandIndex, opType, elemTypes, maps, /*bcastMap=*/std::nullopt, roundDimsTo);