diff --git a/requirements-dev.txt b/requirements-dev.txt index b86630327a7..c0b50bacbe9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ types-requests<2.31 types-setuptools types-cachetools types-pyvmomi +types-networkx # testing pytest diff --git a/sky/dag.py b/sky/dag.py index 0f74484518e..3b789b96ddd 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -36,7 +36,7 @@ def __init__(self, name: Optional[str] = None) -> None: """ self.name = name self._task_name_lookup: Dict[str, 'task.Task'] = {} - self.graph = nx.DiGraph() + self.graph: nx.DiGraph['task.Task'] = nx.DiGraph() @property def tasks(self) -> List['task.Task']: @@ -213,7 +213,7 @@ def __exit__(self, exec_type: Any, exec_val: Any, exec_tb: Any) -> None: def __repr__(self) -> str: """Return a string representation of the DAG.""" - task_info = [] + task_info: List[str] = [] for task in self.tasks: downstream = self.get_downstream(task) downstream_names = ','.join( @@ -236,11 +236,10 @@ def is_chain(self) -> bool: True if the DAG is a linear chain, False otherwise. """ nodes = list(self.graph.nodes) - out_degrees = [self.graph.out_degree(node) for node in nodes] return (len(nodes) <= 1 or - (all(degree <= 1 for degree in out_degrees) and - sum(degree == 0 for degree in out_degrees) == 1)) + (all(degree <= 1 for _, degree in self.graph.out_degree) and + sum(degree == 0 for _, degree in self.graph.out_degree) == 1)) def is_connected_dag(self) -> bool: """Check if the graph is a connected directed acyclic graph (DAG).