diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 18e8640..9a4781d 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -28,8 +28,8 @@ def is_contiguous(x: Tensor): assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first' -triton_version = version('triton') -assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade' +triton_version = version('triton-nightly') +assert pkg_version.parse(triton_version) >= pkg_version.parse('3.0.0'), 'triton must be version 3.0.0 or above. `pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly` to upgrade' import triton import triton.language as tl @@ -743,15 +743,16 @@ def _bwd_kernel_one_col_block( else: # If we're parallelizing across the seqlen_k dimension dq = tl.dot(ds, k) if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) + tl.atomic_add(dq_ptrs, dq, sem = 'relaxed') else: if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, sem = 'relaxed') else: tl.atomic_add( dq_ptrs, dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + mask = (offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + sem = 'relaxed', ) # increment pointers dq_ptrs += BLOCK_M * stride_dqm diff --git a/setup.py b/setup.py index 3963e40..9dfe7d4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.5.0', + version = '0.5.1', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',