Skip to content

Commit

Permalink
Fine tune tile sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 23, 2023
1 parent f7cad21 commit 1c4e61a
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,14 @@ static SmallVector<int64_t> getPackVectorTileSizes(func::FuncOp entryPointFn,
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
tensor::PackOp op) {
assert(!getLoweringConfig(op) && "expected lowering_config is not set");

auto tilingOp = cast<TilingInterface>(op.getOperation());
unsigned numLoops = tilingOp.getLoopIteratorTypes().size();
SmallVector<int64_t> minTileSizes(numLoops, 1);
SmallVector<int64_t> maxTileSizes(numLoops, defaultDistTileSize);

SmallVector<int64_t> distTileSizes =
getDefaultDistributionTileSizes(cast<TilingInterface>(op.getOperation()));
getDefaultDistributedLevelTileSizes(tilingOp, minTileSizes, maxTileSizes);

// The default function aims to returns the number of workload per workgroup,
// but it does not know that it is working on packed domain. We need to take
Expand Down Expand Up @@ -1302,7 +1308,7 @@ setUnPackOpRootConfig(func::FuncOp entryPointFn, tensor::UnPackOp op,
auto tilingOp = cast<TilingInterface>(op.getOperation());
unsigned numLoops = tilingOp.getLoopIteratorTypes().size();
SmallVector<int64_t> minTileSizes(numLoops, 1);
SmallVector<int64_t> maxTileSizes(numLoops, defaultDistTileSize);
SmallVector<int64_t> maxTileSizes(numLoops, defaultDistTileSize * 2);

if (numLoops > 2) {
for (unsigned i = 0; i < numLoops - 2; i++) {
Expand Down Expand Up @@ -1600,7 +1606,10 @@ static LogicalResult setElementwiseGenericOpRootConfig(

SmallVector<int64_t> minTileSizes = getMinTilingSizesForEachDim(
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo);
SmallVector<int64_t> maxTileSizes(numLoops, defaultDistTileSize);
SmallVector<int64_t> maxTileSizes(numLoops, defaultDistTileSize * 2);
if (numLoops > 2) {
maxTileSizes[0] = 1;
}
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(genericOp, minTileSizes, maxTileSizes,
/*allowIncompleteTile=*/true);
Expand Down

0 comments on commit 1c4e61a

Please sign in to comment.