From b8ac39adf5221028c794995f23ecfd29fbcd9e1d Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 20 Jun 2024 09:26:40 +0000 Subject: [PATCH] Fixed streamk kernel bug --- python/perf-kernels/streamk/persistent_streamk_kernel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/perf-kernels/streamk/persistent_streamk_kernel.py b/python/perf-kernels/streamk/persistent_streamk_kernel.py index f4c26c843a18..9d5b22a3943e 100644 --- a/python/perf-kernels/streamk/persistent_streamk_kernel.py +++ b/python/perf-kernels/streamk/persistent_streamk_kernel.py @@ -44,6 +44,7 @@ def persistent_streamk_gemm( iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config(M, N, K, num_sms, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for tile_id in range(pid, total_full_tiles, num_sms): if GROUP_SIZE_M == 1: @@ -62,7 +63,7 @@ def persistent_streamk_gemm( rk = tl.arange(0, BLOCK_SIZE_K) A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + acc = acc * 0.0 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(A_BASE) b = tl.load(B_BASE) @@ -76,6 +77,7 @@ def persistent_streamk_gemm( mask = (rm < M)[:, None] & (rn < N)[None, :] tl.store(C_, acc, mask=mask) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters) last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum(pid + 1, streamk_remainder_iters) while start_iter < last_iter: @@ -99,7 +101,7 @@ def persistent_streamk_gemm( rk = tl.arange(0, BLOCK_SIZE_K) A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + acc = acc * 0.0 for current_iter in range(start_iter, end_iter): a = tl.load(A_BASE) b = tl.load(B_BASE) @@ -110,7 +112,7 @@ def persistent_streamk_gemm( # ower iter is starting from middle of the iter # if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM tile_iter = tile_id * iters_per_tile - if start_iter != tile_iter: + if start_iter != tile_iter: rm1 = tl.arange(0, BLOCK_SIZE_M) rn1 = tl.arange(0, BLOCK_SIZE_N) P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]