Skip to content

Commit

Permalink
enforce triton nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 8, 2024
1 parent 6210e0f commit 76266d5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def is_contiguous(x: Tensor):

assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

triton_version = version('triton')
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade'
triton_version = version('triton-nightly')
assert pkg_version.parse(triton_version) >= pkg_version.parse('3.0.0'), 'triton must be version 3.0.0 or above. `pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly` to upgrade'

import triton
import triton.language as tl
Expand Down Expand Up @@ -743,15 +743,16 @@ def _bwd_kernel_one_col_block(
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq)
tl.atomic_add(dq_ptrs, dq, sem = 'relaxed')
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, sem = 'relaxed')
else:
tl.atomic_add(
dq_ptrs,
dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
mask = (offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
sem = 'relaxed',
)
# increment pointers
dq_ptrs += BLOCK_M * stride_dqm
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.0',
version = '0.5.1',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 76266d5

Please sign in to comment.