Skip to content

Commit

Permalink
[Spot/UX] Print spot jobs' resources before confirmation (#2524)
Browse files Browse the repository at this point in the history
* Print spot jobs' resources before confirmation

* format

* Update sky/utils/dag_utils.py

Co-authored-by: Tian Xia <cblmemo@gmail.com>

* Fix typings

* Fix typing

* Additional typing

* format

---------

Co-authored-by: Tian Xia <cblmemo@gmail.com>
  • Loading branch information
Michaelvll and cblmemo authored Sep 10, 2023
1 parent dd9c5d2 commit c87052d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 40 deletions.
8 changes: 7 additions & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,9 +3602,15 @@ def spot_launch(
dag.name = name

dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)
dag_utils.fill_default_spot_config_in_dag(dag)

click.secho(
f'Managed spot job {dag.name!r} will be launched on (estimated):',
fg='yellow')
dag = sky.optimize(dag)

if not yes:
prompt = f'Launching a new spot job {dag.name!r}. Proceed?'
prompt = f'Launching the spot job {dag.name!r}. Proceed?'
if prompt is not None:
click.confirm(prompt, default=True, abort=True, show_default=True)

Expand Down
9 changes: 3 additions & 6 deletions sky/clouds/cloud_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
import sky
from sky.clouds import cloud


class _CloudRegistry(dict):
"""Registry of clouds."""

def from_str(self,
name: Optional[str]) -> Optional['sky.clouds.cloud.Cloud']:
def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']:
if name is None:
return None
if name.lower() not in self:
Expand All @@ -22,9 +21,7 @@ def from_str(self,
f'{list(self.keys())}')
return self.get(name.lower())

