Skip to content

Commit

Permalink
make the assumption of dim keys == dim values and ring pass in one co…
Browse files Browse the repository at this point in the history
…mmunication
  • Loading branch information
lucidrains committed Feb 21, 2024
1 parent bafbe46 commit 29f76f1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ $ python assert.py
- [ ] allow for variable ring passes per layer, for <a href="https://arxiv.org/abs/2007.03356">local -> global attention</a> in ring transformer as one goes up the layers.
- [ ] also allow for mixed striped layers, so one can do local strided attention (used in some image transformers but not language) would be a good fit for the initial local attention layers, in addition to classic lookback local
- [ ] find a machine with 8 GPUs and test with a quarter million tokens first
- [ ] figure out batch_isend_irecv
- [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training
- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] `batch_isend_irecv` in the presence of key padding mask needing ring exchange, but not a big priority

## Citations

Expand Down
19 changes: 15 additions & 4 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def forward(

num_tiles = math.ceil(per_machine_seq_size / bucket_size)

for ring_rank, (k, v, mask) in ring_pass_fn(k, v, mask, max_iters = max_ring_passes):
kv = torch.stack((k, v))

for ring_rank, (kv, mask) in ring_pass_fn(kv, mask, max_iters = max_ring_passes):

k, v = kv

col_splits = zip(
k.split(bucket_size, dim = -2),
Expand Down Expand Up @@ -209,7 +213,11 @@ def backward(ctx, do):
dq.split(bucket_size, dim = -2)
)

for ring_rank, (k, v, mask, dk, dv) in ring_pass_fn(k, v, mask, dk, dv, max_iters = max_ring_passes):
kv_and_dkv = torch.stack((k, v, dk, dv))

for ring_rank, (kv_and_dkv, mask) in ring_pass_fn(kv_and_dkv, mask, max_iters = max_ring_passes):

k, v, dk, dv = kv_and_dkv

col_splits = zip(
k.split(bucket_size, dim = -2),
Expand Down Expand Up @@ -257,8 +265,11 @@ def backward(ctx, do):
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)

dk = one_ring_pass(dk)
dv = one_ring_pass(dv)
dkv = kv_and_dkv[2:]

dkv = one_ring_pass(dkv)

dk, dv = dkv

return dq, dk, dv, None, None, None, None, None, None, None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.12',
version = '0.1.14',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 29f76f1

Please sign in to comment.