Skip to content

Commit

Permalink
Merge pull request #38 from mlcommons/ruff
Browse files Browse the repository at this point in the history
Migrate to ruff
  • Loading branch information
srinivas212 authored Apr 11, 2024
2 parents 07c94f0 + 826fc74 commit dfbfea9
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 428 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ jobs:

- name: Install dependencies
run: |
pip install flake8
pip install pyre-check
pip install -r requirements-dev.txt
pip install .
- name: Run Flake8
run: flake8 .
- name: Run ruff
run: ruff format --check --diff .

- name: Run Pyre Check
run: pyre check
63 changes: 18 additions & 45 deletions et_converter/et_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from .text2chakra_converter import Text2ChakraConverter
from .pytorch2chakra_converter import PyTorch2ChakraConverter


def get_logger(log_filename: str) -> logging.Logger:
formatter = logging.Formatter(
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")
formatter = logging.Formatter("%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")

file_handler = FileHandler(log_filename, mode="w")
file_handler.setLevel(logging.DEBUG)
Expand All @@ -29,44 +28,23 @@ def get_logger(log_filename: str) -> logging.Logger:

return logger


def main() -> None:
parser = argparse.ArgumentParser(
description="Execution Trace Converter")
parser.add_argument(
"--input_type",
type=str,
default=None,
required=True,
help="Input execution trace type")
parser = argparse.ArgumentParser(description="Execution Trace Converter")
parser.add_argument("--input_type", type=str, default=None, required=True, help="Input execution trace type")
parser.add_argument(
"--input_filename",
type=str,
default=None,
required=True,
help="Input execution trace filename")
"--input_filename", type=str, default=None, required=True, help="Input execution trace filename"
)
parser.add_argument(
"--output_filename",
type=str,
default=None,
required=True,
help="Output Chakra execution trace filename")
"--output_filename", type=str, default=None, required=True, help="Output Chakra execution trace filename"
)
parser.add_argument(
"--num_npus",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of NPUs in a system")
"--num_npus", type=int, default=None, required="Text" in sys.argv, help="Number of NPUs in a system"
)
parser.add_argument(
"--num_passes",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of training passes")
parser.add_argument(
"--log_filename",
type=str,
default="debug.log",
help="Log filename")
"--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes"
)
parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename")
args = parser.parse_args()

logger = get_logger(args.log_filename)
Expand All @@ -75,17 +53,11 @@ def main() -> None:
try:
if args.input_type == "Text":
converter = Text2ChakraConverter(
args.input_filename,
args.output_filename,
args.num_npus,
args.num_passes,
logger)
args.input_filename, args.output_filename, args.num_npus, args.num_passes, logger
)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorch2ChakraConverter(
args.input_filename,
args.output_filename,
logger)
converter = PyTorch2ChakraConverter(args.input_filename, args.output_filename, logger)
converter.convert()
else:
logger.error(f"{args.input_type} unsupported")
Expand All @@ -95,5 +67,6 @@ def main() -> None:
logger.debug(traceback.format_exc())
sys.exit(1)


if __name__ == "__main__":
main()
117 changes: 53 additions & 64 deletions et_converter/pytorch2chakra_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ class PyTorch2ChakraConverter:
dependencies.
"""

def __init__(
self,
input_filename: str,
output_filename: str,
logger: logging.Logger
) -> None:
def __init__(self, input_filename: str, output_filename: str, logger: logging.Logger) -> None:
"""
Initializes the PyTorch to Chakra converter. It sets up necessary
attributes and prepares the environment for the conversion process.
Expand Down Expand Up @@ -157,8 +152,9 @@ def convert(self) -> None:
self.open_chakra_execution_trace()

for pytorch_nid, pytorch_node in self.pytorch_nodes.items():
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\
or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL):
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP) or (
pytorch_node.get_op_type() == PyTorchNodeType.LABEL
):
chakra_node = self.convert_to_chakra_node(pytorch_node)
self.chakra_nodes[chakra_node.id] = chakra_node

Expand All @@ -167,11 +163,12 @@ def convert(self) -> None:

if chakra_node.type == COMM_COLL_NODE:
collective_comm_type = self.get_collective_comm_type(pytorch_node.name)
chakra_gpu_node.attr.extend([
ChakraAttr(name="comm_type",
int64_val=collective_comm_type),
ChakraAttr(name="comm_size",
int64_val=pytorch_gpu_node.comm_size)])
chakra_gpu_node.attr.extend(
[
ChakraAttr(name="comm_type", int64_val=collective_comm_type),
ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size),
]
)

self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node

