Skip to content

Commit

Permalink
rotary positions under all circumstances complete
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 20, 2024
1 parent eb9a1de commit bafbe46
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ $ python assert.py
- [x] move flash attention back to key / value column traversal on outer loop and save on ring communication
- [x] backwards
- [x] forwards
- [x] fix rotary positions for striped ring attention when flash buckets > 1

- [ ] fix rotary positions for striped ring attention when flash buckets > 1
- [ ] option to auto-decide ring sequence size based on world size
- [ ] allow for finely specifying how to distribute sharding of batch and sequence, depending on world size
- [ ] allow for variable ring passes per layer, for <a href="https://arxiv.org/abs/2007.03356">local -> global attention</a> in ring transformer as one goes up the layers.
Expand Down
4 changes: 2 additions & 2 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def start(
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ceil(seq_len / world_size),
bucket_size = ceil(seq_len / world_size),
bucket_size = ceil(seq_len / world_size / 2),
)

flash_attention_net = RingTransformer(
Expand Down Expand Up @@ -114,7 +114,7 @@ def start(
batch_size_var_len = False
use_cuda = False
causal = True
striped_ring_attn = True
striped_ring_attn = False

assert not use_cuda or torch.cuda.device_count() <= world_size

Expand Down
29 changes: 18 additions & 11 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from torch.nn import Module, ModuleList

from einops import rearrange
from einops import rearrange, repeat

from ring_attention_pytorch.ring import (
all_ring_pass,
Expand Down Expand Up @@ -375,20 +375,27 @@ def forward(
# rotary positions
# taking into account ring and striping

maybe_chunk_seq_len = x.shape[-1]

pos = torch.arange(maybe_chunk_seq_len, device = device)
pos = None
curr_seq_len = x.shape[-1]

if auto_shard_seq:
if self.striped_ring_attn:
ring_stride = get_world_size()
ring_offset = 1
else:
ring_stride = 1
ring_offset = maybe_chunk_seq_len
buckets = self.ring_seq_size // self.bucket_size
ring_stride = get_world_size() * buckets
ring_offset = buckets

pos = torch.arange(curr_seq_len // buckets, device = device)
pos = repeat(pos, 'n -> n b', b = buckets)

pos *= ring_stride
pos += ring_offset * get_rank()
pos = pos * ring_stride
pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
pos = rearrange(pos, 'n b -> (b n)')

else:
pos = torch.arange(curr_seq_len, device = device)
pos += curr_seq_len * get_rank()
else:
pos = torch.arange(curr_seq_len, device = device)

rotary_emb = self.rotary_emb(pos)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.11',
version = '0.1.12',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit bafbe46

Please sign in to comment.