Skip to content

Commit

Permalink
Add a check that PredicateLoads must be used in the outermost split o…
Browse files Browse the repository at this point in the history
…f a dimension (#7788)

* add a check that PredicateLoads must be used in the outermost split of a dimension

* newline

* use the repro example

* fix

* avoid check for every other tail strategy

* update error message to point out what's not allowed

---------

Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
TH3CHARLie and steven-johnson authored Sep 5, 2023
1 parent 8188b42 commit 02865e2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,34 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
<< "Use TailStrategy::GuardWithIf instead.";
}

bool predicate_loads_ok = !exact;
if (predicate_loads_ok && tail == TailStrategy::PredicateLoads) {
// If it's the outermost split in this dimension, PredicateLoads
// is OK. Otherwise we can't prove it's safe.
std::set<string> inner_vars;
for (const Split &s : definition.schedule().splits()) {
if (s.is_split()) {
inner_vars.insert(s.inner);
if (inner_vars.count(s.old_var)) {
inner_vars.insert(s.outer);
}
} else if (s.is_rename() || s.is_purify()) {
if (inner_vars.count(s.old_var)) {
inner_vars.insert(s.outer);
}
} else if (s.is_fuse()) {
if (inner_vars.count(s.inner) || inner_vars.count(s.outer)) {
inner_vars.insert(s.old_var);
}
}
}
predicate_loads_ok = !inner_vars.count(old_name);
user_assert(predicate_loads_ok || tail != TailStrategy::PredicateLoads)
<< "Can't use TailStrategy::PredicateLoads for splitting " << old_name
<< " in the definition of " << name() << ". "
<< "PredicateLoads may not be used to split a Var stemming from the inner Var of a prior split.";
}

if (tail == TailStrategy::Auto) {
// Select a tail strategy
if (exact) {
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ tests(GROUPS error
overflow_during_constant_folding.cpp
pointer_arithmetic.cpp
race_condition.cpp
predicate_loads_used_in_inner_splits.cpp
rdom_undefined.cpp
rdom_where_races.cpp
realization_with_too_many_outputs.cpp
Expand Down
15 changes: 15 additions & 0 deletions test/error/predicate_loads_used_in_inner_splits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv) {
Func f;
Var x, xo, xi, xio, xii;
f(x) = x;
f.split(x, xo, xi, 2, TailStrategy::Auto)
.split(xi, xio, xii, 4, TailStrategy::PredicateLoads)
.reorder(xo, xio, xii);

printf("Success!\n");
return 0;
}

0 comments on commit 02865e2

Please sign in to comment.