diff --git a/README.md b/README.md index efebd02..9f4fcad 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,9 @@ $ python assert.py - [ ] allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers. - [ ] also allow for mixed striped layers, so one can do local strided attention (used in some image transformers but not language) would be a good fit for the initial local attention layers, in addition to classic lookback local - [ ] find a machine with 8 GPUs and test with a quarter million tokens first -- [ ] figure out batch_isend_irecv - [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training - [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl +- [ ] `batch_isend_irecv` in the presence of key padding mask needing ring exchange, but not a big priority ## Citations diff --git a/ring_attention_pytorch/ring_flash_attention.py b/ring_attention_pytorch/ring_flash_attention.py index bec219c..f13a5e1 100644 --- a/ring_attention_pytorch/ring_flash_attention.py +++ b/ring_attention_pytorch/ring_flash_attention.py @@ -91,7 +91,11 @@ def forward( num_tiles = math.ceil(per_machine_seq_size / bucket_size) - for ring_rank, (k, v, mask) in ring_pass_fn(k, v, mask, max_iters = max_ring_passes): + kv = torch.stack((k, v)) + + for ring_rank, (kv, mask) in ring_pass_fn(kv, mask, max_iters = max_ring_passes): + + k, v = kv col_splits = zip( k.split(bucket_size, dim = -2), @@ -209,7 +213,11 @@ def backward(ctx, do): dq.split(bucket_size, dim = -2) ) - for ring_rank, (k, v, mask, dk, dv) in ring_pass_fn(k, v, mask, dk, dv, max_iters = max_ring_passes): + kv_and_dkv = torch.stack((k, v, dk, dv)) + + for ring_rank, (kv_and_dkv, mask) in ring_pass_fn(kv_and_dkv, mask, max_iters = max_ring_passes): + + k, v, dk, dv = kv_and_dkv col_splits = zip( k.split(bucket_size, dim = -2), @@ -257,8 +265,11 @@ def backward(ctx, do): dkc.add_(dk_chunk) dvc.add_(dv_chunk) - dk = one_ring_pass(dk) - dv = one_ring_pass(dv) + dkv = kv_and_dkv[2:] + + dkv = one_ring_pass(dkv) + + dk, dv = dkv return dq, dk, dv, None, None, None, None, None, None, None diff --git a/setup.py b/setup.py index 61aacd6..7cfdb96 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.1.12', + version = '0.1.14', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',