Skip to content

Commit

Permalink
complete rotary embeddings in context of ring reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 20, 2024
1 parent ebc5e58 commit ed81af6
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 16 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ $ python assert.py
- [x] modify flash attention to output intermediates and figure out backwards with recompute and ring passes
- [x] functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper
- [x] basic test case with two processes and check for equivalent output and gradients

- [ ] testing
- [x] testing
- [x] make sure key padding mask works
- [x] make sure causal mask works
- [ ] option to auto-decide ring sequence size based on world size
- [ ] rotary embeddings, with proper key/value offset depending on ring rank
- [x] rotary embeddings, with proper key/value offset depending on ring rank
- [x] striped attention
- [x] add the permutating logic before and after transformer
- [x] add causal masking logic - account for sub bucketing by flash attention
Expand All @@ -68,6 +66,8 @@ $ python assert.py
- [x] backwards
- [x] forwards

- [ ] option to auto-decide ring sequence size based on world size
- [ ] allow for finely specifying how to distribute sharding of batch and sequence, depending on world size
- [ ] allow for variable ring passes per layer, for <a href="https://arxiv.org/abs/2007.03356">local -> global attention</a> in ring transformer as one goes up the layers.
- [ ] also allow for mixed striped layers, so one can do local strided attention (used in some image transformers but not language) would be a good fit for the initial local attention layers, in addition to classic lookback local
- [ ] find a machine with 8 GPUs and test with a quarter million tokens first
Expand Down
38 changes: 35 additions & 3 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
AllGather
)

from ring_attention_pytorch.rotary import (
RotaryEmbedding,
apply_rotary_pos_emb
)

# helper functions

def exists(v):
Expand Down Expand Up @@ -203,7 +208,8 @@ def __init__(
def forward(
self,
x,
mask = None
mask = None,
rotary_emb = None
):
"""
einstein notation
Expand All @@ -227,6 +233,14 @@ def forward(
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)

# rotary relative positions

if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)

# regular attention vs flash w/ or w/o kv ring reduce

if self.force_regular_attn or not is_distributed():
out = default_attention(q, k, v, mask = mask, causal = self.causal)
else:
Expand Down Expand Up @@ -304,6 +318,7 @@ def __init__(
assert not (self.striped_ring_attn and not causal), 'striped ring attention only applies to autoregressive models'

self.token_emb = nn.Embedding(num_tokens, dim)
self.rotary_emb = RotaryEmbedding(dim_head)

self.layers = ModuleList([])

Expand Down Expand Up @@ -334,7 +349,7 @@ def forward(
x,
mask = None
):
seq_len = x.shape[-1]
seq_len, device = x.shape[-1], x.device
auto_shard_seq = self.auto_shard_seq & is_distributed()

# take care of padding to divide sequence across the machines
Expand All @@ -351,19 +366,36 @@ def forward(
if self.striped_ring_attn:
x = rearrange(x, 'b (i j) -> b (j i)', i = self.bucket_size)

stripe_stride = x.shape[-1] // self.bucket_size

if exists(mask):
mask = rearrange(mask, 'b (i j) -> b (j i)', i = self.bucket_size)

# gather across batch and divide across world

(x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)

# rotary positions
# taking into account ring and striping

maybe_chunk_seq_len = x.shape[-1]

pos = torch.arange(maybe_chunk_seq_len, device = device)

if auto_shard_seq:
pos += maybe_chunk_seq_len * get_rank()

if self.striped_ring_attn:
pos *= stripe_stride

rotary_emb = self.rotary_emb(pos)

# main transformer logic

x = self.token_emb(x)

for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, rotary_emb = rotary_emb) + x
x = ff(x) + x

logits = self.to_logits(x)
Expand Down
5 changes: 0 additions & 5 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
get_rank
)

from ring_attention_pytorch.rotary import (
RotaryEmbedding,
apply_rotary_pos_emb
)

# constants

EPSILON = 1e-10
Expand Down
6 changes: 3 additions & 3 deletions ring_attention_pytorch/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def __init__(
@autocast(enabled = False)
def forward(
self,
seq_len,
pos,
offset = 0
):
t = torch.arange(seq_len + offset, device = self.inv_freq.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
pos = pos.type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', pos, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)

def rotate_half(x):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.9',
version = '0.1.10',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ed81af6

Please sign in to comment.