Skip to content

Commit

Permalink
account for scenario where keys and values are already sharded in tre…
Browse files Browse the repository at this point in the history
…e attn decoding
  • Loading branch information
lucidrains committed Aug 14, 2024
1 parent f2f6d7a commit d42c0c5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 26 deletions.
18 changes: 18 additions & 0 deletions ring_attention_pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial, lru_cache

import torch
from torch import nn
from torch.nn import Module
Expand All @@ -20,6 +22,22 @@ def pad_dim_to(t, length, dim = 0):
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

cache = partial(lru_cache, maxsize = None)

# distributed helpers

@cache()
def get_rank():
return dist.get_rank() if dist.is_initialized() else 0

@cache()
def get_world_size():
return dist.get_world_size() if dist.is_initialized() else 1

@cache()
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

def all_gather_same_dim(t):
t = t.contiguous()
world_size = dist.get_world_size()
Expand Down
19 changes: 2 additions & 17 deletions ring_attention_pytorch/ring.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import lru_cache, partial, wraps
from functools import wraps, partial
from collections import namedtuple

import torch
Expand All @@ -9,6 +9,7 @@
from torch.autograd import Function

import torch.distributed as dist
from ring_attention_pytorch.distributed import get_rank, get_world_size, is_distributed

# helper functions

Expand All @@ -21,22 +22,6 @@ def default(v, d):
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)

cache = partial(lru_cache, maxsize = None)

# distributed globals

@cache()
def get_rank():
return dist.get_rank() if dist.is_initialized() else 0

@cache()
def get_world_size():
return dist.get_world_size() if dist.is_initialized() else 1

@cache()
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

# ring functions

def circular_index_left(pos, ring_size, num = 1):
Expand Down
25 changes: 17 additions & 8 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
from torch import einsum
import torch.distributed as dist

from ring_attention_pytorch.distributed import get_rank, get_world_size

def exists(v):
return v is not None

@torch.no_grad()
def tree_attn_decode(q, k, v, eps = 1e-8):
def tree_attn_decode(
q, k, v,
eps = 1e-8,
shard_kv_seq = False
):

assert k.shape[:-1] == v.shape[:-1]
assert q.shape[-2:] == (1, k.shape[-1])
Expand All @@ -16,8 +25,8 @@ def tree_attn_decode(q, k, v, eps = 1e-8):

device, dim_v = q.device, v.shape[-1]

rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = get_rank()
world_size = get_world_size()

# scale queries

Expand All @@ -26,12 +35,12 @@ def tree_attn_decode(q, k, v, eps = 1e-8):

# each machine (rank) takes care of a chunk of kv sequence within the world of many machines

k = k.chunk(world_size, dim = -2)
v = v.chunk(world_size, dim = -2)

if rank < len(k):
k, v = k[rank], v[rank]
if shard_kv_seq:
k = k.chunk(world_size, dim = -2)
v = v.chunk(world_size, dim = -2)
k, v = (k[rank], v[rank]) if rank < len(k) else (None, None)

if exists(k) and exists(v):
# calculate local output and derive numerator and denominator

sim = einsum('... i d, ... j d -> ... i j', q, k)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.5',
version = '0.5.6',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d42c0c5

Please sign in to comment.