Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed streamk kernel bug #602

Open
wants to merge 1 commit into
base: streamk-no-atomic
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/perf-kernels/streamk/persistent_streamk_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def persistent_streamk_gemm(
iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config(M, N, K, num_sms, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

for tile_id in range(pid, total_full_tiles, num_sms):
if GROUP_SIZE_M == 1:
Expand All @@ -62,7 +63,7 @@ def persistent_streamk_gemm(
rk = tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
acc = acc * 0.0
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(A_BASE)
b = tl.load(B_BASE)
Expand All @@ -76,6 +77,7 @@ def persistent_streamk_gemm(
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C_, acc, mask=mask)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we could remove this line, we will get better perf

start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters)
last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum(pid + 1, streamk_remainder_iters)
while start_iter < last_iter:
Expand All @@ -99,7 +101,7 @@ def persistent_streamk_gemm(
rk = tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
acc = acc * 0.0
for current_iter in range(start_iter, end_iter):
a = tl.load(A_BASE)
b = tl.load(B_BASE)
Expand All @@ -110,7 +112,7 @@ def persistent_streamk_gemm(
# ower iter is starting from middle of the iter
# if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM
tile_iter = tile_id * iters_per_tile
if start_iter != tile_iter:
if start_iter != tile_iter:
rm1 = tl.arange(0, BLOCK_SIZE_M)
rn1 = tl.arange(0, BLOCK_SIZE_N)
P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
Expand Down