From ec89417d49bd0c0e234d8897d1bcb1947ed0a357 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 14 Jul 2023 04:53:57 +0000 Subject: [PATCH] [CPU] Improve vector size computation in Vectorization pass This PR should be mostly RFC for existing working cases. It simplifies the retrieval of the vector sizes by removing the `getCanonicalVectorShape` and only looking at the lowering config of the operation to be vectorized and no longer looking at the lowering config of another (root) operation. It also gives priority to the lowering config found in the operation to be vectorize and fall back to retrieving it from the IR when the lowering config is not found. --- .../Codegen/Common/GenericVectorization.cpp | 104 ++++-------------- 1 file changed, 24 insertions(+), 80 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index 4a6359a46b2f..1d1683edb42d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/PassDetail.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/LLVMCPU/TileSizeSelection.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" @@ -71,46 +72,6 @@ static FailureOr getRootOp(Operation *op) { return rootOp; } -/// Computes the canonical shape used to vectorize this dispatch. Retrieves -/// the vectorization tile sizes (parallel and reduction levels) out of the -/// lowering config and adjusts them to the format expected by the Linalg -/// vectorizer. -static SmallVector getCanonicalVectorShape(func::FuncOp funcOp) { - FailureOr rootOp = getRootOp(funcOp); - if (failed(rootOp)) { - return {}; - } - - unsigned numTileLevels = - mlir::iree_compiler::getNumTileLevels(rootOp.value()); - if (numTileLevels < 3) { - return {}; - } - - // Retrieve the tile sizes from the last two tiling levels (parallel and - // reduction) used for vectorization. - SmallVector canonicalVectorShape = - mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 2); - SmallVector reductionTileSizes = - mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 1); - - if (!reductionTileSizes.empty()) { - assert(canonicalVectorShape.size() == reductionTileSizes.size() && - "Unexpected tile sizes"); - - // Combine the reduction tile sizes with the parallel tile sizes already in - // the canonical vector shape. - for (int i = 0, end = canonicalVectorShape.size(); i < end; ++i) { - if (reductionTileSizes[i] > 0) - canonicalVectorShape[i] = reductionTileSizes[i]; - } - } - - // Replace zeros in canonical vector shape to turn it into a valid shape. - std::replace(canonicalVectorShape.begin(), canonicalVectorShape.end(), 0, 1); - return canonicalVectorShape; -} - /// Tries to infer the vector sizes from an IR using ValueBounds analysis. /// Returns failure if vector sizes can't be inferred. static FailureOr> @@ -169,46 +130,34 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) { return vectorSizes; } -// Give the canonical vector shape of a dispatch, returns the vector sizes for a -// particular linalg op within that dispatch. -static SmallVector -getVectorSizes(linalg::LinalgOp linalgOp, - ArrayRef canonicalVectorShape) { - // Try to infer the vector sizes from the IR. If it fails, try to get them - // from the lowering config. - auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp); - if (succeeded(inferredVectorSizes)) { - return *inferredVectorSizes; - } - - FailureOr rootOp = getRootOp(linalgOp); - if (failed(rootOp)) { - return {}; - } +// Returns the vector sizes to vectorize a linalg operation. We try to retrieve +// them from its `lowering_config`, if available. Otherwise, we try to infer +// them from the tiled loops in the IR. +static SmallVector getVectorSizes(linalg::LinalgOp linalgOp) { + auto loweringConfig = iree_compiler::getLoweringConfig(linalgOp); + // Give priority to the operation's lowering config. + if (loweringConfig) { + TilingConfig tilingConfig(loweringConfig); + SmallVector vectorShape = tilingConfig.getVectorTileSizes(); - // TODO: Infer the tiles sizes for an op that is not the root op. - if (*rootOp != linalgOp.getOperation()) { - return {}; - } + // Replace zeros in vector shape to turn it into a valid vector shape. + std::replace(vectorShape.begin(), vectorShape.end(), 0, 1); - if (canonicalVectorShape.empty()) { - return {}; + LLVM_DEBUG(VEC_DBGS() << "Using vector sizes from 'lowering_config'\n"); + return vectorShape; } - assert(canonicalVectorShape.size() >= linalgOp.getNumLoops() && - "Unexpected canonical vector shape or number of loops"); - - // Return the valid canonical vector shape subset based on the number of loops - // of the linalg op. - SmallVector vecSize( - canonicalVectorShape.take_front(linalgOp.getNumLoops())); - for (auto [idx, val] : llvm::enumerate(linalgOp.getStaticLoopRanges())) { - if (ShapedType::isDynamic(val)) - continue; - vecSize[idx] = std::max(vecSize[idx], val); + // Try to infer the vector sizes from the IR. If it fails, we can't vectorize + // this op. + auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp); + if (succeeded(inferredVectorSizes)) { + return *inferredVectorSizes; } - return vecSize; + // We couldn't infer the vector sizes for this op so we return all the vector + // sizes set to zero. + LLVM_DEBUG(VEC_DBGS() << "Couldn't infer vector sizes\n"); + return SmallVector(linalgOp.getNumLoops(), 0); } class GenericVectorizationPass @@ -231,11 +180,6 @@ class GenericVectorizationPass void GenericVectorizationPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); - SmallVector canonicalVectorShape; - if (enableVectorMasking) { - canonicalVectorShape = getCanonicalVectorShape(funcOp); - } - IRRewriter rewriter(context); SmallVector candidates; funcOp.walk([&](Operation *op) { @@ -248,7 +192,7 @@ void GenericVectorizationPass::runOnOperation() { SmallVector vectorSizes; if (enableVectorMasking) { if (auto linalgOp = dyn_cast(op)) { - vectorSizes.append(getVectorSizes(linalgOp, canonicalVectorShape)); + vectorSizes = getVectorSizes(linalgOp); } else if (auto padOp = dyn_cast(op)) { auto ty = padOp.getResultType(); // TODO(hanchung): Infer the vector sizes for pad op after