Skip to content

Commit

Permalink
now ring and flash attention works together, needed to scale sequence…
Browse files Browse the repository at this point in the history
… length
  • Loading branch information
lucidrains committed Feb 19, 2024
1 parent 8c19b24 commit b632792
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ $ python assert.py
- [x] striped attention
- [x] add the permutating logic before and after transformer
- [x] add causal masking logic - account for sub bucketing by flash attention
- [x] fix issue with ring attention when flash buckets > 1

- [ ] fix autoregressive when there is greater than 1 flash attention tile per machine
- [ ] move flash attention back to key / value column traversal on outer loop and save on ring communication
- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] find a machine with 8 GPUs and test with a quarter million tokens first
- [ ] figure out batch_isend_irecv
Expand Down
10 changes: 5 additions & 5 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def start(
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ceil(seq_len / world_size),
bucket_size = ceil(seq_len / world_size),
bucket_size = ceil(seq_len / world_size / 2),
)

flash_attention_net = RingTransformer(
Expand Down Expand Up @@ -101,24 +101,24 @@ def start(
assert torch.allclose(
ring_embed_grad,
flash_embed_grad,
atol = 1e-3
atol = 1e-2
), 'grad is not the same'

print('✅ outputs and gradients are same between ring attention and non-ring attention')

cleanup()

if __name__ == '__main__':
world_size = 8
world_size = 2
batch_size = 2
batch_size_var_len = False
use_cuda = False
causal = False
striped_ring_attn = False
striped_ring_attn = True

assert not use_cuda or torch.cuda.device_count() <= world_size

seq_len = 32
seq_len = 3
dim = 8

mp.spawn(
Expand Down
4 changes: 4 additions & 0 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def forward(
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)

k = one_ring_pass(k)
v = one_ring_pass(v)
mask = maybe(one_ring_pass)(mask)

oc.div_(row_sums)
Expand Down Expand Up @@ -238,6 +240,8 @@ def backward(ctx, do):
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)

k = one_ring_pass(k)
v = one_ring_pass(v)
mask = maybe(one_ring_pass)(mask)

dk = one_ring_pass(dk)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b632792

Please sign in to comment.