diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 9532d601..a2a5429c 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -11,7 +11,6 @@ from shark_turbine.kernel._support.indexing import IndexingContext from shark_turbine.kernel.ops.wave_ops import get_custom from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.visualization import visualize_graph def run(func: Callable[[], None]) -> Callable[[], None]: @@ -238,7 +237,6 @@ def test_gemm(): graph = gemm() IndexingContext.current().finalize() expand_graph(graph, constraints) - visualize_graph(graph.get_subgraph("region_0"), "gemm.png") print_trace(graph) # Root graph: # CHECK: %a diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py index 7bdf6aa9..2f5e5277 100644 --- a/shark_turbine/kernel/wave/visualization.py +++ b/shark_turbine/kernel/wave/visualization.py @@ -4,10 +4,12 @@ except: disabled_graphviz = True from torch import fx +import warnings def visualize_graph(graph: fx.Graph, file_name: str): if disabled_graphviz: + warnings.warn("pygraphviz not installed, skipping visualization.") return G = pgv.AGraph(directed=True) for node in graph.nodes: diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py new file mode 100644 index 00000000..c8d64950 --- /dev/null +++ b/tests/kernel/wave/visualization_test.py @@ -0,0 +1,82 @@ +import logging +from typing import Callable +import unittest +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.wave.expansion import expand_graph +from shark_turbine.kernel._support.tracing import CapturedTrace +from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.ops.wave_ops import get_custom +from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel.wave.visualization import visualize_graph + + +def run(func: Callable[[], None]) -> Callable[[], None]: + """Run a function as part of the test suite.""" + if __name__ == "__main__": + func() + return func + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K + +# Workgroup tile sizes +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K + +# Address space (for GPU, shared(1) or global(0)) +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + + +@tkw.wave_trace_only() +def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run +def test_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 32, + } + ): + graph = gemm() + IndexingContext.current().finalize() + expand_graph(graph, constraints) + visualize_graph(graph.get_subgraph("region_0"), "gemm.png") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()