From ed81af6035ee9a5d25d755999264de10aa9130e2 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 20 Feb 2024 10:44:21 -0800 Subject: [PATCH] complete rotary embeddings in context of ring reduce --- README.md | 8 ++-- ring_attention_pytorch/ring_attention.py | 38 +++++++++++++++++-- .../ring_flash_attention.py | 5 --- ring_attention_pytorch/rotary.py | 6 +-- setup.py | 2 +- 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index de5744d..d275cde 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,10 @@ $ python assert.py - [x] modify flash attention to output intermediates and figure out backwards with recompute and ring passes - [x] functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper - [x] basic test case with two processes and check for equivalent output and gradients - -- [ ] testing +- [x] testing - [x] make sure key padding mask works - [x] make sure causal mask works - - [ ] option to auto-decide ring sequence size based on world size - - [ ] rotary embeddings, with proper key/value offset depending on ring rank + - [x] rotary embeddings, with proper key/value offset depending on ring rank - [x] striped attention - [x] add the permutating logic before and after transformer - [x] add causal masking logic - account for sub bucketing by flash attention @@ -68,6 +66,8 @@ $ python assert.py - [x] backwards - [x] forwards +- [ ] option to auto-decide ring sequence size based on world size + - [ ] allow for finely specifying how to distribute sharding of batch and sequence, depending on world size - [ ] allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers. - [ ] also allow for mixed striped layers, so one can do local strided attention (used in some image transformers but not language) would be a good fit for the initial local attention layers, in addition to classic lookback local - [ ] find a machine with 8 GPUs and test with a quarter million tokens first diff --git a/ring_attention_pytorch/ring_attention.py b/ring_attention_pytorch/ring_attention.py index c813b92..9249e44 100644 --- a/ring_attention_pytorch/ring_attention.py +++ b/ring_attention_pytorch/ring_attention.py @@ -23,6 +23,11 @@ AllGather ) +from ring_attention_pytorch.rotary import ( + RotaryEmbedding, + apply_rotary_pos_emb +) + # helper functions def exists(v): @@ -203,7 +208,8 @@ def __init__( def forward( self, x, - mask = None + mask = None, + rotary_emb = None ): """ einstein notation @@ -227,6 +233,14 @@ def forward( qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads) + # rotary relative positions + + if exists(rotary_emb): + q = apply_rotary_pos_emb(rotary_emb, q) + k = apply_rotary_pos_emb(rotary_emb, k) + + # regular attention vs flash w/ or w/o kv ring reduce + if self.force_regular_attn or not is_distributed(): out = default_attention(q, k, v, mask = mask, causal = self.causal) else: @@ -304,6 +318,7 @@ def __init__( assert not (self.striped_ring_attn and not causal), 'striped ring attention only applies to autoregressive models' self.token_emb = nn.Embedding(num_tokens, dim) + self.rotary_emb = RotaryEmbedding(dim_head) self.layers = ModuleList([]) @@ -334,7 +349,7 @@ def forward( x, mask = None ): - seq_len = x.shape[-1] + seq_len, device = x.shape[-1], x.device auto_shard_seq = self.auto_shard_seq & is_distributed() # take care of padding to divide sequence across the machines @@ -351,6 +366,8 @@ def forward( if self.striped_ring_attn: x = rearrange(x, 'b (i j) -> b (j i)', i = self.bucket_size) + stripe_stride = x.shape[-1] // self.bucket_size + if exists(mask): mask = rearrange(mask, 'b (i j) -> b (j i)', i = self.bucket_size) @@ -358,12 +375,27 @@ def forward( (x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size) + # rotary positions + # taking into account ring and striping + + maybe_chunk_seq_len = x.shape[-1] + + pos = torch.arange(maybe_chunk_seq_len, device = device) + + if auto_shard_seq: + pos += maybe_chunk_seq_len * get_rank() + + if self.striped_ring_attn: + pos *= stripe_stride + + rotary_emb = self.rotary_emb(pos) + # main transformer logic x = self.token_emb(x) for attn, ff in self.layers: - x = attn(x, mask = mask) + x + x = attn(x, mask = mask, rotary_emb = rotary_emb) + x x = ff(x) + x logits = self.to_logits(x) diff --git a/ring_attention_pytorch/ring_flash_attention.py b/ring_attention_pytorch/ring_flash_attention.py index 81bdb3a..bec219c 100644 --- a/ring_attention_pytorch/ring_flash_attention.py +++ b/ring_attention_pytorch/ring_flash_attention.py @@ -16,11 +16,6 @@ get_rank ) -from ring_attention_pytorch.rotary import ( - RotaryEmbedding, - apply_rotary_pos_emb -) - # constants EPSILON = 1e-10 diff --git a/ring_attention_pytorch/rotary.py b/ring_attention_pytorch/rotary.py index 0b5bc21..3abd6ea 100644 --- a/ring_attention_pytorch/rotary.py +++ b/ring_attention_pytorch/rotary.py @@ -16,11 +16,11 @@ def __init__( @autocast(enabled = False) def forward( self, - seq_len, + pos, offset = 0 ): - t = torch.arange(seq_len + offset, device = self.inv_freq.device).type_as(self.inv_freq) - freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + pos = pos.type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', pos, self.inv_freq) return torch.cat((freqs, freqs), dim = -1) def rotate_half(x): diff --git a/setup.py b/setup.py index b468254..a4e4359 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.1.9', + version = '0.1.10', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',