Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to visualize fx graphs #84

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ wheelhouse
*.safetensors
*.gguf
*.vmfb
*.png
20 changes: 20 additions & 0 deletions shark_turbine/kernel/wave/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
graphviz_disabled = False
try:
import pygraphviz as pgv
except:
graphviz_disabled = True
from torch import fx
import warnings


def visualize_graph(graph: fx.Graph, file_name: str):
if graphviz_disabled:
raise ImportError("pygraphviz not installed, cannot visualize graph")
G = pgv.AGraph(directed=True)
for node in graph.nodes:
G.add_node(node.name)
for node in graph.nodes:
for user in node.users.keys():
G.add_edge(node.name, user.name)
G.layout(prog="dot")
G.draw(file_name)
93 changes: 93 additions & 0 deletions tests/kernel/wave/visualization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
from typing import Callable
import unittest
import os
import pytest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import get_custom
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.visualization import visualize_graph


def run(func: Callable[[], None]) -> Callable[[], None]:
"""Run a function as part of the test suite."""
if __name__ == "__main__":
func()
return func


# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K

# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


graphviz_disabled = False
try:
import pygraphviz
except:
graphviz_disabled = True


@pytest.mark.xfail(condition=graphviz_disabled, reason="pygraphviz not installed")
@run
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_K: 32,
}
):
graph = gemm()
IndexingContext.current().finalize()
expand_graph(graph, constraints)
visualize_graph(graph.get_subgraph("region_0"), "gemm.png")
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
assert os.path.exists("gemm.png")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 2 additions & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.lang.global_symbols import *


class Test(unittest.TestCase):
Expand Down Expand Up @@ -65,7 +66,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 1,
BLOCK_M: 32,
Expand Down
Loading