diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index 4a6359a46b2f..1d1683edb42d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/PassDetail.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/LLVMCPU/TileSizeSelection.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" @@ -71,46 +72,6 @@ static FailureOr getRootOp(Operation *op) { return rootOp; } -/// Computes the canonical shape used to vectorize this dispatch. Retrieves -/// the vectorization tile sizes (parallel and reduction levels) out of the -/// lowering config and adjusts them to the format expected by the Linalg -/// vectorizer. -static SmallVector getCanonicalVectorShape(func::FuncOp funcOp) { - FailureOr rootOp = getRootOp(funcOp); - if (failed(rootOp)) { - return {}; - } - - unsigned numTileLevels = - mlir::iree_compiler::getNumTileLevels(rootOp.value()); - if (numTileLevels < 3) { - return {}; - } - - // Retrieve the tile sizes from the last two tiling levels (parallel and - // reduction) used for vectorization. - SmallVector canonicalVectorShape = - mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 2); - SmallVector reductionTileSizes = - mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 1); - - if (!reductionTileSizes.empty()) { - assert(canonicalVectorShape.size() == reductionTileSizes.size() && - "Unexpected tile sizes"); - - // Combine the reduction tile sizes with the parallel tile sizes already in - // the canonical vector shape. - for (int i = 0, end = canonicalVectorShape.size(); i < end; ++i) { - if (reductionTileSizes[i] > 0) - canonicalVectorShape[i] = reductionTileSizes[i]; - } - } - - // Replace zeros in canonical vector shape to turn it into a valid shape. - std::replace(canonicalVectorShape.begin(), canonicalVectorShape.end(), 0, 1); - return canonicalVectorShape; -} - /// Tries to infer the vector sizes from an IR using ValueBounds analysis. /// Returns failure if vector sizes can't be inferred. static FailureOr> @@ -169,46 +130,34 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) { return vectorSizes; } -// Give the canonical vector shape of a dispatch, returns the vector sizes for a -// particular linalg op within that dispatch. -static SmallVector -getVectorSizes(linalg::LinalgOp linalgOp, - ArrayRef canonicalVectorShape) { - // Try to infer the vector sizes from the IR. If it fails, try to get them - // from the lowering config. - auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp); - if (succeeded(inferredVectorSizes)) { - return *inferredVectorSizes; - } - - FailureOr rootOp = getRootOp(linalgOp); - if (failed(rootOp)) { - return {}; - } +// Returns the vector sizes to vectorize a linalg operation. We try to retrieve +// them from its `lowering_config`, if available. Otherwise, we try to infer +// them from the tiled loops in the IR. +static SmallVector getVectorSizes(linalg::LinalgOp linalgOp) { + auto loweringConfig = iree_compiler::getLoweringConfig(linalgOp); + // Give priority to the operation's lowering config. + if (loweringConfig) { + TilingConfig tilingConfig(loweringConfig); + SmallVector vectorShape = tilingConfig.getVectorTileSizes(); - // TODO: Infer the tiles sizes for an op that is not the root op. - if (*rootOp != linalgOp.getOperation()) { - return {}; - } + // Replace zeros in vector shape to turn it into a valid vector shape. + std::replace(vectorShape.begin(), vectorShape.end(), 0, 1); - if (canonicalVectorShape.empty()) { - return {}; + LLVM_DEBUG(VEC_DBGS() << "Using vector sizes from 'lowering_config'\n"); + return vectorShape; } - assert(canonicalVectorShape.size() >= linalgOp.getNumLoops() && - "Unexpected canonical vector shape or number of loops"); - - // Return the valid canonical vector shape subset based on the number of loops - // of the linalg op. - SmallVector vecSize( - canonicalVectorShape.take_front(linalgOp.getNumLoops())); - for (auto [idx, val] : llvm::enumerate(linalgOp.getStaticLoopRanges())) { - if (ShapedType::isDynamic(val)) - continue; - vecSize[idx] = std::max(vecSize[idx], val); + // Try to infer the vector sizes from the IR. If it fails, we can't vectorize + // this op. + auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp); + if (succeeded(inferredVectorSizes)) { + return *inferredVectorSizes; } - return vecSize; + // We couldn't infer the vector sizes for this op so we return all the vector + // sizes set to zero. + LLVM_DEBUG(VEC_DBGS() << "Couldn't infer vector sizes\n"); + return SmallVector(linalgOp.getNumLoops(), 0); } class GenericVectorizationPass @@ -231,11 +180,6 @@ class GenericVectorizationPass void GenericVectorizationPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); - SmallVector canonicalVectorShape; - if (enableVectorMasking) { - canonicalVectorShape = getCanonicalVectorShape(funcOp); - } - IRRewriter rewriter(context); SmallVector candidates; funcOp.walk([&](Operation *op) { @@ -248,7 +192,7 @@ void GenericVectorizationPass::runOnOperation() { SmallVector vectorSizes; if (enableVectorMasking) { if (auto linalgOp = dyn_cast(op)) { - vectorSizes.append(getVectorSizes(linalgOp, canonicalVectorShape)); + vectorSizes = getVectorSizes(linalgOp); } else if (auto padOp = dyn_cast(op)) { auto ty = padOp.getResultType(); // TODO(hanchung): Infer the vector sizes for pad op after