-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parser changes to handle MatMulIntegerToFloat #3445
base: develop
Are you sure you want to change the base?
Conversation
TODO:
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #3445 +/- ##
===========================================
- Coverage 92.17% 92.17% -0.01%
===========================================
Files 512 512
Lines 21387 21459 +72
===========================================
+ Hits 19714 19779 +65
- Misses 1673 1680 +7 ☔ View full report in Codecov by Sentry. |
Updated parser to handle bias case as well as bad scale conditions Initial float/half tests bad scale tests bad bias tests
avoid tidy screaming about complexity
74f8ae0
to
cdb307d
Compare
Use dequantizelinear which elminates the need to add in shifts due to int8/uint8 mismatches still needs parser tests
|
||
if(not(contains(supported_dq_types, scale_arg->get_shape().type()))) | ||
{ | ||
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALDED: Scales must be float or half_type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SCALDED?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, would this message be proper for MatMul
operator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't reach there as its gated by whether the operator contains the scaled inputs. This variant of MatMul also includes the dequantize to convert the quantized input types to float
{ | ||
has_valid_scale_bias = false; | ||
|
||
if(args.size() > index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see an index defined in MatMulIntegerToFloat. Is this for some other operator. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its argument index. We're doing the check here so its done for every arg
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias have same dim as matrix B column"); | ||
} | ||
|
||
has_valid_scale_bias = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As against invalid
? ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If scale bias doesn't exist there isn't a bias at the end of the matmulintergertofloat added then.
instruction_ref& zp_a0, | ||
bool no_zp) | ||
{ | ||
if(no_zp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems only zp_a0
needs to be in the if clause..?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No the second input as its bound by column of the input vector (which is 1-d always).
I broke this out instead of adding it inline to encapsulate logic.
return dequantized_op; | ||
} | ||
|
||
static instruction_ref handle_scaled_output(const onnx_parser::node_info& info, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too many parameters. Ideally they should be handled by a struct
parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're the same amount of a parameters gathered by the operator. These are all needed for dequantize steps and adding the proper unsqueeze->transpose paths. Order matters here with respect to matrix input A or B
@@ -173,12 +333,20 @@ struct parse_matmul : op_parser<parse_matmul> | |||
} | |||
|
|||
auto is_quant_dot = opd.op_name == "quant_dot"; | |||
auto has_scales = opd.op_name == "quant_dot_scaled"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little confusing naming convention. Between quant_dots
and Matmul
**. And then there is has_scales
: which is presumably also a quant_dot
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you suggest I name it? quant_dot_dequant? This operator essentially takes in quantized input and dequantizes the output.
@@ -200,23 +368,50 @@ struct parse_matmul : op_parser<parse_matmul> | |||
auto s0_lens = a0->get_shape().lens(); | |||
auto s1_lens = a1->get_shape().lens(); | |||
|
|||
if(not is_quant_dot and args.size() > 2) | |||
if(not is_quant_dot and args.size() > 2 and not has_scales) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be simpler to just check if it is just a dot
, instead of looking for quant_dot
and quant_dot_scaled
, as this clause seems to be doing? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure that can easily be swapped
|
||
// Only INT8 or UINT8 type currently supported | ||
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::uint8_type, | ||
migraphx::shape::int8_type}; | ||
const auto a0_type = a0->get_shape().type(); | ||
const auto a1_type = a1->get_shape().type(); | ||
|
||
if(is_quant_dot and | ||
if((is_quant_dot or has_scales) and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it has_scales
, then it is perhaps not a MATLMULINTEGER
: as shown in the exception message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, simple to just add op.name() here as part of the string. Both MatMulInteger and MatMulIntegerToFloat have the same error on this
Check results before merge 🔆 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
Changes to MatMul parser to handle the Microsoft Contrib operator MatMulintegarToFloat
Since we have the scale and zero points in our operands we can just perform a multiplied after int8 biases are added and then insert a regular dot on the scaled input values which should give the same output as the input data types.
Able to leverage the existing set of tests for matmul
Needs #3526 as there's a bug with dequantizelinear this has uncovered