From 724462884474b54bc664416e4b000ca79425dbdb Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 00:32:10 -0800 Subject: [PATCH 1/6] chore: add type hints, mostly for more accurate `List[T]` --- sky/dag.py | 8 +++++--- sky/jobs/controller.py | 20 ++++++++++++-------- sky/jobs/utils.py | 5 +++-- sky/optimizer.py | 2 +- sky/serve/autoscalers.py | 4 ++-- sky/serve/load_balancer.py | 8 ++++---- sky/serve/replica_managers.py | 12 +++++++----- sky/serve/service_spec.py | 13 ++++++++----- 8 files changed, 42 insertions(+), 30 deletions(-) diff --git a/sky/dag.py b/sky/dag.py index 4af5adc76b5..b1609121447 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -2,7 +2,7 @@ import pprint import threading import typing -from typing import List, Optional +from typing import Any, List, Optional if typing.TYPE_CHECKING: from sky import task @@ -21,7 +21,7 @@ def __init__(self) -> None: self.tasks: List['task.Task'] = [] import networkx as nx # pylint: disable=import-outside-toplevel - self.graph = nx.DiGraph() + self.graph: nx.DiGraph = nx.DiGraph() self.name: Optional[str] = None def add(self, task: 'task.Task') -> None: @@ -44,7 +44,8 @@ def __enter__(self) -> 'Dag': push_dag(self) return self - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__(self, exec_type: Any, exec_val: Any, exec_tb: Any) -> None: + """Exit the runtime context related to this object.""" pop_dag() def __repr__(self) -> str: @@ -60,6 +61,7 @@ def is_chain(self) -> bool: visited_zero_out_degree = False for node in self.graph.nodes: out_degree = self.graph.out_degree(node) + assert isinstance(out_degree, int) if out_degree > 1: is_chain = False break diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 5219c564500..ad08387750e 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -6,7 +6,7 @@ import time import traceback import typing -from typing import Tuple +from typing import List, Tuple import filelock @@ -28,6 +28,8 @@ from sky.utils import ux_utils if typing.TYPE_CHECKING: + import networkx as nx + import sky # Use the explicit logger name so that the logger is under the @@ -55,11 +57,11 @@ def __init__(self, job_id: int, dag_yaml: str, # TODO(zhwu): this assumes the specific backend. self._backend = cloud_vm_ray_backend.CloudVmRayBackend() - # pylint: disable=line-too-long # Add a unique identifier to the task environment variables, so that # the user can have the same id for multiple recoveries. - # Example value: sky-2022-10-04-22-46-52-467694_my-spot-name_spot_id-17-0 - job_id_env_vars = [] + # Example value: + # sky-2022-10-04-22-46-52-467694_my-spot-name_spot_id-17-0 + job_id_env_vars: List[str] = [] for i, task in enumerate(self._dag.tasks): if len(self._dag.tasks) <= 1: task_name = self._dag_name @@ -416,7 +418,7 @@ def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): jobs_controller.run() -def _handle_signal(job_id): +def _handle_signal(job_id: int) -> None: """Handle the signal if the user sent it.""" signal_file = pathlib.Path( managed_job_utils.SIGNAL_FILE_PREFIX.format(job_id)) @@ -426,9 +428,9 @@ def _handle_signal(job_id): # signal writing. with filelock.FileLock(str(signal_file) + '.lock'): with signal_file.open(mode='r', encoding='utf-8') as f: - user_signal = f.read().strip() + user_signal_str = f.read().strip() try: - user_signal = managed_job_utils.UserSignal(user_signal) + user_signal = managed_job_utils.UserSignal(user_signal_str) except ValueError: logger.warning( f'Unknown signal received: {user_signal}. Ignoring.') @@ -469,7 +471,7 @@ def _cleanup(job_id: int, dag_yaml: str): backend.teardown_ephemeral_storage(task) -def start(job_id, dag_yaml, retry_until_up): +def start(job_id: int, dag_yaml: str, retry_until_up: bool) -> None: """Start the controller.""" controller_process = None cancelling = False @@ -493,6 +495,7 @@ def start(job_id, dag_yaml, retry_until_up): task_id, _ = managed_job_state.get_latest_task_id_status(job_id) logger.info( f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}') + assert task_id is not None managed_job_state.set_cancelling( job_id=job_id, callback_func=managed_job_utils.event_callback_func( @@ -521,6 +524,7 @@ def start(job_id, dag_yaml, retry_until_up): _cleanup(job_id, dag_yaml=dag_yaml) logger.info(f'Cluster of managed job {job_id} has been cleaned up.') + assert task_id is not None if cancelling: managed_job_state.set_cancelled( job_id=job_id, diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 981f6d8286f..37bfeed6281 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -237,7 +237,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: return 'No job to cancel.' job_id_str = ', '.join(map(str, job_ids)) logger.info(f'Cancelling jobs {job_id_str}.') - cancelled_job_ids = [] + cancelled_job_ids: List[int] = [] for job_id in job_ids: # Check the status of the managed job status. If it is in # terminal state, we can safely skip it. @@ -490,7 +490,8 @@ def stream_logs(job_id: Optional[int], if controller: if job_id is None: assert job_name is not None - managed_jobs = managed_job_state.get_managed_jobs() + managed_jobs: List[Dict[ + str, Any]] = managed_job_state.get_managed_jobs() # We manually filter the jobs by name, instead of using # get_nonterminal_job_ids_by_name, as with `controller=True`, we # should be able to show the logs for jobs in terminal states. diff --git a/sky/optimizer.py b/sky/optimizer.py index 0f931e15079..0a2eb0ec4a7 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -529,7 +529,7 @@ def _optimize_by_ilp( # Prepare the constants. V = topo_order # pylint: disable=invalid-name - E = graph.edges() # pylint: disable=invalid-name + E: List[Tuple[Any, Any]] = list(graph.edges()) # pylint: disable=invalid-name k = { node: list(resource_cost_map.values()) for node, resource_cost_map in node_to_cost_map.items() diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py index a4278f192fb..7a6311ad535 100644 --- a/sky/serve/autoscalers.py +++ b/sky/serve/autoscalers.py @@ -320,8 +320,8 @@ def select_outdated_replicas_to_scale_down( """Select outdated replicas to scale down.""" if self.update_mode == serve_utils.UpdateMode.ROLLING: - latest_ready_replicas = [] - old_nonterminal_replicas = [] + latest_ready_replicas: List['replica_managers.ReplicaInfo'] = [] + old_nonterminal_replicas: List['replica_managers.ReplicaInfo'] = [] for info in replica_infos: if info.version == self.latest_version: if info.is_ready: diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index c15f71e214a..64c7b0660f2 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -2,7 +2,7 @@ import asyncio import logging import threading -from typing import Dict, Union +from typing import Coroutine, Dict, List, NoReturn, Union import aiohttp import fastapi @@ -55,7 +55,7 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None: # updating it from _sync_with_controller. self._client_pool_lock: threading.Lock = threading.Lock() - async def _sync_with_controller(self): + async def _sync_with_controller(self) -> NoReturn: """Sync with controller periodically. Every `constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS` seconds, the @@ -68,7 +68,7 @@ async def _sync_with_controller(self): await asyncio.sleep(5) while True: - close_client_tasks = [] + close_client_tasks: List[Coroutine[None, None, None]] = [] async with aiohttp.ClientSession() as session: try: # Send request information @@ -101,7 +101,7 @@ async def _sync_with_controller(self): httpx.AsyncClient(base_url=replica_url)) urls_to_close = set( self._client_pool.keys()) - set(ready_replica_urls) - client_to_close = [] + client_to_close: List[httpx.AsyncClient] = [] for replica_url in urls_to_close: client_to_close.append( self._client_pool.pop(replica_url)) diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index c0e5220e779..557ec6f2227 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -548,9 +548,9 @@ def probe( f'{colorama.Style.RESET_ALL}') return self, False, probe_time - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: """Set state from pickled state, for backward compatibility.""" - version = state.pop('_version', None) + version: Optional[int] = state.pop('_version', None) # Handle old version(s) here. if version is None: version = -1 @@ -1036,8 +1036,10 @@ def _probe_all_replicas(self) -> None: (2) the consecutive failure times. The replica will be terminated if any of the thresholds exceeded. """ - probe_futures = [] - replica_to_probe = [] + # TODO(andyl): Define a TypeAlias for the return type of info.probe. + probe_futures: List[mp_pool.ApplyResult[Tuple[ReplicaInfo, bool, + float]]] = [] + replica_to_probe: List[str] = [] with mp_pool.ThreadPool() as pool: infos = serve_state.get_replica_infos(self._service_name) for info in infos: @@ -1160,7 +1162,7 @@ def get_active_replica_urls(self) -> List[str]: record = serve_state.get_service_from_name(self._service_name) assert record is not None, (f'{self._service_name} not found on ' 'controller records.') - ready_replica_urls = [] + ready_replica_urls: List[str] = [] active_versions = set(record['active_versions']) for info in serve_state.get_replica_infos(self._service_name): if (info.status == serve_state.ReplicaStatus.READY and diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 2eff6f40a9d..8e2a887974d 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -2,7 +2,7 @@ import json import os import textwrap -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml @@ -173,9 +173,12 @@ def from_yaml(yaml_path: str) -> 'SkyServiceSpec': return SkyServiceSpec.from_yaml_config(config['service']) def to_yaml_config(self) -> Dict[str, Any]: - config = dict() + config: Dict[str, Any] = dict() - def add_if_not_none(section, key, value, no_empty: bool = False): + def add_if_not_none(section: str, + key: Optional[str], + value: Any, + no_empty: bool = False): if no_empty and not value: return if value is not None: @@ -216,8 +219,8 @@ def probe_str(self): ' with custom headers') return f'{method}{headers}' - def spot_policy_str(self): - policy_strs = [] + def spot_policy_str(self) -> str: + policy_strs: List[str] = [] if (self.dynamic_ondemand_fallback is not None and self.dynamic_ondemand_fallback): policy_strs.append('Dynamic on-demand fallback') From 95a98aa3ab72ef5dc848a239dab4339ad7946eba Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 00:32:27 -0800 Subject: [PATCH 2/6] chore: remove duplicated code --- sky/serve/core.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sky/serve/core.py b/sky/serve/core.py index ea8f380a2e7..e15e5ff00e1 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -687,10 +687,6 @@ def tail_logs( """ if isinstance(target, str): target = serve_utils.ServiceComponent(target) - if not isinstance(target, serve_utils.ServiceComponent): - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'`target` must be a string or ' - f'sky.serve.ServiceComponent, got {type(target)}.') if target == serve_utils.ServiceComponent.REPLICA: if replica_id is None: with ux_utils.print_exception_no_traceback(): From d9ba99f72d452440785cd6d5ad9ede9aff988ec8 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 00:36:25 -0800 Subject: [PATCH 3/6] chore: more generics and protocals --- requirements-dev.txt | 2 +- sky/dag.py | 3 ++- sky/serve/replica_managers.py | 21 ++++++++++++--- sky/utils/common_utils.py | 50 ++++++++++++++++++++++++++++------- sky/utils/timeline.py | 15 +++++++---- 5 files changed, 70 insertions(+), 21 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f91ce750ad..b86630327a7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,7 +9,7 @@ toml==0.10.2 isort==5.12.0 # type checking -mypy==0.991 +mypy==1.4.1 types-PyYAML # 2.31 requires urlib3>2, which is incompatible with SkyPilot, IBM and # kubernetes packages, which require urllib3<2. diff --git a/sky/dag.py b/sky/dag.py index b1609121447..437ed18d97d 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -4,6 +4,8 @@ import typing from typing import Any, List, Optional +import networkx as nx + if typing.TYPE_CHECKING: from sky import task @@ -19,7 +21,6 @@ class Dag: def __init__(self) -> None: self.tasks: List['task.Task'] = [] - import networkx as nx # pylint: disable=import-outside-toplevel self.graph: nx.DiGraph = nx.DiGraph() self.name: Optional[str] = None diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index 557ec6f2227..89448900ecf 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -9,11 +9,12 @@ import time import traceback import typing -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple import colorama import psutil import requests +import typing_extensions import sky from sky import backends @@ -39,6 +40,9 @@ from sky import resources from sky.serve import service_spec +T = typing.TypeVar('T') +P = typing_extensions.ParamSpec('P') +R = typing.TypeVar('R') logger = sky_logging.init_logger(__name__) _JOB_STATUS_FETCH_INTERVAL = 30 @@ -197,10 +201,19 @@ def _should_use_spot(task_yaml: str, return len(spot_use_resources) == len(task.resources) -def with_lock(func): +class HasLock(Protocol): + lock: threading.Lock + + +L = typing.TypeVar('L', bound=HasLock) + + +def with_lock( + func: Callable[typing_extensions.Concatenate[L, P], R] +) -> Callable[typing_extensions.Concatenate[L, P], R]: @functools.wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: L, *args: P.args, **kwargs: P.kwargs) -> R: with self.lock: return func(self, *args, **kwargs) @@ -563,7 +576,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) -class ReplicaManager: +class ReplicaManager(HasLock): """Each replica manager monitors one service.""" def __init__(self, service_name: str, diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 5fce435b770..3a6bd9a8d94 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -1,5 +1,6 @@ """Utils shared between all of sky""" +from contextlib import _GeneratorContextManager import difflib import functools import getpass @@ -13,11 +14,14 @@ import socket import sys import time -from typing import Any, Callable, Dict, List, Optional, Union +import typing +from typing import Any, Callable, Dict, List, Optional, Protocol, Type, Union import uuid import jinja2 import jsonschema +import typing_extensions +from typing_extensions import Self import yaml from sky import exceptions @@ -26,6 +30,11 @@ from sky.utils import ux_utils from sky.utils import validator +T = typing.TypeVar('T') +P = typing_extensions.ParamSpec('P') +R = typing.TypeVar('R') +C = typing_extensions.ParamSpec('C') + _USER_HASH_FILE = os.path.expanduser('~/.sky/user_hash') USER_HASH_LENGTH = 8 USER_HASH_LENGTH_IN_CLUSTER_NAME = 4 @@ -340,8 +349,24 @@ def write_line_break(self, data=None): default_flow_style=False) -def make_decorator(cls, name_or_fn: Union[str, Callable], - **ctx_kwargs) -> Callable: +class ContextManagerProtocol(Protocol): + + def __init__(self, name: str, **_) -> None: + ... + + def __enter__(self) -> Self: + ... + + def __exit__(self, exec_type: Any, exec_value: Any, traceback: Any) -> None: + ... + + +# TODO(andy): more accurate typing for Callable[...] +def make_decorator( + cls: Union[Type[ContextManagerProtocol], + Callable[..., _GeneratorContextManager[Any]]], + name_or_fn: Union[str, Callable[P, R]], **ctx_kwargs: Any +) -> Union[Callable[[Callable[P, R]], Callable[P, R]], Callable[P, R]]: """Make the cls a decorator. class cls: @@ -358,10 +383,10 @@ def __exit__(self, exc_type, exc_value, traceback): """ if isinstance(name_or_fn, str): - def _wrapper(f): + def _wrapper(f: Callable[P, R]) -> Callable[P, R]: @functools.wraps(f) - def _record(*args, **kwargs): + def _record(*args: P.args, **kwargs: P.kwargs) -> R: with cls(name_or_fn, **ctx_kwargs): return f(*args, **kwargs) @@ -374,8 +399,8 @@ def _record(*args, **kwargs): 'Should directly apply the decorator to a function.') @functools.wraps(name_or_fn) - def _record(*args, **kwargs): - f = name_or_fn + def _record(*args: P.args, **kwargs: P.kwargs) -> R: + f: Callable[P, R] = name_or_fn func_name = getattr(f, '__qualname__', f.__name__) module_name = getattr(f, '__module__', '') if module_name: @@ -388,11 +413,15 @@ def _record(*args, **kwargs): return _record -def retry(method, max_retries=3, initial_backoff=1): +def retry(method: Callable[P, R], + max_retries: int = 3, + initial_backoff: float = 1) -> Callable[P, R]: """Retry a function up to max_retries times with backoff between retries.""" + assert max_retries > 0 + @functools.wraps(method) - def method_with_retries(*args, **kwargs): + def method_with_retries(*args: P.args, **kwargs: P.kwargs) -> R: backoff = Backoff(initial_backoff) try_count = 0 while try_count < max_retries: @@ -405,6 +434,7 @@ def method_with_retries(*args, **kwargs): time.sleep(backoff.current_backoff()) else: raise + assert False, 'Unreachable' return method_with_retries @@ -446,7 +476,7 @@ def decode_payload(payload_str: str) -> Any: return payload -def class_fullname(cls, skip_builtins: bool = True): +def class_fullname(cls: Type[T], skip_builtins: bool = True) -> str: """Get the full name of a class. Example: diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index 29c6c3d94ee..d58fea868f9 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -9,13 +9,18 @@ import os import threading import time -from typing import Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import filelock +import typing_extensions +from typing_extensions import Self from sky.utils import common_utils -_events = [] +_events: List[Dict[str, Any]] = [] + +P = typing_extensions.ParamSpec('P') +R = typing_extensions.TypeVar('R') class Event: @@ -30,7 +35,7 @@ def __init__(self, name: str, message: Optional[str] = None): self._name = name self._message = message # See the module doc for the event format. - self._event = { + self._event: Dict[str, Any] = { 'name': self._name, 'cat': 'event', 'pid': str(os.getpid()), @@ -62,11 +67,11 @@ def end(self): event_end['args'] = {'message': self._message} _events.append(event_end) - def __enter__(self): + def __enter__(self) -> Self: self.begin() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exec_type: Any, exec_val: Any, exec_tb: Any) -> None: self.end() From 6848f8884a47b6cb46ab424192295a2fb7848610 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 00:48:29 -0800 Subject: [PATCH 4/6] chore: check nx types --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index b86630327a7..c0b50bacbe9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ types-requests<2.31 types-setuptools types-cachetools types-pyvmomi +types-networkx # testing pytest From 860c4ae27529ca21e733a8805ce786597cf54d71 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 15:13:29 -0800 Subject: [PATCH 5/6] chore: make pylint check in ci use the version specified in requirements-dev --- .github/workflows/pylint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 0555fb934d0..9ff01b4e17a 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -29,8 +29,8 @@ jobs: run: | python -m pip install --upgrade pip pip install ".[all]" - pip install pylint==2.14.5 - pip install pylint-quotes==0.2.3 + pip install pylint==$(grep 'pylint==' requirements-dev.txt | cut -d'=' -f3) + pip install pylint-quotes==$(grep 'pylint-quotes==' requirements-dev.txt | cut -d'=' -f3) - name: Analysing the code with pylint run: | pylint --load-plugins pylint_quotes sky From 70b1f41361229f1ed18d59e7d00bf13e71bb91b5 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 4 Nov 2024 14:19:10 -0800 Subject: [PATCH 6/6] chore: deprecate pylint-quotes so that we can upgrade pylint --- .github/workflows/pylint.yml | 3 +-- .pylintrc | 13 +++---------- format.sh | 6 ------ requirements-dev.txt | 4 +--- sky/backends/cloud_vm_ray_backend.py | 8 +++++--- sky/cli.py | 11 +++++++---- sky/optimizer.py | 1 + sky/serve/constants.py | 4 ++++ sky/serve/replica_managers.py | 2 +- sky/serve/serve_utils.py | 8 ++++++-- 10 files changed, 29 insertions(+), 31 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 9ff01b4e17a..9f2d91ba2a1 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -30,7 +30,6 @@ jobs: python -m pip install --upgrade pip pip install ".[all]" pip install pylint==$(grep 'pylint==' requirements-dev.txt | cut -d'=' -f3) - pip install pylint-quotes==$(grep 'pylint-quotes==' requirements-dev.txt | cut -d'=' -f3) - name: Analysing the code with pylint run: | - pylint --load-plugins pylint_quotes sky + pylint sky diff --git a/.pylintrc b/.pylintrc index a2ef1829167..513f7e8b10b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -429,13 +429,6 @@ valid-metaclass-classmethod-first-arg=mcs # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException - -####### - -# https://github.com/edaniszewski/pylint-quotes#configuration -string-quote=single -triple-quote=double -docstring-quote=double +overgeneral-exceptions=builtins.StandardError, + builtins.Exception, + builtins.BaseException diff --git a/format.sh b/format.sh index b06481b4c10..882dc8a64b5 100755 --- a/format.sh +++ b/format.sh @@ -23,7 +23,6 @@ builtin cd "$ROOT" || exit 1 YAPF_VERSION=$(yapf --version | awk '{print $2}') PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') -PYLINT_QUOTES_VERSION=$(pip list | grep pylint-quotes | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') @@ -37,7 +36,6 @@ tool_version_check() { tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "pylint-quotes" $PYLINT_QUOTES_VERSION "$(grep "pylint-quotes==" requirements-dev.txt | cut -d'=' -f3)" tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "black" "$BLACK_VERSION" "$(grep black requirements-dev.txt | cut -d'=' -f3)" @@ -60,10 +58,6 @@ BLACK_INCLUDES=( 'sky/skylet/providers/ibm' ) -PYLINT_FLAGS=( - '--load-plugins' 'pylint_quotes' -) - # Format specified files format() { yapf --in-place "${YAPF_FLAGS[@]}" "$@" diff --git a/requirements-dev.txt b/requirements-dev.txt index c0b50bacbe9..7634e399ffa 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,8 @@ # formatting yapf==0.32.0 -pylint==2.14.5 +pylint==3.2.7 # formatting the node_providers code from upstream ray-project/ray project black==22.10.0 -# https://github.com/edaniszewski/pylint-quotes -pylint-quotes==0.2.3 toml==0.10.2 isort==5.12.0 diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 0013e6cbaf9..17f257b290f 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2475,10 +2475,12 @@ def num_ips_per_node(self) -> int: num_ips = 1 return num_ips - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]): self._version = self._VERSION version = state.pop('_version', None) + head_ip: Optional[str] = None + if version is None: version = -1 state.pop('cluster_region', None) @@ -2502,7 +2504,6 @@ def __setstate__(self, state): if version < 9: # For backward compatibility, we should update the region of a # SkyPilot cluster on Kubernetes to the actual context it is using. - # pylint: disable=import-outside-toplevel launched_resources = state['launched_resources'] if isinstance(launched_resources.cloud, clouds.Kubernetes): yaml_config = common_utils.read_yaml( @@ -3272,7 +3273,7 @@ def _exec_code_on_head( mkdir_code = (f'{cd} && mkdir -p {remote_log_dir} && ' f'touch {remote_log_path}') encoded_script = shlex.quote(codegen) - create_script_code = (f'{{ echo {encoded_script} > {script_path}; }}') + create_script_code = f'{{ echo {encoded_script} > {script_path}; }}' job_submit_cmd = ( f'RAY_DASHBOARD_PORT=$({constants.SKY_PYTHON_CMD} -c "from sky.skylet import job_lib; print(job_lib.get_job_submission_port())" 2> /dev/null || echo 8265);' # pylint: disable=line-too-long f'{cd} && {constants.SKY_RAY_CMD} job submit ' @@ -3829,6 +3830,7 @@ def teardown_no_lock(self, RuntimeError: If the cluster fails to be terminated/stopped. """ cluster_status_fetched = False + prev_cluster_status = None if refresh_cluster_status: try: prev_cluster_status, _ = ( diff --git a/sky/cli.py b/sky/cli.py index 462e8a5b9de..e433bd618c3 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -210,7 +210,7 @@ def _merge_env_vars(env_dict: Optional[Dict[str, str]], default=None, type=int, required=False, - help=('OS disk size in GBs.')), + help='OS disk size in GBs.'), click.option('--disk-tier', default=None, type=click.Choice(resources_utils.DiskTier.supported_tiers(), @@ -392,7 +392,8 @@ def _install_shell_completion(ctx: click.Context, param: click.Parameter, else: click.secho(f'Unsupported shell: {value}', fg='red') ctx.exit() - + # Though `ctx.exit()` already NoReturn, we need to make pylint happy. + assert False, 'Unreachable' try: subprocess.run(cmd, shell=True, @@ -447,6 +448,8 @@ def _uninstall_shell_completion(ctx: click.Context, param: click.Parameter, else: click.secho(f'Unsupported shell: {value}', fg='red') ctx.exit() + # Though `ctx.exit()` already NoReturn, we need to make pylint happy. + assert False, 'Unreachable' try: subprocess.run(cmd, shell=True, check=True) @@ -3530,7 +3533,7 @@ def jobs(): default=None, type=str, hidden=True, - help=('Alias for --name, the name of the managed job.')) + help='Alias for --name, the name of the managed job.') @click.option('--job-recovery', default=None, type=str, @@ -4544,7 +4547,7 @@ def serve_logs( sky serve logs [SERVICE_NAME] 1 """ have_replica_id = replica_id is not None - num_flags = (controller + load_balancer + have_replica_id) + num_flags = controller + load_balancer + have_replica_id if num_flags > 1: raise click.UsageError('At most one of --controller, --load-balancer, ' '[REPLICA_ID] can be specified.') diff --git a/sky/optimizer.py b/sky/optimizer.py index 0a2eb0ec4a7..58fa4839614 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -435,6 +435,7 @@ def _optimize_by_dp( # FIXME: Account for egress costs for multi-node clusters for resources, execution_cost in node_to_cost_map[node].items(): min_pred_cost_plus_egress = np.inf + best_parent_hardware = resources_lib.Resources() for parent_resources, parent_cost in \ dp_best_objective[parent].items(): egress_cost = Optimizer._egress_cost_or_time( diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 3974293190e..327934e0dd7 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -100,3 +100,7 @@ TERMINATE_REPLICA_VERSION_MISMATCH_ERROR = ( 'The version of service is outdated and does not support manually ' 'terminating replicas. Please terminate the service and spin up again.') + +# Default timeout in seconds for HTTP requests to avoid hanging indefinitely. +# This is used for internal service communication requests. +DEFAULT_HTTP_REQUEST_TIMEOUT_SECONDS = 30 diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index 89448900ecf..ceba84741d5 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -528,7 +528,7 @@ def probe( logger.info(f'Error when probing {replica_identity}: ' 'Cannot get the endpoint.') return self, False, probe_time - readiness_path = (f'{url}{readiness_path}') + readiness_path = f'{url}{readiness_path}' logger.info(f'Probing {replica_identity} with {readiness_path}.') if post_data is not None: msg += 'POST' diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 6e7b6f6eb4a..ba4e984293a 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -300,7 +300,9 @@ def update_service_encoded(service_name: str, version: int, mode: str) -> str: json={ 'version': version, 'mode': mode, - }) + }, + timeout=constants.DEFAULT_HTTP_REQUEST_TIMEOUT_SECONDS, + ) if resp.status_code == 404: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -342,7 +344,9 @@ def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str: json={ 'replica_id': replica_id, 'purge': purge, - }) + }, + timeout=constants.DEFAULT_HTTP_REQUEST_TIMEOUT_SECONDS, + ) message: str = resp.json()['message'] if resp.status_code != 200: