Skip to content

Commit

Permalink
refactor mask in preparation for moving column traversal to outer loop
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 19, 2024
1 parent 00d9afa commit 8c19b24
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from einx import rearrange

from ring_attention_pytorch.ring import (
maybe,
all_ring_pass,
null_ring_pass,
one_ring_pass,
Expand Down Expand Up @@ -71,7 +72,7 @@ def forward(
bucket_size = min(per_machine_seq_size, bucket_size)
per_machine_buckets = per_machine_seq_size // bucket_size

orig_k, orig_v, device = k, v, q.device
orig_k, orig_v, orig_mask, device = k, v, mask, q.device

row_ring_rank = get_rank() if ring_reduce_col else 0

Expand All @@ -87,29 +88,22 @@ def forward(

num_tiles = math.ceil(per_machine_seq_size / bucket_size)

if exists(mask):
mask = rearrange('b n -> b 1 1 n', mask)

mask = ((mask,) * num_tiles)
orig_mask = mask

row_splits = zip(
q.split(bucket_size, dim = -2),
o.split(bucket_size, dim = -2),
mask,
all_row_sums.split(bucket_size, dim = -2),
all_row_maxes.split(bucket_size, dim = -2),
)

for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
for ind, (qc, oc, row_sums, row_maxes) in enumerate(row_splits):
row_bucket_index = row_ring_rank * per_machine_buckets + ind

for ring_rank, (k, v, row_mask) in ring_pass_fn(k, v, row_mask):
for ring_rank, (k, v, mask) in ring_pass_fn(k, v, mask):

col_splits = zip(
k.split(bucket_size, dim = -2),
v.split(bucket_size, dim = -2),
maybe_split(row_mask, bucket_size, dim = -1)
maybe_split(mask, bucket_size, dim = -1)
)

for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
Expand All @@ -118,7 +112,7 @@ def forward(
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if exists(col_mask):
attn_weights.masked_fill_(~col_mask, max_neg_value)
attn_weights = einx.where('b j, b h i j, -> b h i j', col_mask, attn_weights, max_neg_value)

if causal:
if striped_ring_attn:
Expand All @@ -139,7 +133,7 @@ def forward(
exp_weights = torch.exp(attn_weights - new_row_maxes)

if exists(col_mask):
exp_weights.masked_fill_(~col_mask, 0.)
exp_weights = einx.where('b j, b h i j, -> b h i j', col_mask, exp_weights, 0.)

block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

Expand All @@ -154,6 +148,8 @@ def forward(
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)

mask = maybe(one_ring_pass)(mask)

oc.div_(row_sums)

lse = all_row_sums.log() + all_row_maxes
Expand Down Expand Up @@ -190,22 +186,21 @@ def backward(ctx, do):
q.split(bucket_size, dim = -2),
o.split(bucket_size, dim = -2),
do.split(bucket_size, dim = -2),
mask,
lse.split(bucket_size, dim = -2),
dq.split(bucket_size, dim = -2)
)

for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
for ind, (qc, oc, doc, lsec, dqc) in enumerate(row_splits):
row_bucket_index = row_ring_rank * per_machine_buckets + ind

for ring_rank, (k, v, row_mask, dk, dv) in ring_pass_fn(k, v, row_mask, dk, dv):
for ring_rank, (k, v, mask, dk, dv) in ring_pass_fn(k, v, mask, dk, dv):

col_splits = zip(
k.split(bucket_size, dim = -2),
v.split(bucket_size, dim = -2),
dk.split(bucket_size, dim = -2),
dv.split(bucket_size, dim = -2),
maybe_split(row_mask, bucket_size, dim = -1)
maybe_split(mask, bucket_size, dim = -1)
)

for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
Expand All @@ -228,7 +223,7 @@ def backward(ctx, do):
p = torch.exp(attn_weights - lsec)

if exists(col_mask):
p.masked_fill_(~col_mask, 0.)
p = einx.where('b j, b h i j, -> b h i j', col_mask, p, 0.)

dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
Expand All @@ -243,6 +238,8 @@ def backward(ctx, do):
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)

mask = maybe(one_ring_pass)(mask)

dk = one_ring_pass(dk)
dv = one_ring_pass(dv)

Expand Down

0 comments on commit 8c19b24

Please sign in to comment.