Skip to content

Commit

Permalink
Use llvm::EquivalenceClasses
Browse files Browse the repository at this point in the history
  • Loading branch information
ezhulenev committed Sep 1, 2023
1 parent 3036110 commit bd828c0
Showing 1 changed file with 42 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <iterator>

#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -42,94 +41,63 @@ namespace {
// We use union-find algorithm to construct it from pairs of aliasing values.
class ValueAliasingSet {
public:
// A mutable builder to construct value aliasing set. Once builder is freezed
// value aliasing set provides efficient access to discovered value alises.
class Builder {
public:
void addAlias(Value aliasee, Value aliaser) {
int64_t aliaseeRoot = getRoot(aliasee);
int64_t aliaserRoot = getRoot(aliaser);
root[aliaserRoot] = aliaseeRoot;
}

ValueAliasingSet build();

private:
int64_t getId(Value value) {
auto [iterator, inserted] = id.try_emplace(value, id.size());
if (inserted)
root.push_back(iterator->second);
return iterator->second;
}
void addAlias(Value aliasee, Value aliaser) {
valueAliasing.unionSets(getWithId(aliasee), getWithId(aliaser));
}

int64_t getRoot(Value value) {
int64_t id = getId(value);
while (root[id] != id) {
id = root[id] = root[root[id]];
SmallVector<SmallVector<Value>> getValueAliasSets() const {
SmallVector<SmallVector<Value>> result;
for (auto it = valueAliasing.begin(); it != valueAliasing.end(); ++it) {
if (!it->isLeader())
continue; // Ignore non-leader sets.
auto &aliasSet = result.emplace_back();
for (auto mi = valueAliasing.member_begin(it);
mi != valueAliasing.member_end(); ++mi) {
aliasSet.push_back(mi->value);
}
return id;
}

llvm::DenseMap<Value, int64_t> id;
SmallVector<int64_t> root;
};

ArrayRef<SmallVector<Value>> getValueAliasSets() const {
return valueAliasSets;
return result;
}

auto getValueAliases(Value value) const {
ArrayRef<Value> aliasers;
if (auto it = valueAliasSet.find(value); it != valueAliasSet.end()) {
aliasers = valueAliasSets[it->second];
}
return llvm::make_filter_range(
aliasers, [=](Value aliaser) { return aliaser != value; });
llvm::map_range(
llvm::make_range(valueAliasing.findLeader(getWithId(value)),
valueAliasing.member_end()),
NumberedValue::getValue),
[=](Value aliaser) { return aliaser != value; });
}

private:
ValueAliasingSet(SmallVector<SmallVector<Value>> valueAliasSets,
llvm::DenseMap<Value, int64_t> valueAliasSet)
: valueAliasSets(std::move(valueAliasSets)),
valueAliasSet(std::move(valueAliasSet)) {}
// EquivalenceClasses require ordering for value type to return deterministic
// results, so we provide it by assigning id to all values added to the set.
struct NumberedValue {
Value value;
int64_t id;

SmallVector<SmallVector<Value>> valueAliasSets;
llvm::DenseMap<Value, int64_t> valueAliasSet;
};
static Value getValue(const NumberedValue &value) { return value.value; }
};

ValueAliasingSet ValueAliasingSet::Builder::build() {
SmallVector<SmallVector<Value>> valueAliasSets;
llvm::DenseMap<Value, int64_t> valueAliasSet;

// Sort all values to guarantee that we return them in determenistic order.
SmallVector<std::pair<Value, int64_t>> values(id.begin(), id.end());
llvm::sort(values, [](auto a, auto b) { return a.second < b.second; });

// Run path compression to propagate roots to all values and guarantee that in
// the next step we'll get only "real" roots.
for (auto &[value, index] : values)
(void)getRoot(value);

// Collect unique roots, and sort them to guarantee determenistic order.
auto roots = llvm::SetVector<int64_t>(root.begin(), root.end()).takeVector();
llvm::sort(roots);

valueAliasSets.resize(roots.size());
for (auto &[value, index] : values) {
int64_t aliasSetIndex =
std::distance(roots.begin(), llvm::find(roots, getRoot(value)));
assert(aliasSetIndex < valueAliasSets.size() && "root was not found");
valueAliasSets[valueAliasSet[value] = aliasSetIndex].push_back(value);
struct Comparator {
int operator()(const NumberedValue &a, const NumberedValue &b) const {
return a.id < b.id;
}
};

NumberedValue getWithId(Value value) const {
auto [iterator, inserted] = id.try_emplace(value, id.size());
return {value, iterator->second};
}

return ValueAliasingSet(std::move(valueAliasSets), std::move(valueAliasSet));
}
mutable llvm::DenseMap<Value, int64_t> id;
llvm::EquivalenceClasses<NumberedValue, Comparator> valueAliasing;
};

// Builds a map of value aliases from aliasee to a set of aliasers.
// Only values that alias will be present in the map. The map may contain
// values nested within the |regionOp|.
static void computeRegionValueAliases(Operation *regionOp,
ValueAliasingSet::Builder &valueAliases) {
ValueAliasingSet &valueAliases) {
auto *block = &regionOp->getRegion(0).front();

// Filter out to only resource results - some regions may return additional
Expand Down Expand Up @@ -182,9 +150,9 @@ static void computeRegionValueAliases(Operation *regionOp,
// set. The set may contain values nested within the |executeOp|.
static ValueAliasingSet
computeExecutionRegionValueAliases(IREE::Stream::AsyncExecuteOp executeOp) {
ValueAliasingSet::Builder valueAliases;
ValueAliasingSet valueAliases;
computeRegionValueAliases(executeOp, valueAliases);
return valueAliases.build();
return valueAliases;
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit bd828c0

Please sign in to comment.