Skip to content

Commit

Permalink
handle a small edge case for tree attn decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 13, 2024
1 parent f763e58 commit cc1aea8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
30 changes: 20 additions & 10 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def tree_attn_decode(q, k, v, eps = 1e-8):
https://arxiv.org/abs/2408.04093
"""

device, dim_v = q.device, v.shape[-1]

rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1

Expand All @@ -27,21 +29,29 @@ def tree_attn_decode(q, k, v, eps = 1e-8):
k = k.chunk(world_size, dim = -2)
v = v.chunk(world_size, dim = -2)

k, v = k[rank], v[rank]
if rank < len(k):
k, v = k[rank], v[rank]

# calculate local output and derive numerator and denominator

sim = einsum('... i d, ... j d -> ... i j', q, k)

# first calculate local output
local_max = sim.amax(dim = -1, keepdim = True)
sim -= local_max
lse = sim.logsumexp(dim = -1, keepdim = True)

sim = einsum('... i d, ... j d -> ... i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)

local_max = sim.amax(dim = -1, keepdim = True)
sim -= local_max
lse = sim.logsumexp(dim = -1, keepdim = True)
den = lse.exp()
num = out * den

attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)
else:
# handle edge case where seq length < world size

den = lse.exp()
num = out * den
num = q.new_zeros((*q.shape[:-1], dim_v))
den = q.new_zeros((*q.shape[:-1], 1))
local_max = torch.zeros_like(den)

# first get global max through an all reduce (max)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.4',
version = '0.5.5',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit cc1aea8

Please sign in to comment.