Skip to content

Commit

Permalink
an assert for key padding mask for ring_flash_attn_cuda and 0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2024
1 parent 0f38719 commit 5b202e4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def is_empty(t: Tensor):
def is_contiguous(x: Tensor):
return x.stride(-1) == 1

def padded_false_on_right_side(t: Tensor):
if t.shape[-1] <= 1:
return True

false_to_true = ~t[..., :-1] & t[..., 1:]
return not false_to_true.any()

# make sure flash attention is installed for backwards

import importlib
Expand Down Expand Up @@ -488,6 +495,7 @@ def forward(
ring_size: Optional[int]
):
assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda'
assert not exists(mask) or padded_false_on_right_side(mask), 'key padding mask must only contain True (attend) on the left hand side, and False (not attend) on the right'

dtype = q.dtype
softmax_scale = q.shape[-1] ** -0.5
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.28',
version = '0.3.0',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 5b202e4

Please sign in to comment.