diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index 2c55894..9ef934b 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -40,6 +40,13 @@ def is_empty(t: Tensor): def is_contiguous(x: Tensor): return x.stride(-1) == 1 +def padded_false_on_right_side(t: Tensor): + if t.shape[-1] <= 1: + return True + + false_to_true = ~t[..., :-1] & t[..., 1:] + return not false_to_true.any() + # make sure flash attention is installed for backwards import importlib @@ -488,6 +495,7 @@ def forward( ring_size: Optional[int] ): assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda' + assert not exists(mask) or padded_false_on_right_side(mask), 'key padding mask must only contain True (attend) on the left hand side, and False (not attend) on the right' dtype = q.dtype softmax_scale = q.shape[-1] ** -0.5 diff --git a/setup.py b/setup.py index 2e3349a..f65df16 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.2.28', + version = '0.3.0', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',