diff --git a/src/trace_link/kineto_operator.py b/src/trace_link/kineto_operator.py index 4cb588fc..70bad07d 100644 --- a/src/trace_link/kineto_operator.py +++ b/src/trace_link/kineto_operator.py @@ -5,31 +5,25 @@ class KinetoOperator: """ - Represents a single operator in a Kineto trace by default, with fields primarily sourced - from the Kineto traces. In addition to the default fields from Kineto traces, additional - fields have been introduced for postprocessing purposes. These additional fields facilitate - the correlation of PyTorch operators and the enforcement of dependencies among them, - enhancing trace analysis and utility. + Represents a single operator in a Kineto trace. Attributes: - op_dict (Dict[str, Any]): Dictionary containing the operator data. + id (Optional[int]): Identifier of the operator. category (str): Category of the operator. name (str): Name of the operator. - phase (Optional[str]): Phase of the operator. - inclusive_dur (int): Inclusive duration of the operator in microseconds. - exclusive_dur (int): Exclusive duration of the operator in microseconds. - timestamp (int): Timestamp of the operator in microseconds. - external_id (str): External ID associated with the operator. - ev_idx (str): Event index associated with the operator. - tid (int): Thread ID associated with the operator. - pytorch_op (Optional[PyTorchOperator]): Associated PyTorch operator. + phase (Optional[str]): Execution phase of the operator. + inclusive_dur (int): Total duration of the operator, including its children. + exclusive_dur (int): Duration of the operator execution alone. Corresponds to the self time field in chrome://tracing. + timestamp (int): Start time of the operator in microseconds. + external_id (int): An external identifier associated with the operator. + ev_idx (int): Event index of the operator. + tid (int): Thread identifier where the operator was executed. + pytorch_op (Optional[PyTorchOperator]): Corresponding PyTorch operator object. parent_pytorch_op_id (Optional[int]): ID of the parent PyTorch operator. - inter_thread_dep (Optional[int]): ID of the latest CPU node from other - threads before the gap. - stream (Optional[int]): Stream ID associated with the operator. - rf_id (Optional[int]): Record function ID. - correlation (int): Correlation ID used to link CUDA runtime operations - with their GPU counterparts. + inter_thread_dep (Optional[int]): Identifier for inter-thread dependencies. + stream (Optional[int]): CUDA stream identifier associated with the operator. + rf_id (Optional[int]): Record function identifier. + correlation (int): Identifier used to correlate CUDA runtime and GPU operations. """ def __init__(self, kineto_op: Dict[str, Any]) -> None: @@ -40,21 +34,21 @@ def __init__(self, kineto_op: Dict[str, Any]) -> None: kineto_op (Dict[str, Any]): The dictionary representing the operator data. """ - self.op_dict: Dict[str, Any] = kineto_op + self.id: Optional[int] = kineto_op.get("id") self.category: str = kineto_op.get("cat", "") self.name: str = kineto_op.get("name", "") self.phase: Optional[str] = kineto_op.get("ph") self.inclusive_dur: int = kineto_op.get("dur", 0) self.exclusive_dur: int = kineto_op.get("dur", 0) self.timestamp: int = kineto_op.get("ts", 0) - self.external_id: str = kineto_op.get("args", {}).get("External id", "") - self.ev_idx: str = kineto_op.get("args", {}).get("Ev Idx", "") + self.external_id: int = int(kineto_op.get("args", {}).get("External id", -1)) + self.ev_idx: int = int(kineto_op.get("args", {}).get("Ev Idx", -1)) self.tid: int = kineto_op.get("tid", 0) self.pytorch_op: Optional[PyTorchOperator] = None self.parent_pytorch_op_id: Optional[int] = None self.inter_thread_dep: Optional[int] = None - self.stream: Optional[int] = kineto_op.get("args", {}).get("stream") - self.rf_id: Optional[int] = kineto_op.get("args", {}).get("Record function id") + self.stream: Optional[int] = kineto_op.get("args", {}).get("stream", None) + self.rf_id: Optional[int] = kineto_op.get("args", {}).get("Record function id", None) self.correlation: int = kineto_op.get("args", {}).get("correlation", -1) def __repr__(self) -> str: @@ -65,40 +59,65 @@ def __repr__(self) -> str: str: A string representation of the KinetoOperator. """ return ( - f"KinetoOperator(category={self.category}, name={self.name}, phase={self.phase}, " - f"inclusive_dur={self.inclusive_dur}, exclusive_dur={self.exclusive_dur}, " - f"timestamp={self.timestamp}, external_id={self.external_id}, ev_idx={self.ev_idx}, " - f"tid={self.tid}, parent_pytorch_op_id={self.parent_pytorch_op_id}, " - f"inter_thread_dep={self.inter_thread_dep}, stream={self.stream}, rf_id={self.rf_id}, " - f"correlation={self.correlation})" + f"KinetoOperator(id={self.id}, category={self.category}, name={self.name}, " + f"phase={self.phase}, inclusive_dur={self.inclusive_dur}, " + f"exclusive_dur={self.exclusive_dur}, timestamp={self.timestamp}, " + f"external_id={self.external_id}, ev_idx={self.ev_idx}, tid={self.tid}, " + f"parent_pytorch_op_id={self.parent_pytorch_op_id}, inter_thread_dep={self.inter_thread_dep}, " + f"stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})" ) - def is_valid( - self, - category: str, - name_exception: str = "ProfilerStep", - phase: Optional[str] = None, - ) -> bool: + def is_cpu_op(self) -> bool: """ - Checks if the operator matches specified filtering criteria. + Determines if the operator is simulatable based on its category and name. + The categories 'cpu_op' and 'user_annotation' are considered CPU operators. + Notably, 'user_annotation' operators often include the duration of CPU operator launch times. + Ignoring the duration measured in 'user_annotation' can lead to inaccuracies in simulation. + An exception to this is 'ProfilerStep', which should be completely ignored. + Ideally, a more general rule should be developed to identify such exception nodes. - Comment (TODO): - This is legacy code from a previous implementation. Ideally, we should merge this logic - into trace_linker.py. The purpose of is_valid is ambiguous, and it is unclear whether - the function is essential. However, we keep it as it is to avoid breaking downstream - tools. After properly setting up CI/CD pipelines and testing, we can consider removing it. + Returns: + bool: True if the operator is simulatable, False otherwise. + """ + simulatable_categories = {"cpu_op", "user_annotation"} + name_exceptions = {"ProfilerStep"} + if self.category in simulatable_categories and all(exc not in self.name for exc in name_exceptions): + return True + return False - Args: - category (str): The category to check against. - name_exception (str): A name to exclude in the check. - phase (Optional[str]): The phase to check against, if any. + def is_cuda_launch_op(self) -> bool: + """ + Determines whether the operator is a kernel-launching CUDA runtime operator. Returns: - bool: True if the operator matches the criteria, False otherwise. + bool: True if it's a launch operation, otherwise False. """ - return ( - self.category is not None - and name_exception not in self.name - and self.category == category - and (phase is None or self.phase == phase) - ) + cuda_launch_categories = {"cuda_runtime", "cuda_driver"} + cuda_launch_operations = { + "cudaLaunchKernel", + "cudaLaunchKernelExC", + "cudaMemcpy", + "cudaMemcpyAsync", + "cudaMemcpyToSymbol", + "cudaMemcpyFromSymbol", + } + return self.category in cuda_launch_categories and self.name in cuda_launch_operations + + def is_gpu_op(self) -> bool: + """ + Checks if the operator is a GPU-side operator based on its category. + + Returns: + bool: True if it's a GPU-side operation, otherwise False. + """ + gpu_categories = {"kernel", "gpu_memcpy"} + return self.category in gpu_categories + + def is_arrow_op(self) -> bool: + """ + Checks if the operator is categorized as 'ac2g', which stands for arrows from CPU to GPU. + + Returns: + bool: True if the operator is an 'ac2g' type, otherwise False. + """ + return self.category == "ac2g" diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index 292f4c70..ded5506e 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -29,41 +29,32 @@ class TraceLinker: """ Links PyTorch Execution Traces (ET) and Kineto Traces to generate PyTorch ET plus. - This class handles the process of loading, processing, and linking - PyTorch Execution Traces with Kineto Traces, enriching the PyTorch - Execution Trace with detailed performance data. - Attributes: pytorch_et_file (str): Path to the PyTorch execution trace file. kineto_file (str): Path to the Kineto trace file. - pytorch_ops (List[PyTorchOperator]): PyTorch operators from ET trace. - kineto_ops (List[KinetoOperator]): Kineto operators from the trace. - sorted_kineto_ops (List[KinetoOperator]): Sorted list of Kineto operators based on timestamps. - sorted_ts (List[int]): Sorted list of timestamps extracted from Kineto operators for efficient temporal queries. - kineto_ops_by_tid (Dict[int, List[KinetoOperator]]): Operators grouped by thread ID. - kineto_cuda_runtime (Dict[int, KinetoOperator]): Mapping of CUDA runtime - API calls to Kineto operators, indexed by their correlation ID. This - includes operations like `cudaLaunchKernel` and `cudaMemcpyAsync`, - crucial for mapping GPU activities back to their initiating CPU calls. - kineto_ac2g_s_ops (Dict[str, KinetoOperator]): Start ops for CPU to GPU. - kineto_ac2g_f_ops (Dict[str, KinetoOperator]): Final ops for CPU to GPU. - kineto_cpu_launcher_ops (Dict[str, KinetoOperator]): CPU launcher ops. - kineto_gpu_ops (List[KinetoOperator]): GPU operators. - kineto_process_start_time (int): Start time of the process, based on the - earliest operator timestamp. - kineto_process_end_time (int): End time of the process, based on the - latest operator timestamp. - kineto_thread_info (Dict[int, Tuple[int, int]]): Information about threads, - mapping thread IDs to a tuple of start and end times. - kineto_rf_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping from - rf_id to KinetoOperator instances. - pytorch_op_id_to_kineto_ops_map (Dict[int, List[KinetoOperator]]): - Map from PyTorch op IDs to Kineto GPU ops. + pytorch_ops (List[PyTorchOperator]): PyTorch operators. + kineto_cpu_ops (List[KinetoOperator]): Kineto CPU operators. + sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators based on timestamps. + sorted_kineto_cpu_op_ts (List[int]): Sorted list of timestamps extracted from Kineto operators for efficient + temporal queries. + kineto_tid_cpu_ops_map (Dict[int, List[KinetoOperator]]): Kineto CPU operators grouped by thread ID. + kineto_correlation_cuda_runtime_map (Dict[int, KinetoOperator]): Mapping between correlation IDs and + kernel-launching CUDA runtime operators. + kineto_gpu_ops (List[KinetoOperator]): Kineto GPU operators. + kineto_id_arrow_op_map (Dict[int, KinetoOperator]): Arrows from a CPU op to a GPU op. + kineto_id_cuda_launch_op_map (Dict[int, KinetoOperator]): Mapping between external ID and kernel-launching CUDA + runtime operators. + kineto_process_start_time (int): Start time of the process, based on the earliest operator timestamp. + kineto_process_end_time (int): End time of the process, based on the latest operator timestamp. + kineto_thread_info (Dict[int, Tuple[int, int]]): Information about threads, mapping thread IDs to a tuple of + start and end times. + kineto_rf_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping from rf_id to KinetoOperator instances. + pytorch_op_id_to_kineto_ops_map (Dict[int, List[KinetoOperator]]): Map from PyTorch op IDs to Kineto GPU ops. pytorch_op_id_to_inclusive_dur_map (Dict[int, int]): Inclusive duration map for PyTorch ops. - pytorch_op_id_to_inclusive_dur_map (Dict[int, int]): Exclusive duration map for PyTorch ops. + pytorch_op_id_to_exclusive_dur_map (Dict[int, int]): Exclusive duration map for PyTorch ops. pytorch_op_id_to_timestamp_map (Dict[int, int]): Timestamp map for PyTorch ops. - pytorch_op_id_to_inter_thread_dep_map (Dict[int, int]): Mapping of PyTorch - operator IDs to IDs of latest CPU node from other threads before the gap. + pytorch_op_id_to_inter_thread_dep_map (Dict[int, int]): Mapping of PyTorch operator IDs to IDs of latest CPU + node from other threads before the gap. id_assigner (UniqueIdAssigner): Assigns unique IDs to operators. pytorch_et_plus_data (Optional[Dict]): PyTorch ET plus data. logger (logging.Logger): Logger for the class. @@ -79,18 +70,17 @@ def __init__(self, pytorch_et_file: str, kineto_file: str, log_level: str = "INF kineto_file (str): Path to the Kineto trace file. log_level (str): Logging level for the class. """ - self.pytorch_et_file = pytorch_et_file - self.kineto_file = kineto_file + self.pytorch_et_file: str = pytorch_et_file + self.kineto_file: str = kineto_file self.pytorch_ops: List[PyTorchOperator] = [] - self.kineto_ops: List[KinetoOperator] = [] - self.sorted_kineto_ops: List[KinetoOperator] = [] - self.sorted_ts: List[int] = [] - self.kineto_ops_by_tid: Dict[int, List[KinetoOperator]] = {} - self.kineto_cuda_runtime: Dict[int, KinetoOperator] = {} - self.kineto_ac2g_s_ops: Dict[str, KinetoOperator] = {} - self.kineto_ac2g_f_ops: Dict[str, KinetoOperator] = {} - self.kineto_cpu_launcher_ops: Dict[str, KinetoOperator] = {} + self.kineto_cpu_ops: List[KinetoOperator] = [] + self.sorted_kineto_cpu_ops: List[KinetoOperator] = [] + self.sorted_kineto_cpu_op_ts: List[int] = [] + self.kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]] = {} + self.kineto_correlation_cuda_runtime_map: Dict[int, KinetoOperator] = {} self.kineto_gpu_ops: List[KinetoOperator] = [] + self.kineto_id_arrow_op_map: Dict[int, KinetoOperator] = {} + self.kineto_id_cuda_launch_op_map: Dict[int, KinetoOperator] = {} self.kineto_process_start_time: int = 0 self.kineto_process_end_time: int = 0 self.kineto_thread_info: Dict[int, Tuple[int, int]] = {} @@ -102,14 +92,12 @@ def __init__(self, pytorch_et_file: str, kineto_file: str, log_level: str = "INF self.pytorch_op_id_to_inter_thread_dep_map: Dict[int, int] = {} self.id_assigner = UniqueIdAssigner() self.pytorch_et_plus_data: Optional[Dict] = None - self.logger = logging.getLogger(__name__) + self.logger: logging.Logger = logging.getLogger(__name__) self.logger.setLevel(log_level.upper()) def load_traces(self) -> None: """ Loads both PyTorch Execution Traces and Kineto Traces. - This method is a high-level orchestrator that calls specific methods to load - and process the PyTorch and Kineto traces individually. """ self.load_pytorch_et() self.load_kineto_trace() @@ -132,9 +120,8 @@ def extract_pytorch_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]: """ Extracts and sorts nodes from the PyTorch execution trace recursively. - This method traverses the execution trace starting from the provided node, - extracting all the operator nodes recursively, and then returns them sorted - by their identifiers. + This method traverses the execution trace starting from the provided node, extracting all the operator nodes + recursively, and then returns them sorted by their identifiers. Args: node (PyTorchOperator): Starting node for extraction. @@ -155,9 +142,8 @@ def traverse(node: PyTorchOperator): def load_kineto_trace(self) -> None: """ Loads and processes the Kineto Trace. - This method parses the Kineto trace file, creating KinetoOperator instances - for each operator in the trace. It then categorizes and segments these - operators for further processing and linking with PyTorch operators. + This method parses the Kineto trace file, creating KinetoOperator instances for each operator in the trace. + It then categorizes and segments these operators for further processing and linking with PyTorch operators. """ self.logger.info("Starting to load Kineto Trace.") kineto_trace_data = read_dictionary_from_json_file(self.kineto_file) @@ -166,31 +152,30 @@ def load_kineto_trace(self) -> None: key=lambda op: op.timestamp, ) - self.categorize_and_track_kineto_ops(sorted_kineto_ops) - self.construct_kineto_rf_id_map() + self.construct_kineto_data_structures(sorted_kineto_ops) self.calculate_exclusive_dur() - self.sorted_kineto_ops = sorted(self.kineto_ops, key=lambda op: op.timestamp) - self.sorted_kineto_ts = [op.timestamp for op in self.sorted_kineto_ops] + self.sorted_kineto_cpu_ops = sorted(self.kineto_cpu_ops, key=lambda op: op.timestamp) + self.sorted_kineto_cpu_op_ts = [op.timestamp for op in self.sorted_kineto_cpu_ops] self.logger.info( - f"Processed Kineto trace with {len(self.kineto_ops)} CPU ops, " - f"{len(self.kineto_cpu_launcher_ops)} CPU launcher ops, " + f"Processed Kineto trace with {len(self.kineto_cpu_ops)} CPU ops, " + f"{len(self.kineto_id_cuda_launch_op_map)} CPU launcher ops, " f"and {len(self.kineto_gpu_ops)} GPU ops." ) self.logger.info("Kineto Trace loaded successfully.") - def categorize_and_track_kineto_ops(self, kineto_ops: List[KinetoOperator]) -> None: + def construct_kineto_data_structures(self, kineto_ops: List[KinetoOperator]) -> None: """ - Categorizes Kineto operators based on their properties and assigns them to - corresponding groups for CPU, GPU, and other operations. + Constructs necessary data structures required for trace linking from the provided Kineto operators. This method + identifies process start time, end time, thread start time, and end time, and also categorizes operators into + CPU, GPU, and other relevant groups. Args: kineto_ops (List[KinetoOperator]): List of Kineto operators to categorize. Raises: - ValueError: If duplicate correlation IDs are found in 'cuda_runtime' - category operators. + ValueError: If duplicate correlation IDs are found in 'cuda_runtime' category operators. """ self.logger.info("Categorizing Kineto operators and calculating timing boundaries.") process_start_time = sys.maxsize @@ -198,33 +183,35 @@ def categorize_and_track_kineto_ops(self, kineto_ops: List[KinetoOperator]) -> N thread_info = {} for op in kineto_ops: - if op.is_valid("cpu_op") or op.is_valid("user_annotation"): - self.kineto_ops.append(op) - self.kineto_ops_by_tid.setdefault(op.tid, []).append(op) + if op.is_cpu_op(): + self.kineto_cpu_ops.append(op) + self.kineto_tid_cpu_ops_map.setdefault(op.tid, []).append(op) self.logger.debug(f"Added CPU or user annotation op: {op.name}") - elif op.is_valid("ac2g", phase="s"): - self._add_op_to_dict(op, self.kineto_ac2g_s_ops, "id") - elif op.is_valid("ac2g", phase="f"): - self._add_op_to_dict(op, self.kineto_ac2g_f_ops, "id") - elif ( - op.is_valid("cuda_runtime") - and op.name - in [ - "cudaLaunchKernel", - "cudaLaunchKernelExC", - "cudaMemcpyAsync", - ] - ) or (op.category == "cuda_driver" and op.name in ["cuLaunchKernel"]): - self._add_op_to_dict(op, self.kineto_cpu_launcher_ops, "args", "External id") + + elif op.is_cuda_launch_op(): + self.kineto_id_cuda_launch_op_map[op.external_id] = op + if op.correlation in self.kineto_correlation_cuda_runtime_map: + raise ValueError( + f"Duplicate correlation ID {op.correlation} found in self.kineto_id_cuda_launch_op_map." + ) + self.kineto_correlation_cuda_runtime_map[op.correlation] = op self.logger.debug(f"Added CPU launcher op: {op.name}") - elif op.is_valid("kernel") or op.is_valid("gpu_memcpy"): + + elif op.is_gpu_op(): self.kineto_gpu_ops.append(op) self.logger.debug(f"Added GPU op: {op.name}") - if (op.category == "cuda_runtime") or (op.category == "cuda_driver"): - if op.correlation in self.kineto_cuda_runtime: - raise ValueError(f"Duplicate correlation ID {op.correlation} found in cuda_runtime operators.") - self.kineto_cuda_runtime[op.correlation] = op + elif op.is_arrow_op(): + assert (op.phase == "s") or (op.phase == "f") + if op.id is None: + error_msg = ( + f"'id' field is None in Kineto operator, {op}. This is unexpected as 'id' " + "should generally be populated for 'ac2g' operators. Please verify the validity of " + "the Kineto trace and the 'op' data." + ) + self.logger.error(error_msg) + raise KeyError(error_msg) + self.kineto_id_arrow_op_map[op.id] = op # Update timing boundaries if op.tid is not None: @@ -234,32 +221,23 @@ def categorize_and_track_kineto_ops(self, kineto_ops: List[KinetoOperator]) -> N thread_start_end[0] = min(thread_start_end[0], op.timestamp) thread_start_end[1] = max(thread_start_end[1], op.timestamp + op.inclusive_dur) + self.kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in self.kineto_cpu_ops if op.rf_id is not None} + # Apply collected timing info self.kineto_process_start_time = process_start_time self.kineto_process_end_time = process_end_time self.kineto_thread_info = thread_info self.logger.info("Kineto operators categorized and timing boundaries calculated.") - def construct_kineto_rf_id_map(self) -> None: - """ - Constructs a map from rf_id to KinetoOperator instances. - """ - self.kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in self.kineto_ops if op.rf_id is not None} - def calculate_exclusive_dur(self) -> None: """ - Calculates the exclusive duration of each operator in the Kineto traces - in parallel. The exclusive duration is defined as the total duration of - the operator minus any time spent in child operators, effectively - representing the time spent exclusively in that operator. This approach - significantly improves the performance of calculating exclusive durations, - especially for traces with a large number of operators. Additionally, by - processing each thread's operators in parallel, the method takes advantage - of concurrent execution capabilities to further speed up the computation. + Calculates the exclusive duration of each operator in the Kineto traces in parallel. The exclusive duration is + defined as the total duration of the operator minus any time spent in child operators, effectively representing + the time spent exclusively in that operator. """ self.logger.info("Calculating exclusive durations for Kineto operators in parallel.") - def process_ops_for_thread(ops: List["KinetoOperator"]) -> None: + def process_ops_for_thread(ops: List[KinetoOperator]) -> None: self.logger.info(f"Processing {len(ops)} operators in thread.") sorted_ops = sorted(ops, key=lambda op: (op.timestamp, op.inclusive_dur)) for i, op in enumerate(sorted_ops): @@ -285,10 +263,8 @@ def process_ops_for_thread(ops: List["KinetoOperator"]) -> None: # Check if exclusive_dur is not negative or zero if exclusive_dur < 0: error_msg = ( - f"Exclusive duration calculation error for node " - f"'{op.name}' (ts: {op.timestamp}, " - f"inclusive_dur: {op.inclusive_dur}, " - f"rf_id: {op.rf_id}): " + f"Exclusive duration calculation error for node '{op.name}' " + f"(ts: {op.timestamp}, inclusive_dur: {op.inclusive_dur}, rf_id: {op.rf_id}): " f"Duration cannot be less than zero." ) self.logger.error(error_msg) @@ -296,14 +272,12 @@ def process_ops_for_thread(ops: List["KinetoOperator"]) -> None: op.exclusive_dur = exclusive_dur self.logger.debug( - f"Node '{op.name}' (ts: {op.timestamp}, " - f"inclusive_dur: {op.inclusive_dur}, " - f"rf_id: {op.rf_id}) " - f"exclusive duration: {op.exclusive_dur} microseconds." + f"Node '{op.name}' (ts: {op.timestamp}, inclusive_dur: {op.inclusive_dur}, " + f"rf_id: {op.rf_id}) exclusive duration: {op.exclusive_dur} microseconds." ) with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_ops_for_thread, ops) for ops in self.kineto_ops_by_tid.values()] + futures = [executor.submit(process_ops_for_thread, ops) for ops in self.kineto_tid_cpu_ops_map.values()] for future in as_completed(futures): future.result() # Wait for all threads to complete and handle any exceptions @@ -339,47 +313,18 @@ def merge_overlapping_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[ return merged - def _add_op_to_dict(self, op: KinetoOperator, target_dict: Dict, *keys: str) -> None: - """ - Adds an operator to a specific dictionary based on provided keys. - The method navigates through the operator's dictionary using the keys - and adds the operator to the target dictionary. - - Args: - op (KinetoOperator): The operator to be added. - target_dict (Dict): The dictionary to which the operator should be added. - *keys (str): Keys used to navigate through the operator's dictionary. - - Raises: - KeyError: If any of the keys are not found in the operator's dictionary. - """ - value = op.op_dict - for key in keys: - if key not in value: - error_msg = f"Key '{key}' not found in operator dictionary for op {op.name}." - self.logger.error(error_msg) - raise KeyError(error_msg) - value = value[key] - - target_dict[value] = op - def enforce_inter_thread_order(self, threshold: int = 1000) -> None: """ - Enforces order between groups of operators in different threads. In - Kineto traces with multiple threads, operators are executed in turns, - creating groups. This function identifies these groups by detecting - significant gaps in execution within each thread. It then establishes - dependencies between these groups across different threads, ensuring - the final Chakra execution traces reflect inter-thread dependencies - realistically. + Enforces order between groups of operators in different threads. In Kineto traces with multiple threads, + operators are executed in turns, creating groups. This function identifies these groups by detecting + significant gaps in execution within each thread. It then establishes dependencies between these groups across + different threads, ensuring the final Chakra execution traces reflect inter-thread dependencies realistically. - An isolated group is formed when there's a significant gap in execution - within a thread. Each new group relies on the last CPU operator from - other threads, enforcing order and dependency across threads. + An isolated group is formed when there's a significant gap in execution within a thread. Each new group relies + on the last CPU operator from other threads, enforcing order and dependency across threads. Args: - threshold (int): Threshold for significant gap detection in - microseconds, used to define group boundaries. + threshold (int): Threshold for significant gap detection in microseconds, used to define group boundaries. """ self.logger.info("Enforcing inter-thread order in Kineto traces.") @@ -401,8 +346,7 @@ def process_thread( last_cpu_node_rf_id = self.find_last_cpu_node_before_timestamp(ops_by_tid, tid, op.timestamp) if last_cpu_node_rf_id: self.logger.debug( - f"Thread {tid}: Linking op '{op.name}' " - f"to CPU node before gap with rf_id " + f"Thread {tid}: Linking op '{op.name}' to CPU node before gap with rf_id " f"'{last_cpu_node_rf_id}'." ) @@ -411,8 +355,8 @@ def process_thread( with ThreadPoolExecutor() as executor: futures = { - executor.submit(process_thread, tid, ops, self.kineto_ops_by_tid): tid - for tid, ops in self.kineto_ops_by_tid.items() + executor.submit(process_thread, tid, ops, self.kineto_tid_cpu_ops_map): tid + for tid, ops in self.kineto_tid_cpu_ops_map.items() } for future in as_completed(futures): @@ -430,13 +374,11 @@ def find_last_cpu_node_before_timestamp( timestamp: int, ) -> Optional[int]: """ - Finds the last CPU node ID before a given timestamp in threads other - than the excluded one. This ID is used to establish dependencies - between groups across threads. + Finds the last CPU node ID before a given timestamp in threads other than the excluded one. This ID is used + to establish dependencies between groups across threads. Args: - ops_by_tid (Dict[int, List[KinetoOperator]]): Operators grouped by - thread ID. + ops_by_tid (Dict[int, List[KinetoOperator]]): Operators grouped by thread ID. exclude_tid (int): Thread ID to exclude from the search. timestamp (int): Timestamp to compare against. @@ -464,9 +406,8 @@ def find_last_cpu_node_before_timestamp( def link_traces(self) -> None: """ - Initiates the linking process between PyTorch Execution Traces (ET) and - Kineto Traces to produce an enhanced PyTorch Execution Trace (ET+). This - process relies on the assumption of an 'exact match' between these traces. + Initiates the linking process between PyTorch Execution Traces (ET) and Kineto Traces to produce an enhanced + PyTorch Execution Trace (ET+). This process relies on the assumption of an 'exact match' between these traces. """ self.logger.info("Starting the process of linking PyTorch and Kineto traces.") self.add_thread_and_process_annotations() @@ -476,13 +417,11 @@ def link_traces(self) -> None: def add_thread_and_process_annotations(self) -> None: """ - Adds thread and process annotations to Kineto operators based on - previously tracked timing information. These annotations are crucial - for aligning Kineto operators with PyTorch ET nodes, ensuring - completeness and compatibility of trace data for analysis. This method - uses the process start and end times, as well as thread start and end - times, collected during the categorization process to insert - appropriate annotations directly into the Kineto operators list. + Adds thread and process annotations to Kineto operators based on previously tracked timing information. These + annotations are crucial for aligning Kineto operators with PyTorch ET nodes, ensuring completeness and + compatibility of trace data for analysis. This method uses the process start and end times, as well as thread + start and end times, collected during the categorization process to insert appropriate annotations directly + into the Kineto operators list. """ self.logger.info("Adding process and thread annotations to Kineto operators.") @@ -496,7 +435,7 @@ def add_thread_and_process_annotations(self) -> None: "exclusive_dur": 0, # Process exclusive duration not applicable } ) - self.kineto_ops.insert(0, process_annotation_op) + self.kineto_cpu_ops.insert(0, process_annotation_op) self.logger.debug( "Process annotation added with start time {} and duration {}.".format( self.kineto_process_start_time, @@ -513,55 +452,52 @@ def add_thread_and_process_annotations(self) -> None: "name": EXECUTION_TRACE_THREAD_ANNOTATION, "ts": start_ts, "inclusive_dur": inclusive_dur, - # Exclusive duration is set to zero in the final annotation. - # This is to avoid constraining the execution schedule to the - # original trace, allowing more flexibility in analyzing + # Exclusive duration is set to zero in the final annotation. This is to avoid constraining + # the execution schedule to the original trace, allowing more flexibility in analyzing # dependencies without being bound by specific execution timings. "exclusive_dur": 0, } ) # Find the correct position to insert the thread annotation position = next( - (i for i, op in enumerate(self.kineto_ops) if op.tid == tid and op.timestamp >= start_ts), + (i for i, op in enumerate(self.kineto_cpu_ops) if op.tid == tid and op.timestamp >= start_ts), None, ) if position is not None: - self.kineto_ops.insert(position, thread_annotation_op) + self.kineto_cpu_ops.insert(position, thread_annotation_op) else: - self.kineto_ops.append(thread_annotation_op) + self.kineto_cpu_ops.append(thread_annotation_op) self.logger.debug( "Thread {} annotation added with start time {} and duration {}.".format(tid, start_ts, inclusive_dur) ) def map_pytorch_to_kineto_ops(self) -> None: """ - Maps PyTorch ET nodes to corresponding Kineto operators, ensuring - each PyTorch node has a matching Kineto operator. + Maps PyTorch ET nodes to corresponding Kineto operators, ensuring each PyTorch node has a matching Kineto + operator. """ self.logger.info("Mapping PyTorch ET nodes to Kineto operators.") cpu_ev_idx_to_gpu_ops_map = self.group_gpu_ops_by_cpu_launchers() pytorch_ops_count = len(self.pytorch_ops) - kineto_ops_count = len(self.kineto_ops) + kineto_ops_count = len(self.kineto_cpu_ops) if pytorch_ops_count > kineto_ops_count: # The specific comment is placed within the if block as requested. self.logger.warning( - f"Number of PyTorch operators ({pytorch_ops_count}) is larger " - f"than the number of Kineto operators ({kineto_ops_count}). " - f"It is expected that the number of PyTorch operators (CPU only) " - f"will be smaller than the number of Kineto operators (CPU and GPU)." - f" A warning is logged if this is not the case, which is a rare " - f"but possible scenario." + f"Number of PyTorch operators ({pytorch_ops_count}) is larger than the number of Kineto operators " + f"({kineto_ops_count}). Expected PyTorch ops (CPU only) to be fewer than Kineto ops (CPU and GPU). " + f"Logging this rare but possible scenario." ) for _, pytorch_op in enumerate(self.pytorch_ops): if (pytorch_op.rf_id is not None) and (pytorch_op.rf_id in self.kineto_rf_id_to_kineto_op_map): kineto_op = self.kineto_rf_id_to_kineto_op_map[pytorch_op.rf_id] if kineto_op is None: - self.logger.warning( - f"No corresponding Kineto op found for PyTorch op " - f"ID: {pytorch_op.id}, Name: '{pytorch_op.name}'." - ) + if kineto_op is None: + self.logger.warning( + f"No corresponding Kineto op found for PyTorch op ID: " + f"{pytorch_op.id}, Name: '{pytorch_op.name}'." + ) continue self.link_ops(pytorch_op, kineto_op, cpu_ev_idx_to_gpu_ops_map) @@ -571,12 +507,10 @@ def group_gpu_ops_by_cpu_launchers(self) -> Dict[str, List[KinetoOperator]]: """ Groups GPU operators based on their corresponding CPU launchers. - This is determined by the 'ev_idx' which links GPU operators to their - initiating CPU launcher events. + This is determined by the 'ev_idx' which links GPU operators to their initiating CPU launcher events. Returns: - Dict[str, List[KinetoOperator]]: Mapping from CPU launch event indices - to GPU operators. + Dict[str, List[KinetoOperator]]: Mapping from CPU launch event indices to GPU operators. Raises: ValueError: If 'ev_idx' is missing for any GPU operator. @@ -605,11 +539,9 @@ def group_gpu_ops_by_cpu_launchers(self) -> Dict[str, List[KinetoOperator]]: def find_parent_cpu_op(self, kineto_gpu_op: KinetoOperator) -> Optional[KinetoOperator]: """ - Finds the parent CPU operator for a given GPU operator by identifying - the corresponding CUDA runtime operator through the correlation ID. It - then locates the closest preceding CPU operator based on the CUDA runtime's - timestamp, considering the temporal distance between the GPU operation's - start and the initiating CPU operation. + Finds the parent CPU operator for a given GPU operator by identifying the corresponding CUDA runtime operator + through the correlation ID. It then locates the closest preceding CPU operator based on the CUDA runtime's + timestamp, considering the temporal distance between the GPU operation's start and the initiating CPU operation. Args: kineto_gpu_op (KinetoOperator): The GPU operator. @@ -618,29 +550,28 @@ def find_parent_cpu_op(self, kineto_gpu_op: KinetoOperator) -> Optional[KinetoOp Optional[KinetoOperator]: The parent CPU operator if found. Raises: - ValueError: If no CUDA runtime operator is found for the given - correlation ID. + ValueError: If no CUDA runtime operator is found for the given correlation ID. """ - if kineto_gpu_op.correlation not in self.kineto_cuda_runtime: + if kineto_gpu_op.correlation not in self.kineto_correlation_cuda_runtime_map: warning_msg = "No CUDA runtime operator found for correlation ID {kineto_gpu_op.correlation}." self.logger.warning(warning_msg) return None - kineto_cuda_runtime_op = self.kineto_cuda_runtime[kineto_gpu_op.correlation] - kineto_gpu_op.tid = kineto_cuda_runtime_op.tid + kineto_runtime_op = self.kineto_correlation_cuda_runtime_map[kineto_gpu_op.correlation] + kineto_gpu_op.tid = kineto_runtime_op.tid self.logger.debug( - f"Found CUDA runtime operation '{kineto_cuda_runtime_op.name}' for GPU operator '{kineto_gpu_op.name}'." + f"Found CUDA runtime operation '{kineto_runtime_op.name}' for GPU operator '{kineto_gpu_op.name}'." ) kineto_gpu_op.timestamp = self.get_start_timestamp_for_gpu_op(kineto_gpu_op) # Find the closest CPU operator that precedes the CUDA runtime operation - parent_cpu_op = self.find_closest_op(kineto_gpu_op, self.sorted_kineto_ops, kineto_cuda_runtime_op.timestamp) + parent_cpu_op = self.find_closest_op(kineto_gpu_op, self.sorted_kineto_cpu_ops, kineto_runtime_op.timestamp) if not parent_cpu_op: self.logger.warning( f"No parent CPU operator found for GPU operator '{kineto_gpu_op.name}' " - f"linked to CUDA runtime operation '{kineto_cuda_runtime_op.name}' " - f"(ts: {kineto_cuda_runtime_op.timestamp})." + f"linked to CUDA runtime operation '{kineto_runtime_op.name}' " + f"(ts: {kineto_runtime_op.timestamp})." ) return parent_cpu_op @@ -658,21 +589,19 @@ def get_start_timestamp_for_gpu_op(self, kineto_gpu_op: KinetoOperator) -> int: Raises: RuntimeError: If no valid timestamp is found for the GPU operator. """ - if kineto_gpu_op.external_id in self.kineto_cpu_launcher_ops: - cpu_launcher_op = self.kineto_cpu_launcher_ops[kineto_gpu_op.external_id] + if kineto_gpu_op.external_id in self.kineto_id_cuda_launch_op_map: + cpu_launcher_op = self.kineto_id_cuda_launch_op_map[kineto_gpu_op.external_id] return cpu_launcher_op.timestamp + cpu_launcher_op.inclusive_dur - if kineto_gpu_op.external_id in self.kineto_ac2g_s_ops: - return self.kineto_ac2g_s_ops[kineto_gpu_op.external_id].timestamp - if kineto_gpu_op.external_id in self.kineto_ac2g_f_ops: - return self.kineto_ac2g_f_ops[kineto_gpu_op.external_id].timestamp + if kineto_gpu_op.external_id in self.kineto_id_arrow_op_map: + return self.kineto_id_arrow_op_map[kineto_gpu_op.external_id].timestamp raise RuntimeError(f"No valid timestamp found for GPU operator: {kineto_gpu_op.name}") def find_closest_op( self, kineto_gpu_op: KinetoOperator, kineto_ops: List[KinetoOperator], ts: int ) -> Optional[KinetoOperator]: """ - Finds the Kineto operator that is closest in start time to a given timestamp - and has a duration that covers the timestamp. + Finds the Kineto operator that is closest in start time to a given timestamp and has a duration that covers + the timestamp. Args: kineto_gpu_op (KinetoOperator): The GPU operator being compared. @@ -683,7 +612,7 @@ def find_closest_op( Optional[KinetoOperator]: The closest Kineto operator if found. """ # Searching for the closest timestamp index - index = bisect.bisect_left(self.sorted_kineto_ts, ts) + index = bisect.bisect_left(self.sorted_kineto_cpu_op_ts, ts) if index == 0: # All operators are later than the timestamp @@ -759,11 +688,11 @@ def link_gpu_ops(self, pytorch_op: PyTorchOperator, kineto_gpu_ops: List[KinetoO def construct_et_plus_data(self) -> None: """ - Constructs the enhanced PyTorch Execution Trace (ET+) data structure by - integrating Kineto data into the original PyTorch Execution Trace. + Constructs the enhanced PyTorch Execution Trace (ET+) data structure by integrating Kineto data into the + original PyTorch Execution Trace. - This method enriches the PyTorch execution trace with detailed performance - data from the Kineto trace, offering a comprehensive view of the execution. + This method enriches the PyTorch execution trace with detailed performance data from the Kineto trace, offering + a comprehensive view of the execution. """ self.logger.info("Constructing ET+ data.") with open(self.pytorch_et_file, "r") as file: @@ -786,8 +715,8 @@ def construct_et_plus_data(self) -> None: def process_op_and_dependents(self, op: Dict) -> List[Dict]: """ - Processes a single operator in the PyTorch ET data, assigns a new unique ID, - and processes any dependent GPU operators. + Processes a single operator in the PyTorch ET data, assigns a new unique ID, and processes any dependent GPU + operators. Args: op (Dict): The operator to be processed. @@ -821,9 +750,8 @@ def process_op_and_dependents(self, op: Dict) -> List[Dict]: def process_dependent_gpu_ops(self, cpu_op: Dict, orig_op_id: int) -> List[Dict]: """ - Creates and returns a list of GPU operators that are dependent on a - specific CPU operator, sorted by their timestamp. The GPU operators are - deep copies of the existing operators with updated IDs and other relevant + Creates and returns a list of GPU operators that are dependent on a specific CPU operator, sorted by their + timestamp. The GPU operators are deep copies of the existing operators with updated IDs and other relevant fields from the CPU operator. Args: diff --git a/src/trace_link/unique_id_assigner.py b/src/trace_link/unique_id_assigner.py index a5179175..a26d0441 100644 --- a/src/trace_link/unique_id_assigner.py +++ b/src/trace_link/unique_id_assigner.py @@ -5,16 +5,14 @@ class UniqueIdAssigner: """ Assigns unique IDs to items, ensuring each item gets a distinct ID. - This class is used to maintain a consistent and unique mapping of original - identifiers to new unique identifiers. It's particularly useful in scenarios - where the uniqueness of IDs across different entities or iterations needs to + This class is used to maintain a consistent and unique mapping of original identifiers to new unique identifiers. + It's particularly useful in scenarios where the uniqueness of IDs across different entities or iterations needs to be preserved. Attributes: next_id (int): The next unique ID to be assigned. - original_to_new_ids (Dict[int, int]): A mapping from original IDs to their - corresponding new unique IDs. This helps in retrieving already assigned - unique IDs and ensures the same original ID always maps to the same + original_to_new_ids (Dict[int, int]): A mapping from original IDs to their corresponding new unique IDs. This + helps in retrieving already assigned unique IDs and ensures the same original ID always maps to the same unique ID. """ @@ -22,13 +20,13 @@ def __init__(self) -> None: """ Initializes the UniqueIdAssigner with a starting ID of 0. """ - self.next_id = 0 + self.next_id: int = 0 self.original_to_new_ids: Dict[int, int] = {} def assign_or_retrieve_id(self, original_id: int) -> int: """ - Assigns a new unique ID to the given original ID if it doesn't have one already; - otherwise, returns the previously assigned unique ID. + Assigns a new unique ID to the given original ID if it doesn't have one already; otherwise, returns the + previously assigned unique ID. Args: original_id (int): The original ID for which a unique ID is needed. @@ -46,8 +44,7 @@ def generate_new_id(self) -> int: """ Generates a new unique ID without needing an original ID. - This is useful for cases where new entities are created that do not - have an existing identifier. + This is useful for cases where new entities are created that do not have an existing identifier. Returns: int: A new unique ID. @@ -60,14 +57,12 @@ def lookup_new_id(self, original_id: int) -> int: """ Retrieves the new unique ID for a given original ID, if it has been assigned. - This method is useful for checking if a unique ID has already been - assigned to an original ID and retrieving it. + This method is useful for checking if a unique ID has already been assigned to an original ID and retrieving it. Args: original_id (int): The original ID to look up. Returns: - int: The new unique ID if it has been assigned, otherwise returns - the original ID. + int: The new unique ID if it has been assigned, otherwise returns the original ID. """ return self.original_to_new_ids.get(original_id, original_id) diff --git a/tests/trace_link/test_kineto_operator.py b/tests/trace_link/test_kineto_operator.py index 8561609b..0c3f131b 100644 --- a/tests/trace_link/test_kineto_operator.py +++ b/tests/trace_link/test_kineto_operator.py @@ -13,7 +13,7 @@ def sample_operator_data(): "dur": 100, "ts": 1590000000, "tid": 1234, - "args": {"External id": "ext123", "Ev Idx": "ev456", "stream": 7, "Record function id": 12, "correlation": 99}, + "args": {"External id": "123", "Ev Idx": "456", "stream": 7, "Record function id": 12, "correlation": 99}, } @@ -26,8 +26,8 @@ def test_init_kineto_operator(sample_operator_data): assert operator.inclusive_dur == 100 assert operator.exclusive_dur == 100 assert operator.timestamp == 1590000000 - assert operator.external_id == "ext123" - assert operator.ev_idx == "ev456" + assert operator.external_id == 123 + assert operator.ev_idx == 456 assert operator.tid == 1234 assert operator.stream == 7 assert operator.rf_id == 12 @@ -41,20 +41,9 @@ def test_repr_method(sample_operator_data): """Test the __repr__ method output.""" operator = KinetoOperator(sample_operator_data) expected_repr = ( - "KinetoOperator(category=Kernel, name=cudaLaunchKernel, phase=X, " - "inclusive_dur=100, exclusive_dur=100, timestamp=1590000000, external_id=ext123, ev_idx=ev456, " + "KinetoOperator(id=None, category=Kernel, name=cudaLaunchKernel, phase=X, " + "inclusive_dur=100, exclusive_dur=100, timestamp=1590000000, external_id=123, ev_idx=456, " "tid=1234, parent_pytorch_op_id=None, inter_thread_dep=None, stream=7, rf_id=12, " "correlation=99)" ) assert repr(operator) == expected_repr - - -def test_is_valid_method(sample_operator_data): - """Test the is_valid method under various conditions.""" - operator = KinetoOperator(sample_operator_data) - assert operator.is_valid(category="Kernel") # Matching category - assert not operator.is_valid(category="Memory") # Non-matching category - assert operator.is_valid(category="Kernel", name_exception="cudaMalloc") # Matching category, name not excluded - assert not operator.is_valid(category="Kernel", name_exception="cudaLaunchKernel") # Name excluded - assert operator.is_valid(category="Kernel", phase="X") # Matching phase - assert not operator.is_valid(category="Kernel", phase="B") # Non-matching phase