Skip to content

Commit

Permalink
Using the correct datatype on prefix prefill for fp8 kv cache (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Oct 23, 2024
1 parent 2a3f461 commit 46aa3d2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl

from vllm.platforms import current_platform
from vllm.utils import is_hip

if triton.__version__ >= "2.1.0":

Expand Down Expand Up @@ -724,7 +725,8 @@ def context_attention_fwd(q,
assert (v_cache.dtype == torch.uint8)

if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
target_dtype = torch.float8_e4m3fn if not is_hip(
) else torch.float8_e4m3fnuz
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
Expand Down

0 comments on commit 46aa3d2

Please sign in to comment.