From cc1aea86f2fb8d6800a0d8f7cebe503cc696bfcc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 13 Aug 2024 10:29:15 -0700 Subject: [PATCH] handle a small edge case for tree attn decoding --- ring_attention_pytorch/tree_attn_decoding.py | 30 +++++++++++++------- setup.py | 2 +- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/ring_attention_pytorch/tree_attn_decoding.py b/ring_attention_pytorch/tree_attn_decoding.py index 83864e4..4fc65cf 100644 --- a/ring_attention_pytorch/tree_attn_decoding.py +++ b/ring_attention_pytorch/tree_attn_decoding.py @@ -14,6 +14,8 @@ def tree_attn_decode(q, k, v, eps = 1e-8): https://arxiv.org/abs/2408.04093 """ + device, dim_v = q.device, v.shape[-1] + rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 @@ -27,21 +29,29 @@ def tree_attn_decode(q, k, v, eps = 1e-8): k = k.chunk(world_size, dim = -2) v = v.chunk(world_size, dim = -2) - k, v = k[rank], v[rank] + if rank < len(k): + k, v = k[rank], v[rank] + + # calculate local output and derive numerator and denominator + + sim = einsum('... i d, ... j d -> ... i j', q, k) - # first calculate local output + local_max = sim.amax(dim = -1, keepdim = True) + sim -= local_max + lse = sim.logsumexp(dim = -1, keepdim = True) - sim = einsum('... i d, ... j d -> ... i j', q, k) + attn = sim.softmax(dim = -1) + out = einsum('... i j, ... j d -> ... i d', attn, v) - local_max = sim.amax(dim = -1, keepdim = True) - sim -= local_max - lse = sim.logsumexp(dim = -1, keepdim = True) + den = lse.exp() + num = out * den - attn = sim.softmax(dim = -1) - out = einsum('... i j, ... j d -> ... i d', attn, v) + else: + # handle edge case where seq length < world size - den = lse.exp() - num = out * den + num = q.new_zeros((*q.shape[:-1], dim_v)) + den = q.new_zeros((*q.shape[:-1], 1)) + local_max = torch.zeros_like(den) # first get global max through an all reduce (max) diff --git a/setup.py b/setup.py index 717b562..c54627d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.5.4', + version = '0.5.5', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',