Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence length 1 GEMV alternative for fused attention #779

Open
cloudhan opened this issue Jun 30, 2023 · 1 comment
Open

Sequence length 1 GEMV alternative for fused attention #779

cloudhan opened this issue Jun 30, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@cloudhan
Copy link
Contributor

Sequence length 1 is extremely important for decoding (ASR, text generation, etc)

In onnxruntime, we found the rocblas gemm + sofmax kernel +rocblas gemm is much faster for this case,

> KERNEL_EXPLORER_BUILD_DIR=./ python ../../onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py float16 2 1 1500 6 64 0 0 --scale 0.125
 27.26 us  0.17 tflops float16 B=2 S=1 T=1500 N=6 H=64, Generic   # <------------- this is rocblas gemm + sofmax kernel +rocblas gemm
187.71 us  0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 256, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.33 us  0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.19 us  0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
187.60 us  0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>

The shape for the previous shape is as follows

a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 1500, 64}, strides {576000, 96000, 64, 1}
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 1500}, strides {576000, 96000, 1, 64}
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}

Another cases are

a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 21, 64}, strides {49152, 8192, 64, 1}  <---- 21 will increase during decoding 
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 21}, strides {49152, 8192, 1, 64}  <---- 21 will increase during decoding 
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}

It seems current fused attention pad the matrices and calls into tensor cores in any case, hence, wasting of computing power for smaller sequence length. We might need DeviceBatchedGemvSoftmaxGemvPermute variant in this case.

@zjing14 zjing14 added the enhancement New feature or request label Jun 30, 2023
@ppanchad-amd
Copy link

@cloudhan Apologies for the lack of response. Can you please check if this is an issue still with the latest ROCm 6.2? If not, please close the ticket. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants