Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash attention版本 #21

Open
hxdtest opened this issue Mar 6, 2024 · 11 comments
Open

flash attention版本 #21

hxdtest opened this issue Mar 6, 2024 · 11 comments

Comments

@hxdtest
Copy link

hxdtest commented Mar 6, 2024

请问最低的flash-attention版本是?

@zhuzilin
Copy link
Owner

zhuzilin commented Mar 7, 2024

这个我没有仔细测过,应该 2.4.x 以上肯定是能跑的

@void-main
Copy link

如果要使用llama3的方案,至少需要 2.6.0 (支持unpadded lse),参考这个commit:Dao-AILab/flash-attention@f816dee

@zhuzilin
Copy link
Owner

前两天更新过了,现在支持 unpadded lse~

@void-main
Copy link

@zhuzilin 支持完unpadded lse后,老版本不兼容了,用不支持unpadded lse的FA版本会错误的slice softmax_lse,导致illegal memory access

@zhuzilin
Copy link
Owner

@void-main 我刚刚测试了一下,flash_attn 2.5.9 也就是老版本的 lse,和最新的 flash_attn,都是可以正确运行 test_zigzag_ring_flash_attn_varlen_func.py 的...

@void-main
Copy link

@void-main 我刚刚测试了一下,flash_attn 2.5.9 也就是老版本的 lse,和最新的 flash_attn,都是可以正确运行 test_zigzag_ring_flash_attn_varlen_func.py 的...

@zhuzilin 啊,可能没说清楚,我这里跑的是llama3那个

@void-main
Copy link

而且是GQA场景下,例如num_heads=64,num_kv_heads=8,这时候如果head_k_stride==1,就会挂

@void-main
Copy link

这个是测试代码:

import torch
import torch.distributed as dist
from flash_attn import flash_attn_varlen_func
from ring_flash_attn import (
    llama3_flash_attn_prepare_cu_seqlens,
    llama3_flash_attn_varlen_func,
)
from utils import log, set_seed


if __name__ == "__main__":
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    set_seed(rank)
    world_size = dist.get_world_size()
    dtype = torch.bfloat16
    device = torch.device(f"cuda:{rank}")

    batch_size = 1
    nheads = 64
    nkv_heads = 8

    d = 128
    dropout_p = 0
    causal = True
    deterministic = False

    cu_seqlens = [0, 120, 1248, 4232]
    cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
    max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
    total_length = cu_seqlens[-1]
    local_length = total_length // world_size
    num_seq = len(cu_seqlens) - 1

    assert cu_seqlens_tensor[-1] % world_size == 0
    assert d % 8 == 0

    q = torch.randn(
        total_length, nheads, d, device=device, dtype=dtype, requires_grad=True
    )
    k = torch.randn(
        total_length, nkv_heads, d, device=device, dtype=dtype, requires_grad=True
    )
    v = torch.randn(
        total_length, nkv_heads, d, device=device, dtype=dtype, requires_grad=True
    )
    dist.broadcast(q, src=0)
    dist.broadcast(k, src=0)
    dist.broadcast(v, src=0)

    dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
    dist.broadcast(dout, src=0)

    local_q = q[rank * local_length : (rank + 1) * local_length].detach().clone()
    local_k = k[rank * local_length : (rank + 1) * local_length].detach().clone()
    local_v = v[rank * local_length : (rank + 1) * local_length].detach().clone()
    local_q.requires_grad = True
    local_k.requires_grad = True
    local_v.requires_grad = True
    local_dout = dout[rank * local_length : (rank + 1) * local_length].detach().clone()

    dist.barrier()
    if rank == 0:
        print("#" * 30)
        print("# forward:")
        print("#" * 30)

    out, lse, _ = flash_attn_varlen_func(
        q, k, v,
        cu_seqlens_tensor,
        cu_seqlens_tensor,
        max_seqlen,
        max_seqlen,
        dropout_p=dropout_p,
        causal=causal,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=deterministic,
        return_attn_probs=True,
    )

    print(f'qkv shape: {q.shape}, cu_seqlens shape: {cu_seqlens_tensor.shape}, lse shape: {lse.shape}', flush=True)

    local_out = out[rank * local_length : (rank + 1) * local_length]
    local_lse = lse[:, rank * local_length : (rank + 1) * local_length]

    (
        local_cu_seqlens_q,
        local_cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        local_k_slice,
    ) = llama3_flash_attn_prepare_cu_seqlens(
        cu_seqlens_tensor,
        causal=causal,
        rank=rank,
        world_size=world_size,
    )

    llama3_out, llama3_lse, _ = llama3_flash_attn_varlen_func(
        local_q, local_k, local_v,
        local_cu_seqlens_q,
        local_cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        heads_k_stride=1,
        # heads_k_stride=nkv_heads,
        local_k_slice=local_k_slice,
        dropout_p=dropout_p,
        causal=causal,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=deterministic,
        return_attn_probs=True,
    )

    log("out", out, rank0_only=True)
    log("out diff", local_out - llama3_out)
    # log("lse", lse, rank0_only=True)
    # print(f'local lse shape: {local_lse.shape}, llama3 lse shape: {llama3_lse.shape}', flush=True)
    # log("lse diff", local_lse - llama3_lse)

    dist.barrier()
    if rank == 0:
        print("#" * 30)
        print("# backward:")
        print("#" * 30)

    out.backward(dout)
    dq = q.grad
    dk = k.grad
    dv = v.grad
    local_dq = dq[rank * local_length : (rank + 1) * local_length]
    local_dk = dk[rank * local_length : (rank + 1) * local_length]
    local_dv = dv[rank * local_length : (rank + 1) * local_length]

    llama3_out.backward(local_dout)
    llama3_dq = local_q.grad
    llama3_dk = local_k.grad
    llama3_dv = local_v.grad

    log("dq diff", local_dq[:, 0] - llama3_dq[:, 0])
    log("dk diff", local_dk[:, 1] - llama3_dk[:, 1])
    log("dv diff", local_dv[:, 2] - llama3_dv[:, 2])

@zhuzilin
Copy link
Owner

啊,可能没说清楚,我这里跑的是llama3那个

奥奥,好的~ llama3 那个因为是新写的,所以确实没写老版本 flash attn 的兼容.... 我去加一下~

@zhuzilin
Copy link
Owner

@void-main 已经修了,可以 pull 一下最新的代码试试~

@void-main
Copy link

@zhuzilin 基于FA v2.4.2测试能过,谢谢大佬 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants