diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 9474c55..ea383a5 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -457,6 +457,15 @@ def convert_ctrl_dep_to_data_dep( last_visited_non_gpu = current_node last_visited_any = current_node + if json_node.sync_dep: + for sync_dep in json_node.sync_dep: + if sync_dep not in current_node.data_deps: + current_node.data_deps.append(sync_dep) + logging.info( + f"Node ID {current_node.id} now has an synchonization dependency on Node ID {sync_dep}" + ) + + # Add children to the stack children_chakra_ids = [child.id for child in json_node.children] for child_chakra_id in sorted(children_chakra_ids, reverse=True): child_chakra_node = protobuf_node_map.get(child_chakra_id) diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index 50feb4a..86b59ac 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -110,6 +110,7 @@ def _parse_data_1_0_3_chakra_0_0_4(self, node_data: Dict[str, Any]) -> None: self.exclusive_dur = node_data.get("exclusive_dur", 0) self.ts = node_data.get("ts") self.inter_thread_dep = node_data.get("inter_thread_dep") + self.sync_dep = node_data.get("sync_dep") self.cat = node_data.get("cat") self.stream = node_data.get("stream", 0) # In Colletive comms nodes, pg_name is in node_data if exists.