Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matmulnbits zero_point fix #3566

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open

Conversation

lakhinderwalia
Copy link
Contributor

Currently the matmulnbits parsing introduces a fixed type uint8_type for zero_point, and misses int8_type.

@kahmed10 kahmed10 self-requested a review October 29, 2024 15:09
@kahmed10
Copy link
Collaborator

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

@lakhinderwalia
Copy link
Contributor Author

lakhinderwalia commented Oct 29, 2024

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8.

@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

@TedThemistokleous
Copy link
Collaborator

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8.

@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs

@lakhinderwalia
Copy link
Contributor Author

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8.
@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs

Ok. Let me enhance this PR to include the basic type checking for this operator.

@lakhinderwalia lakhinderwalia changed the title matmulnbits zezo_point fix matmulnbits zero_point fix Oct 29, 2024
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
14084c
Rate old
71fd27
Diff Compare
torchvision-resnet50 64 3,259.22 3,258.47 0.02%
torchvision-resnet50_fp16 64 nan 6,989.35 nan%
torchvision-densenet121 32 2,435.53 2,437.00 -0.06%
torchvision-densenet121_fp16 32 nan 4,070.38 nan%
torchvision-inceptionv3 32 1,638.28 1,639.85 -0.10%
torchvision-inceptionv3_fp16 32 nan 2,763.41 nan%
cadene-inceptionv4 16 776.15 776.50 -0.04%
cadene-resnext64x4 16 811.84 808.24 0.44%
slim-mobilenet 64 7,536.31 7,538.65 -0.03%
slim-nasnetalarge 64 211.48 211.54 -0.03%
slim-resnet50v2 64 nan 3,507.21 nan%
bert-mrpc-onnx 8 1,151.20 1,150.51 0.06%
bert-mrpc-tf 1 499.42 475.19 5.10% 🔆
pytorch-examples-wlang-gru 1 476.40 426.61 11.67% 🔆
pytorch-examples-wlang-lstm 1 382.40 376.26 1.63%
torchvision-resnet50_1 1 782.58 785.19 -0.33%
cadene-dpn92_1 1 398.76 399.09 -0.08%
cadene-resnext101_1 1 383.83 383.01 0.21%
onnx-taau-downsample 1 nan 343.03 nan%
dlrm-criteoterabyte 1 33.34 33.33 0.05%
dlrm-criteoterabyte_fp16 1 52.71 52.73 -0.04%
agentmodel 1 8,200.96 8,178.96 0.27%
unet_fp16 2 nan 58.92 nan%
resnet50v1_fp16 1 nan 925.30 nan%
resnet50v1_int8 1 nan 1,011.55 nan%
bert_base_cased_fp16 64 nan 1,169.93 nan%
bert_large_uncased_fp16 32 nan 363.31 nan%
bert_large_fp16 1 nan 200.50 nan%
distilgpt2_fp16 16 nan 2,194.69 nan%
yolov5s 1 545.17 533.06 2.27%
tinyllama 1 nan 43.45 nan%
vicuna-fastchat 1 170.69 172.29 -0.92%
whisper-tiny-encoder 1 417.98 417.95 0.01%
whisper-tiny-decoder 1 nan 425.65 nan%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

❌bert_base_cased_fp16: ERROR - check error output


❌bert_large_uncased_fp16: ERROR - check error output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

❌tinyllama: ERROR - check error output


     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

❌whisper-tiny-decoder: ERROR - check error output


❌distilgpt2_fp16: ERROR - check error output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants