From c3ec4eed07db1f01d46ca02379ad40e462d03c05 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 11 Jul 2024 08:15:48 -0400 Subject: [PATCH] Add simulate option to PyTorchConverter --- src/converter/converter.py | 3 ++- src/converter/pytorch_converter.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/converter/converter.py b/src/converter/converter.py index 606fb94c..9aecf507 100644 --- a/src/converter/converter.py +++ b/src/converter/converter.py @@ -36,6 +36,7 @@ def main() -> None: parser.add_argument( "--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes" ) + parser.add_argument("--simulate", action="store_true", help="Run simulate_execution if set") parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename") args = parser.parse_args() @@ -47,7 +48,7 @@ def main() -> None: converter = TextConverter(args.input_filename, args.output_filename, args.num_npus, args.num_passes) converter.convert() elif args.input_type == "PyTorch": - converter = PyTorchConverter(args.input_filename, args.output_filename) + converter = PyTorchConverter(args.input_filename, args.output_filename, simulate=args.simulate) converter.convert() else: supported_types = ["Text", "PyTorch"] diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 0e49b2dd..c0e7995c 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -33,16 +33,18 @@ class PyTorchConverter: output_filename (str): Output file name for the converted Chakra trace. """ - def __init__(self, input_filename: str, output_filename: str) -> None: + def __init__(self, input_filename: str, output_filename: str, simulate: bool = False) -> None: """ Initialize the PyTorch to Chakra converter. It sets up necessary attributes and prepares the environment. Args: input_filename (str): Name of the input file containing PyTorch execution trace. output_filename (str): Name of the output file for the converted Chakra trace. + simulate (bool): Whether to run simulate_execution after conversion. """ self.input_filename = input_filename self.output_filename = output_filename + self.simulate = simulate def convert(self) -> None: """Convert PyTorch execution traces into the Chakra format.""" @@ -74,7 +76,8 @@ def convert(self) -> None: chakra_nodes, ) self.close_chakra_execution_trace(chakra_et) - self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map) + if self.simulate: + self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map) def load_pytorch_execution_traces(self) -> Dict: """