Skip to content

Commit

Permalink
Merge pull request #126 from mlcommons/stream-id-encode
Browse files Browse the repository at this point in the history
Add stream ID to GPU operators and support multiple streams in simulate_execution
  • Loading branch information
srinivas212 authored Jul 13, 2024
2 parents 7f8b892 + c3ec4ee commit 2d16e16
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
3 changes: 2 additions & 1 deletion src/converter/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main() -> None:
parser.add_argument(
"--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes"
)
parser.add_argument("--simulate", action="store_true", help="Run simulate_execution if set")
parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename")
args = parser.parse_args()

Expand All @@ -47,7 +48,7 @@ def main() -> None:
converter = TextConverter(args.input_filename, args.output_filename, args.num_npus, args.num_passes)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorchConverter(args.input_filename, args.output_filename)
converter = PyTorchConverter(args.input_filename, args.output_filename, simulate=args.simulate)
converter.convert()
else:
supported_types = ["Text", "PyTorch"]
Expand Down
59 changes: 39 additions & 20 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ class PyTorchConverter:
output_filename (str): Output file name for the converted Chakra trace.
"""

def __init__(self, input_filename: str, output_filename: str) -> None:
def __init__(self, input_filename: str, output_filename: str, simulate: bool = False) -> None:
"""
Initialize the PyTorch to Chakra converter. It sets up necessary attributes and prepares the environment.
Args:
input_filename (str): Name of the input file containing PyTorch execution trace.
output_filename (str): Name of the output file for the converted Chakra trace.
simulate (bool): Whether to run simulate_execution after conversion.
"""
self.input_filename = input_filename
self.output_filename = output_filename
self.simulate = simulate

def convert(self) -> None:
"""Convert PyTorch execution traces into the Chakra format."""
Expand Down Expand Up @@ -74,7 +76,8 @@ def convert(self) -> None:
chakra_nodes,
)
self.close_chakra_execution_trace(chakra_et)
self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map)
if self.simulate:
self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map)

def load_pytorch_execution_traces(self) -> Dict:
"""
Expand Down Expand Up @@ -337,6 +340,8 @@ def convert_to_chakra_node(
ChakraAttr(name="is_cpu_op", bool_val=not pytorch_node.is_gpu_op()),
]
)
if pytorch_node.stream is not None:
chakra_node.attr.append(ChakraAttr(name="stream", int64_val=pytorch_node.stream))
return chakra_node

def get_chakra_node_type_from_pytorch_node(
Expand Down Expand Up @@ -677,6 +682,7 @@ def close_chakra_execution_trace(self, chakra_et: IO[bytes]) -> None:
if chakra_et and not chakra_et.closed:
chakra_et.close()

# ruff: noqa: C901
def simulate_execution(
self,
chakra_nodes: Dict[int, ChakraNode],
Expand Down Expand Up @@ -711,44 +717,57 @@ def simulate_execution(

issued_nodes: Set[int] = set()
current_cpu_node: Optional[Tuple[int, int]] = None
current_gpu_node: Optional[Tuple[int, int]] = None
current_gpu_nodes: Dict[int, Tuple[int, int]] = {}

current_time: int = 0 # Simulated global clock in microseconds

while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, current_gpu_node]):
while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, current_gpu_nodes]):
if ready_cpu_nodes and not current_cpu_node:
cpu_node_id, cpu_node = ready_cpu_nodes.pop(0)
current_cpu_node = (cpu_node_id, current_time)
issued_nodes.add(cpu_node_id)
tid = pytorch_nodes[cpu_node_id].tid
logging.info(
f"Issuing CPU Node ID {cpu_node_id} ({cpu_node.name}) at {current_time}us with duration "
f"{cpu_node.duration_micros}us"
f"{cpu_node.duration_micros}us, tid: {tid}"
)

if ready_gpu_nodes and not current_gpu_node:
gpu_node_id, gpu_node = ready_gpu_nodes.pop(0)
current_gpu_node = (gpu_node_id, current_time)
issued_nodes.add(gpu_node_id)
logging.info(
f"Issuing GPU Node ID {gpu_node_id} ({gpu_node.name}) at {current_time}us with duration "
f"{gpu_node.duration_micros}us"
)
if ready_gpu_nodes:
for gpu_node_id, gpu_node in ready_gpu_nodes[:]:
pytorch_node = pytorch_nodes[gpu_node_id]
stream_id = pytorch_node.stream
if stream_id not in current_gpu_nodes:
ready_gpu_nodes.remove((gpu_node_id, gpu_node))
current_gpu_nodes[stream_id] = (gpu_node_id, current_time)
issued_nodes.add(gpu_node_id)
tid = f"stream {stream_id}"
logging.info(
f"Issuing GPU Node ID {gpu_node_id} ({gpu_node.name}) at {current_time}us on stream "
f"{stream_id} with duration {gpu_node.duration_micros}us, tid: {tid}"
)

current_time += 1

if (
current_cpu_node
and current_time - current_cpu_node[1] >= chakra_nodes[current_cpu_node[0]].duration_micros
):
logging.info(f"CPU Node ID {current_cpu_node[0]} completed at {current_time}us")
cpu_node_id, _ = current_cpu_node
tid = pytorch_nodes[cpu_node_id].tid
logging.info(f"CPU Node ID {cpu_node_id} completed at {current_time}us, tid: {tid}")
current_cpu_node = None

if (
current_gpu_node
and current_time - current_gpu_node[1] >= chakra_nodes[current_gpu_node[0]].duration_micros
):
logging.info(f"GPU Node ID {current_gpu_node[0]} completed at {current_time}us")
current_gpu_node = None
completed_streams = []
for stream_id, (gpu_node_id, start_time) in current_gpu_nodes.items():
if current_time - start_time >= chakra_nodes[gpu_node_id].duration_micros:
logging.info(
f"GPU Node ID {gpu_node_id} on stream {stream_id} completed at {current_time}us, "
f"tid: stream {stream_id}"
)
completed_streams.append(stream_id)

for stream_id in completed_streams:
del current_gpu_nodes[stream_id]

for node_id in list(issued_nodes):
children_ids = parent_to_children_map.get(node_id, [])
Expand Down
4 changes: 2 additions & 2 deletions src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PyTorchNode:
ts (Optional[float]): Timestamp of the node.
inter_thread_dep (Any): Inter-thread dependency of the node.
cat (Any): Category of the node.
stream (Any): Stream associated with the node.
stream (int): Stream associated with the node.
"""

SUPPORTED_VERSIONS = ["1.0.2-chakra.0.0.4", "1.0.3-chakra.0.0.4", "1.1.0-chakra.0.0.4"]
Expand Down Expand Up @@ -102,7 +102,7 @@ def _parse_data_1_0_3_chakra_0_0_4(self, node_data: Dict[str, Any]) -> None:
self.ts = node_data.get("ts")
self.inter_thread_dep = node_data.get("inter_thread_dep")
self.cat = node_data.get("cat")
self.stream = node_data.get("stream")
self.stream = node_data.get("stream", 0)

for attr in node_data.get("attrs", []):
setattr(self, attr["name"], attr["value"])
Expand Down

0 comments on commit 2d16e16

Please sign in to comment.