From 46aa3d28d19cd0dba0c1e4db96d72f7d52b329af Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:36:24 -0400 Subject: [PATCH] Using the correct datatype on prefix prefill for fp8 kv cache (#242) --- vllm/attention/ops/prefix_prefill.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a2a649c8ebcfd..f40df684d4baf 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -6,6 +6,7 @@ import triton.language as tl from vllm.platforms import current_platform +from vllm.utils import is_hip if triton.__version__ >= "2.1.0": @@ -724,7 +725,8 @@ def context_attention_fwd(q, assert (v_cache.dtype == torch.uint8) if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn + target_dtype = torch.float8_e4m3fn if not is_hip( + ) else torch.float8_e4m3fnuz elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: