Skip to content

Commit

Permalink
Add simulate option to PyTorchConverter
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 11, 2024
1 parent cce606b commit c3ec4ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/converter/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main() -> None:
parser.add_argument(
"--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes"
)
parser.add_argument("--simulate", action="store_true", help="Run simulate_execution if set")
parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename")
args = parser.parse_args()

Expand All @@ -47,7 +48,7 @@ def main() -> None:
converter = TextConverter(args.input_filename, args.output_filename, args.num_npus, args.num_passes)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorchConverter(args.input_filename, args.output_filename)
converter = PyTorchConverter(args.input_filename, args.output_filename, simulate=args.simulate)
converter.convert()
else:
supported_types = ["Text", "PyTorch"]
Expand Down
7 changes: 5 additions & 2 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ class PyTorchConverter:
output_filename (str): Output file name for the converted Chakra trace.
"""

def __init__(self, input_filename: str, output_filename: str) -> None:
def __init__(self, input_filename: str, output_filename: str, simulate: bool = False) -> None:
"""
Initialize the PyTorch to Chakra converter. It sets up necessary attributes and prepares the environment.
Args:
input_filename (str): Name of the input file containing PyTorch execution trace.
output_filename (str): Name of the output file for the converted Chakra trace.
simulate (bool): Whether to run simulate_execution after conversion.
"""
self.input_filename = input_filename
self.output_filename = output_filename
self.simulate = simulate

def convert(self) -> None:
"""Convert PyTorch execution traces into the Chakra format."""
Expand Down Expand Up @@ -74,7 +76,8 @@ def convert(self) -> None:
chakra_nodes,
)
self.close_chakra_execution_trace(chakra_et)
self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map)
if self.simulate:
self.simulate_execution(chakra_nodes, pytorch_nodes, parent_to_children_map)

def load_pytorch_execution_traces(self) -> Dict:
"""
Expand Down

0 comments on commit c3ec4ee

Please sign in to comment.