Skip to content

Commit

Permalink
Move visualization to new test and add warning
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Aug 15, 2024
1 parent 14579c0 commit f116bb9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import get_custom
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.visualization import visualize_graph


def run(func: Callable[[], None]) -> Callable[[], None]:
Expand Down Expand Up @@ -238,7 +237,6 @@ def test_gemm():
graph = gemm()
IndexingContext.current().finalize()
expand_graph(graph, constraints)
visualize_graph(graph.get_subgraph("region_0"), "gemm.png")
print_trace(graph)
# Root graph:
# CHECK: %a
Expand Down
2 changes: 2 additions & 0 deletions shark_turbine/kernel/wave/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
except:
disabled_graphviz = True
from torch import fx
import warnings


def visualize_graph(graph: fx.Graph, file_name: str):
if disabled_graphviz:
warnings.warn("pygraphviz not installed, skipping visualization.")
return
G = pgv.AGraph(directed=True)
for node in graph.nodes:
Expand Down
82 changes: 82 additions & 0 deletions tests/kernel/wave/visualization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from typing import Callable
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import get_custom
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.visualization import visualize_graph


def run(func: Callable[[], None]) -> Callable[[], None]:
"""Run a function as part of the test suite."""
if __name__ == "__main__":
func()
return func


# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K

# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


@run
def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_K: 32,
}
):
graph = gemm()
IndexingContext.current().finalize()
expand_graph(graph, constraints)
visualize_graph(graph.get_subgraph("region_0"), "gemm.png")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit f116bb9

Please sign in to comment.