Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Improve vector size selection in vectorization pass #14403

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 24 additions & 80 deletions compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/TileSizeSelection.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
Expand Down Expand Up @@ -71,46 +72,6 @@ static FailureOr<Operation *> getRootOp(Operation *op) {
return rootOp;
}

/// Computes the canonical shape used to vectorize this dispatch. Retrieves
/// the vectorization tile sizes (parallel and reduction levels) out of the
/// lowering config and adjusts them to the format expected by the Linalg
/// vectorizer.
static SmallVector<int64_t> getCanonicalVectorShape(func::FuncOp funcOp) {
FailureOr<Operation *> rootOp = getRootOp(funcOp);
if (failed(rootOp)) {
return {};
}

unsigned numTileLevels =
mlir::iree_compiler::getNumTileLevels(rootOp.value());
if (numTileLevels < 3) {
return {};
}

// Retrieve the tile sizes from the last two tiling levels (parallel and
// reduction) used for vectorization.
SmallVector<int64_t> canonicalVectorShape =
mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 2);
SmallVector<int64_t> reductionTileSizes =
mlir::iree_compiler::getTileSizes(rootOp.value(), numTileLevels - 1);

if (!reductionTileSizes.empty()) {
assert(canonicalVectorShape.size() == reductionTileSizes.size() &&
"Unexpected tile sizes");

// Combine the reduction tile sizes with the parallel tile sizes already in
// the canonical vector shape.
for (int i = 0, end = canonicalVectorShape.size(); i < end; ++i) {
if (reductionTileSizes[i] > 0)
canonicalVectorShape[i] = reductionTileSizes[i];
}
}

// Replace zeros in canonical vector shape to turn it into a valid shape.
std::replace(canonicalVectorShape.begin(), canonicalVectorShape.end(), 0, 1);
return canonicalVectorShape;
}

/// Tries to infer the vector sizes from an IR using ValueBounds analysis.
/// Returns failure if vector sizes can't be inferred.
static FailureOr<SmallVector<int64_t>>
Expand Down Expand Up @@ -169,46 +130,34 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) {
return vectorSizes;
}

// Give the canonical vector shape of a dispatch, returns the vector sizes for a
// particular linalg op within that dispatch.
static SmallVector<int64_t>
getVectorSizes(linalg::LinalgOp linalgOp,
ArrayRef<int64_t> canonicalVectorShape) {
// Try to infer the vector sizes from the IR. If it fails, try to get them
// from the lowering config.
auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp);
if (succeeded(inferredVectorSizes)) {
return *inferredVectorSizes;
}

FailureOr<Operation *> rootOp = getRootOp(linalgOp);
if (failed(rootOp)) {
return {};
}
// Returns the vector sizes to vectorize a linalg operation. We try to retrieve
// them from its `lowering_config`, if available. Otherwise, we try to infer
// them from the tiled loops in the IR.
static SmallVector<int64_t> getVectorSizes(linalg::LinalgOp linalgOp) {
auto loweringConfig = iree_compiler::getLoweringConfig(linalgOp);
// Give priority to the operation's lowering config.
if (loweringConfig) {
TilingConfig tilingConfig(loweringConfig);
SmallVector<int64_t> vectorShape = tilingConfig.getVectorTileSizes();

// TODO: Infer the tiles sizes for an op that is not the root op.
if (*rootOp != linalgOp.getOperation()) {
return {};
}
// Replace zeros in vector shape to turn it into a valid vector shape.
std::replace(vectorShape.begin(), vectorShape.end(), 0, 1);

if (canonicalVectorShape.empty()) {
return {};
LLVM_DEBUG(VEC_DBGS() << "Using vector sizes from 'lowering_config'\n");
return vectorShape;
}

assert(canonicalVectorShape.size() >= linalgOp.getNumLoops() &&
"Unexpected canonical vector shape or number of loops");

// Return the valid canonical vector shape subset based on the number of loops
// of the linalg op.
SmallVector<int64_t> vecSize(
canonicalVectorShape.take_front(linalgOp.getNumLoops()));
for (auto [idx, val] : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
if (ShapedType::isDynamic(val))
continue;
vecSize[idx] = std::max(vecSize[idx], val);
// Try to infer the vector sizes from the IR. If it fails, we can't vectorize
// this op.
auto inferredVectorSizes = inferVectorSizesFromIR(linalgOp);
if (succeeded(inferredVectorSizes)) {
return *inferredVectorSizes;
}

return vecSize;
// We couldn't infer the vector sizes for this op so we return all the vector
// sizes set to zero.
LLVM_DEBUG(VEC_DBGS() << "Couldn't infer vector sizes\n");
return SmallVector<int64_t>(linalgOp.getNumLoops(), 0);
}

class GenericVectorizationPass
Expand All @@ -231,11 +180,6 @@ class GenericVectorizationPass
void GenericVectorizationPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
SmallVector<int64_t> canonicalVectorShape;
if (enableVectorMasking) {
canonicalVectorShape = getCanonicalVectorShape(funcOp);
}

IRRewriter rewriter(context);
SmallVector<Operation *> candidates;
funcOp.walk([&](Operation *op) {
Expand All @@ -248,7 +192,7 @@ void GenericVectorizationPass::runOnOperation() {
SmallVector<int64_t> vectorSizes;
if (enableVectorMasking) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
vectorSizes.append(getVectorSizes(linalgOp, canonicalVectorShape));
vectorSizes = getVectorSizes(linalgOp);
} else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
auto ty = padOp.getResultType();
// TODO(hanchung): Infer the vector sizes for pad op after
Expand Down
Loading