def register(
self, cloud_cls: Type['sky.clouds.cloud.Cloud']
) -> Type['sky.clouds.cloud.Cloud']:
def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']:
name = cloud_cls.__name__.lower()
assert name not in self, f'{name} already registered'
self[name] = cloud_cls()
Expand Down
4 changes: 2 additions & 2 deletions sky/dag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""DAGs: user applications to be run."""
import pprint
import threading
from typing import List
from typing import List, Optional


class Dag:
Expand Down Expand Up @@ -86,7 +86,7 @@ def pop_dag(self):
self._current_dag = None
return old_dag

def get_current_dag(self):
def get_current_dag(self) -> Optional[Dag]:
return self._current_dag


Expand Down
17 changes: 4 additions & 13 deletions sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import getpass
import os
import tempfile
from typing import Any, Dict, List, Optional, Union
from typing import Any, List, Optional, Union
import uuid

import colorama
Expand Down Expand Up @@ -628,19 +628,10 @@ def spot_launch(
'and comment out the task names (so that they will be auto-'
'generated) .')
task_names.add(task_.name)
for task_ in dag.tasks:
assert len(task_.resources) == 1, task_
resources = list(task_.resources)[0]

change_default_value: Dict[str, Any] = {}
if not resources.use_spot_specified:
change_default_value['use_spot'] = True
if resources.spot_recovery is None:
change_default_value['spot_recovery'] = spot.SPOT_DEFAULT_STRATEGY

new_resources = resources.copy(**change_default_value)
task_.set_resources({new_resources})
dag_utils.fill_default_spot_config_in_dag(dag)

for task_ in dag.tasks:
_maybe_translate_local_file_mounts_and_sync_up(task_)

with tempfile.NamedTemporaryFile(prefix=f'spot-dag-{dag.name}-',
Expand Down Expand Up @@ -761,7 +752,7 @@ def spot_launch(
assert len(controller_task.resources) == 1

print(f'{colorama.Fore.YELLOW}'
f'Launching managed spot job {dag.name} from spot controller...'
f'Launching managed spot job {dag.name!r} from spot controller...'
f'{colorama.Style.RESET_ALL}')
print('Launching spot controller...')
_execute(
Expand Down
10 changes: 6 additions & 4 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class Optimizer:

@staticmethod
def _egress_cost(src_cloud: clouds.Cloud, dst_cloud: clouds.Cloud,
gigabytes: float):
gigabytes: float) -> float:
"""Returns estimated egress cost."""
if isinstance(src_cloud, DummyCloud) or isinstance(
dst_cloud, DummyCloud):
return 0.0
Expand All @@ -74,7 +75,7 @@ def _egress_cost(src_cloud: clouds.Cloud, dst_cloud: clouds.Cloud,

@staticmethod
def _egress_time(src_cloud: clouds.Cloud, dst_cloud: clouds.Cloud,
gigabytes: float):
gigabytes: float) -> float:
"""Returns estimated egress time in seconds."""
# FIXME: estimate bandwidth between each cloud-region pair.
if isinstance(src_cloud, DummyCloud) or isinstance(
Expand Down Expand Up @@ -181,7 +182,7 @@ def _get_egress_info(
parent_resources: resources_lib.Resources,
node: task_lib.Task,
resources: resources_lib.Resources,
) -> Tuple[Optional[clouds.Cloud], Optional[clouds.Cloud], float]:
) -> Tuple[Optional[clouds.Cloud], Optional[clouds.Cloud], Optional[float]]:
if isinstance(parent_resources.cloud, DummyCloud):
# Special case. The current 'node' is a real
# source node, and its input may be on a different
Expand All @@ -205,7 +206,8 @@ def _egress_cost_or_time(minimize_cost: bool, parent: task_lib.Task,
"""Computes the egress cost or time depending on 'minimize_cost'."""
src_cloud, dst_cloud, nbytes = Optimizer._get_egress_info(
parent, parent_resources, node, resources)
if nbytes == 0:
if not nbytes:
# nbytes can be None, if the task has no inputs/outputs.
return 0
assert src_cloud is not None and dst_cloud is not None

Expand Down
30 changes: 17 additions & 13 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sky import global_user_state
from sky import sky_logging
from sky.backends import backend_utils
import sky.dag
from sky.data import data_utils
from sky.data import storage as storage_lib
from sky.skylet import constants
Expand Down Expand Up @@ -44,7 +45,7 @@
'a list of node ip addresses (List[str]). Got {run_sig}')


def _is_valid_name(name: str) -> bool:
def _is_valid_name(name: Optional[str]) -> bool:
"""Checks if the task name is valid.
Valid is defined as either NoneType or str with ASCII characters which may
Expand Down Expand Up @@ -107,7 +108,7 @@ def replace_var(match):
def _with_docker_login_config(
resources_set: Set['resources_lib.Resources'],
task_envs: Dict[str, str],
) -> 'resources_lib.Resources':
) -> Set['resources_lib.Resources']:
all_keys = {
constants.DOCKER_USERNAME_ENV_VAR,
constants.DOCKER_PASSWORD_ENV_VAR,
Expand Down Expand Up @@ -223,10 +224,10 @@ def __init__(
# https://github.com/python/mypy/issues/3004
self.num_nodes = num_nodes # type: ignore

self.inputs = None
self.outputs = None
self.estimated_inputs_size_gigabytes = None
self.estimated_outputs_size_gigabytes = None
self.inputs: Optional[str] = None
self.outputs: Optional[str] = None
self.estimated_inputs_size_gigabytes: Optional[float] = None
self.estimated_outputs_size_gigabytes: Optional[float] = None
# Default to CPUNode
self.resources = {sky.Resources()}
# Resources that this task cannot run on.
Expand Down Expand Up @@ -423,7 +424,7 @@ def from_yaml(yaml_path: str) -> 'Task':
ValueError: if the path gets loaded into a str instead of a dict; or
if there are any other parsing errors.
"""
with open(os.path.expanduser(yaml_path), 'r') as f:
with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f:
# TODO(zongheng): use
# https://github.com/yaml/pyyaml/issues/165#issuecomment-430074049
# to raise errors on duplicate keys.
Expand Down Expand Up @@ -503,16 +504,17 @@ def need_spot_recovery(self) -> bool:
def use_spot(self) -> bool:
return any(r.use_spot for r in self.resources)

def set_inputs(self, inputs, estimated_size_gigabytes) -> 'Task':
def set_inputs(self, inputs: str,
estimated_size_gigabytes: float) -> 'Task':
# E.g., 's3://bucket', 'gs://bucket', or None.
self.inputs = inputs
self.estimated_inputs_size_gigabytes = estimated_size_gigabytes
return self

def get_inputs(self):
def get_inputs(self) -> Optional[str]:
return self.inputs

def get_estimated_inputs_size_gigabytes(self):
def get_estimated_inputs_size_gigabytes(self) -> Optional[float]:
return self.estimated_inputs_size_gigabytes

def get_inputs_cloud(self):
Expand All @@ -528,15 +530,16 @@ def get_inputs_cloud(self):
with ux_utils.print_exception_no_traceback():
raise ValueError(f'cloud path not supported: {self.inputs}')

def set_outputs(self, outputs, estimated_size_gigabytes) -> 'Task':
def set_outputs(self, outputs: str,
estimated_size_gigabytes: float) -> 'Task':
self.outputs = outputs
self.estimated_outputs_size_gigabytes = estimated_size_gigabytes
return self

def get_outputs(self):
def get_outputs(self) -> Optional[str]:
return self.outputs

def get_estimated_outputs_size_gigabytes(self):
def get_estimated_outputs_size_gigabytes(self) -> Optional[float]:
return self.estimated_outputs_size_gigabytes

def set_resources(
Expand Down Expand Up @@ -804,6 +807,7 @@ def get_preferred_store_type(self) -> storage_lib.StoreType:
if storage_cloud is None:
storage_cloud = clouds.CLOUD_REGISTRY.from_str(
enabled_storage_clouds[0])
assert storage_cloud is not None, enabled_storage_clouds[0]

store_type = storage_lib.get_storetype_from_cloud(storage_cloud)
return store_type
Expand Down
18 changes: 17 additions & 1 deletion sky/utils/dag_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Utilities for loading and dumping DAGs from/to YAML files."""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from sky import dag as dag_lib
from sky import spot
from sky import task as task_lib
from sky.backends import backend_utils
from sky.utils import common_utils
Expand Down Expand Up @@ -86,3 +87,18 @@ def maybe_infer_and_fill_dag_and_task_names(dag: dag_lib.Dag) -> None:
for task_id, task in enumerate(dag.tasks):
if task.name is None:
task.name = f'{dag.name}-{task_id}'


def fill_default_spot_config_in_dag(dag: dag_lib.Dag) -> None:
for task_ in dag.tasks:
assert len(task_.resources) == 1, task_
resources = list(task_.resources)[0]

change_default_value: Dict[str, Any] = {}
if not resources.use_spot_specified:
change_default_value['use_spot'] = True
if resources.spot_recovery is None:
change_default_value['spot_recovery'] = spot.SPOT_DEFAULT_STRATEGY

new_resources = resources.copy(**change_default_value)
task_.set_resources({new_resources})

0 comments on commit c87052d

Please sign in to comment.