Skip to content

Commit

Permalink
[AWS] Fix check/status failure when no permission is granted for the …
Browse files Browse the repository at this point in the history
…account (#2415)

* Fail sky check for AWS when no enough permission

* Fix

* revert unexpected docstr change

* Address comments

* Update sky/clouds/service_catalog/config.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* format

* Update sky/resources.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* Update sky/resources.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* address comments

---------

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>
  • Loading branch information
Michaelvll and concretevitamin authored Aug 21, 2023
1 parent 66611b9 commit 0249308
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 25 deletions.
2 changes: 1 addition & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3112,7 +3112,7 @@ def check(verbose: bool):
('The region to use. If not specified, shows accelerators from all regions.'
),
)
@service_catalog.use_default_catalog
@service_catalog.fallback_to_default_catalog
@usage_lib.entrypoint
def show_gpus(
accelerator_str: Optional[str],
Expand Down
10 changes: 9 additions & 1 deletion sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,15 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
from sky.clouds.service_catalog import aws_catalog

# Trigger the fetch of the availability zones mapping.
aws_catalog.get_default_instance_type()
try:
aws_catalog.get_default_instance_type()
except RuntimeError as e:
return False, (
'Failed to fetch the availability zones for the account. It is '
'likely due to permission issues, please check the minimal '
'permission required for AWS: https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/aws.html' # pylint: disable=
f'\n{cls._INDENT_PREFIX}Details: '
f'{common_utils.format_exception(e, use_bracket=True)}')
return True, hints

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from typing import Dict, List, Optional, Set, Tuple, Union

from sky.clouds.service_catalog.config import use_default_catalog
from sky.clouds.service_catalog.config import fallback_to_default_catalog
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.clouds.service_catalog.constants import LOCAL_CATALOG_DIR
Expand Down Expand Up @@ -45,7 +45,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
return results


@use_default_catalog
@fallback_to_default_catalog
def list_accelerators(
gpus_only: bool = True,
name_filter: Optional[str] = None,
Expand Down Expand Up @@ -330,7 +330,7 @@ def is_image_tag_valid(tag: str,
'get_image_id_from_tag',
'is_image_tag_valid',
# Configuration
'use_default_catalog',
'fallback_to_default_catalog',
# Constants
'HOSTED_CATALOG_DIR_URL',
'CATALOG_SCHEMA_VERSION',
Expand Down
23 changes: 15 additions & 8 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sky.clouds.service_catalog import common
from sky.clouds.service_catalog import config
from sky.clouds.service_catalog.data_fetchers import fetch_aws
from sky.utils import common_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
Expand Down Expand Up @@ -148,14 +149,19 @@ def _fetch_and_apply_az_mapping(df: pd.DataFrame) -> pd.DataFrame:


def _get_df() -> pd.DataFrame:
if config.get_use_default_catalog():
return _default_df
else:
global _user_df
with _apply_az_mapping_lock:
if _user_df is None:
global _user_df
with _apply_az_mapping_lock:
if _user_df is None:
try:
_user_df = _fetch_and_apply_az_mapping(_default_df)
return _user_df
except RuntimeError as e:
if config.get_use_default_catalog_if_failed():
logger.warning('Failed to fetch availability zone mapping. '
f'{common_utils.format_exception(e)}')
return _default_df
else:
raise
return _user_df


def get_quota_code(instance_type: str, use_spot: bool) -> Optional[str]:
Expand Down Expand Up @@ -248,7 +254,8 @@ def get_instance_type_for_accelerator(
"""Filter the instance types based on resource requirements.
Returns a list of instance types satisfying the required count of
accelerators with sorted prices and a list of candidates with fuzzy search.
accelerators/cpus/memory with sorted prices and a list of candidates with
fuzzy search.
"""
return common.get_instance_type_for_accelerator_impl(df=_get_df(),
acc_name=acc_name,
Expand Down
29 changes: 17 additions & 12 deletions sky/clouds/service_catalog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@


@contextlib.contextmanager
def _set_use_default_catalog(value: bool):
old_value = get_use_default_catalog()
def _set_use_default_catalog_if_failed(value: bool):
old_value = get_use_default_catalog_if_failed()
_thread_local_config.use_default_catalog = value
try:
yield
finally:
_thread_local_config.use_default_catalog = old_value


# Whether the caller requires the catalog to be narrowed down
# to the account-specific catalog (e.g., removing regions not
# enabled for the current account) or just the raw catalog
# fetched from SkyPilot catalog service. The former is used
# for launching clusters, while the latter for commands like
# `show-gpus`.
def get_use_default_catalog() -> bool:
def get_use_default_catalog_if_failed() -> bool:
"""Whether to use default catalog if failed to fetch account-specific one.
Whether the caller requires the catalog to be narrowed down to the account-
specific catalog (e.g., removing regions not enabled for the current account
or use zone name assigned to the AWS account).
When set to True, the caller allows to use the default service catalog,
which may have inaccurate information (e.g., AWS's zone names are account-
specific), but it is ok for the read-only operators, such as `show-gpus` or
`sky status`.
"""
if not hasattr(_thread_local_config, 'use_default_catalog'):
# Should not set it globally, as the global assignment
# will be executed only once if the module is imported
Expand All @@ -33,16 +38,16 @@ def get_use_default_catalog() -> bool:
return _thread_local_config.use_default_catalog


def use_default_catalog(func):
"""Decorator: disable fetching account-specific catalog.
def fallback_to_default_catalog(func):
"""Decorator: allow failure for fetching account-specific catalog.
The account-specific catalog requires the credentials of the
cloud to be set.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with _set_use_default_catalog(True):
with _set_use_default_catalog_if_failed(True):
return func(*args, **kwargs)

return wrapper
13 changes: 13 additions & 0 deletions sky/resources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Resources: compute requirements of Tasks."""
import functools
from typing import Dict, List, Optional, Set, Union

import colorama
Expand All @@ -10,6 +11,7 @@
from sky import skypilot_config
from sky import spot
from sky.backends import backend_utils
from sky.clouds import service_catalog
from sky.skylet import constants
from sky.utils import accelerator_registry
from sky.utils import schemas
Expand All @@ -24,6 +26,10 @@
class Resources:
"""Resources: compute requirements of Tasks.
This class is immutable once created (to ensure some validations are done
whenever properties change). To update the property of an instance of
Resources, use `resources.copy(**new_properties)`.
Used:
* for representing resource requests for tasks/apps
Expand Down Expand Up @@ -171,6 +177,12 @@ def __init__(
self._try_validate_disk_tier()
self._try_validate_ports()

# When querying the accelerators inside this func (we call self.accelerators
# which is a @property), we will check the cloud's catalog, which can error
# if it fails to fetch some account specific catalog information (e.g., AWS
# zone mapping). It is fine to use the default catalog as this function is
# only for display purposes.
@service_catalog.fallback_to_default_catalog
def __repr__(self) -> str:
"""Returns a string representation for display.
Expand Down Expand Up @@ -303,6 +315,7 @@ def memory(self) -> Optional[str]:
return self._memory

@property
@functools.lru_cache(maxsize=1)
def accelerators(self) -> Optional[Dict[str, int]]:
"""Returns the accelerators field directly or by inferring.
Expand Down

0 comments on commit 0249308

Please sign in to comment.