diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 50388a38cfa4de..6f8bc73fcfc267 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include +#include "llvm/ADT/SmallPtrSet.h" #define DEBUG_TYPE "pdl-predicate-tree" @@ -889,8 +890,17 @@ static void insertExitNode(std::unique_ptr *root) { template void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { while (begin != end) { + // Cannot compute sortBeforeOthers in the predicate of stable_partition + // because stable_partition will not keep the [begin, end) range intact + // while it runs. + llvm::SmallPtrSet sortBeforeOthers; + for(auto i = begin; i != end; ++i) { + if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) + sortBeforeOthers.insert(*i); + } + auto const next = std::stable_partition(begin, end, [&](auto const &a) { - return std::none_of(begin, end, [&](auto const &b) { return cmp(b, a); }); + return sortBeforeOthers.contains(a); }); assert(next != begin && "not a partial ordering"); begin = next;