Skip to content

Commit

Permalink
remove redundant slice when chunked prefill feature is disabled (open…
Browse files Browse the repository at this point in the history
  • Loading branch information
sanyalington authored Sep 20, 2024
1 parent 9d8035b commit 0e80e85
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,10 @@ def forward(

# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
Expand Down Expand Up @@ -564,12 +567,16 @@ def forward(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
ops.paged_attention_rocm(
output[num_prefill_tokens:], exp_sums, max_logits,
tmp_output, decode_query, key_cache, value_cache,
self.num_kv_heads, self.scale, decode_meta.block_tables,
decode_meta.seq_lens_tensor, block_size, max_seq_len,
self.alibi_slopes, self.kv_cache_dtype, k_scale, v_scale)
out, exp_sums, max_logits, tmp_output, decode_query,
key_cache, value_cache, self.num_kv_heads, self.scale,
decode_meta.block_tables, decode_meta.seq_lens_tensor,
block_size, max_seq_len, self.alibi_slopes,
self.kv_cache_dtype, k_scale, v_scale)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
Expand Down

0 comments on commit 0e80e85

Please sign in to comment.