Skip to content

Commit

Permalink
[Flow] Make CollapseDimensions iterative (iree-org#18203)
Browse files Browse the repository at this point in the history
Each operation must propagate it's collapse information to every other node in the dispatch to ensure that no collapse_shape ops are needed between ops in the dispatch.
  • Loading branch information
IanWood1 authored Aug 20, 2024
1 parent 0247962 commit 3af05b9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace {
struct CollapseDimensionsPass
: public IREE::Flow::impl::CollapseDimensionsPassBase<
CollapseDimensionsPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -269,9 +270,10 @@ class CollapseInfo {
using CollapsableLoopsSet = llvm::SmallSetVector<int64_t, 8>;

CollapseInfo() = default;
CollapseInfo(ArrayRef<ReassociationIndices> reassociation)
: reassociation(reassociation),
collapsableLoops(getCollapsedFromReassociation(reassociation)) {}
CollapseInfo(linalg::GenericOp genericOp) {
reassociation = Flow::getCollapsibleLoops(genericOp);
collapsableLoops = getCollapsedFromReassociation(reassociation);
}

// Print the current operation & reassociation indicies
void print(raw_ostream &os) const;
Expand All @@ -281,11 +283,13 @@ class CollapseInfo {

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
void updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable);
// Returns true if the operation modified the number of collapsable loops.
bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable);

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
void updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable);
// Returns true if the operation modified the number of collapsable loops.
bool updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable);

