Skip to content

Commit

Permalink
facepalm yet again, got it backwards. autoregressive ring attention n…
Browse files Browse the repository at this point in the history
…ow works
  • Loading branch information
lucidrains committed Feb 18, 2024
1 parent 16fd91d commit 51ec491
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ $ python assert.py

- [ ] testing
- [x] make sure key padding mask works
- [ ] make sure causal mask works
- [x] make sure causal mask works
- [ ] option to auto-decide ring sequence size based on world size
- [ ] rotary embeddings, with proper key/value offset depending on ring rank

Expand Down
2 changes: 1 addition & 1 deletion assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def start(
batch_size = 1
batch_size_var_len = False
use_cuda = False
causal = False
causal = True

assert not use_cuda or torch.cuda.device_count() <= world_size

Expand Down
6 changes: 3 additions & 3 deletions ring_attention_pytorch/ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank):
dist.barrier()

class OneRingPass(Function):
""" one ring pass to the right - assume tensor is all same shape for now """
""" one ring pass to the left and receive from the right - assume tensor is all same shape for now """

@staticmethod
def forward(ctx, x):
x = x.contiguous()
receive_buffer = torch.zeros_like(x)
send_and_receive_(x, receive_buffer, circular_rank_right(), circular_rank_left())
send_and_receive_(x, receive_buffer, circular_rank_left(), circular_rank_right())
return receive_buffer

@staticmethod
def backward(ctx, grads):
grads = grads.contiguous()
receive_buffer = torch.zeros_like(grads)
send_and_receive_(grads, receive_buffer, circular_rank_left(), circular_rank_right())
send_and_receive_(grads, receive_buffer, circular_rank_right(), circular_rank_left())
return receive_buffer

one_ring_pass = OneRingPass.apply
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 51ec491

Please sign in to comment.