diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 515bb3bea75d..a3ad95f80049 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -35,6 +35,7 @@ iree_compiler_cc_library( "CleanupTensorShapes.cpp", "CloneProducersIntoDispatchRegions.cpp", "CollapseDimensions.cpp", + "CollapseReductionDims.cpp", "Convert1X1FilterConv2DToMatmul.cpp", "ConvertRegionToWorkgroups.cpp", "ConvertToFlow.cpp", @@ -99,7 +100,6 @@ iree_compiler_cc_library( "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index bab9e87869b9..c3b86ccf1eef 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -34,6 +34,7 @@ iree_cc_library( "CleanupTensorShapes.cpp" "CloneProducersIntoDispatchRegions.cpp" "CollapseDimensions.cpp" + "CollapseReductionDims.cpp" "Convert1X1FilterConv2DToMatmul.cpp" "ConvertRegionToWorkgroups.cpp" "ConvertToFlow.cpp" @@ -81,7 +82,6 @@ iree_cc_library( IREELinalgTransformDialect LLVMSupport MLIRAffineDialect - MLIRAffineUtils MLIRAnalysis MLIRArithDialect MLIRArithUtils diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp index 19a0a43f7325..6b0445e60ac5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp @@ -12,8 +12,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -29,8 +27,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include - #define DEBUG_TYPE "iree-flow-collapse-dimensions" namespace mlir { @@ -52,88 +48,47 @@ struct CollapseDimensionsPass /// Searches the same sequence in all the affine maps and collapses these /// dimensions. It only applies these to "parallel" loops without mixing them -/// with "reduction" types. It is expected that the `genericOp` has projected -/// permutations only as indexing maps. (Checked using `isEligibleForCollapse`). +/// with "reduction" types. static SmallVector getCollapsibleLoops(linalg::GenericOp genericOp) { SmallVector contiguousLoops; - SmallVector pDims, rDims; + SmallVector pDims; genericOp.getParallelDims(pDims); - genericOp.getReductionDims(rDims); - llvm::SmallDenseSet pDimsSet, rDimsSet; - pDimsSet.insert(pDims.begin(), pDims.end()); - rDimsSet.insert(rDims.begin(), rDims.end()); + if (pDims.size() < 2) + return contiguousLoops; + + llvm::SmallDenseSet pLoops(pDims.begin(), pDims.end()); auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) { - // Check that all indexing maps of the `genericOp` - // - Either both `preExpr` and `nextExpr` contiguous, or - // - are missing in - // Then `preExpr` and `nextExpr` can be collapsed. for (AffineMap map : genericOp.getIndexingMapsArray()) { - // If map has no results, no need to check. - if (map.getNumResults() == 0) { - continue; - } + bool foundSeq = false; for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) { - // If we find the preExpr, we should find the nextExpr. - if (resultExpr == preExpr) { - if (index == map.getNumResults() - 1) { - // Reached end of list. Return false; - return false; - } - if (map.getResult(index + 1) != nextExpr) { - return false; - } - } - // If we find nextExpr the previous one should be `prevExpr`. - // This is redundant check for the most part, but is cheap enough, so - // #YOLO if (resultExpr == nextExpr) { - if (index == 0) { - // match at beginning of the list. Return false; - return false; - } - if (map.getResult(index - 1) != preExpr) { - return false; - } + foundSeq = (index > 0 && preExpr == map.getResult(index - 1)); + break; } } + if (!foundSeq) + return false; } return true; }; - auto hasSameIteratorType = [&](AffineExpr preExpr, AffineExpr nextExpr) { - unsigned prePos = preExpr.cast().getPosition(); - unsigned nextPos = nextExpr.cast().getPosition(); - return (pDimsSet.count(prePos) && pDimsSet.count(nextPos)) || - (rDimsSet.count(prePos) && rDimsSet.count(nextPos)); - }; ReassociationIndices range; AffineExpr preExpr; - // Find the largest sequence of dimensions that are - // - Either preserved in all maps, or - // - are completely absent - // This sequence can be collapsed. To find the sequence, - // 1) Take the result expressions of one of the indexing maps - // 2) Find a sequence of 2 that is found in all maps - // 3) Then take last element of this sequence and the next - // result expression, and check if this sequence of 2 is - // found in all maps. If so, add to sequence (to get a sequence of 3) - // and repeat till the last element of sequence and the next result - // expression is not found as a sequence in all maps. for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) { + unsigned pos = nextExpr.cast().getPosition(); if (!range.empty()) { - if (!hasAllMapsSameSequence(preExpr, nextExpr) || - !hasSameIteratorType(preExpr, nextExpr)) { - if (range.size() > 1) { + if (!hasAllMapsSameSequence(preExpr, nextExpr) || !pLoops.count(pos)) { + if (range.size() > 1) contiguousLoops.push_back({range.begin(), range.end()}); - } range.clear(); } } - range.push_back(nextExpr.cast().getPosition()); preExpr = nextExpr; + if (pLoops.count(pos)) + range.push_back(pos); } if (range.size() > 1) contiguousLoops.push_back(range); @@ -152,6 +107,22 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { return contiguousLoops; } +/// Collapse possible dimension of the given linalg.generic +static FailureOr> +collapseLinalgGeneric(IRRewriter &rewriter, linalg::GenericOp genericOp, + SmallVector &collapseIndices) { + rewriter.setInsertionPoint(genericOp->getParentOp()); + FailureOr> replacements = + mlir::linalg::collapseGenericOpIterationDims(genericOp, collapseIndices, + rewriter); + if (failed(replacements) || replacements->empty()) { + return rewriter.notifyMatchFailure(genericOp, + "failed to collapse dimensions"); + } + + return replacements; +} + /// Returns true if the given op is collapsable. static bool isEligibleForCollapse(linalg::GenericOp genericOp) { // TODO(guray) There is no mechanism to tell the collapsed indexes to @@ -183,298 +154,101 @@ static bool isEligibleForCollapse(linalg::GenericOp genericOp) { /// without any producers. static FailureOr findRootGenericOp(DispatchRegionOp regionOp) { - if (!llvm::hasSingleElement(regionOp.getBody())) { - return failure(); + SmallVector computeOps; + auto &ops = regionOp.getBody().front().getOperations(); + for (Operation &op : ops) { + if (isa(op)) + computeOps.push_back(&op); } - - // Check the yielded value is from a single `linalg.generic`. - auto returnOp = - cast(regionOp.getBody().front().getTerminator()); - auto collapsibleOp = dyn_cast_or_null( - returnOp->getOperand(0).getDefiningOp()); - if (!collapsibleOp) { + // Looking for root without producer + if (computeOps.size() != 1 || ops.size() != 2) return failure(); - } - for (auto returnVal : returnOp->getOperands().drop_front()) { - if (returnVal.getDefiningOp() != collapsibleOp.getOperation()) { - return failure(); - } - } - - // Check that the operands of the generic op are defined outside the dispatch. - for (OpOperand *inputOperands : collapsibleOp.getDpsInputOperands()) { - Operation *definingOp = inputOperands->get().getDefiningOp(); - if (definingOp && - definingOp->getParentOfType() == regionOp) { - return failure(); - } - } - - // Check that the output is either a `tensor.empty` or a `linalg.fill` op by - // traversing the operations that define the `init` operands of the - // `collapsibleOp`. - std::deque worklist; - llvm::SmallDenseSet visited; - auto addDefiningOpToWorklist = [&](Value v) { - Operation *definingOp = v.getDefiningOp(); - if (definingOp && - definingOp->getParentOfType() == regionOp && - !visited.count(definingOp)) { - worklist.push_back(definingOp); - visited.insert(definingOp); - } - }; - for (Value initOperand : collapsibleOp.getDpsInits()) { - addDefiningOpToWorklist(initOperand); - } - - while (!worklist.empty()) { - Operation *op = worklist.front(); - worklist.pop_front(); - if (auto fillOp = dyn_cast(op)) { - addDefiningOpToWorklist(fillOp.getDpsInitOperand(0)->get()); - continue; - } - if (isa(op)) { - continue; - } + auto genericOp = llvm::dyn_cast(computeOps.front()); + if (!genericOp) return failure(); - } - return collapsibleOp; + return genericOp; } -/// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp` -/// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the -/// dispatch. -static FailureOr -hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter, - DispatchRegionOp dispatchOp) { - // Only do this for `dispatchOp` with a single operation. - if (!llvm::hasSingleElement(dispatchOp.getBody())) { - return failure(); - } - Block &body = dispatchOp.getBody().front(); - auto returnOp = cast(body.getTerminator()); - - // 1. Get the slice of operations within `dispatchOp` that produce the yielded - // value. - BackwardSliceOptions sliceOptions; - sliceOptions.filter = [&](Operation *op) { - return op->getParentOfType(); - }; - SetVector slice; - getBackwardSlice(returnOp, &slice, sliceOptions); - - // 2. Get the leaf operations that are tensor.collapse_shape ops. - SmallVector leafs; - for (Operation *op : slice) { - auto collapseShapeOp = dyn_cast(op); - if (!collapseShapeOp) { - continue; - } - if (llvm::all_of(op->getOperands(), [&](Value operand) { - Operation *definingOp = operand.getDefiningOp(); - return !definingOp || slice.count(definingOp) == 0; - })) { - leafs.push_back(collapseShapeOp); - } - } - - // 3. Clone the leaf `tensor.collapse_shape` ops outside the dispatch. +/// Generate a new dispatch.region and workload according with the collapsed +/// linalg Generic Op +static LogicalResult +generateNewDispatchRegion(IRRewriter &rewriter, DispatchRegionOp regionOp, + SmallVector collapseResults, + linalg::GenericOp newGenericOp) { OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(dispatchOp); - for (auto reshapeOp : leafs) { - Operation *clonedOp = rewriter.clone(*reshapeOp.getOperation()); - rewriter.replaceOp(reshapeOp, clonedOp->getResults()); - } - - // 4. From the yielded values find any that are produced by - // `tensor.expand_shape` operation and move them out of the dispatch. For - // this a new `DispatchRegionOp` is needed. For values that are yielded and - // produced from `tensor.expand_shape`, the type of the result changes. The - // dynamic dimensions of the result type also need to be updated. - SmallVector newReturnTypes; - SmallVector newDynamicDims; - SmallVector newYieldVals; - SmallVector> allReassociationIndices; - ValueRange dynamicDimsList = dispatchOp.getResultDims(); - Location loc = dispatchOp.getLoc(); - for (Value yieldedValue : returnOp->getOperands()) { - auto expandShapeOp = yieldedValue.getDefiningOp(); - if (!expandShapeOp) { - // 4a. Keep the same yield value if the producer is not a - // `tensor.expand_shape` op. - newReturnTypes.push_back(yieldedValue.getType()); - newYieldVals.push_back(yieldedValue); - continue; - } + rewriter.setInsertionPoint(regionOp->getParentOp()); - // 4b. The return type is same as the type of the source of the - // `tensor.expand_shape`. - RankedTensorType collapsedShapeType = expandShapeOp.getSrcType(); - newReturnTypes.push_back(collapsedShapeType); - newYieldVals.push_back(expandShapeOp.getSrc()); - SmallVector reassociation = - expandShapeOp.getReassociationIndices(); - ArrayRef expandedShape = expandShapeOp.getResultType().getShape(); - - // 4c. Dynamic dims of the result shape is obtained by taking the static - // shape + dynamic dims and collapsing them using the same reassociation - // map as the `tensor.expand_shape`. - for (auto [index, shape] : llvm::enumerate(collapsedShapeType.getShape())) { - int64_t staticCollapsedShape = 1; - SmallVector dynamicCollapsedDims; - for (auto collapsedDim : reassociation[index]) { - if (expandedShape[collapsedDim] == ShapedType::kDynamic) { - dynamicCollapsedDims.push_back(dynamicDimsList.front()); - dynamicDimsList = dynamicDimsList.drop_front(); - } else { - staticCollapsedShape *= expandedShape[collapsedDim]; - } - } - - if (dynamicCollapsedDims.empty()) { - // If there are no dynamic dims, there is nothing to do. - continue; - } - SmallVector exprs(dynamicCollapsedDims.size()); - bindSymbolsList(rewriter.getContext(), - MutableArrayRef(exprs)); - AffineExpr multiplyAll = exprs.front(); - for (auto expr : ArrayRef(exprs).drop_front()) { - multiplyAll = multiplyAll * expr; - } - if (staticCollapsedShape != 1) { - multiplyAll = multiplyAll * staticCollapsedShape; - } - OpFoldResult collapsedShape = affine::makeComposedFoldedAffineApply( - rewriter, loc, multiplyAll, dynamicCollapsedDims); - newDynamicDims.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, collapsedShape)); - } - allReassociationIndices.emplace_back(std::move(reassociation)); - } - - // 5. Create the new dispatch op. - auto newDispatchOp = rewriter.create( - loc, newReturnTypes, newDynamicDims, dispatchOp.getWorkload()); - - // 5a. Move the body over, but replace the `flow.return` to use the new yield - // values. - Region &newBody = newDispatchOp.getBody(); - rewriter.inlineRegionBefore(dispatchOp.getBody(), newBody, newBody.begin()); - { - Operation *terminator = newBody.front().getTerminator(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(terminator); - rewriter.replaceOpWithNewOp(terminator, newYieldVals); - } + auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, newGenericOp); + if (failed(maybeRegionOp)) + return failure(); - // 5b. Move the workgroup count region over. - Region &workgroupCountRegion = dispatchOp.getWorkgroupCount(); - if (!workgroupCountRegion.empty()) { - Region &newWorkgroupCountRegion = newDispatchOp.getWorkgroupCount(); - rewriter.inlineRegionBefore(workgroupCountRegion, newWorkgroupCountRegion, - newWorkgroupCountRegion.begin()); - } + // Replace old regionOp with the result of collapse + rewriter.replaceOp(regionOp, collapseResults); - // 6. Map the modified result values back to their original shape using - // `tensor.expand_shape` operations. - ArrayRef> allReassociationIndicesRef( - allReassociationIndices); - for (auto [index, returnValue] : - llvm::enumerate(newDispatchOp.getResults())) { - Value origResult = dispatchOp->getResult(index); - if (returnValue.getType() == origResult.getType()) { - rewriter.replaceAllUsesWith(origResult, returnValue); - continue; - } - auto newExpandShapeOp = rewriter.create( - loc, origResult.getType(), returnValue, - allReassociationIndicesRef.front()); - allReassociationIndicesRef = allReassociationIndicesRef.drop_front(); - rewriter.replaceAllUsesWith(origResult, newExpandShapeOp.getResult()); - } - rewriter.eraseOp(dispatchOp); - return newDispatchOp; + return success(); } /// Traverses DispatchRegionOps to find linalg genericOps that has no /// producers and tries to collapse its dimensions. -static bool collapseDimensions(IRRewriter &rewriter, - DispatchRegionOp ®ionOp) { +static LogicalResult collapseDimensions(IRRewriter &rewriter, + DispatchRegionOp ®ionOp) { // Step 1. Find the root linalg.generic Op with no producer std::optional genericOp = findRootGenericOp(regionOp); if (!genericOp.has_value()) - return false; + return success(); // Step 2. Check whether it is possible to collapse if (!isEligibleForCollapse(genericOp.value())) - return false; + return success(); SmallVector collapseIndices; collapseIndices = getCollapsibleLoops(genericOp.value()); if (collapseIndices.empty()) - return false; + return success(); // Step 3. Collapse dimensions - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(genericOp.value()); - - FailureOr> maybeReplacements = - mlir::linalg::collapseGenericOpIterationDims(genericOp.value(), - collapseIndices, rewriter); + auto maybeReplacements = + collapseLinalgGeneric(rewriter, genericOp.value(), collapseIndices); if (failed(maybeReplacements)) - return false; - rewriter.replaceOp(genericOp.value(), maybeReplacements.value()); - return true; + return failure(); + auto expandshapeOp = + maybeReplacements->front().getDefiningOp(); + if (!expandshapeOp) + return failure(); + auto newGenericOp = + expandshapeOp.getOperand().getDefiningOp(); + if (!newGenericOp) + return failure(); + + // Step 4. Generate new dispatch region and replace old one users + if (failed(generateNewDispatchRegion(rewriter, regionOp, *maybeReplacements, + newGenericOp))) + return failure(); + + return success(); } void CollapseDimensionsPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); - MLIRContext *context = funcOp->getContext(); - IRRewriter rewriter(context); - - SmallVector modifiedDispatchOps; - funcOp->walk([&](DispatchRegionOp dispatchOp) { - if (collapseDimensions(rewriter, dispatchOp)) { - modifiedDispatchOps.push_back(dispatchOp); - } - }); + IRRewriter rewriter(funcOp->getContext()); - LLVM_DEBUG({ - llvm::dbgs() << "[CollapseDims] : After collapsing generic ops: \n"; - funcOp.print(llvm::dbgs()); - llvm::dbgs() << "\n"; + auto walkResult = funcOp->walk([&](DispatchRegionOp regionOp) { + if (failed(collapseDimensions(rewriter, regionOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); }); + if (walkResult.wasInterrupted()) { + funcOp->emitOpError("failed in collapsing dimensions pass"); + return signalPassFailure(); + } - // Move all the `tensor.collapse_shape` leafs and `tensor.expand_shape` roots - // of the modified dispatches out of the dispatch. - for (auto dispatchOp : modifiedDispatchOps) { - Region &body = dispatchOp.getBody(); - assert(llvm::hasSingleElement(body) && "expected op with a single body"); - Block &block = body.front(); - RewritePatternSet moveReshapeOps(&getContext()); - linalg::FillOp::getCanonicalizationPatterns(moveReshapeOps, context); - memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps); - tensor::populateFoldTensorEmptyPatterns(moveReshapeOps); - SmallVector candidateOps; - block.walk([&](Operation *op) { - if (isa(op)) { - candidateOps.push_back(op); - } - }); - if (failed( - applyOpPatternsAndFold(candidateOps, std::move(moveReshapeOps)))) { - funcOp.emitOpError( - "failed to propagate reshape ops introduced during collapse"); - return signalPassFailure(); - } - - if (failed(hoistTensorReshapesOutOfDispatchRegion( - rewriter, cast(dispatchOp)))) { - dispatchOp->emitOpError("failed to hoist reshapes out of dispatch"); - return signalPassFailure(); - } + RewritePatternSet canonicalizationPatterns(&getContext()); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + canonicalizationPatterns); + tensor::populateFoldTensorEmptyPatterns(canonicalizationPatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(canonicalizationPatterns)))) { + funcOp->emitOpError("failed to apply cleanup patterns"); + return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp new file mode 100644 index 000000000000..407777bfdbeb --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp @@ -0,0 +1,95 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +namespace { + +/// Check whether the given dimensions are contiguous in the result map. +/// If non of the dimension are present in the map return true as well. +static bool hasContiguousDims(AffineMap map, ArrayRef dims) { + if (!map.isProjectedPermutation()) + return false; + llvm::SmallDenseSet existingDims(dims.begin(), dims.end()); + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + if (map.getDimPosition(i) != dims[0]) { + if (existingDims.count(map.getDimPosition(i))) { + return false; + } + continue; + } + // Check that the following dimensions are match the order of `dims` + for (unsigned j = 1, numDims = dims.size(); j < numDims; j++) { + unsigned pos = i + j; + if (pos >= map.getNumResults() || map.getDimPosition(pos) != dims[j]) { + return false; + } + } + break; + } + return true; +} + +static SmallVector +collapseDimensions(linalg::GenericOp genericOp) { + SmallVector collapseIndices; + + if (!isNonNullAndOutsideDispatch(genericOp)) { + return collapseIndices; + } + + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); + if (reductionDims.size() < 2) + return collapseIndices; + + for (AffineMap map : genericOp.getIndexingMapsArray()) { + if (!hasContiguousDims(map, reductionDims)) + return collapseIndices; + } + ReassociationIndices indices; + for (unsigned dim : reductionDims) { + indices.push_back(int64_t(dim)); + } + collapseIndices.push_back(indices); + return collapseIndices; +} + +struct CollapseDimsPass : public CollapseDimsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + linalg::populateCollapseDimensions(patterns, collapseDimensions); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createCollapseDimsPass() { + return std::make_unique(); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp index a3f1f6747fcb..20746e728650 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp @@ -114,16 +114,10 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) { // broadcast this ends up redundantly computing operations without more // parallelism. if (auto linalgConsumerOp = dyn_cast(consumerOp)) { - if (linalgConsumerOp.getNumParallelLoops() == - linalgConsumerOp.getNumLoops()) { - return true; - } - if (linalgConsumerOp.getNumReductionLoops() != 1 || - !linalgConsumerOp.getMatchingIndexingMap(fusedOperand) - .isPermutation()) { - return false; - } - return true; + return linalgConsumerOp.getNumParallelLoops() == + linalgConsumerOp.getNumLoops() || + linalgConsumerOp.getMatchingIndexingMap(fusedOperand) + .isPermutation(); } // All other cases dont fuse. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 524977e31d43..beece70682d8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -129,6 +129,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, // Preprocess the input to a form more amenable for fusion .addPass(createRaiseSpecialOps) .addPass(createInterchangeGenericOpsPass) + .addPass(createCollapseDimsPass) .addPass(memref::createResolveShapedTypeResultDimsPass) .addPass(mlir::createCanonicalizerPass) .addPass(mlir::createCSEPass) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h index 16f4c1e74442..b8db1df54c13 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -239,6 +239,9 @@ std::unique_ptr createRaiseSpecialOps(); // Create a pass to split reduction dimension. std::unique_ptr createSplitReductionPass(); +// Create a pass to collapse reduction dimensions +std::unique_ptr createCollapseDimsPass(); + //===----------------------------------------------------------------------===// // Module Analysis and Finalization //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 52fa8fb2e1c3..2aa8ef36f600 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -26,6 +26,12 @@ def CleanupNumericNarrowing : let constructor = "mlir::iree_compiler::IREE::Flow::createCleanupNumericNarrowingPass()"; } +def CollapseDims : + Pass<"iree-flow-collapse-dims", ""> { + let summary = "Collapse reduction dimensions when possible."; + let constructor = "mlir::iree_compiler::IREE::Flow::createCollapseDimsPass()"; +} + def Convert1X1FilterConv2DToMatmul: Pass<"iree-flow-convert-1x1-filter-conv2d-to-matmul", ""> { let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops."; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index b8cd4bc2710d..e506ec9fa09f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -20,6 +20,7 @@ iree_lit_test_suite( "cleanup_numeric_narrowing.mlir", "cleanup_tensor_shapes.mlir", "clone_producers_into_dispatch_regions.mlir", + "collapse_reduction.mlir", "conv1x1_to_matmul.mlir", "convert_region_to_workgroups.mlir", "deduplicate_executables.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 2bf5af23db8e..9c7242033bfc 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -19,6 +19,7 @@ iree_lit_test_suite( "cleanup_tensor_shapes.mlir" "clone_producers_into_dispatch_regions.mlir" "collapse_linalg_generic_on_tensors.mlir" + "collapse_reduction.mlir" "conv1x1_to_matmul.mlir" "convert_region_to_workgroups.mlir" "deduplicate_executables.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir index 5e4ec9b3eba1..21b6fe95133a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-collapse-dimensions, cse))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-collapse-dimensions))" %s | FileCheck %s !type = tensor<2x4x8x16x32x64xf32> util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type @@ -23,14 +23,14 @@ func.func @collapse1() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @collapse1 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @collapse1 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} +// CHECK: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32> // ----- @@ -58,15 +58,15 @@ func.func @collapse2() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func.func @collapse2 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func.func @collapse2 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32> // ----- !type = tensor<2x4x8x16x32x64x128x256xf32> @@ -93,14 +93,14 @@ func.func @collapse3() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func.func @collapse3 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func.func @collapse3 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32> // ----- @@ -127,15 +127,15 @@ func.func @collapse4() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> -// CHECK-LABEL: func.func @collapse4 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> +// CHECK-LABEL: func.func @collapse4 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32> // ----- @@ -167,18 +167,18 @@ func.func @collapse5() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)> -// CHECK-LABEL: func.func @collapse5 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]} -// CHECK-SAME: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)> +// CHECK-LABEL: func.func @collapse5 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]} +// CHECK: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32> // ----- @@ -205,15 +205,15 @@ func.func @collapse6() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> -// CHECK-LABEL: func.func @collapse6 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> +// CHECK-LABEL: func.func @collapse6 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32> // ----- @@ -239,23 +239,24 @@ func.func @collapse7() -> !type_out { return %result: !type_out } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func.func @collapse7 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32> +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32> // ----- !type_in = tensor<16x4x32x2xf32> !type_out = tensor<8x16x4x32x8x2xf32> -func.func @collapse8(%input : !type_in) -> !type_out { +func.func @collapse8() -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index + %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d3, d0, d1) %6 = linalg.generic { indexing_maps = [ @@ -271,16 +272,15 @@ func.func @collapse8(%input : !type_in) -> !type_out { return %6: !type_out } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @collapse8 -// CHECK-SAME: (%[[IN:.+]]: tensor<16x4x32x2xf32>) -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1, 2], [3]{{\]}} -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32> -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32 -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32> +// CHECK: %[[IN:.+]] = tensor.empty() : tensor<2048x2xf32> +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK: ins(%[[IN]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32 +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32> // ----- @@ -304,7 +304,7 @@ func.func @dont_collapse() -> !type_out { return %6: !type_out } // CHECK-LABEL: func.func @dont_collapse -// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]} // ----- @@ -333,11 +333,11 @@ func.func @collapse9() -> !type_out { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)> // CHECK-LABEL: func.func @collapse9 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} // ----- @@ -345,9 +345,10 @@ func.func @collapse9() -> !type_out { !type_in = tensor<10x10x30xf32> !type_out = tensor<20x10x10x30x20xf32> -func.func @collapse10(%input : !type_in) -> !type_out { +func.func @collapse10() -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index + %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d1, d3, d0) @@ -363,18 +364,21 @@ func.func @collapse10(%input : !type_in) -> !type_out { return %result: !type_out } +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-LABEL: func.func @collapse10 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} // ----- !type_in = tensor<10x20xf32> !type_out = tensor<10x20xf32> -func.func @collapse11(%input : !type_in) -> !type_out { +func.func @collapse11() -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index + %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d1, d0) @@ -390,10 +394,10 @@ func.func @collapse11(%input : !type_in) -> !type_out { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @collapse11 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} // ----- @@ -416,7 +420,7 @@ func.func @dont_collapse_dueto_index(%height : index, %width : index) -> !type { } // CHECK-LABEL: func.func @dont_collapse -// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]} +// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]} // ----- @@ -452,146 +456,8 @@ func.func @collapse12() -> (!type,!type,!type,!type) { return %6, %7, %8, %9 : !type,!type,!type,!type } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @collapse12 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} - -// ----- - -func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { - %cst = arith.constant -0.000000e+00 : f32 - %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32> - %1 = tensor.empty() : tensor<2x32xf32> - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32> - %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %6 = arith.addf %arg1, %arg2 : f32 - linalg.yield %6 : f32 - } -> tensor<2x32xf32> - %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32> - %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view - return %5 : !hal.buffer_view -} - -// Check that we collapse dimensions. -// CHECK-LABEL: @multi_reduce_dim( -// CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} -// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[FILL:.+]] = linalg.fill -// CHECK-SAME: outs(%[[EMPTY]] : -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[COLLAPSE]] : -// CHECK-SAME: outs(%[[FILL]] : -// CHECK: flow.return %[[GENERIC]] -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]] {{\[}}[0, 1]{{\]}} - -// ----- - -// Collapsing is not supported when an input is broadcasted; we can't collapse -// the input from tensor<4xf32> to tensor<32xf32> for example. - -func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor { - %empty = tensor.empty() : tensor - %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor) { - ^bb0(%arg2: f32, %arg3: f32, %out: f32): - %div = arith.divf %arg2, %arg3 : f32 - %add = arith.addf %out, %div : f32 - linalg.yield %add : f32 - } -> tensor - return %reduce : tensor -} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} -// CHECK: @input_broadcast -// CHECK-NOT: tensor.collapse_shape - -// ----- - -// Do nothing if the dispatch is not a single elementwise op (with tensor.empty/linalg.fill producers) - -#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> -#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> -module { - func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> { - %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> - %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> - %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) { - %cst_1 = arith.constant 0.000000e+00 : f32 - %1 = tensor.empty() : tensor<1x1x4096xf32> - %2 = tensor.empty() : tensor<4096x32x128xf32> - %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> - %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) { - ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32): - %6 = arith.extui %in : i8 to i32 - %7 = arith.uitofp %6 : i32 to f32 - %8 = arith.subf %7, %in_3 : f32 - %9 = arith.mulf %8, %in_2 : f32 - linalg.yield %9 : f32 - } -> tensor<4096x32x128xf32> - %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %6 = arith.mulf %in, %in_2 : f32 - %7 = arith.addf %6, %out : f32 - linalg.yield %7 : f32 - } -> tensor<1x1x4096xf32> - flow.return %5 : tensor<1x1x4096xf32> - } - return %0 : tensor<1x1x4096xf32> - } -} - -// CHECK-LABEL: func.func @quantized_matmul -// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] -// CHECK: flow.return -// CHECK: return %[[DISPATCH]] - -// ----- - -module { - func.func @batchnorm_failure_repro(%arg0 : tensor<2x4xf32>, %arg1 : tensor<4xf32>) -> tensor<2x4xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor<2x4xf32>, tensor<4xf32>) outs(%0 : tensor<2x4xf32>) { - ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): - %2 = arith.addf %b0, %b1 : f32 - linalg.yield %2 : f32 - } -> tensor<2x4xf32> - return %1 : tensor<2x4xf32> - } -} -// CHECK-LABEL: func @batchnorm_failure_repro -// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK: flow.return %[[GENERIC]] -// CHECK: return %[[DISPATCH]] - -// ----- - -module { - func.func @catch_invalid_collapse(%arg0 : tensor<10x20x30xf32>) -> tensor<10x30x40xf32> { - %0 = tensor.empty() : tensor<10x30x40xf32> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0 : tensor<10x20x30xf32>) outs(%0 : tensor<10x30x40xf32>) { - ^bb0(%b0 : f32, %b1 : f32): - linalg.yield %b0 : f32 - } -> tensor<10x30x40xf32> - return %1 : tensor<10x30x40xf32> - } -} -// CHECK-LABEL: func @catch_invalid_collapse -// CHECK: linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir new file mode 100644 index 000000000000..5409dfba5909 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir @@ -0,0 +1,64 @@ +// RUN: iree-opt --split-input-file -iree-flow-collapse-dims %s | FileCheck %s + +func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { + %cst = arith.constant -0.000000e+00 : f32 + %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32> + %1 = tensor.empty() : tensor<2x32xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %6 = arith.addf %arg1, %arg2 : f32 + linalg.yield %6 : f32 + } -> tensor<2x32xf32> + %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32> + %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view + return %5 : !hal.buffer_view +} + +// Check that we collapse dimensions. +// CHECK: @multi_reduce_dim +// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction"] + +// ----- + +// Collapsing is not supported when an input is broadcasted; we can't collapse +// the input from tensor<4xf32> to tensor<32xf32> for example. + +func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor { + %empty = tensor.empty() : tensor + %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %out: f32): + %div = arith.divf %arg2, %arg3 : f32 + %add = arith.addf %out, %div : f32 + linalg.yield %add : f32 + } -> tensor + return %reduce : tensor +} + +// CHECK: @input_broadcast +// CHECK-NOT: tensor.collapse_shape + +// ----- + +// Collapsing should not happen to ops in flow.dispatch.region or flow.dispatch.workgroups + +func.func @multi_reduce_dim_dispatch(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { + %cst = arith.constant -0.000000e+00 : f32 + %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32> + %1 = tensor.empty() : tensor<2x32xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32> + %3 = flow.dispatch.region -> (tensor<2x32xf32>) { + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %7 = arith.addf %arg1, %arg2 : f32 + linalg.yield %7 : f32 + } -> tensor<2x32xf32> + flow.return %6 : tensor<2x32xf32> + } + %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32> + %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view + return %5 : !hal.buffer_view +} + +// CHECK: @multi_reduce_dim_dispatch +// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction", "reduction"]