Skip to content

Commit

Permalink
Solve nits
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Oct 21, 2024
1 parent 3d6fbba commit 3bd9d00
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
7 changes: 4 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,9 +963,10 @@ def wrapper(f):

def get_root_graph(self):
"""
Get root graph from some child graph inside a nested graph.
Using the assumption that any child/nested graph should have a parent_op,
who we can query for it's owner graph from to go up one level.
Return the "root"/most outter layer of our computation graph.
This is done by iteratively accessing parent_graph of current
graph. This is done until we find the "root" graph who
will have "subgraph" attribute.
"""
cur_graph = self.graph
while not hasattr(cur_graph, "subgraphs"):
Expand Down
12 changes: 9 additions & 3 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp, GetResult)
anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
legalSubtypes = (IterArg,)
nonPropagatableTypes = anchorOpTypes + noHandleTypes
Expand All @@ -68,7 +68,13 @@ def propagatable_op(node: fx.Node):
)


def handle_binaryop_conflict(custom_node: CustomOp):
def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
"""
This function will attempt to resolve binaryOp conflicts
by inserting broadcastOp. It will then propagate the resolutions,
and return the list of fx.Nodes that we have resolved.
"""

# Analyze if we can resolve conflict with broadcast.
lhs = get_custom(custom_node.lhs)
rhs = get_custom(custom_node.rhs)
Expand Down Expand Up @@ -161,7 +167,7 @@ def determine_thread_shapes(trace: CapturedTrace):
for anchor_op in anchor_ops:
custom = get_custom(anchor_op)
index_sizes = get_custom_dim_sizes(custom)
if isinstance(custom, (Read, GetResult)):
if isinstance(custom, Read):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
Expand Down

0 comments on commit 3bd9d00

Please sign in to comment.