diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 52aa2a3ec6..2af6c621ae 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -441,9 +441,6 @@ struct find_mul_add auto x_ins = r.instructions["x"]; assert(x_ins != b_ins); - if(a_ins->get_shape().scalar()){ - return; - } auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); @@ -451,45 +448,6 @@ struct find_mul_add } }; -struct find_scalar_mul_conv -{ - auto matcher() const - { - return match::name("mul")( - match::either_arg(0, 1)( - conv_const_weights().bind("conv"), - match::either_arg(0, 1)( - match::name("broadcast", "multibroadcast", "constant").bind("scalar"), - match::any().bind("scalar") - ) - ) - ); - } - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto conv_ins = r.instructions["conv"]; - auto scalar_ins = r.instructions["scalar"]; - auto w_ins = r.instructions["w"]; - - if(scalar_ins->get_shape().elements() != 1) - return; - const auto& w_shape = w_ins->get_shape().lens(); - - if(scalar_ins->get_shape().ndim() != w_shape.size()) - { - scalar_ins = m.insert_instruction(ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_shape}}), scalar_ins); - } - - auto new_weights = m.insert_instruction(ins, make_op("mul"), scalar_ins, w_ins); - - auto new_conv = m.insert_instruction( - ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights); - - m.replace_instruction(ins, new_conv); - } -}; - struct find_dot_add { auto matcher() const @@ -2025,7 +1983,6 @@ void simplify_algebra::apply(module& m) const find_dot_slice{}, find_dot_mul{}, find_mul_add{}, - find_scalar_mul_conv{}, find_unit_ops{}, find_neg_unit_ops{}, eliminate_zero_point{},