From 3bd9d00337b919043482703ea2a06bfeb2376ebe Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Sun, 20 Oct 2024 18:26:41 -0700 Subject: [PATCH] Solve nits Signed-off-by: Stanley Winata --- iree/turbine/kernel/ops/wave_ops.py | 7 ++++--- iree/turbine/kernel/wave/thread_shape_analysis.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 43db5fdf..e11a27c1 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -963,9 +963,10 @@ def wrapper(f): def get_root_graph(self): """ - Get root graph from some child graph inside a nested graph. - Using the assumption that any child/nested graph should have a parent_op, - who we can query for it's owner graph from to go up one level. + Return the "root"/most outter layer of our computation graph. + This is done by iteratively accessing parent_graph of current + graph. This is done until we find the "root" graph who + will have "subgraph" attribute. """ cur_graph = self.graph while not hasattr(cur_graph, "subgraphs"): diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 5371b0bf..7bcb6f57 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -51,7 +51,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): # Anchor Indicies and Conflict resolution helpers ################################################################# -anchorOpTypes = (Read, Write, MMA, ReduceOp, GetResult) +anchorOpTypes = (Read, Write, MMA, ReduceOp) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) legalSubtypes = (IterArg,) nonPropagatableTypes = anchorOpTypes + noHandleTypes @@ -68,7 +68,13 @@ def propagatable_op(node: fx.Node): ) -def handle_binaryop_conflict(custom_node: CustomOp): +def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: + """ + This function will attempt to resolve binaryOp conflicts + by inserting broadcastOp. It will then propagate the resolutions, + and return the list of fx.Nodes that we have resolved. + """ + # Analyze if we can resolve conflict with broadcast. lhs = get_custom(custom_node.lhs) rhs = get_custom(custom_node.rhs) @@ -161,7 +167,7 @@ def determine_thread_shapes(trace: CapturedTrace): for anchor_op in anchor_ops: custom = get_custom(anchor_op) index_sizes = get_custom_dim_sizes(custom) - if isinstance(custom, (Read, GetResult)): + if isinstance(custom, Read): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) thread_size_to_ops[index_sizes] = thread_size_to_ops.get( index_sizes, set([])