From a41b150d2118cc650889ef0dd6b70cba1e2ef6a3 Mon Sep 17 00:00:00 2001 From: Sara Kokkila Schumacher Date: Fri, 22 Mar 2024 14:43:35 -0500 Subject: [PATCH] fix: replace if/else statement with tl.where --- python/perf-kernels/flash-attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 91ec10a9e61d..c4f2e5b7628e 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -335,7 +335,7 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + off_h_k = tl.where(is_mqa, off_h_q % hk, off_h_q) need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: