From abf4234ba380d9d32aafc5f82810341e3abcd9e3 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 26 Jan 2024 08:52:04 +0100 Subject: [PATCH] PDLToPDLInterp: Ensure dependencies between native constraints and their arguments --- .../PDLToPDLInterp/PredicateTree.cpp | 31 +++++++++++++++++++ .../PDLToPDLInterp/use-constraint-result.mlir | 21 +++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 8079390f46ebb3..50388a38cfa4de 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -884,6 +884,19 @@ static void insertExitNode(std::unique_ptr *root) { *root = std::make_unique(); } +/// Sorts the range begin/end with the partial order given by cmp. +/// cmp must be a partial ordering. +template +void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { + while (begin != end) { + auto const next = std::stable_partition(begin, end, [&](auto const &a) { + return std::none_of(begin, end, [&](auto const &b) { return cmp(b, a); }); + }); + assert(next != begin && "not a partial ordering"); + begin = next; + } +} + /// Given a module containing PDL pattern operations, generate a matcher tree /// using the patterns within the given module and return the root matcher node. std::unique_ptr @@ -964,6 +977,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, return *lhs < *rhs; }); + // Mostly keep the now established order, but also ensure that + // ConstraintQuestions come after the results they use. + stableTopologicalSort(ordered.begin(), ordered.end(), + [](OrderedPredicate *a, OrderedPredicate *b) { + auto *cqa = dyn_cast(a->question); + auto *cqb = dyn_cast(b->question); + if (cqa && cqb) { + // Does any argument of b use a? Then b must be + // sorted after a. + return llvm::any_of( + cqb->getArgs(), [&](Position *p) { + auto *cp = dyn_cast(p); + return cp && cp->getQuestion() == cqa; + }); + } + return false; + }); + // Build the matchers for each of the pattern predicate lists. std::unique_ptr root; for (OrderedPredicateList &list : lists) diff --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir new file mode 100644 index 00000000000000..cdd51ff0ad6379 --- /dev/null +++ b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s + +// Ensuse that the dependency between add & less +// causes them to be in the correct order. +// CHECK: apply_constraint "__builtin_add" +// CHECK: apply_constraint "__builtin_less" + +module { + pdl.pattern @test : benefit(1) { + %0 = attribute + %1 = types + %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range) + %3 = attribute = 0 : i32 + %4 = attribute = 1 : i32 + %5 = apply_native_constraint "__builtin_add"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute + apply_native_constraint "__builtin_less"(%0, %5 : !pdl.attribute, !pdl.attribute) + rewrite %2 { + replace %2 with %2 + } + } +}