Skip to content

Commit

Permalink
Don't let users disguise RVars as Vars
Browse files Browse the repository at this point in the history
Fixes #7827
  • Loading branch information
abadams committed Oct 18, 2024
1 parent 7ceddb4 commit 21b6270
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ bool var_name_match(const string &candidate, const string &var) {
}
return Internal::ends_with(candidate, "." + var);
}

bool dim_match(const Dim &candidate, const VarOrRVar &var) {
if (var_name_match(candidate.var, var.name())) {
user_assert(candidate.is_rvar() == var.is_rvar)
<< (var.is_rvar ? "RVar " : "Var ") << var.name()
<< " used in scheduling directive has the same name as existing "
<< (candidate.is_rvar() ? "RVar " : "Var ") << candidate.var << "\n";
return true;
} else {
return false;
}
}

} // namespace

std::string Stage::name() const {
Expand Down Expand Up @@ -455,7 +468,7 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) {
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
if (dim_match(dim, var)) {
found = true;
dim.for_type = t;

Expand Down Expand Up @@ -523,7 +536,7 @@ void Stage::set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api) {
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
if (dim_match(dim, var)) {
found = true;
dim.device_api = device_api;
}
Expand Down Expand Up @@ -1129,7 +1142,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
string inner_name, outer_name, old_name;

for (size_t i = 0; (!found) && i < dims.size(); i++) {
if (var_name_match(dims[i].var, old)) {
if (dim_match(dims[i], VarOrRVar(old, exact))) {
found = true;
old_name = dims[i].var;
inner_name = old_name + "." + inner;
Expand Down Expand Up @@ -1321,7 +1334,7 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV

DimType outer_type = DimType::PureRVar;
for (size_t i = 0; (!found_outer) && i < dims.size(); i++) {
if (var_name_match(dims[i].var, outer.name())) {
if (dim_match(dims[i], outer)) {
found_outer = true;
outer_name = dims[i].var;
outer_type = dims[i].dim_type;
Expand All @@ -1337,7 +1350,7 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV
}

for (size_t i = 0; (!found_inner) && i < dims.size(); i++) {
if (var_name_match(dims[i].var, inner.name())) {
if (dim_match(dims[i], inner)) {
found_inner = true;
inner_name = dims[i].var;
fused_name = inner_name + "." + fused.name();
Expand Down Expand Up @@ -1450,7 +1463,7 @@ Stage &Stage::purify(const VarOrRVar &old_var, const VarOrRVar &new_var) {
vector<Dim> &dims = schedule.dims();

for (size_t i = 0; (!found) && i < dims.size(); i++) {
if (var_name_match(dims[i].var, old_var.name())) {
if (dim_match(dims[i], old_var)) {
found = true;
old_name = dims[i].var;
dims[i].var = new_name;
Expand Down Expand Up @@ -1592,7 +1605,7 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) {
string old_name;
vector<Dim> &dims = schedule.dims();
for (size_t i = 0; (!found) && i < dims.size(); i++) {
if (var_name_match(dims[i].var, old_var.name())) {
if (dim_match(dims[i], old_var)) {
found = true;
old_name = dims[i].var;
dims[i].var += "." + new_var.name();
Expand Down Expand Up @@ -1735,7 +1748,7 @@ Stage &Stage::partition(const VarOrRVar &var, Partition policy) {
bool found = false;
vector<Dim> &dims = definition.schedule().dims();
for (auto &dim : dims) {
if (var_name_match(dim.var, var.name())) {
if (dim_match(dim, var)) {
found = true;
dim.partition_policy = policy;
}
Expand Down Expand Up @@ -1851,7 +1864,7 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {
for (size_t i = 0; i < vars.size(); i++) {
bool found = false;
for (size_t j = 0; j < dims.size(); j++) {
if (var_name_match(dims[j].var, vars[i].name())) {
if (dim_match(dims[j], vars[i])) {
idx[i] = j;
found = true;
}
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ tests(GROUPS error
store_at_without_compute_at.cpp
thread_id_outside_block_id.cpp
too_many_args.cpp
treat_rvar_as_var.cpp
tuple_arg_select_undef.cpp
tuple_output_bounds_check.cpp
tuple_realization_to_buffer.cpp
Expand Down
21 changes: 21 additions & 0 deletions test/error/treat_rvar_as_var.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "Halide.h"

#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
Func f;
Var x, y;

RDom r(0, 10);
f(x, y) += r;

// Sneakily disguising an RVar as a Var by reusing the name should result in
// an error. Otherwise it can permit schedules that aren't legal.
Var xo, xi;
f.update().split(Var(r.x.name()), xo, xi, 8, TailStrategy::RoundUp);

printf("Success!\n");
return 0;
}

0 comments on commit 21b6270

Please sign in to comment.