Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical errors in backward #42

Open
grimulkan opened this issue Jun 27, 2024 · 4 comments
Open

Numerical errors in backward #42

grimulkan opened this issue Jun 27, 2024 · 4 comments

Comments

@grimulkan
Copy link

Were you able to find out the reason for the small numerical errors in backward pass with ring flash attention?

I found the errors increase as you increase the world size, so it does seem to be related to the fact that flash attention returns 16-bit tensors, and even though we accumulate in a 32-bit buffer it seems it is not enough.

Maybe it is an easy PR in flash attention to have them return raw fp32, or do the accumulation upstream?

@takfate
Copy link

takfate commented Sep 27, 2024

Do you have any idea how to implement it?

@zhuzilin
Copy link
Owner

tbh, not really... as far as I can tell, it won't be a simple PR with few changes to flash attention.

@grimulkan
Copy link
Author

With Llama 405B there are many layers, and with ring sizes of 4 or 8 the numerical errors become catastrophic in backward. The errors actually originate in the forward pass in out & lse, but the backward grads can blow up, and it is deceptive because the forward loss value can look quite reasonable.

A quick workaround is to use the context parallel (llama3) implementation in this repo for the forward pass alone. It is much more numerically stable, and if you have NVLink is quite communication efficient (if you don't, there's a penalty, but you could try overlapping compute & comm over the head strides).

The downsides in this repo's implementation is that it doesn't support zigzag and doesn't have an explicit non-varlen implementation, but those are actually easy to address.

The memory overhead is not that bad: you no longer need the 32-bit buffers used in the ring forward, and you could try to minimize the overhead with head stride = 1.

Backward pass can remain the usual zigzag ring implementation. This combination allowed me to scale the ring size even with very large models without huge numerical errors.

An alternative solution would be to replace the forward pass with this Triton version which supports rings and basically does what this issue wants done (no external sigmoid accumulation): https://github.com/lucidrains/ring-attention-pytorch/blob/main/ring_attention_pytorch/triton_flash_attn.py
But some modifications would be needed, and this version only supports striped attention (which is a bit worse than this repo's zigzag implementation), and I am not sure it actually takes advantage of GQA (just replicates KV heads in VRAM).

@zhuzilin
Copy link
Owner

Thank you for this suggestion, I'll try to implement the llama3 implementation with zigzag when I have a moment :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants