Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

verify causal masking #54

Open
huseinzol05 opened this issue Oct 6, 2024 · 6 comments
Open

verify causal masking #54

huseinzol05 opened this issue Oct 6, 2024 · 6 comments

Comments

@huseinzol05
Copy link

Hi @zhuzilin, follow up from #15

I just wanted to verify the causal, and I simply use loop because I dont have multigpus, but it should be working, when I do causal using your ring logic, the argmax accuracy is super low, but when I do non causal, accuracy is almost perfect 100%, you can check the notebook at https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/flash-ring-attention-causal.ipynb

From what I understand, let say, I got 2 devices and seqlen of 100k, partitioned to 2, 100k // 2 = 50k 50k, so,

each 50k seq len,
device 0: 50k q0k0v0
device 1: 50k q1k1v1

So the blockwise attention calculation,
device 0: 50k q0k0v0 + 50k q0k1v1
device 1: 50k q1k0v0 + 50k q1k1v1

(+) denoted as blockwise attention.

For causal base, attention mask is necessary, so the attention mask originally is [100k, 100k] and attention mask we must chunk properly, to become mask0 = [50k, 100k] and mask1 = [50k, 100k], so the blockwise attention calculation,

device 0: 50k (q0k0 * mask0[:, 0:50k])v0 + 50k q0k1v1 * mask0[:, 50k:100k]
device 1: 50k (q1k0 * mask1[:, 0:50k])v0 + 50k q1k1v1 * mask1[:, 50k:100k]

You can see this slicing from original https://github.com/forhaoliu/ringattention/blob/main/ringattention/ringattention_pallas_tpu.py#L61

Correct me if im wrong here, thanks!

@huseinzol05
Copy link
Author

This is to replicate the original jax implementation, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/blockwise-vanilla-attention-causal.ipynb

I just simply generate global multiplier mask (lazy to do addition),

temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
attn_bias_blocks = torch.chunk(temp_mask, chunk_size)
attn_bias_block = attn_bias_blocks[0]
seq_chunk = Q.shape[0] // chunk_size
attn_bias_b = attn_bias_block[:, no * seq_chunk: (no + 1) * seq_chunk]
scores = torch.matmul(Q_block, K_block.T) * attn_bias_b

While original jax generate addition mask during blockwise,

def _chunk_attention_bias(query_chunk_size, key_chunk_size,
            bias, segment_ids, deterministic, attn_dropout, attn_pdrop, causal,
            dtype, query_chunk_idx, key_chunk_idx):
    query_offset = query_chunk_idx * query_chunk_size
    key_offset = key_chunk_idx * key_chunk_size
    chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
    if bias is not None:
        chunk_bias = lax.dynamic_slice(
            bias,
            start_indices=(0, 0, 0, key_offset),
            slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
        )

    if segment_ids is not None:
        q_segment_ids = lax.dynamic_slice(
            segment_ids,
            start_indices=(0, query_offset),
            slice_sizes=(segment_ids.shape[0], query_chunk_size)
        )
        k_segment_ids = lax.dynamic_slice(
            segment_ids,
            start_indices=(0, key_offset),
            slice_sizes=(segment_ids.shape[0], key_chunk_size)
        )
        segment_ids_mask = q_segment_ids[:, :, None] != k_segment_ids[:, None, :]
        segment_ids_mask = segment_ids_mask[:, None] # B1QK
        segment_ids_bias = segment_ids_mask * jnp.finfo(dtype).min
        chunk_bias += segment_ids_bias

    if causal:
        query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
        key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
        offset = query_offset - key_offset
        query_idx += offset
        causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
        chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)

    if not deterministic and attn_pdrop > 0.0:
        attn_dropout_slice = lax.dynamic_slice(
            attn_dropout,
            start_indices=(0, 0, query_offset, key_offset),
            slice_sizes=(
                *attn_dropout.shape[:2],
                min(attn_dropout.shape[-2], query_chunk_size),
                min(attn_dropout.shape[-1], key_chunk_size),
            ),
        )
        chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
    return chunk_bias.astype(dtype)
_chunk_bias_fn = partial(
        _chunk_attention_bias,
        query_chunk_size, key_chunk_size, bias, segment_ids, deterministic,
        attn_dropout, attn_pdrop, causal, dtype)
bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)

@huseinzol05
Copy link
Author

any comment @zhuzilin ?

@zhuzilin
Copy link
Owner

hmm... I'm not sure what you are aiming at. If you just want to be sure that this implementation supports causal mask, you can try running the code in the test folder.

The code in the repo is not a step by step transfer from the origin jax implementation and I actually haven't read that before...

@huseinzol05
Copy link
Author

Sorry, I just want to verify the causal masking because based on the code,

for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if not causal or step <= comm.rank:
            params = get_default_args(_flash_attn_forward).copy()
            params.update(
                {
                    "q": q,
                    "k": k,
                    "v": v,
                    "dropout_p": dropout_p,
                    "softmax_scale": softmax_scale,
                    "causal": causal and step == 0,
                    "window_size": window_size,
                    "alibi_slopes": alibi_slopes,
                    "return_softmax": True and dropout_p > 0,
                }
            )
            block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)

Let say I have qkv size [L, dim], [10, 100] with causal mask [10, 10],

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

Now I chunk on sequence dimension to 2 devices with each qkv size [5, 100] and causal [5, 10],

device 0: [5, 100] mask [5, 10]

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])

device 1: [5, 100] mask [5, 10]

tensor([[ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

On device 0, so the blockwise,

q0k0mask[:, 0:5]v0 + q0k1mask[:, 5:10]v1

where mask[:, 0:5],

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

where mask[:, 5:10],

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

On device 1, so the blockwise,

q1k0mask[0:5]v0 + q1k1mask[5:10]v1

where mask[:, 0:5],

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

where mask[:, 5:10],

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

Now back to the forward code,

for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor = comm.send_recv(k)
            next_v: torch.Tensor = comm.send_recv(v)
            comm.commit()

        if not causal or step <= comm.rank:
            params = get_default_args(_flash_attn_forward).copy()
            params.update(
                {
                    "q": q,
                    "k": k,
                    "v": v,
                    "dropout_p": dropout_p,
                    "softmax_scale": softmax_scale,
                    "causal": causal and step == 0,
                    "window_size": window_size,
                    "alibi_slopes": alibi_slopes,
                    "return_softmax": True and dropout_p > 0,
                }
            )
            block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
            out, lse = update_out_and_lse(out, lse, block_out, block_lse)

let say currently at device 0, and world_size == 2 so device 0 will calculate q0k0v0 + q0k1v1.

when step == 0,
causal is true based on causal and step == 0, flash attention will generate causal mask lower triangle but we need,

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

when step == 1,
causal is false based on causal and step == 0, flash attention will do full attention but we need,

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

Same goes to device 1.

@zhuzilin Hopefully you can understand what I am trying to share, lol

@zhuzilin
Copy link
Owner

hmm... there will never be a rectangular mask as all the k and q chunks will have the same sequence length. And there will also not be a mask with all False, which actually means doing no calculation...

@huseinzol05
Copy link
Author

regardless, calculation of qk-mask-v still happpened, it just produced really really small value and merging later will produced correct results. Whats your thought on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants