-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parser changes to handle MatMulIntegerToFloat #3445
base: develop
Are you sure you want to change the base?
Changes from 11 commits
b19ce16
ae9f722
7f62a33
92d8ea4
cdb307d
f912e61
02c3918
3ca3e6a
547826d
c6d8679
6d89fdd
c0c8120
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,9 @@ | |
{ | ||
std::vector<op_desc> operators() const | ||
{ | ||
return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}}; | ||
return {{"MatMul", "dot"}, | ||
{"MatMulInteger", "quant_dot"}, | ||
{"MatMulIntegerToFloat", "quant_dot_scaled"}}; | ||
} | ||
|
||
static void broadcast_dimensions(const onnx_parser::node_info& info, | ||
|
@@ -106,6 +108,62 @@ | |
return all_zeros; | ||
} | ||
|
||
static instruction_ref set_scale_arg(const onnx_parser::node_info& info, | ||
const std::vector<instruction_ref>& args, | ||
const int index) | ||
{ | ||
instruction_ref scale_arg = args[index]; | ||
std::set<migraphx::shape::type_t> supported_dq_types = {migraphx::shape::float_type, | ||
migraphx::shape::half_type}; | ||
|
||
if(not(contains(supported_dq_types, scale_arg->get_shape().type()))) | ||
{ | ||
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALDED: Scales must be float or half_type"); | ||
} | ||
|
||
if(scale_arg->get_shape().scalar()) | ||
{ | ||
scale_arg = info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), scale_arg); | ||
} | ||
|
||
return scale_arg; | ||
} | ||
|
||
static instruction_ref set_scale_bias(const std::vector<instruction_ref>& args, | ||
const int index, | ||
const migraphx::shape& scale_arg_shape, | ||
const instruction_ref& compare_arg, | ||
bool& has_valid_scale_bias) | ||
{ | ||
has_valid_scale_bias = false; | ||
|
||
if(args.size() > index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see an index defined in MatMulIntegerToFloat. Is this for some other operator. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its argument index. We're doing the check here so its done for every arg |
||
{ | ||
instruction_ref scale_bias_arg = args[index]; | ||
std::set<migraphx::shape::type_t> supported_dq_types = {migraphx::shape::float_type, | ||
migraphx::shape::half_type}; | ||
|
||
if(not(contains(supported_dq_types, scale_bias_arg->get_shape().type()))) | ||
{ | ||
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALDED: Bias must be float or half_type"); | ||
} | ||
|
||
if(scale_bias_arg->get_shape().type() != scale_arg_shape.type()) | ||
{ | ||
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias must be the same type as scales"); | ||
} | ||
|
||
if(scale_bias_arg->get_shape().lens().at(0) != compare_arg->get_shape().lens().at(1)) | ||
{ | ||
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias have same dim as matrix B column"); | ||
} | ||
|
||
has_valid_scale_bias = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As against There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If scale bias doesn't exist there isn't a bias at the end of the matmulintergertofloat added then. |
||
return scale_bias_arg; | ||
} | ||
return compare_arg; | ||
} | ||
|
||
static instruction_ref set_bias_arg(const std::vector<instruction_ref>& args, | ||
const int index, | ||
const instruction_ref& input, | ||
|
@@ -148,7 +206,109 @@ | |
} | ||
} | ||
|
||
static void handle_scaled_transposes(const onnx_parser::node_info& info, | ||
instruction_ref& scale_a0, | ||
instruction_ref& zp_a0, | ||
bool no_zp) | ||
{ | ||
if(no_zp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No the second input as its bound by column of the input vector (which is 1-d always). I broke this out instead of adding it inline to encapsulate logic. |
||
{ | ||
scale_a0 = | ||
info.add_instruction(make_op("transpose", {{"permutation", {0, 1}}}), scale_a0); | ||
} | ||
else | ||
{ | ||
scale_a0 = | ||
info.add_instruction(make_op("transpose", {{"permutation", {0, 1}}}), scale_a0); | ||
zp_a0 = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), zp_a0); | ||
} | ||
} | ||
|
||
static instruction_ref handle_dequantized(const onnx_parser::node_info& info, | ||
const instruction_ref& a0, | ||
const instruction_ref& scale_a0, | ||
const instruction_ref& zp_a0, | ||
bool no_zp) | ||
{ | ||
instruction_ref dequantized_op; | ||
|
||
if(no_zp) | ||
{ | ||
auto bc_scale_a0 = info.add_instruction( | ||
make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), scale_a0); | ||
dequantized_op = info.add_instruction(make_op("dequantizelinear"), a0, bc_scale_a0); | ||
} | ||
else | ||
{ | ||
auto bc_scale_a0 = info.add_instruction( | ||
make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), scale_a0); | ||
|
||
auto bc_zp_a0 = info.add_instruction( | ||
make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), zp_a0); | ||
|
||
dequantized_op = | ||
info.add_instruction(make_op("dequantizelinear"), a0, bc_scale_a0, bc_zp_a0); | ||
} | ||
return dequantized_op; | ||
} | ||
|
||
static instruction_ref handle_scaled_output(const onnx_parser::node_info& info, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too many parameters. Ideally they should be handled by a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They're the same amount of a parameters gathered by the operator. These are all needed for dequantize steps and adding the proper unsqueeze->transpose paths. Order matters here with respect to matrix input A or B |
||
const instruction_ref& a0, | ||
const instruction_ref& a1, | ||
const instruction_ref& scale_a0, | ||
const instruction_ref& scale_a1, | ||
const instruction_ref& zp_a0, | ||
const instruction_ref& zp_a1, | ||
const instruction_ref& scaled_bias, | ||
const bool has_scale_bias) | ||
{ | ||
|
||
instruction_ref unsq_zp_a0; | ||
instruction_ref unsq_zp_a1; | ||
|
||
bool a0_has_no_zp = (a0 == zp_a0); | ||
bool a1_has_no_zp = (a1 == zp_a1); | ||
|
||
auto unsq_scale_a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), scale_a0); | ||
if(not a0_has_no_zp) | ||
{ | ||
unsq_zp_a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), zp_a0); | ||
if(zp_a0->get_shape().scalar()) | ||
{ | ||
unsq_zp_a0 = | ||
info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), unsq_zp_a0); | ||
} | ||
} | ||
|
||
if(not a1_has_no_zp) | ||
{ | ||
unsq_zp_a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), zp_a1); | ||
if(zp_a1->get_shape().scalar()) | ||
{ | ||
unsq_zp_a1 = | ||
info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), unsq_zp_a1); | ||
} | ||
} | ||
|
||
auto dq_a0 = handle_dequantized(info, a0, unsq_scale_a0, unsq_zp_a0, a0_has_no_zp); | ||
|
||
// Transpose second input to get column dims before we broadcast to dequantizelinear | ||
auto unsq_scale_a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), scale_a1); | ||
instruction_ref scale_a1_tp = unsq_scale_a1; | ||
instruction_ref zp_a1_tp = unsq_zp_a1; | ||
handle_scaled_transposes(info, scale_a1_tp, zp_a1_tp, a1_has_no_zp); | ||
|
||
auto dq_a1 = handle_dequantized(info, a1, scale_a1_tp, zp_a1_tp, a1_has_no_zp); | ||
auto res = info.add_instruction(make_op("dot"), dq_a0, dq_a1); | ||
|
||
// Handle case of the bias after scaling | ||
if(has_scale_bias) | ||
res = info.add_common_op("sub", res, scaled_bias); | ||
|
||
return res; | ||
} | ||
|
||
instruction_ref parse(const op_desc& opd, | ||
const onnx_parser& /*parser*/, | ||
const onnx_parser::node_info& info, | ||
std::vector<instruction_ref> args) const | ||
|
@@ -173,12 +333,20 @@ | |
} | ||
|
||
auto is_quant_dot = opd.op_name == "quant_dot"; | ||
auto has_scales = opd.op_name == "quant_dot_scaled"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little confusing naming convention. Between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you suggest I name it? quant_dot_dequant? This operator essentially takes in quantized input and dequantizes the output. |
||
if(s0.dynamic() or s1.dynamic()) | ||
{ | ||
if(is_quant_dot) | ||
{ | ||
MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported"); | ||
} | ||
|
||
if(has_scales) | ||
{ | ||
MIGRAPHX_THROW( | ||
"PARSE_MATMULINTEGERTOFLOAT: dynamic MatMulIntegerToFloat not supported"); | ||
} | ||
|
||
auto s0_dds = a0->get_shape().to_dynamic().dyn_dims(); | ||
auto s1_dds = a1->get_shape().to_dynamic().dyn_dims(); | ||
|
||
|
@@ -200,23 +368,50 @@ | |
auto s0_lens = a0->get_shape().lens(); | ||
auto s1_lens = a1->get_shape().lens(); | ||
|
||
if(not is_quant_dot and args.size() > 2) | ||
if(not is_quant_dot and args.size() > 2 and not has_scales) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be simpler to just check if it is just a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure that can easily be swapped |
||
{ | ||
MIGRAPHX_THROW("PARSE_MATMUL: Bias Args not supported for MatMul"); | ||
} | ||
|
||
bool has_ba0 = false; | ||
bool has_ba1 = false; | ||
instruction_ref ba0 = set_bias_arg(args, 2, a0, has_ba0); | ||
instruction_ref ba1 = set_bias_arg(args, 3, a1, has_ba1); | ||
bool has_scale_bias = false; | ||
|
||
int a0_zp_index = 2; | ||
int a1_zp_index = 3; | ||
|
||
instruction_ref scale_a0; | ||
instruction_ref scale_a1; | ||
// Handles case with for when scales are present in operator | ||
if(has_scales) | ||
{ | ||
a0_zp_index = 4; | ||
a1_zp_index = 5; | ||
scale_a0 = set_scale_arg(info, args, 2); | ||
scale_a1 = set_scale_arg(info, args, 3); | ||
if(scale_a0->get_shape().type() != scale_a1->get_shape().type()) | ||
{ | ||
MIGRAPHX_THROW("PARSE_MATMULINTEGERTOFLOAT: Scales must be the same type"); | ||
} | ||
} | ||
|
||
instruction_ref ba0 = set_bias_arg(args, a0_zp_index, a0, has_ba0); | ||
instruction_ref ba1 = set_bias_arg(args, a1_zp_index, a1, has_ba1); | ||
|
||
// handle optional bias arg to the result | ||
instruction_ref scaled_bias; | ||
if(has_scales) | ||
{ | ||
scaled_bias = set_scale_bias(args, 6, scale_a1->get_shape(), a1, has_scale_bias); | ||
} | ||
|
||
// Only INT8 or UINT8 type currently supported | ||
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::uint8_type, | ||
migraphx::shape::int8_type}; | ||
const auto a0_type = a0->get_shape().type(); | ||
const auto a1_type = a1->get_shape().type(); | ||
|
||
if(is_quant_dot and | ||
if((is_quant_dot or has_scales) and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, simple to just add op.name() here as part of the string. Both MatMulInteger and MatMulIntegerToFloat have the same error on this |
||
(not contains(supported_types, a0_type) or not contains(supported_types, a1_type))) | ||
{ | ||
MIGRAPHX_THROW("PARSE_MATMULINTEGER: Unsupported type"); | ||
|
@@ -254,7 +449,18 @@ | |
|
||
broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1); | ||
|
||
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); | ||
// Apply the scale to dequantize input to then perform a simple dot | ||
// after the zero points are applied otherwise get a int32 output from the quantized | ||
// equivalent. Ensure these are broadcasted accordingly before we perform a dot | ||
if(has_scales) | ||
{ | ||
dot_res = handle_scaled_output( | ||
info, a0, a1, scale_a0, scale_a1, ba0, ba1, scaled_bias, has_scale_bias); | ||
} | ||
else | ||
{ | ||
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); | ||
} | ||
} | ||
|
||
// squeeze the appended or prepended dimensions | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SCALDED?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, would this message be proper for
MatMul
operator?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't reach there as its gated by whether the operator contains the scaled inputs. This variant of MatMul also includes the dequantize to convert the quantized input types to float