Expand Down Expand Up @@ -229,14 +226,10 @@ def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None:
self.pytorch_finish_ts = pytorch_et_data["finish_ts"]

pytorch_nodes = pytorch_et_data["nodes"]
pytorch_node_objects = {
node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes
}
pytorch_node_objects = {node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes}
self._establish_parent_child_relationships(pytorch_node_objects)

def _establish_parent_child_relationships(
self, pytorch_node_objects: Dict[int, PyTorchNode]
) -> None:
def _establish_parent_child_relationships(self, pytorch_node_objects: Dict[int, PyTorchNode]) -> None:
"""
Establishes parent-child relationships among PyTorch nodes and counts
the node types.
Expand All @@ -252,7 +245,7 @@ def _establish_parent_child_relationships(
"gpu_op": 0,
"record_param_comms_op": 0,
"nccl_op": 0,
"root_op": 0
"root_op": 0,
}

# Establish parent-child relationships
Expand All @@ -271,8 +264,10 @@ def _establish_parent_child_relationships(
if pytorch_node.is_nccl_op():
parent_node.nccl_node = pytorch_node

if pytorch_node.name in ["[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]"]:
if pytorch_node.name in [
"[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]",
]:
self.pytorch_root_nids.append(pytorch_node.id)
node_type_counts["root_op"] += 1

Expand Down Expand Up @@ -333,17 +328,19 @@ def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode:
chakra_node.outputs.values = str(pytorch_node.outputs)
chakra_node.outputs.shapes = str(pytorch_node.output_shapes)
chakra_node.outputs.types = str(pytorch_node.output_types)
chakra_node.attr.extend([
ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id),
ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent),
ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id),
ChakraAttr(name="scope", int64_val=pytorch_node.scope),
ChakraAttr(name="tid", int64_val=pytorch_node.tid),
ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid),
ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema),
ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()),
ChakraAttr(name="ts", int64_val=pytorch_node.ts)
])
chakra_node.attr.extend(
[
ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id),
ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent),
ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id),
ChakraAttr(name="scope", int64_val=pytorch_node.scope),
ChakraAttr(name="tid", int64_val=pytorch_node.tid),
ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid),
ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema),
ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()),
ChakraAttr(name="ts", int64_val=pytorch_node.ts),
]
)
return chakra_node

def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> ChakraNodeType:
Expand All @@ -356,9 +353,7 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> C
Returns:
int: The corresponding Chakra node type.
"""
if pytorch_node.is_gpu_op() and (
"ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name
):
if pytorch_node.is_gpu_op() and ("ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name):
return COMM_COLL_NODE
elif ("c10d::" in pytorch_node.name) or ("nccl:" in pytorch_node.name):
return COMM_COLL_NODE
Expand Down Expand Up @@ -392,8 +387,10 @@ def get_collective_comm_type(self, name: str) -> int:
if key.lower() in name.lower():
return value

raise ValueError(f"'{name}' not found in collective communication mapping. "
"Please add this collective communication name to the mapping.")
raise ValueError(
f"'{name}' not found in collective communication mapping. "
"Please add this collective communication name to the mapping."
)

def is_root_node(self, node):
"""
Expand All @@ -412,8 +409,7 @@ def is_root_node(self, node):
Returns:
bool: True if the node is a root node, False otherwise.
"""
if node.name in ["[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]"]:
if node.name in ["[pytorch|profiler|execution_graph|thread]", "[pytorch|profiler|execution_trace|thread]"]:
return True

def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None:
Expand Down Expand Up @@ -591,9 +587,7 @@ def dfs(node_id: int, path: List[int]) -> bool:
bool: True if a cycle is detected, False otherwise.
"""
if node_id in stack:
cycle_nodes = " -> ".join(
[self.chakra_nodes[n].name for n in path + [node_id]]
)
cycle_nodes = " -> ".join([self.chakra_nodes[n].name for n in path + [node_id]])
self.logger.error(f"Cyclic dependency detected: {cycle_nodes}")
return True
if node_id in visited:
Expand All @@ -611,10 +605,7 @@ def dfs(node_id: int, path: List[int]) -> bool:

for node_id in self.chakra_nodes:
if dfs(node_id, []):
raise Exception(
f"Cyclic dependency detected starting from node "
f"{self.chakra_nodes[node_id].name}"
)
raise Exception(f"Cyclic dependency detected starting from node " f"{self.chakra_nodes[node_id].name}")

def write_chakra_et(self) -> None:
"""
Expand Down Expand Up @@ -642,7 +633,7 @@ def _write_global_metadata(self) -> None:
ChakraAttr(name="pid", uint64_val=self.pytorch_pid),
ChakraAttr(name="time", string_val=self.pytorch_time),
ChakraAttr(name="start_ts", uint64_val=self.pytorch_start_ts),
ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts)
ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts),
]
)
encode_message(self.chakra_et, global_metadata)
Expand Down Expand Up @@ -684,21 +675,18 @@ def simulate_execution(self) -> None:
execution based on the readiness determined by dependency resolution.
A simplistic global clock is used to model the execution time.
"""
self.logger.info("Simulating execution of Chakra nodes based on data "
"dependencies.")
self.logger.info("Simulating execution of Chakra nodes based on data " "dependencies.")

# Initialize queues for ready CPU and GPU nodes
ready_cpu_nodes = [
(node_id, self.chakra_nodes[node_id])
for node_id in self.chakra_nodes
if not self.chakra_nodes[node_id].data_deps and
not self.pytorch_nodes[node_id].is_gpu_op()
if not self.chakra_nodes[node_id].data_deps and not self.pytorch_nodes[node_id].is_gpu_op()
]
ready_gpu_nodes = [
(node_id, self.chakra_nodes[node_id])
for node_id in self.chakra_nodes
if not self.chakra_nodes[node_id].data_deps and
self.pytorch_nodes[node_id].is_gpu_op()
if not self.chakra_nodes[node_id].data_deps and self.pytorch_nodes[node_id].is_gpu_op()
]
ready_cpu_nodes.sort(key=lambda x: x[1].id)
ready_gpu_nodes.sort(key=lambda x: x[1].id)
Expand All @@ -709,8 +697,7 @@ def simulate_execution(self) -> None:

current_time: int = 0 # Simulated global clock in microseconds

while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node,
current_gpu_node]):
while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, current_gpu_node]):
if ready_cpu_nodes and not current_cpu_node:
cpu_node_id, cpu_node = ready_cpu_nodes.pop(0)
current_cpu_node = (cpu_node_id, current_time)
Expand All @@ -731,16 +718,18 @@ def simulate_execution(self) -> None:

current_time += 1

if current_cpu_node and current_time - current_cpu_node[1] >= \
self.chakra_nodes[current_cpu_node[0]].duration_micros:
self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed "
f"at {current_time}us")
if (
current_cpu_node
and current_time - current_cpu_node[1] >= self.chakra_nodes[current_cpu_node[0]].duration_micros
):
self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed " f"at {current_time}us")
current_cpu_node = None

if current_gpu_node and current_time - current_gpu_node[1] >= \
self.chakra_nodes[current_gpu_node[0]].duration_micros:
self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed "
f"at {current_time}us")
if (
current_gpu_node
and current_time - current_gpu_node[1] >= self.chakra_nodes[current_gpu_node[0]].duration_micros
):
self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed " f"at {current_time}us")
current_gpu_node = None

for node_id in list(issued_nodes):
Expand Down
16 changes: 8 additions & 8 deletions et_converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __init__(self, node_data: Dict[str, Any]) -> None:
PyTorch node.
"""
self.node_data = node_data
self.data_deps: List['PyTorchNode'] = []
self.children: List['PyTorchNode'] = []
self.gpu_children: List['PyTorchNode'] = []
self.record_param_comms_node: Optional['PyTorchNode'] = None
self.nccl_node: Optional['PyTorchNode'] = None
self.data_deps: List["PyTorchNode"] = []
self.children: List["PyTorchNode"] = []
self.gpu_children: List["PyTorchNode"] = []
self.record_param_comms_node: Optional["PyTorchNode"] = None
self.nccl_node: Optional["PyTorchNode"] = None

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -527,7 +527,7 @@ def is_gpu_op(self) -> bool:
"""
return self.has_cat()

def add_data_dep(self, parent_node: 'PyTorchNode') -> None:
def add_data_dep(self, parent_node: "PyTorchNode") -> None:
"""
Adds a data-dependent parent node to this node.
Expand All @@ -536,7 +536,7 @@ def add_data_dep(self, parent_node: 'PyTorchNode') -> None:
"""
self.data_deps.append(parent_node)

def add_child(self, child_node: 'PyTorchNode') -> None:
def add_child(self, child_node: "PyTorchNode") -> None:
"""
Adds a child node to this node.
Expand All @@ -545,7 +545,7 @@ def add_child(self, child_node: 'PyTorchNode') -> None:
"""
self.children.append(child_node)

def add_gpu_child(self, gpu_child_node: 'PyTorchNode') -> None:
def add_gpu_child(self, gpu_child_node: "PyTorchNode") -> None:
"""
Adds a child GPU node for this node.
Expand Down
Loading

0 comments on commit dfbfea9

Please sign in to comment.