-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move visualization to new test and add warning
- Loading branch information
Showing
3 changed files
with
84 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import logging | ||
from typing import Callable | ||
import unittest | ||
import shark_turbine.kernel as tk | ||
import shark_turbine.kernel.lang as tkl | ||
import shark_turbine.kernel.wave as tkw | ||
from shark_turbine.kernel.wave.expansion import expand_graph | ||
from shark_turbine.kernel._support.tracing import CapturedTrace | ||
from shark_turbine.kernel._support.indexing import IndexingContext | ||
from shark_turbine.kernel.ops.wave_ops import get_custom | ||
from shark_turbine.kernel.lang.global_symbols import * | ||
from shark_turbine.kernel.wave.visualization import visualize_graph | ||
|
||
|
||
def run(func: Callable[[], None]) -> Callable[[], None]: | ||
"""Run a function as part of the test suite.""" | ||
if __name__ == "__main__": | ||
func() | ||
return func | ||
|
||
|
||
# Input sizes | ||
M = tkl.sym.M | ||
N = tkl.sym.N | ||
K = tkl.sym.K | ||
|
||
# Workgroup tile sizes | ||
BLOCK_M = tkl.sym.BLOCK_M | ||
BLOCK_N = tkl.sym.BLOCK_N | ||
BLOCK_K = tkl.sym.BLOCK_K | ||
|
||
# Address space (for GPU, shared(1) or global(0)) | ||
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
|
||
# Induction variable for dimension K | ||
ARGK = tkl.sym.ARGK | ||
|
||
|
||
@tkw.wave_trace_only() | ||
def gemm( | ||
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], | ||
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], | ||
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], | ||
): | ||
c_reg = tkl.Register[M, N, tkl.f32](0.0) | ||
|
||
@tkw.reduction(K, init_args=[c_reg]) | ||
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: | ||
a_reg = tkw.read(a, elements_per_thread=4) | ||
b_reg = tkw.read(b, elements_per_thread=4) | ||
acc = tkw.mma(a_reg, b_reg, acc) | ||
return acc | ||
|
||
tkw.write(repeat, c, elements_per_thread=4) | ||
|
||
|
||
@run | ||
def test_gemm(): | ||
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] | ||
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] | ||
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] | ||
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] | ||
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] | ||
constraints += [ | ||
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) | ||
] | ||
with tk.gen.TestLaunchContext( | ||
{ | ||
BLOCK_M: 32, | ||
BLOCK_N: 32, | ||
BLOCK_K: 32, | ||
} | ||
): | ||
graph = gemm() | ||
IndexingContext.current().finalize() | ||
expand_graph(graph, constraints) | ||
visualize_graph(graph.get_subgraph("region_0"), "gemm.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |