Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zigzag attention support? #20

Closed
dwromero opened this issue Oct 23, 2024 · 20 comments
Closed

Zigzag attention support? #20

dwromero opened this issue Oct 23, 2024 · 20 comments

Comments

@dwromero
Copy link

Hi @lucidrains ,

I hope you are doing well. And thank you for yet another useful repo! :)

I was wondering if you have any plans to support the zigzag version of ring attention. It seems to distributed compute better in autoregressive settings and is quite hot at the moment (zhuzilin/ring-flash-attention#2). I could help if you need help with that.

David

@lucidrains
Copy link
Owner

hey David, no problem

could you link me to the paper?

did you see the rotation trick from Chris Fifty yet?

@dwromero
Copy link
Author

Hi,

could you link me to the paper?

-> It's used in the Llama3 paper (https://arxiv.org/abs/2407.21783). Page 11 of the paper in the section on context parallelism. Though they don't actually use the form of ring attention implemented here, for GQA and attention masking reasons.

did you see the rotation trick from Chris Fifty yet?

-> I have not. What is it about?

@lucidrains
Copy link
Owner

check out the vq repo

nice! didn't even know Meta was using ring attention 🤣 I'll read the paper tomorrow

@lucidrains
Copy link
Owner

guess all the big players will be using some form of sequence parallel attention soon (google, meta, and you at nvidia)

@lucidrains
Copy link
Owner

@dwromero could i prompt you for a summary of what zigzag is? is it just another way to permute the sequence for better balancing?

@dwromero
Copy link
Author

That's right

@lucidrains
Copy link
Owner

@dwromero ok, should be an easy add!

@dwromero
Copy link
Author

🤟🤟🤟

@lucidrains
Copy link
Owner

@dwromero oh, there is nothing to zigzag (did you coin that term?)

it is just an all gather for keys and values, with GQA as justification

@lucidrains
Copy link
Owner

lucidrains commented Oct 24, 2024

Screen Shot 2024-10-24 at 7 24 35 AM

yes i see, this is a greatly simplified version than what is here

@lucidrains
Copy link
Owner

@dwromero let me break this project into two, where i first handle the permuting they do, then offer the all gather for the key / values, both configurable.

@lucidrains
Copy link
Owner

@dwromero actually, maybe it should just be a separate self contained file given how different it is

@lucidrains lucidrains changed the title Zigzag ring attention support? Zigzag attention support? Oct 24, 2024
@dwromero
Copy link
Author

I actually tried this with TransformerEngine and it works simply by splitting differently. Ran some tests and all seems to match. Do you think that would be sufficient here too?

Basically, using a splitting like:

def extract_local(value, rank, world_size, dim=1):
    value_chunks = value.chunk(2 * world_size, dim=dim)
    local_value = torch.cat(
        [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
    )
    return local_value.contiguous()

@lucidrains
Copy link
Owner

@dwromero yea that works for sharding the sequence

but you'll need to handle the masking (maybe flex attention can come in handy here). and it seems like they project the key values on each rank separately then do the all gather?

@lucidrains
Copy link
Owner

lucidrains commented Oct 24, 2024

yea i don't know if i completely buy this. sure GQA can be enough savings that an all gather at 128k is fine, but how about 10 million? yea, this is definitely sequence parallelism in its crudest form, imo

@lucidrains lucidrains mentioned this issue Oct 24, 2024
@lucidrains
Copy link
Owner

lucidrains commented Oct 24, 2024

@dwromero made a bit of progress in the linked PR but out of steam

will resume tomorrow morning

feel free to leave any comments for anything that doesn't look right

@lucidrains
Copy link
Owner

@dwromero alright, think i can knock out the remaining this morning

you still there?

@lucidrains
Copy link
Owner

@dwromero think it is all there in 0.5.19, you can play around with it by running the assert_zig_zag.py test script

@dwromero
Copy link
Author

Wow cool! Thank you so much @lucidrains ! 💪

@lucidrains
Copy link
Owner

@dwromero no problem. if you can get me some nvidia cloud compute, i can throw in the flex attention logic. but not a big priority for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants