Skip to content

Commit

Permalink
Throw error if pygraphviz not present
Browse files Browse the repository at this point in the history
Also add assert to check for image and xfail
test if graphviz is not available.
  • Loading branch information
harsh-nod committed Aug 15, 2024
1 parent b9a6941 commit e994f9e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
9 changes: 4 additions & 5 deletions shark_turbine/kernel/wave/visualization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
disabled_graphviz = False
graphviz_disabled = False
try:
import pygraphviz as pgv
except:
disabled_graphviz = True
graphviz_disabled = True
from torch import fx
import warnings


def visualize_graph(graph: fx.Graph, file_name: str):
if disabled_graphviz:
warnings.warn("pygraphviz not installed, skipping visualization.")
return
if graphviz_disabled:
raise ImportError("pygraphviz not installed, cannot visualize graph")
G = pgv.AGraph(directed=True)
for node in graph.nodes:
G.add_node(node.name)
Expand Down
11 changes: 11 additions & 0 deletions tests/kernel/wave/visualization_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Callable
import unittest
import os
import pytest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
Expand Down Expand Up @@ -54,6 +56,14 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
tkw.write(repeat, c, elements_per_thread=4)


graphviz_disabled = False
try:
import pygraphviz
except:
graphviz_disabled = True


@pytest.mark.xfail(condition=graphviz_disabled, reason="pygraphviz not installed")
@run
def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand All @@ -75,6 +85,7 @@ def test_gemm():
IndexingContext.current().finalize()
expand_graph(graph, constraints)
visualize_graph(graph.get_subgraph("region_0"), "gemm.png")
assert os.path.exists("gemm.png")


if __name__ == "__main__":
Expand Down

0 comments on commit e994f9e

Please sign in to comment.