Skip to content

Commit

Permalink
removed conv changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aarushjain29 committed Oct 30, 2024
1 parent 6a99db9 commit e17bb63
Showing 1 changed file with 0 additions and 43 deletions.
43 changes: 0 additions & 43 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,55 +441,13 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

if(a_ins->get_shape().scalar()){
return;
}

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};

struct find_scalar_mul_conv
{
auto matcher() const
{
return match::name("mul")(
match::either_arg(0, 1)(
conv_const_weights().bind("conv"),
match::either_arg(0, 1)(
match::name("broadcast", "multibroadcast", "constant").bind("scalar"),
match::any().bind("scalar")
)
)
);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto scalar_ins = r.instructions["scalar"];
auto w_ins = r.instructions["w"];

if(scalar_ins->get_shape().elements() != 1)
return;
const auto& w_shape = w_ins->get_shape().lens();

if(scalar_ins->get_shape().ndim() != w_shape.size())
{
scalar_ins = m.insert_instruction(ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_shape}}), scalar_ins);
}

auto new_weights = m.insert_instruction(ins, make_op("mul"), scalar_ins, w_ins);

auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);

m.replace_instruction(ins, new_conv);
}
};

struct find_dot_add
{
auto matcher() const
Expand Down Expand Up @@ -2025,7 +1983,6 @@ void simplify_algebra::apply(module& m) const
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_scalar_mul_conv{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down

0 comments on commit e17bb63

Please sign in to comment.