Skip to content

Commit

Permalink
start on striped attention paper for workload balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 19, 2024
1 parent 51ec491 commit 7c93dda
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ $ python assert.py
- [ ] option to auto-decide ring sequence size based on world size
- [ ] rotary embeddings, with proper key/value offset depending on ring rank

- [ ] figure out striped attention
- [ ] striped attention
- [x] add the permutating logic before and after transformer
- [ ] add causal masking logic - account for sub bucketing by flash attention

- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] find a machine with 8 GPUs and test with a quarter million tokens first
- [ ] figure out batch_isend_irecv
Expand Down
4 changes: 4 additions & 0 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def start(
batch_size_var_len,
seq_len,
causal,
striped_ring_attn,
dim,
use_cuda
):
Expand All @@ -37,6 +38,7 @@ def start(
depth = 1,
dim_head = 8,
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ceil(seq_len / world_size),
q_bucket_size = ceil(seq_len / world_size),
k_bucket_size = ceil(seq_len / world_size)
Expand Down Expand Up @@ -115,6 +117,7 @@ def start(
batch_size_var_len = False
use_cuda = False
causal = True
striped_ring_attn = True

assert not use_cuda or torch.cuda.device_count() <= world_size

Expand All @@ -129,6 +132,7 @@ def start(
batch_size_var_len,
seq_len,
causal,
striped_ring_attn,
dim,
use_cuda
),
Expand Down
52 changes: 45 additions & 7 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,33 @@ def pad_to_multiple(
pad_length = length - remainder
return F.pad(x, (0, pad_length), value = pad_value), pad_length

def sharded_batch_to_sharded_seq(
def maybe_pad_seq_and_mask(
x: Tensor,
mask: Optional[Tensor],
seq_size: int
):
assert is_distributed()

orig_x, seq_len = x, x.shape[-1]

# auto pad sequence and mask, as ring passing makes assumption tensor is all same shape

x, pad_length = pad_to_multiple(x, seq_size)

if pad_length > 0:
if not exists(mask):
mask = torch.ones_like(orig_x).bool()
if pad_length == 0:
return x, mask

if not exists(mask):
mask = torch.ones_like(orig_x).bool()

mask, _ = pad_to_multiple(mask, seq_size, pad_value = False)
mask, _ = pad_to_multiple(mask, seq_size, pad_value = False)

return x, mask

def sharded_batch_to_sharded_seq(
x: Tensor,
mask: Optional[Tensor],
seq_size: int
):
assert is_distributed()

# all gather across batch

Expand Down Expand Up @@ -272,11 +281,14 @@ def __init__(
q_bucket_size = 512,
k_bucket_size = 512,
ring_attn = False,
striped_ring_attn = False,
ring_seq_size = 512,
auto_shard_seq = None,
):
super().__init__()
self.ring_attn = ring_attn
self.striped_ring_attn = striped_ring_attn

self.ring_seq_size = ring_seq_size
self.auto_shard_seq = default(auto_shard_seq, ring_attn) # if ring attention is turned on, auto-shard across sequence dimension. this can also be turned off and done manually elsewhere in the data loading

Expand Down Expand Up @@ -315,9 +327,29 @@ def forward(
seq_len = x.shape[-1]
auto_shard_seq = self.auto_shard_seq & is_distributed()

# take care of padding to divide sequence across the machines

if auto_shard_seq:

# first pad to right multiple

x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)

# account for striped attention
# for workload balancing https://arxiv.org/abs/2311.09431 - MIT paper from Brandon et al.

if self.striped_ring_attn:
x = rearrange('b (i j) -> b (j i)', x, i = self.ring_seq_size)

if exists(mask):
mask = rearrange('b (i j) -> b (j i)', mask, i = self.ring_seq_size)

# gather across batch and divide across world

(x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)

# main transformer logic

x = self.token_emb(x)

for attn, ff in self.layers:
Expand All @@ -326,7 +358,13 @@ def forward(

logits = self.to_logits(x)

# now gather all sequence chunks across machines and shard back to original batch for cross entropy loss

if auto_shard_seq:

if self.striped_ring_attn:
logits = rearrange('b (i j) d -> b (j i) d', logits, i = self.ring_seq_size)

logits, _ = sharded_seq_to_sharded_batch(logits, batch_sizes)
logits = logits[:, :seq_len]

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7c93dda

Please sign in to comment.