// Get `collapsableLoops` after applying the transformation provided by `map`.
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
Expand Down Expand Up @@ -404,9 +408,8 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
void CollapseInfo::updateCollapseViaIntersect(
bool CollapseInfo::updateCollapseViaIntersect(
const CollapsableLoopsSet &otherCollapsable) {

CollapsableLoopsSet toRemove;
for (auto elem : collapsableLoops) {
if (!otherCollapsable.contains(elem)) {
Expand All @@ -415,15 +418,17 @@ void CollapseInfo::updateCollapseViaIntersect(
}
collapsableLoops.set_subtract(toRemove);
updateReassociation();
return toRemove.size();
}

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
void CollapseInfo::updateCollapseViaSubtract(
bool CollapseInfo::updateCollapseViaSubtract(
const CollapsableLoopsSet &uncollapsable) {

auto initialSize = collapsableLoops.size();
collapsableLoops.set_subtract(uncollapsable);
updateReassociation();
return initialSize != collapsableLoops.size();
}

void CollapseInfo::print(raw_ostream &os) const {
Expand Down Expand Up @@ -648,18 +653,20 @@ hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter,
// For each consumer, use it's producers to constrain which dimensions it will
// collapse. `slice` is expected to be topologically sorted (getBackwardSlice
// does this automatically).
static void updateConsumersFromProducers(
// Returns true if the operation modified any op's `CollapseInfo`.
static bool updateConsumersFromProducers(
ArrayRef<Operation *> slice,
llvm::DenseMap<linalg::GenericOp, CollapseInfo> &opMap) {
bool didChange = false;

// Slice is topologically sorted to ensure that `op`'s producers have been
// updated before we visit it.
for (auto op : slice) {
auto genericOp = cast<linalg::GenericOp>(op);
assert(opMap.contains(genericOp));
CollapseInfo &consumerInfo = opMap.find(genericOp)->second;
auto consumerOp = cast<linalg::GenericOp>(op);
assert(opMap.contains(consumerOp));
CollapseInfo &consumerInfo = opMap.find(consumerOp)->second;

for (auto operand : genericOp.getDpsInputOperands()) {
for (auto operand : consumerOp.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
if (!definingOp || isNonNullAndOutsideDispatch(definingOp)) {
continue;
Expand All @@ -671,7 +678,8 @@ static void updateConsumersFromProducers(
// cannot be done via union of producer and consumer collapsable loops
// because the consumer may have loops that the producer does not.
CollapseInfo::CollapsableLoopsSet producerUncollapsable;
for (auto expr : genericOp.getMatchingIndexingMap(operand).getResults()) {
for (auto expr :
consumerOp.getMatchingIndexingMap(operand).getResults()) {
producerUncollapsable.insert(cast<AffineDimExpr>(expr).getPosition());
}

Expand All @@ -682,7 +690,8 @@ static void updateConsumersFromProducers(
// If the producer is not a generic or there is no mapping, the tensor is
// not collapsable. So, all dimensions of the producer are uncollapsable.
if (!producerOp || !opMap.contains(producerOp) || failed(mapping)) {
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
continue;
}

Expand All @@ -693,17 +702,21 @@ static void updateConsumersFromProducers(
producerUncollapsable.set_subtract(producerCollapsable.value());
}

consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
}
}
return didChange;
}

// For each producer, use it's consumers to constrain which dimensions it will
// collapse. `slice` is expected to be topologically sorted (getBackwardSlice
// does this automatically).
static void updateProducersFromConsumers(
// Returns true if the operation modified any op's `CollapseInfo`.
static bool updateProducersFromConsumers(
ArrayRef<Operation *> slice,
llvm::DenseMap<linalg::GenericOp, CollapseInfo> &opMap) {
bool didChange = false;

// Iterate over `slice` in reverse so that we visit each `op` 's consumer
// before visiting `op`.
Expand All @@ -727,6 +740,7 @@ static void updateProducersFromConsumers(
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
didChange |= producerInfo.getCollapsibleLoops().size();
producerInfo.clear();
continue;
}
Expand All @@ -741,16 +755,19 @@ static void updateProducersFromConsumers(
}
// Only loops collapsable in both the consumer and producer may be
// collapsed.
producerInfo.updateCollapseViaIntersect(consumerCollapsable.value());
didChange |=
producerInfo.updateCollapseViaIntersect(consumerCollapsable.value());
}
}
return didChange;
}

// Construct a DAG of `linalg.generic` operations with 1 root op. Find
// dimensions that can be collapsed all the way from the root to the leaves,
// ensuring that all `collapse_shape` ops can be hoisted out of the dispatch.
static bool collapseDimensionsForDispatch(IRRewriter &rewriter,
DispatchRegionOp &regionOp) {
DispatchRegionOp &regionOp,
int maxIterations) {
// Only collapse dispatches with 1 block
if (!llvm::hasSingleElement(regionOp.getBody())) {
return false;
Expand All @@ -776,9 +793,10 @@ static bool collapseDimensionsForDispatch(IRRewriter &rewriter,
// Step 3. Populate each op's info with a maximally collapsable reassociation
// indicies
llvm::DenseMap<linalg::GenericOp, CollapseInfo> opMap;
opMap.reserve(slice.size());
for (auto *op : slice) {
auto genericOp = cast<linalg::GenericOp>(op);
opMap[genericOp] = CollapseInfo(getCollapsibleLoops(genericOp));
opMap[genericOp] = CollapseInfo(genericOp);
}

LLVM_DEBUG({
Expand All @@ -792,36 +810,50 @@ static bool collapseDimensionsForDispatch(IRRewriter &rewriter,
llvm::dbgs() << "\n";
});

// Step 4. For each producer, reduce the number of collapsed dimensions
// based on the dimensions that it's consumers can collapse.
updateProducersFromConsumers(slice.getArrayRef(), opMap);

LLVM_DEBUG({
llvm::dbgs() << "[CollapseDims] : After updating producers: \n";
for (auto &[op, info] : opMap) {
info.dump();
llvm::dbgs() << "\n";
op.dump();
llvm::dbgs() << "\n";
bool didUpdateProducers = true;
bool didUpdateConsumers = true;
int iterationCount = 0;
while (didUpdateProducers || didUpdateConsumers) {
// Cap the max number of iterations at 10. If it hasn't converged by then,
// don't collapse any ops in this dispatch.
iterationCount++;
if (iterationCount > maxIterations) {
return false;
}
llvm::dbgs() << "\n";
});

// Step 5. For each consumer, update it's CollapseInfo to only collapse
// dimensions that all of its producers can collapse. This ensures that all
// reshapes can be propagated to leafs and be hoisted out of the dispatch.
updateConsumersFromProducers(slice.getArrayRef(), opMap);

LLVM_DEBUG({
llvm::dbgs() << "[CollapseDims] : After updating consumers: \n";
for (auto &[op, info] : opMap) {
info.dump();
// Step 4. For each producer, reduce the number of collapsed dimensions
// based on the dimensions that it's consumers can collapse.
didUpdateProducers =
updateProducersFromConsumers(slice.getArrayRef(), opMap);

LLVM_DEBUG({
llvm::dbgs() << "[CollapseDims] : After updating producers: \n";
for (auto &[op, info] : opMap) {
info.dump();
llvm::dbgs() << "\n";
op.dump();
llvm::dbgs() << "\n";
}
llvm::dbgs() << "\n";
op.dump();
});

// Step 5. For each consumer, update it's CollapseInfo to only collapse
// dimensions that all of its producers can collapse. This ensures that all
// reshapes can be propagated to leafs and be hoisted out of the dispatch.
didUpdateConsumers =
updateConsumersFromProducers(slice.getArrayRef(), opMap);

LLVM_DEBUG({
llvm::dbgs() << "[CollapseDims] : After updating consumers: \n";
for (auto &[op, info] : opMap) {
info.dump();
llvm::dbgs() << "\n";
op.dump();
llvm::dbgs() << "\n";
}
llvm::dbgs() << "\n";
}
llvm::dbgs() << "\n";
});
});
}

bool didCollapse = false;

// Step 6. Collapse dimensions based on each op's CollapseInfo
Expand Down Expand Up @@ -850,7 +882,7 @@ void CollapseDimensionsPass::runOnOperation() {

SmallVector<DispatchRegionOp> modifiedDispatchOps;
funcOp->walk([&](DispatchRegionOp dispatchOp) {
if (collapseDimensionsForDispatch(rewriter, dispatchOp)) {
if (collapseDimensionsForDispatch(rewriter, dispatchOp, maxIterations)) {
modifiedDispatchOps.push_back(dispatchOp);
}
});
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ def CloneProducersIntoDispatchRegionsPass :
def CollapseDimensionsPass :
InterfacePass<"iree-flow-collapse-dimensions", "mlir::FunctionOpInterface"> {
let summary = "Collapse dimensions of Linalg Ops on tensor ops.";
let options = [
Option<"maxIterations", "max-iterations", "int",
/*default=*/"10",
"Maximum number of iterations to wait for collapse dimensions to converge"
>,
];
let description = [{
Collapse dimensions of Linalg Ops on tensor ops inside dispatch.region ops
and hoist the reshaping operations out of the dispatch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,42 @@ util.func public @uncollapsable_op(%arg0 : tensor<10x10xi64>) -> tensor<10x10xi6
// CHECK-SAME: ins(%[[VAL0]] : tensor<10x10xi64>)
// CHECK-SAME: outs(%{{.*}} : tensor<10x10xi64>)
// CHECK: flow.return

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @propagate_uncollapsable(%arg0: tensor<2x320x128x128xf32>) -> tensor<2x320x128x128xf32> {
%0 = flow.dispatch.region -> (tensor<2x320x128x128xf32>) {
%empty = tensor.empty() : tensor<2x320x128x128xf32>
%cst = arith.constant 3.14 : f32

%elementwise2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0: tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) {
^bb0(%in : f32, %out : f32):
%22 = arith.mulf %cst, %in : f32
linalg.yield %22 : f32
} -> tensor<2x320x128x128xf32>
%barrier = util.optimization_barrier %arg0: tensor<2x320x128x128xf32>
%elementwise4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%barrier, %elementwise2 : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>) outs(%empty : tensor<2x320x128x128xf32>) {
^bb0(%in : f32, %in_1 : f32, %out : f32):
%22 = arith.mulf %in_1, %in : f32
linalg.yield %22 : f32
} -> tensor<2x320x128x128xf32>

flow.return %elementwise4 : tensor<2x320x128x128xf32>
}
util.return %0 : tensor<2x320x128x128xf32>
}

// CHECK-LABEL: util.func public @propagate_uncollapsable
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x320x128x128xf32>
// CHECK: flow.dispatch.region
// CHECK: %[[VAL1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]] : tensor<2x320x128x128xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>)
// CHECK: %[[VAL2:.*]] = util.optimization_barrier %[[ARG0]] : tensor<2x320x128x128xf32>
// CHECK: %[[VAL3:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>)
// CHECK: flow.return %[[VAL3]]

0 comments on commit 3af05b9

Please sign in to comment.