diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 99b46a6a354..de6eb2a0da9 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -153,9 +153,8 @@ Available fields and semantics: # zero cost) if the requested resources matches the reservation. # Ref: https://cloud.google.com/compute/docs/instances/reservations-overview#consumption-type specific_reservations: - # Only one element is allowed in this list, as GCP disallows multiple - # specific_reservations in a single request. - - projects/my-project/reservations/my-reservation + - projects/my-project/reservations/my-reservation1 + - projects/my-project/reservations/my-reservation2 # Advanced Kubernetes configurations (optional). kubernetes: diff --git a/examples/job_queue/job.yaml b/examples/job_queue/job.yaml index aa9c3502247..e5925ed13ea 100644 --- a/examples/job_queue/job.yaml +++ b/examples/job_queue/job.yaml @@ -17,7 +17,7 @@ setup: | run: | timestamp=$(date +%s) conda env list - for i in {1..140}; do + for i in {1..180}; do echo "$timestamp $i" sleep 1 done diff --git a/examples/job_queue/job_docker.yaml b/examples/job_queue/job_docker.yaml index 8907e77bc46..da604125865 100644 --- a/examples/job_queue/job_docker.yaml +++ b/examples/job_queue/job_docker.yaml @@ -18,7 +18,7 @@ setup: | run: | timestamp=$(date +%s) conda env list - for i in {1..120}; do + for i in {1..180}; do echo "$timestamp $i" sleep 1 done diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index 6e611ee1f2b..3835d004338 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -2,6 +2,7 @@ # pylint: disable=import-outside-toplevel import functools +import json googleapiclient = None google = None @@ -82,3 +83,25 @@ def credential_error_exception(): """CredentialError exception.""" from google.auth import exceptions return exceptions.DefaultCredentialsError + + +@import_package +def get_credentials(cred_type: str, credentials_field: str): + """Get GCP credentials.""" + from google.oauth2 import service_account + from google.oauth2.credentials import Credentials as OAuthCredentials + + if cred_type == 'service_account': + # If parsing the gcp_credentials failed, then the user likely made a + # mistake in copying the credentials into the config yaml. + try: + service_account_info = json.loads(credentials_field) + except json.decoder.JSONDecodeError as e: + raise RuntimeError('gcp_credentials found in cluster yaml file but ' + 'formatted improperly.') from e + credentials = service_account.Credentials.from_service_account_info( + service_account_info) + elif cred_type == 'credentials_token': + # Otherwise the credentials type must be credentials_token. + credentials = OAuthCredentials(credentials_field) + return credentials diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 9060066499d..b8c2a947671 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1,7 +1,6 @@ """Util constants/functions for the backends.""" from datetime import datetime import enum -import json import os import pathlib import pprint @@ -39,19 +38,19 @@ from sky import skypilot_config from sky import status_lib from sky.backends import onprem_utils +from sky.clouds import cloud_registry +from sky.clouds.utils import gcp_utils from sky.provision import instance_setup from sky.skylet import constants -from sky.skylet import log_lib from sky.usage import usage_lib +from sky.utils import cluster_yaml_utils from sky.utils import command_runner from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import env_options -from sky.utils import remote_cluster_yaml_utils from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import timeline -from sky.utils import tpu_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -938,11 +937,6 @@ def write_cluster_config( r for r, available_resources in reservations.items() if r in specific_reservations and available_resources > 0 ] - available_specific_reservations = sum( - available_resources for r, available_resources in reservations.items() - if r in specific_reservations) - num_specific_reserved_workers = max( - min(available_specific_reservations - 1, num_nodes - 1), 0) assert cluster_name is not None credentials = sky_check.get_cloud_credential_file_mounts() @@ -1033,7 +1027,6 @@ def write_cluster_config( # GCP only: 'gcp_project_id': gcp_project_id, 'specific_reservations': filtered_specific_reservations, - 'num_specific_reserved_workers': num_specific_reserved_workers, # Conda setup 'conda_installation_commands': @@ -1054,7 +1047,7 @@ def write_cluster_config( 'sky_local_path': str(local_wheel_path), # Add yaml file path to the template variables. 'sky_ray_yaml_remote_path': - remote_cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH, + cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH, 'sky_ray_yaml_local_path': tmp_yaml_path if not isinstance(cloud, clouds.Local) else yaml_path, @@ -1112,7 +1105,7 @@ def write_cluster_config( usage_lib.messages.usage.update_ray_yaml(yaml_path) # For TPU nodes. TPU VMs do not need TPU_NAME. - if tpu_utils.is_tpu(to_provision) and not tpu_utils.is_tpu_vm(to_provision): + if gcp_utils.is_tpu(to_provision) and not gcp_utils.is_tpu_vm(to_provision): tpu_name = resources_vars.get('tpu_name') if tpu_name is None: tpu_name = cluster_name @@ -1533,20 +1526,22 @@ def _query_head_ip_with_retries(cluster_yaml: str, @timeline.event -def get_node_ips(cluster_yaml: str, - expected_num_nodes: int, - handle: Optional[ - 'cloud_vm_ray_backend.CloudVmRayResourceHandle'] = None, - head_ip_max_attempts: int = 1, - worker_ip_max_attempts: int = 1, - get_internal_ips: bool = False) -> List[str]: +def get_node_ips( + cluster_yaml: str, + expected_num_nodes: int, + # TODO: remove this argument once we remove the legacy on-prem + # support. + handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle', + head_ip_max_attempts: int = 1, + worker_ip_max_attempts: int = 1, + get_internal_ips: bool = False) -> List[str]: """Returns the IPs of all nodes in the cluster, with head node at front. Args: cluster_yaml: Path to the cluster yaml. expected_num_nodes: Expected number of nodes in the cluster. - handle: Cloud VM Ray resource handle. It is only required for TPU VM or - on-prem clusters. + handle: Cloud VM Ray resource handle. It is only required for on-prem + clusters. head_ip_max_attempts: Max attempts to get head ip. worker_ip_max_attempts: Max attempts to get worker ips. get_internal_ips: Whether to get internal IPs. When False, it is still @@ -1557,34 +1552,21 @@ def get_node_ips(cluster_yaml: str, exceptions.FetchIPError: if we failed to get the IPs. e.reason is HEAD or WORKER. """ - # When ray up launches TPU VM Pod, Pod workers (except for the head) - # won't be connected to Ray cluster. Thus "ray get-worker-ips" - # won't work and we need to query the node IPs with gcloud as - # implmented in _get_tpu_vm_pod_ips. ray_config = common_utils.read_yaml(cluster_yaml) # Use the new provisioner for AWS. - if '.aws' in ray_config['provider']['module']: + provider_name = cluster_yaml_utils.get_provider_name(ray_config) + cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name) + assert cloud is not None, provider_name + + if cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT: metadata = provision_lib.get_cluster_info( - 'aws', ray_config['provider']['region'], ray_config['cluster_name']) + provider_name, ray_config['provider']['region'], + ray_config['cluster_name'], ray_config['provider']) if len(metadata.instances) < expected_num_nodes: # Simulate the exception when Ray head node is not up. raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD) return metadata.get_feasible_ips(get_internal_ips) - use_tpu_vm = ray_config['provider'].get('_has_tpus', False) - if use_tpu_vm: - assert expected_num_nodes == 1, ( - 'TPU VM only supports single node for now.') - assert handle is not None, 'handle is required for TPU VM.' - try: - ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips) - except exceptions.CommandError as e: - raise exceptions.FetchIPError( - exceptions.FetchIPError.Reason.HEAD) from e - if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources): - raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD) - return ips - if get_internal_ips: with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: ray_config['provider']['use_internal_ips'] = True @@ -1658,74 +1640,6 @@ def get_node_ips(cluster_yaml: str, return head_ip_list + worker_ips -@timeline.event -def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any], - get_internal_ips: bool = False) -> List[str]: - """Returns the IPs of all TPU VM Pod workers using gcloud.""" - - cluster_name = ray_config['cluster_name'] - zone = ray_config['provider']['availability_zone'] - query_cmd = (f'gcloud compute tpus tpu-vm list --filter=' - f'"(labels.ray-cluster-name={cluster_name})" ' - f'--zone={zone} --format="value(name)"') - returncode, stdout, stderr = log_lib.run_with_log(query_cmd, - '/dev/null', - shell=True, - stream_logs=False, - require_outputs=True) - subprocess_utils.handle_returncode( - returncode, - query_cmd, - 'Failed to run gcloud to get TPU VM IDs.', - stderr=stdout + stderr) - if len(stdout) == 0: - logger.debug('No TPU VMs found with cluster name ' - f'{cluster_name} in zone {zone}.') - if len(stdout.splitlines()) > 1: - # Rare case, this could mean resource leakage. Hint user. - logger.warning('Found more than one TPU VM/Pod with the same cluster ' - f'name {cluster_name} in zone {zone}.') - - all_ips = [] - for tpu_id in stdout.splitlines(): - tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}' - f' --zone {zone} --format=json') - returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd, - os.devnull, - shell=True, - stream_logs=False, - require_outputs=True) - subprocess_utils.handle_returncode( - returncode, - tpuvm_cmd, - 'Failed to run gcloud tpu-vm describe.', - stderr=stdout + stderr) - - tpuvm_json = json.loads(stdout) - if tpuvm_json['state'] != 'READY': - # May be a leaked preempted resource, or terminated by user in the - # console, or still in the process of being created. - ux_utils.console_newline() - logger.debug(f'TPU VM {tpu_id} is in {tpuvm_json["state"]} ' - 'state. Skipping IP query... ' - 'Hint: make sure it is not leaked.') - continue - - ips = [] - for endpoint in tpuvm_json['networkEndpoints']: - # Note: if TPU VM is being preempted, its IP field may not exist. - # We use get() to avoid KeyError. - if get_internal_ips: - ip = endpoint.get('ipAddress', None) - else: - ip = endpoint['accessConfig'].get('externalIp', None) - if ip is not None: - ips.append(ip) - all_ips.extend(ips) - - return all_ips - - def check_network_connection(): # Tolerate 3 retries as it is observed that connections can fail. adapter = adapters.HTTPAdapter(max_retries=retry_lib.Retry(total=3)) @@ -1839,15 +1753,12 @@ def _query_cluster_status_via_cloud_api( # correctly yet. ray_config = common_utils.read_yaml(handle.cluster_yaml) provider_config = ray_config['provider'] - region = provider_config.get('region') or provider_config.get('location') - zone = ray_config['provider'].get('availability_zone') - kwargs = {} - if isinstance(handle.launched_resources.cloud, clouds.GCP): - kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False) # Query the cloud provider. # TODO(suquark): move implementations of more clouds here - if isinstance(handle.launched_resources.cloud, clouds.AWS): + cloud = handle.launched_resources.cloud + assert cloud is not None, handle + if cloud.STATUS_VERSION >= clouds.StatusVersion.SKYPILOT: cloud_name = repr(handle.launched_resources.cloud) try: node_status_dict = provision_lib.query_instances( @@ -1864,30 +1775,12 @@ def _query_cluster_status_via_cloud_api( f'status: {common_utils.format_exception(e, use_bracket=True)}' ) else: - node_statuses = handle.launched_resources.cloud.query_status( + region = provider_config.get('region') or provider_config.get( + 'location') + zone = ray_config['provider'].get('availability_zone') + node_statuses = cloud.query_status( cluster_name_on_cloud, - tag_filter_for_cluster(cluster_name_on_cloud), region, zone, - **kwargs) - # GCP does not clean up preempted TPU VMs. We remove it ourselves. - # TODO(wei-lin): handle multi-node cases. - # TODO(zhwu): this should be moved into the GCP class, after we refactor - # the cluster termination, as the preempted TPU VM should always be - # removed. - if kwargs.get('use_tpu_vm', False) and len(node_statuses) == 0: - logger.debug( - f'Terminating preempted TPU VM cluster {cluster_name_in_hint}') - backend = backends.CloudVmRayBackend() - # Do not use refresh cluster status during teardown, as that will - # cause infinite recursion by calling cluster status refresh - # again. - - # Post teardown cleanup be done later in this function, which will - # remove the cluster entry from the status table & the ssh config file. - backend.teardown_no_lock(handle, - terminate=True, - purge=False, - post_teardown_cleanup=False, - refresh_cluster_status=False) + tag_filter_for_cluster(cluster_name_on_cloud), region, zone) return node_statuses @@ -2036,9 +1929,8 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: # in the worst case we time out in the `ray status` SSH command # below. external_ips = handle.cached_external_ips - # This happens to a stopped TPU VM as we use gcloud to query the IP. - # Or user interrupt the `sky launch` process before the first time - # resources handle is written back to local database. + # This happens when user interrupt the `sky launch` process before + # the first time resources handle is written back to local database. # This is helpful when user interrupt after the provision is done # and before the skylet is restarted. After #2304 is merged, this # helps keep the cluster status to INIT after `sky status -r`, so @@ -2076,13 +1968,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: f'-- stdout --\n{output}\n-- stderr --\n{stderr}') ready_head, ready_workers = _count_healthy_nodes_from_ray(output) - - if ready_head + ready_workers == handle.launched_nodes: + total_nodes = handle.launched_nodes * handle.num_ips_per_node + if ready_head + ready_workers == total_nodes: return True raise RuntimeError( f'Refreshing status ({cluster_name!r}): ray status not showing ' f'all nodes ({ready_head + ready_workers}/' - f'{handle.launched_nodes}); output: {output}; stderr: {stderr}') + f'{total_nodes}); output: {output}; stderr: {stderr}') except exceptions.FetchIPError: logger.debug( f'Refreshing status ({cluster_name!r}) failed to get IPs.') diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index ac9e9e11d40..1e6085ec222 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1,5 +1,4 @@ """Backend: runs on cloud virtual machines, managed by Ray.""" -import ast import copy import enum import getpass @@ -39,8 +38,10 @@ from sky.backends import backend_utils from sky.backends import onprem_utils from sky.backends import wheel_utils +from sky.clouds.utils import gcp_utils from sky.data import data_utils from sky.data import storage as storage_lib +from sky.provision import common as provision_common from sky.provision import instance_setup from sky.provision import metadata_utils from sky.provision import provisioner @@ -57,7 +58,6 @@ from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import timeline -from sky.utils import tpu_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -620,299 +620,30 @@ class GangSchedulingStatus(enum.Enum): HEAD_FAILED = 2 -class RetryingVmProvisioner(object): - """A provisioner that retries different cloud/regions/zones.""" - - class ToProvisionConfig: - """Resources to be provisioned.""" - - def __init__( - self, - cluster_name: str, - resources: resources_lib.Resources, - num_nodes: int, - prev_cluster_status: Optional[status_lib.ClusterStatus], - prev_handle: Optional['CloudVmRayResourceHandle'], - ) -> None: - assert cluster_name is not None, 'cluster_name must be specified.' - self.cluster_name = cluster_name - self.resources = resources - self.num_nodes = num_nodes - self.prev_cluster_status = prev_cluster_status - self.prev_handle = prev_handle - - def __init__(self, - log_dir: str, - dag: 'dag.Dag', - optimize_target: 'optimizer.OptimizeTarget', - requested_features: Set[clouds.CloudImplementationFeatures], - local_wheel_path: pathlib.Path, - wheel_hash: str, - blocked_resources: Optional[Iterable[ - resources_lib.Resources]] = None): - self._blocked_resources: Set[resources_lib.Resources] = set() - if blocked_resources: - # blocked_resources is not None and not empty. - self._blocked_resources.update(blocked_resources) - - self.log_dir = os.path.expanduser(log_dir) - self._dag = dag - self._optimize_target = optimize_target - self._requested_features = requested_features - self._local_wheel_path = local_wheel_path - self._wheel_hash = wheel_hash - - def _update_blocklist_on_gcp_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): - style = colorama.Style - assert zones and len(zones) == 1, zones - zone = zones[0] - splits = stderr.split('\n') - exception_list = [s for s in splits if s.startswith('Exception: ')] - httperror_str = [ - s for s in splits - # GCP API errors - if s.startswith('googleapiclient.errors.HttpError: ') or - # 'SKYPILOT_ERROR_NO_NODES_LAUNCHED': skypilot's changes to the - # underlying provisioner provider; for errors prior to provisioning - # like VPC setup. - 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' in s - ] - if len(exception_list) == 1: - # Parse structured response {'errors': [...]}. - exception_str = exception_list[0][len('Exception: '):] - try: - exception_dict = ast.literal_eval(exception_str) - except Exception as e: - if 'wait_ready timeout exceeded' in exception_str: - # This error seems to occur when the provisioning process - # went through partially (e.g., for spot, initial - # provisioning succeeded, but while waiting for ssh/setting - # up it got preempted). - logger.error('Got the following exception, continuing: ' - f'{exception_list[0]}') - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - return - raise RuntimeError( - f'Failed to parse exception: {exception_str}') from e - # TPU VM returns a different structured response. - if 'errors' not in exception_dict: - exception_dict = {'errors': [exception_dict]} - for error in exception_dict['errors']: - code = error['code'] - message = error['message'] - logger.warning(f'Got return code {code} in {zone.name} ' - f'{style.DIM}(message: {message})' - f'{style.RESET_ALL}') - if code == 'QUOTA_EXCEEDED': - if '\'GPUS_ALL_REGIONS\' exceeded' in message: - # Global quota. All regions in GCP will fail. Ex: - # Quota 'GPUS_ALL_REGIONS' exceeded. Limit: 1.0 - # globally. - # This skip is only correct if we implement "first - # retry the region/zone of an existing cluster with the - # same name" correctly. - self._blocked_resources.add( - launchable_resources.copy(region=None, zone=None)) - else: - # Per region. Ex: Quota 'CPUS' exceeded. Limit: 24.0 - # in region us-west1. - self._blocked_resources.add( - launchable_resources.copy(zone=None)) - elif code in [ - 'ZONE_RESOURCE_POOL_EXHAUSTED', - 'ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS', - 'UNSUPPORTED_OPERATION' - ]: # Per zone. - # Return codes can be found at https://cloud.google.com/compute/docs/troubleshooting/troubleshooting-vm-creation # pylint: disable=line-too-long - # However, UNSUPPORTED_OPERATION is observed empirically - # when VM is preempted during creation. This seems to be - # not documented by GCP. - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - elif code in ['RESOURCE_NOT_READY']: - # This code is returned when the VM is still STOPPING. - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - elif code in [3, 8, 9]: - # Error code 3 means TPU is preempted during creation. - # Example: - # {'code': 3, 'message': 'Cloud TPU received a bad request. update is not supported while in state PREEMPTED [EID: 0x73013519f5b7feb2]'} # pylint: disable=line-too-long - # Error code 8 means TPU resources is out of - # capacity. Example: - # {'code': 8, 'message': 'There is no more capacity in the zone "europe-west4-a"; you can try in another zone where Cloud TPU Nodes are offered (see https://cloud.google.com/tpu/docs/regions) [EID: 0x1bc8f9d790be9142]'} # pylint: disable=line-too-long - # Error code 9 means TPU resources is insufficient reserved - # capacity. Example: - # {'code': 9, 'message': 'Insufficient reserved capacity. Contact customer support to increase your reservation. [EID: 0x2f8bc266e74261a]'} # pylint: disable=line-too-long - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - elif code == 'RESOURCE_NOT_FOUND': - # https://github.com/skypilot-org/skypilot/issues/1797 - # In the inner provision loop we have used retries to - # recover but failed. This indicates this zone is most - # likely out of capacity. The provision loop will terminate - # any potentially live VMs before moving onto the next - # zone. - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - else: - assert False, error - elif len(httperror_str) >= 1: - messages = '\n\t'.join(httperror_str) - logger.warning( - f'Got error(s):\n\t{style.DIM}{messages}{style.RESET_ALL}') - if ('SKYPILOT_ERROR_NO_NODES_LAUNCHED: No VPC with name ' - in stderr): - # User has specified a VPC that does not exist. On GCP, VPC is - # global. So we skip the entire cloud. - self._blocked_resources.add( - launchable_resources.copy(region=None, zone=None)) - elif ('SKYPILOT_ERROR_NO_NODES_LAUNCHED: No subnet for region ' - in stderr): - if (region.name == 'us-central2' and - launchable_resources.accelerators is not None and - any(acc.lower().startswith('tpu-v4') - for acc in launchable_resources.accelerators)): - # us-central2 is a TPU v4 only region. The subnet for - # this region may not exist when the user does not have - # the TPU v4 quota. We should skip this region. - logger.warning('Please check if you have TPU v4 quotas ' - f'in {region.name}.') - self._blocked_resources.add( - launchable_resources.copy(region=region.name, zone=None)) - elif ('Requested disk size cannot be smaller than the image size' - in httperror_str[0]): - logger.info('Skipping all regions due to disk size issue.') - self._blocked_resources.add( - launchable_resources.copy(region=None, zone=None)) - elif ('Policy update access denied.' in httperror_str[0] or - 'IAM_PERMISSION_DENIED' in httperror_str[0]): - logger.info( - 'Skipping all regions due to service account not ' - 'having the required permissions and the user ' - 'account does not have enough permission to ' - 'update it. Please contact your administrator and ' - 'check out: https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html\n' # pylint: disable=line-too-long - f'Details: {httperror_str[0]}') - self._blocked_resources.add( - launchable_resources.copy(region=None, zone=None)) - else: - # Parse HttpError for unauthorized regions. Example: - # googleapiclient.errors.HttpError: - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - else: - # No such structured error response found. - assert not exception_list, stderr - if 'Head node fetch timed out' in stderr: - # Example: click.exceptions.ClickException: Head node fetch - # timed out. Failed to create head node. - # This is a transient error, but we have retried in need_ray_up - # and failed. So we skip this zone. - logger.info('Got \'Head node fetch timed out\' in ' - f'{zone.name}.') - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - elif 'was not found' in stderr: - # Example: The resource - # 'projects//zones/zone/acceleratorTypes/nvidia-tesla-v100' - # was not found. - logger.warning(f'Got \'resource not found\' in {zone.name}.') - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - elif 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - else: - logger.info('====== stdout ======') - for s in stdout.split('\n'): - print(s) - logger.info('====== stderr ======') - for s in splits: - print(s) - - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') +def _add_to_blocked_resources(blocked_resources: Set['resources_lib.Resources'], + resources: 'resources_lib.Resources') -> None: + # If the resources is already blocked by blocked_resources, we don't need to + # add it again to avoid duplicated entries. + for r in blocked_resources: + if resources.should_be_blocked_by(r): + return + blocked_resources.add(resources) - def _update_blocklist_on_aws_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): - assert launchable_resources.is_launchable() - assert zones is not None, 'AWS should always have zones.' - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - # 'An error occurred': boto3 errors - # 'SKYPILOT_ERROR_NO_NODES_LAUNCHED': skypilot's changes to the AWS - # node provider; for errors prior to provisioning like VPC - # setup. - if 'An error occurred' in s or - 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' in s - ] - # Need to handle boto3 printing error but its retry succeeded: - # error occurred (Unsupported) .. not supported in your requested - # Availability Zone (us-west-2d)...retrying - # --> it automatically succeeded in another zone - # --> failed in [4/7] Running initialization commands due to user cmd - # In this case, we should error out. - head_node_up = any( - line.startswith('<1/1> Setting up head node') - for line in stdout_splits + stderr_splits) - if not errors or head_node_up: - if 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - # TODO: Got transient 'Failed to create security group' that goes - # away after a few minutes. Should we auto retry other regions, or - # let the user retry. - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') +class FailoverCloudErrorHandlerV1: + """Handles errors during provisioning and updates the blocked_resources. - # Fill in the zones field in the Region. - region_with_zones_list = clouds.AWS.regions_with_offering( - launchable_resources.instance_type, - launchable_resources.accelerators, - launchable_resources.use_spot, - region.name, - zone=None) - assert len(region_with_zones_list) == 1, region_with_zones_list - region_with_zones = region_with_zones_list[0] - assert region_with_zones.zones is not None, region_with_zones - if set(zones) == set(region_with_zones.zones): - # The underlying AWS NodeProvider will try all specified zones of a - # region. (Each boto3 request takes one zone.) - logger.warning(f'Got error(s) in all zones of {region.name}:') - else: - zones_str = ', '.join(z.name for z in zones) - logger.warning(f'Got error(s) in {zones_str}:') - messages = '\n\t'.join(errors) - logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - for zone in zones: - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) + Deprecated: Newly added cloud should use the FailoverCloudErrorHandlerV2, + which is more robust by parsing the errors raised by the cloud's API in a + more structured way, instead of directly based on the stdout and stderr. + """ - def _update_blocklist_on_azure_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): + @staticmethod + def _azure_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): del zones # Unused. # The underlying ray autoscaler will try all zones of a region at once. style = colorama.Style @@ -932,7 +663,8 @@ def _update_blocklist_on_azure_error( # and failed. So we skip this region. logger.info('Got \'Head node fetch timed out\' in ' f'{region.name}.') - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(region=region.name)) elif 'rsync: command not found' in stderr: with ux_utils.print_exception_no_traceback(): @@ -951,15 +683,19 @@ def _update_blocklist_on_azure_error( messages = '\n\t'.join(errors) logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') if any('(ReadOnlyDisabledSubscription)' in s for s in errors): - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, resources_lib.Resources(cloud=clouds.Azure())) else: - self._blocked_resources.add(launchable_resources.copy(zone=None)) - - def _update_blocklist_on_lambda_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) + + @staticmethod + def _lambda_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): del zones # Unused. style = colorama.Style stdout_splits = stdout.split('\n') @@ -986,19 +722,22 @@ def _update_blocklist_on_lambda_error( logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - self._blocked_resources.add(launchable_resources.copy(zone=None)) + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) # Sometimes, LambdaCloudError will list available regions. for e in errors: if e.find('Regions with capacity available:') != -1: for r in clouds.Lambda.regions(): if e.find(r.name) == -1: - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(region=r.name, zone=None)) - def _update_blocklist_on_kubernetes_error( - self, launchable_resources: 'resources_lib.Resources', region, - zones, stdout, stderr): + @staticmethod + def _kubernetes_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region, zones, stdout, stderr): del zones # Unused. style = colorama.Style stdout_splits = stdout.split('\n') @@ -1022,11 +761,13 @@ def _update_blocklist_on_kubernetes_error( logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - self._blocked_resources.add(launchable_resources.copy(zone=None)) + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) - def _update_blocklist_on_scp_error( - self, launchable_resources: 'resources_lib.Resources', region, - zones, stdout, stderr): + @staticmethod + def _scp_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', region, + zones, stdout, stderr): del zones # Unused. style = colorama.Style stdout_splits = stdout.split('\n') @@ -1053,20 +794,24 @@ def _update_blocklist_on_scp_error( logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - self._blocked_resources.add(launchable_resources.copy(zone=None)) + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) # Sometimes, SCPError will list available regions. for e in errors: if e.find('Regions with capacity available:') != -1: for r in clouds.SCP.regions(): if e.find(r.name) == -1: - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(region=r.name, zone=None)) - def _update_blocklist_on_ibm_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): + @staticmethod + def _ibm_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): style = colorama.Style stdout_splits = stdout.split('\n') @@ -1094,13 +839,15 @@ def _update_blocklist_on_ibm_error( logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') for zone in zones: # type: ignore[union-attr] - self._blocked_resources.add( - launchable_resources.copy(zone=zone.name)) - - def _update_blocklist_on_local_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=zone.name)) + + @staticmethod + def _local_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): del zones # Unused. style = colorama.Style stdout_splits = stdout.split('\n') @@ -1127,14 +874,17 @@ def _update_blocklist_on_local_error( logger.warning('Got error(s) on local cluster:') messages = '\n\t'.join(errors) logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(region=region.name, zone=None)) # Apr, 2023 by Hysun(hysun.he@oracle.com): Added support for OCI - def _update_blocklist_on_oci_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): + @staticmethod + def _oci_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): style = colorama.Style stdout_splits = stdout.split('\n') @@ -1168,11 +918,14 @@ def _update_blocklist_on_oci_error( if zones is not None: for zone in zones: - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(zone=zone.name)) - def _update_blocklist_on_error( - self, launchable_resources: 'resources_lib.Resources', + @staticmethod + def update_blocklist_on_error( + blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], stdout: Optional[str], stderr: Optional[str]) -> bool: """Handles cloud-specific errors and updates the block list. @@ -1194,31 +947,24 @@ def _update_blocklist_on_error( assert stderr is None, stderr if zones is not None: for zone in zones: - self._blocked_resources.add( + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(zone=zone.name)) return False # definitely_no_nodes_launched assert stdout is not None and stderr is not None, (stdout, stderr) # TODO(zongheng): refactor into Cloud interface? - handlers = { - clouds.AWS: self._update_blocklist_on_aws_error, - clouds.Azure: self._update_blocklist_on_azure_error, - clouds.GCP: self._update_blocklist_on_gcp_error, - clouds.Lambda: self._update_blocklist_on_lambda_error, - clouds.IBM: self._update_blocklist_on_ibm_error, - clouds.SCP: self._update_blocklist_on_scp_error, - clouds.Local: self._update_blocklist_on_local_error, - clouds.Kubernetes: self._update_blocklist_on_kubernetes_error, - clouds.OCI: self._update_blocklist_on_oci_error, - } cloud = launchable_resources.cloud - cloud_type = type(cloud) - if cloud_type not in handlers: + handler = getattr(FailoverCloudErrorHandlerV1, + f'_{str(cloud).lower()}_handler') + if handler is None: raise NotImplementedError( f'Cloud {cloud} unknown, or has not added ' - 'support for parsing and handling provision failures.') - handler = handlers[cloud_type] - handler(launchable_resources, region, zones, stdout, stderr) + 'support for parsing and handling provision failures. ' + 'Please implement a handler in FailoverCloudErrorHandlerV1 when' + 'ray-autoscaler-based provisioner is used for the cloud.') + handler(blocked_resources, launchable_resources, region, zones, stdout, + stderr) stdout_splits = stdout.split('\n') stderr_splits = stderr.split('\n') @@ -1240,6 +986,236 @@ def _update_blocklist_on_error( return definitely_no_nodes_launched + +class FailoverCloudErrorHandlerV2: + """Handles errors during provisioning and updates the blocked_resources. + + This is a more robust version of FailoverCloudErrorHandlerV1. V2 parses + the errors raised by the cloud's API using the exception, instead of the + stdout and stderr. + """ + + @staticmethod + def _gcp_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', zones: List['clouds.Zone'], + err: Exception): + assert zones and len(zones) == 1, zones + zone = zones[0] + + if not isinstance(err, provision_common.ProvisionError): + logger.warning(f'{colorama.Style.DIM}Got an unparsed error: {err}; ' + f'blocking resources by its zone {zone.name}' + f'{colorama.Style.RESET_ALL}') + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=zone.name)) + return + errors = err.errors + + for e in errors: + code = e['code'] + message = e['message'] + + if code in ('QUOTA_EXCEEDED', 'quotaExceeded'): + if '\'GPUS_ALL_REGIONS\' exceeded' in message: + # Global quota. All regions in GCP will fail. Ex: + # Quota 'GPUS_ALL_REGIONS' exceeded. Limit: 1.0 + # globally. + # This skip is only correct if we implement "first + # retry the region/zone of an existing cluster with the + # same name" correctly. + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=None, zone=None)) + else: + # Per region. Ex: Quota 'CPUS' exceeded. Limit: 24.0 + # in region us-west1. + _add_to_blocked_resources( + blocked_resources, launchable_resources.copy(zone=None)) + elif code in [ + 'ZONE_RESOURCE_POOL_EXHAUSTED', + 'ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS', + 'UNSUPPORTED_OPERATION', + 'insufficientCapacity', + ]: # Per zone. + # Return codes can be found at https://cloud.google.com/compute/docs/troubleshooting/troubleshooting-vm-creation # pylint: disable=line-too-long + # However, UNSUPPORTED_OPERATION is observed empirically + # when VM is preempted during creation. This seems to be + # not documented by GCP. + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + elif code in ['RESOURCE_NOT_READY']: + # This code is returned when the VM is still STOPPING. + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + elif code in [3, 8, 9]: + # Error code 3 means TPU is preempted during creation. + # Example: + # {'code': 3, 'message': 'Cloud TPU received a bad request. update is not supported while in state PREEMPTED [EID: 0x73013519f5b7feb2]'} # pylint: disable=line-too-long + # Error code 8 means TPU resources is out of + # capacity. Example: + # {'code': 8, 'message': 'There is no more capacity in the zone "europe-west4-a"; you can try in another zone where Cloud TPU Nodes are offered (see https://cloud.google.com/tpu/docs/regions) [EID: 0x1bc8f9d790be9142]'} # pylint: disable=line-too-long + # Error code 9 means TPU resources is insufficient reserved + # capacity. Example: + # {'code': 9, 'message': 'Insufficient reserved capacity. Contact customer support to increase your reservation. [EID: 0x2f8bc266e74261a]'} # pylint: disable=line-too-long + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + elif code == 'RESOURCE_NOT_FOUND': + # https://github.com/skypilot-org/skypilot/issues/1797 + # In the inner provision loop we have used retries to + # recover but failed. This indicates this zone is most + # likely out of capacity. The provision loop will terminate + # any potentially live VMs before moving onto the next + # zone. + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + elif code == 'VPC_NOT_FOUND': + # User has specified a VPC that does not exist. On GCP, VPC is + # global. So we skip the entire cloud. + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=None, zone=None)) + elif code == 'SUBNET_NOT_FOUND_FOR_VPC': + if (any(acc.lower().startswith('tpu-v4') + for acc in launchable_resources.accelerators.keys()) and + region.name == 'us-central2'): + # us-central2 is a TPU v4 only region. The subnet for + # this region may not exist when the user does not have + # the TPU v4 quota. We should skip this region. + logger.warning('Please check if you have TPU v4 quotas ' + f'in {region.name}.') + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=region.name, zone=None)) + elif code == 'type.googleapis.com/google.rpc.QuotaFailure': + # TPU VM pod specific error. + if 'in region' in message: + # Example: + # "Quota 'TPUV2sPreemptiblePodPerProjectPerRegionForTPUAPI' + # exhausted. Limit 32 in region europe-west4" + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=region.name, + zone=None)) + elif 'in zone' in message: + # Example: + # "Quota 'TPUV2sPreemptiblePodPerProjectPerZoneForTPUAPI' + # exhausted. Limit 32 in zone europe-west4-a" + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + + elif 'Requested disk size cannot be smaller than the image size' in message: + logger.info('Skipping all regions due to disk size issue.') + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=None, zone=None)) + elif 'Policy update access denied.' in message or code == 'IAM_PERMISSION_DENIED': + logger.info( + 'Skipping all regions due to service account not ' + 'having the required permissions and the user ' + 'account does not have enough permission to ' + 'update it. Please contact your administrator and ' + 'check out: https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html\n' # pylint: disable=line-too-long + f'Details: {message}') + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(region=None, zone=None)) + elif 'is not found or access is unauthorized' in message: + # Parse HttpError for unauthorized regions. Example: + # googleapiclient.errors.HttpError: + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + else: + logger.debug('Got unparsed error blocking resources by zone: ' + f'{e}.') + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + + @staticmethod + def _default_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], + error: Exception) -> None: + """Handles cloud-specific errors and updates the block list.""" + del region # Unused. + logger.debug( + f'Got error(s) in {launchable_resources.cloud}:' + f'{common_utils.format_exception(error, use_bracket=True)}') + if zones is None: + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) + else: + for zone in zones: + _add_to_blocked_resources( + blocked_resources, + launchable_resources.copy(zone=zone.name)) + + @staticmethod + def update_blocklist_on_error( + blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + error: Exception) -> None: + """Handles cloud-specific errors and updates the block list.""" + cloud = launchable_resources.cloud + handler = getattr(FailoverCloudErrorHandlerV2, + f'_{str(cloud).lower()}_handler', + FailoverCloudErrorHandlerV2._default_handler) + handler(blocked_resources, launchable_resources, region, zones, error) + + +class RetryingVmProvisioner(object): + """A provisioner that retries different cloud/regions/zones.""" + + class ToProvisionConfig: + """Resources to be provisioned.""" + + def __init__( + self, + cluster_name: str, + resources: resources_lib.Resources, + num_nodes: int, + prev_cluster_status: Optional[status_lib.ClusterStatus], + prev_handle: Optional['CloudVmRayResourceHandle'], + ) -> None: + assert cluster_name is not None, 'cluster_name must be specified.' + self.cluster_name = cluster_name + self.resources = resources + self.num_nodes = num_nodes + self.prev_cluster_status = prev_cluster_status + self.prev_handle = prev_handle + + def __init__(self, + log_dir: str, + dag: 'dag.Dag', + optimize_target: 'optimizer.OptimizeTarget', + requested_features: Set[clouds.CloudImplementationFeatures], + local_wheel_path: pathlib.Path, + wheel_hash: str, + blocked_resources: Optional[Iterable[ + resources_lib.Resources]] = None): + self._blocked_resources: Set[resources_lib.Resources] = set() + if blocked_resources: + # blocked_resources is not None and not empty. + self._blocked_resources.update(blocked_resources) + + self.log_dir = os.path.expanduser(log_dir) + self._dag = dag + self._optimize_target = optimize_target + self._requested_features = requested_features + self._local_wheel_path = local_wheel_path + self._wheel_hash = wheel_hash + def _yield_zones( self, to_provision: resources_lib.Resources, num_nodes: int, cluster_name: str, @@ -1611,26 +1587,26 @@ def _retry_zones( global_user_state.set_owner_identity_for_cluster( cluster_name, cloud_user_identity) - if isinstance(to_provision.cloud, clouds.AWS): - # Use the new provisioner for AWS. + if (to_provision.cloud.PROVISIONER_VERSION == + clouds.ProvisionerVersion.SKYPILOT): # TODO (suquark): Gradually move the other clouds to # the new provisioner once they are ready. assert to_provision.region == region.name, (to_provision, region) num_nodes = handle.launched_nodes - provision_record = provisioner.bulk_provision( - to_provision.cloud, - region, - zones, - provisioner.ClusterName(cluster_name, - handle.cluster_name_on_cloud), - num_nodes=num_nodes, - cluster_yaml=handle.cluster_yaml, - is_prev_cluster_healthy=is_prev_cluster_healthy, - log_dir=self.log_dir) - # NOTE: We will handle the logic of '_ensure_cluster_ray_started' - # in 'provision_utils.post_provision_runtime_setup()' in the caller. - if provision_record is not None: + try: + provision_record = provisioner.bulk_provision( + to_provision.cloud, + region, + zones, + provisioner.ClusterName(cluster_name, + handle.cluster_name_on_cloud), + num_nodes=num_nodes, + cluster_yaml=handle.cluster_yaml, + is_prev_cluster_healthy=is_prev_cluster_healthy, + log_dir=self.log_dir) + # NOTE: We will handle the logic of '_ensure_cluster_ray_started' + # in 'provision_utils.post_provision_runtime_setup()' in the caller. resources_vars = ( to_provision.cloud.make_deploy_resources_variables( to_provision, handle.cluster_name_on_cloud, region, @@ -1638,37 +1614,35 @@ def _retry_zones( config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle - return config_dict - - # NOTE: We try to cleanup the cluster even if the previous - # cluster does not exist. Also we are fast at - # cleaning up clusters now if there is no existing node.. - CloudVmRayBackend().post_teardown_cleanup( - handle, terminate=not is_prev_cluster_healthy) - # TODO(suquark): other clouds may have different zone - # blocking strategy. See '_update_blocklist_on_error' - # for details. - if zones is None: - self._blocked_resources.add(to_provision.copy(zone=None)) + tpu_name = config_dict.get('tpu_name') + if tpu_name is None: + return config_dict + # tpu_name will only be set when TPU node (not TPU VM) + # is required. + logger.info( + f'{colorama.Style.BRIGHT}Provisioning TPU node on ' + f'{to_provision.cloud} ' + f'{region.name}{colorama.Style.RESET_ALL}{zone_str}') + + success = self._try_provision_tpu(to_provision, config_dict) + if success: + return config_dict + raise RuntimeError('Failed to provision TPU node.') + except Exception as e: # pylint: disable=broad-except + # NOTE: We try to cleanup the cluster even if the previous + # cluster does not exist. Also we are fast at + # cleaning up clusters now if there is no existing node.. + CloudVmRayBackend().post_teardown_cleanup( + handle, terminate=not is_prev_cluster_healthy) + # TODO(suquark): other clouds may have different zone + # blocking strategy. See '_update_blocklist_on_error' + # for details. + FailoverCloudErrorHandlerV2.update_blocklist_on_error( + self._blocked_resources, to_provision, region, zones, e) continue - for zone in zones: - self._blocked_resources.add( - to_provision.copy(zone=zone.name)) - continue # NOTE: The code below in the loop should not be reachable # with the new provisioner. - tpu_name = config_dict.get('tpu_name') - if tpu_name is not None: - logger.info( - f'{colorama.Style.BRIGHT}Provisioning TPU on ' - f'{to_provision.cloud} ' - f'{region.name}{colorama.Style.RESET_ALL}{zone_str}') - - success = self._try_provision_tpu(to_provision, config_dict) - if not success: - continue - logging_info = { 'cluster_name': cluster_name, 'region_name': region.name, @@ -1730,15 +1704,21 @@ def _retry_zones( definitely_no_nodes_launched = False if status == GangSchedulingStatus.HEAD_FAILED: # ray up failed for the head node. - definitely_no_nodes_launched = self._update_blocklist_on_error( - to_provision, region, zones, stdout, stderr) + definitely_no_nodes_launched = FailoverCloudErrorHandlerV1.update_blocklist_on_error( + self._blocked_resources, to_provision, region, zones, + stdout, stderr) else: # gang scheduling failed. assert status == GangSchedulingStatus.GANG_FAILED, status # The stdout/stderr of ray up is not useful here, since # head node is successfully provisioned. - definitely_no_nodes_launched = self._update_blocklist_on_error( - to_provision, region, zones=zones, stdout=None, stderr=None) + definitely_no_nodes_launched = FailoverCloudErrorHandlerV1.update_blocklist_on_error( + self._blocked_resources, + to_provision, + region, + zones=zones, + stdout=None, + stderr=None) # GANG_FAILED means head is up, workers failed. assert definitely_no_nodes_launched is False, ( definitely_no_nodes_launched) @@ -1792,49 +1772,6 @@ def _retry_zones( raise exceptions.ResourcesUnavailableError( message, no_failover=is_prev_cluster_healthy) - def _tpu_pod_setup(self, cluster_yaml: str, - cluster_handle: 'backends.CloudVmRayResourceHandle'): - """Completes setup and start Ray cluster on TPU VM Pod nodes. - - This is a workaround for Ray Autoscaler where `ray up` does not - run setup or launch ray cluster on TPU VM Pod nodes. - """ - ssh_credentials = backend_utils.ssh_credential_from_yaml( - cluster_yaml, cluster_handle.docker_user) - # Always fetch the latest IPs, since GCP may change them without notice - cluster_handle.update_cluster_ips() - all_ips = cluster_handle.external_ips() - num_tpu_devices = tpu_utils.get_num_tpu_devices( - cluster_handle.launched_resources) - if all_ips is None or len(all_ips) != num_tpu_devices: - raise RuntimeError( - f'Nodes IPs: {all_ips} does not' - f'match number of TPU devices: {num_tpu_devices}.') - - # Get the private IP of head node for connecting Ray cluster. - head_runner = command_runner.SSHCommandRunner( - all_ips[0], port=cluster_handle.head_ssh_port, **ssh_credentials) - cmd_str = 'python3 -c \"import ray; print(ray._private.services.get_node_ip_address())\"' # pylint: disable=line-too-long - rc, stdout, stderr = head_runner.run(cmd_str, - require_outputs=True, - stream_logs=False) - subprocess_utils.handle_returncode( - rc, - cmd_str, - 'Failed to get private IP from head node.', - stderr=stdout + stderr) - head_ip_private = stdout.strip() - - ray_config = common_utils.read_yaml(cluster_yaml) - worker_start_ray_commands = [f'echo "export RAY_HEAD_IP={head_ip_private}" >> ~/.bashrc && source ~/.bashrc'] # pylint: disable=line-too-long - worker_start_ray_commands += ray_config['worker_start_ray_commands'] - - # Setup TPU VM Pod workers and launch Ray cluster. - onprem_utils.do_filemounts_and_setup_on_local_workers( - cluster_yaml, - worker_ips=all_ips[1:], - extra_setup_cmds=worker_start_ray_commands) - # TODO(suquark): Deprecate this method # once the `provision_utils` is adopted for all the clouds. @timeline.event @@ -1948,43 +1885,6 @@ def need_ray_up( 'timeout.') return True - if isinstance(to_provision_cloud, clouds.GCP): - if ('Quota exceeded for quota metric \'List requests\' and ' - 'limit \'List requests per minute\'' in stderr): - logger.info( - 'Retrying due to list request rate limit exceeded.') - return True - - # https://github.com/skypilot-org/skypilot/issues/2666 - if ('Head node fetch timed out. Failed to create head node.' - in stderr): - logger.info( - 'Retrying head node provisioning due to head fetching ' - 'timeout.') - return True - - # https://github.com/skypilot-org/skypilot/issues/1797 - # "The resource 'projects/xxx/zones/us-central1-b/instances/ray-yyy-head--compute' was not found" # pylint: disable=line-too-long - pattern = (r'\'code\': \'RESOURCE_NOT_FOUND\'.*The resource' - r'.*instances\/.*-compute\' was not found') - result = re.search(pattern, stderr) - if result is not None: - # Retry. Unlikely will succeed if it's due to no capacity. - logger.info('Retrying due to the possibly transient ' - 'RESOURCE_NOT_FOUND error.') - logger.debug(f'-- Stderr --\n{stderr}\n ----') - return True - - # "The resource 'projects/skypilot-375900/regions/us-central1/subnetworks/default' is not ready". Details: "[{'message': "The resource 'projects/xxx/regions/us-central1/subnetworks/default' is not ready", 'domain': 'global', 'reason': 'resourceNotReady'}]"> # pylint: disable=line-too-long - pattern = (r'is not ready(.*)\'reason\': \'resourceNotReady\'') - result = re.search(pattern, stderr) - if result is not None: - # Retry. Unlikely will succeed if it's due to no capacity. - logger.info('Retrying due to the possibly transient ' - 'resourceNotReady error.') - logger.debug(f'-- Stderr --\n{stderr}\n ----') - return True - if isinstance(to_provision_cloud, clouds.Lambda): if 'Your API requests are being rate limited.' in stderr: logger.info( @@ -2033,12 +1933,6 @@ def need_ray_up( if returncode != 0: return GangSchedulingStatus.HEAD_FAILED, stdout, stderr, None, None - resources = cluster_handle.launched_resources - if tpu_utils.is_tpu_vm_pod(resources): - logger.info(f'{style.BRIGHT}Setting up TPU VM Pod workers...' - f'{style.RESET_ALL}') - self._tpu_pod_setup(cluster_config_file, cluster_handle) - # Only 1 node or head node provisioning failure. if cluster_handle.launched_nodes == 1 and returncode == 0: # Optimization: Try parse head ip from 'ray up' stdout. @@ -2228,7 +2122,8 @@ def provision_with_retries( # The exceptions above should be applicable to the whole # cloud, so we do add the cloud to the blocked resources. logger.warning(common_utils.format_exception(e)) - self._blocked_resources.add( + _add_to_blocked_resources( + self._blocked_resources, resources_lib.Resources(cloud=to_provision.cloud)) failover_history.append(e) except exceptions.ResourcesUnavailableError as e: @@ -2258,7 +2153,7 @@ def provision_with_retries( if prev_cluster_status is None: # Add failed resources to the blocklist, only when it # is in fallback mode. - self._blocked_resources.add(to_provision) + _add_to_blocked_resources(self._blocked_resources, to_provision) else: # If we reach here, it means that the existing cluster must have # a previous status of INIT, because other statuses (UP, @@ -2449,8 +2344,9 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: """ del max_attempts # Unused. head_ssh_port = 22 - self.stable_ssh_ports = ([head_ssh_port] + [22] * - (self.num_node_ips - 1)) + self.stable_ssh_ports = ( + [head_ssh_port] + [22] * + (self.num_ips_per_node * self.launched_nodes - 1)) def update_cluster_ips( self, @@ -2487,7 +2383,8 @@ def update_cluster_ips( """ def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: - return (ips is not None and len(ips) == self.num_node_ips and + return (ips is not None and + len(ips) == self.num_ips_per_node * self.launched_nodes and all(ip is not None for ip in ips)) use_internal_ips = self._use_internal_ips() @@ -2636,13 +2533,13 @@ def head_ssh_port(self): return None @property - def num_node_ips(self) -> int: - """Returns number of IPs of the cluster, correctly handling TPU Pod.""" - is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(self.launched_resources) + def num_ips_per_node(self) -> int: + """Returns number of IPs per node in the cluster, handling TPU Pod.""" + is_tpu_vm_pod = gcp_utils.is_tpu_vm_pod(self.launched_resources) if is_tpu_vm_pod: - num_ips = tpu_utils.get_num_tpu_devices(self.launched_resources) + num_ips = gcp_utils.get_num_tpu_devices(self.launched_resources) else: - num_ips = self.launched_nodes + num_ips = 1 return num_ips def __setstate__(self, state): @@ -2989,6 +2886,10 @@ def _provision( # Update launched resources. handle.launched_resources = handle.launched_resources.copy( region=provision_record.region, zone=provision_record.zone) + + if 'tpu_name' in config_dict: + self._set_tpu_name(handle, config_dict['tpu_name']) + self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, prev_cluster_status, handle.external_ips(), @@ -3007,9 +2908,6 @@ def _provision( if 'docker' in config: handle.setup_docker_user(cluster_config_file) - if 'tpu_name' in config_dict: - self._set_tpu_name(handle, config_dict['tpu_name']) - # Get actual zone info and save it into handle. # NOTE: querying zones is expensive, observed 1node GCP >=4s. zone = handle.launched_resources.zone @@ -3569,9 +3467,9 @@ def _execute( job_id = self._add_job(handle, task_copy.name, resources_str) - is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(handle.launched_resources) + num_actual_nodes = task.num_nodes * handle.num_ips_per_node # Case: task_lib.Task(run, num_nodes=N) or TPU VM Pods - if task_copy.num_nodes > 1 or is_tpu_vm_pod: + if num_actual_nodes > 1: self._execute_task_n_nodes(handle, task_copy, job_id, detach_run) else: # Case: task_lib.Task(run, num_nodes=1) @@ -3975,8 +3873,10 @@ def teardown_no_lock(self, stdout = '' stderr = '' - # Use the new provisioner for AWS. - if isinstance(cloud, (clouds.AWS, clouds.GCP)): + if (cloud.PROVISIONER_VERSION >= + clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): + logger.debug(f'Provisioner version: {cloud.PROVISIONER_VERSION} ' + 'using new provisioner for teardown.') # Stop the ray autoscaler first to avoid the head node trying to # re-launch the worker nodes, during the termination of the # cluster. @@ -4241,7 +4141,9 @@ def post_teardown_cleanup(self, # provision_lib.supports(cloud, 'cleanup_ports') # so that our backend do not need to know the specific details # for different clouds. - if isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)): + if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion. + RAY_PROVISIONER_SKYPILOT_TERMINATOR or + isinstance(cloud, clouds.Azure)): provision_lib.cleanup_ports(repr(cloud), cluster_name_on_cloud, config['provider']) @@ -4541,7 +4443,8 @@ def _setup_tpu_name_on_node( '>> ~/.bashrc || echo "TPU_NAME already set"') returncode = runner.run(cmd, log_path=os.path.join( - self.log_dir, 'tpu_setup.log')) + self.log_dir, 'tpu_setup.log'), + stream_logs=False) subprocess_utils.handle_returncode( returncode, cmd, 'Failed to set TPU_NAME on node.') @@ -4908,14 +4811,8 @@ def _execute_task_n_nodes(self, handle: CloudVmRayResourceHandle, internal_ips = handle.internal_ips() assert internal_ips is not None, 'internal_ips is not cached in handle' - # If TPU VM Pods is used, #num_nodes should be #num_tpu_devices - is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(handle.launched_resources) - if is_tpu_vm_pod: - num_actual_nodes = tpu_utils.get_num_tpu_devices( - handle.launched_resources) - else: - num_actual_nodes = task.num_nodes - assert isinstance(num_actual_nodes, int), num_actual_nodes + # If TPU VM Pods is used, #num_nodes should be num_nodes * num_node_ips + num_actual_nodes = task.num_nodes * handle.num_ips_per_node codegen = RayCodeGen() is_local = isinstance(handle.launched_resources.cloud, clouds.Local) diff --git a/sky/cli.py b/sky/cli.py index 88aef4d7d22..f906502a506 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -1147,13 +1147,6 @@ def _make_task_or_dag_from_entrypoint_with_overrides( if name is not None: task.name = name task.update_envs(env) - # TODO(wei-lin): move this validation into Python API. - for resource in task.resources: - if resource.accelerators is not None: - acc, _ = list(resource.accelerators.items())[0] - if acc.startswith('tpu-') and task.num_nodes > 1: - raise ValueError('Multi-node TPU cluster is not supported. ' - f'Got num_nodes={task.num_nodes}.') return task diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index d3d8aab0d9f..36d843e267e 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -1,7 +1,9 @@ """Clouds in Sky.""" from sky.clouds.cloud import Cloud from sky.clouds.cloud import CloudImplementationFeatures +from sky.clouds.cloud import ProvisionerVersion from sky.clouds.cloud import Region +from sky.clouds.cloud import StatusVersion from sky.clouds.cloud import Zone from sky.clouds.cloud_registry import CLOUD_REGISTRY @@ -32,4 +34,6 @@ 'Region', 'Zone', 'CLOUD_REGISTRY', + 'ProvisionerVersion', + 'StatusVersion', ] diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 374ba48788d..55cfbf9c2e1 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -112,6 +112,9 @@ class AWS(clouds.Cloud): 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html' # pylint: disable=line-too-long ) + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources_lib.Resources' diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index b89c39ff126..c9797767b70 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -60,6 +60,8 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 + PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_AUTOSCALER + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources.Resources' diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 126c3fd5ddb..75103a981dd 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -61,12 +61,47 @@ class Zone(collections.namedtuple('Zone', ['name'])): region: Region +class ProvisionerVersion(enum.Enum): + """The version of the provisioner. + + 1: [Deprecated] ray node provider based implementation + 2: [Deprecated] ray node provider for provisioning and SkyPilot provisioner + for stopping and termination + 3: SkyPilot provisioner for both provisioning and stopping + """ + RAY_AUTOSCALER = 1 + RAY_PROVISIONER_SKYPILOT_TERMINATOR = 2 + SKYPILOT = 3 + + def __ge__(self, other): + return self.value >= other.value + + +class StatusVersion(enum.Enum): + """The version of the status query. + + 1: [Deprecated] cloud-CLI based implementation + 2: SkyPilot provisioner based implementation + """ + CLOUD_CLI = 1 + SKYPILOT = 2 + + def __ge__(self, other): + return self.value >= other.value + + class Cloud: """A cloud provider.""" _REPR = '' _DEFAULT_DISK_TIER = 'medium' + # The version of provisioner and status query. This is used to determine + # the code path to use for each cloud in the backend. + # NOTE: new clouds being added should use the latest version, i.e. SKYPILOT. + PROVISIONER_VERSION = ProvisionerVersion.RAY_AUTOSCALER + STATUS_VERSION = StatusVersion.CLOUD_CLI + @classmethod def max_cluster_name_length(cls) -> Optional[int]: """Returns the maximum length limit of a cluster name. @@ -696,3 +731,9 @@ def delete_image(cls, image_id: str, region: Optional[str]) -> None: def __repr__(self): return self._REPR + + def __getstate__(self): + state = self.__dict__.copy() + state.pop('PROVISIONER_VERSION', None) + state.pop('STATUS_VERSION', None) + return state diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index ed34cb4c055..3fc89cb79fb 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -13,18 +13,16 @@ from sky import clouds from sky import exceptions from sky import sky_logging -from sky import status_lib from sky.adaptors import gcp from sky.clouds import service_catalog from sky.clouds.utils import gcp_utils -from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import subprocess_utils -from sky.utils import tpu_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: from sky import resources + from sky import status_lib logger = sky_logging.init_logger(__name__) @@ -162,11 +160,14 @@ class GCP(clouds.Cloud): 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long ) + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources.Resources' ) -> Dict[clouds.CloudImplementationFeatures, str]: - if tpu_utils.is_tpu_vm_pod(resources): + if gcp_utils.is_tpu_vm_pod(resources): return { clouds.CloudImplementationFeatures.STOP: ( 'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' @@ -491,7 +492,7 @@ def _get_feasible_launchable_resources( assert len(resources.accelerators.items() ) == 1, 'cannot handle more than one accelerator candidates.' acc, acc_count = list(resources.accelerators.items())[0] - use_tpu_vm = tpu_utils.is_tpu_vm(resources) + use_tpu_vm = gcp_utils.is_tpu_vm(resources) # For TPU VMs, the instance type is fixed to 'TPU-VM'. However, we still # need to call the below function to get the fuzzy candidate list. @@ -815,7 +816,7 @@ def need_cleanup_after_preemption(self, # you must delete it and create a new one ..." # See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm - return tpu_utils.is_tpu_vm(resources) + return gcp_utils.is_tpu_vm(resources) @classmethod def get_project_id(cls, dryrun: bool = False) -> str: @@ -944,75 +945,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], region: Optional[str], zone: Optional[str], **kwargs) -> List['status_lib.ClusterStatus']: """Query the status of a cluster.""" - del region # unused - - use_tpu_vm = kwargs.pop('use_tpu_vm', True) - - label_filter_str = cls._label_filter_str(tag_filters) - if use_tpu_vm: - # TPU VM's state definition is different from compute VM - # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State # pylint: disable=line-too-long - status_map = { - 'CREATING': status_lib.ClusterStatus.INIT, - 'STARTING': status_lib.ClusterStatus.INIT, - 'RESTARTING': status_lib.ClusterStatus.INIT, - 'READY': status_lib.ClusterStatus.UP, - 'REPAIRING': status_lib.ClusterStatus.INIT, - # 'STOPPED' in GCP TPU VM means stopped, with disk preserved. - 'STOPPING': status_lib.ClusterStatus.STOPPED, - 'STOPPED': status_lib.ClusterStatus.STOPPED, - 'DELETING': None, - 'PREEMPTED': None, - } - tpu_utils.check_gcp_cli_include_tpu_vm() - query_cmd = ('gcloud compute tpus tpu-vm list ' - f'--zone {zone} ' - f'--filter="({label_filter_str})" ' - '--format="value(state)"') - else: - # Ref: https://cloud.google.com/compute/docs/instances/instance-life-cycle - status_map = { - 'PROVISIONING': status_lib.ClusterStatus.INIT, - 'STAGING': status_lib.ClusterStatus.INIT, - 'RUNNING': status_lib.ClusterStatus.UP, - 'REPAIRING': status_lib.ClusterStatus.INIT, - # 'TERMINATED' in GCP means stopped, with disk preserved. - 'STOPPING': status_lib.ClusterStatus.STOPPED, - 'TERMINATED': status_lib.ClusterStatus.STOPPED, - # 'SUSPENDED' in GCP means stopped, with disk and OS memory - # preserved. - 'SUSPENDING': status_lib.ClusterStatus.STOPPED, - 'SUSPENDED': status_lib.ClusterStatus.STOPPED, - } - # TODO(zhwu): The status of the TPU attached to the cluster should - # also be checked, since TPUs are not part of the VMs. - query_cmd = ('gcloud compute instances list ' - f'--filter="({label_filter_str})" ' - '--format="value(status)"') - returncode, stdout, stderr = log_lib.run_with_log(query_cmd, - '/dev/null', - require_outputs=True, - shell=True) - logger.debug(f'{query_cmd} returned {returncode}.\n' - '**** STDOUT ****\n' - f'{stdout}\n' - '**** STDERR ****\n' - f'{stderr}') - - if returncode != 0: - with ux_utils.print_exception_no_traceback(): - raise exceptions.ClusterStatusFetchingError( - f'Failed to query GCP cluster {name!r} status: ' - f'{stdout + stderr}') - - status_list = [] - for line in stdout.splitlines(): - status = status_map.get(line.strip()) - if status is None: - continue - status_list.append(status) - - return status_list + # TODO(suquark): deprecate this method + assert False, 'This code path should not be used.' @classmethod def create_image_from_cluster(cls, cluster_name: str, diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index 0b8256e4206..98f18d6eba8 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -1,23 +1,59 @@ """Utility functions for GCP. -The functions that are used to access GCP APIs. We have the reservation-related -functions here, so that the cache of the reservations can be shared across -multiple clouds.GCP() objects. +The functions that are used to access GCP APIs and TPU VM. We have the +reservation-related functions here, so that the cache of the reservations can be +shared across multiple clouds.GCP() objects. """ import dataclasses import json import time -from typing import List, Set +import typing +from typing import List, Optional, Set import cachetools from sky import sky_logging from sky.utils import subprocess_utils +if typing.TYPE_CHECKING: + from sky import resources as resources_lib + logger = sky_logging.init_logger(__name__) +def is_tpu(resources: Optional['resources_lib.Resources']) -> bool: + if resources is None or resources.accelerators is None: + return False + acc, _ = list(resources.accelerators.items())[0] + return acc.startswith('tpu') + + +def is_tpu_vm(resources: Optional['resources_lib.Resources']) -> bool: + if not is_tpu(resources): + return False + assert resources is not None + if resources.accelerator_args is None: + return True + return resources.accelerator_args.get('tpu_vm', True) + + +def is_tpu_vm_pod(resources: Optional['resources_lib.Resources']) -> bool: + if not is_tpu_vm(resources): + return False + assert resources is not None + acc, _ = list(resources.accelerators.items())[0] + return not acc.endswith('-8') + + +def get_num_tpu_devices(resources: Optional['resources_lib.Resources']) -> int: + if resources is None or not is_tpu(resources): + raise ValueError('resources must be a valid TPU resource.') + acc, _ = list(resources.accelerators.items())[0] + num_tpu_devices = int(int(acc.split('-')[2]) / 8) + return num_tpu_devices + + @dataclasses.dataclass class SpecificReservation: count: int diff --git a/sky/core.py b/sky/core.py index 76faefbe8bf..d7174e1522e 100644 --- a/sky/core.py +++ b/sky/core.py @@ -456,7 +456,7 @@ def autostop( f'{backend.__class__.__name__!r} is not supported.') # Check autostop is implemented for cloud cloud = handle.launched_resources.cloud - if not down and idle_minutes >= 0: + if not down and not is_cancel: try: cloud.check_features_are_supported( handle.launched_resources, diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index a0c717ed281..8af83f08b85 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -137,7 +137,10 @@ def wait_instances(provider_name: str, region: str, cluster_name_on_cloud: str, @_route_to_cloud_impl -def get_cluster_info(provider_name: str, region: str, - cluster_name_on_cloud: str) -> common.ClusterInfo: +def get_cluster_info( + provider_name: str, + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """Get the metadata of instances in a cluster.""" raise NotImplementedError diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index a50ee46d2ab..cc8cb583fee 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -769,9 +769,12 @@ def wait_instances(region: str, cluster_name_on_cloud: str, waiter.wait(WaiterConfig={'Delay': 5, 'MaxAttempts': 120}, Filters=filters) -def get_cluster_info(region: str, - cluster_name_on_cloud: str) -> common.ClusterInfo: +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" + del provider_config # unused ec2 = _default_ec2_resource(region) filters = [ { @@ -791,12 +794,14 @@ def get_cluster_info(region: str, tags = [(t['Key'], t['Value']) for t in inst.tags] # sort tags by key to support deterministic unit test stubbing tags.sort(key=lambda x: x[0]) - instances[inst.id] = common.InstanceInfo( - instance_id=inst.id, - internal_ip=inst.private_ip_address, - external_ip=inst.public_ip_address, - tags=dict(tags), - ) + instances[inst.id] = [ + common.InstanceInfo( + instance_id=inst.id, + internal_ip=inst.private_ip_address, + external_ip=inst.public_ip_address, + tags=dict(tags), + ) + ] instances = dict(sorted(instances.items(), key=lambda x: x[0])) return common.ClusterInfo( instances=instances, diff --git a/sky/provision/common.py b/sky/provision/common.py index 6e796666f06..4cd61c1d9b7 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -12,6 +12,11 @@ InstanceId = str +class ProvisionError(RuntimeError): + """Exception for provisioning.""" + errors: List[Dict[str, str]] + + @dataclasses.dataclass class ProvisionConfig: """Configuration for provisioning.""" @@ -82,12 +87,35 @@ def get_feasible_ip(self) -> str: @dataclasses.dataclass class ClusterInfo: """Cluster Information.""" - instances: Dict[InstanceId, InstanceInfo] + instances: Dict[InstanceId, List[InstanceInfo]] # The unique identifier of the head instance, i.e., the # `instance_info.instance_id` of the head node. head_instance_id: Optional[InstanceId] docker_user: Optional[str] = None + @property + def num_instances(self) -> int: + """Get the number of instances in the cluster.""" + return sum(len(instances) for instances in self.instances.values()) + + def get_head_instance(self) -> Optional[InstanceInfo]: + """Get the instance metadata of the head node""" + if self.head_instance_id is None: + return None + if self.head_instance_id not in self.instances: + raise ValueError('Head instance ID not in the cluster metadata.') + return self.instances[self.head_instance_id][0] + + def get_worker_instances(self) -> List[InstanceInfo]: + """Get all worker instances.""" + worker_instances = [] + for inst_id, instances in self.instances.items(): + if inst_id == self.head_instance_id: + worker_instances.extend(instances[1:]) + else: + worker_instances.extend(instances) + return worker_instances + def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: """Get IP tuples of all instances. Make sure that list always starts with head node IP, if head node exists. @@ -95,13 +123,15 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: Returns: A list of tuples (internal_ip, external_ip) of all instances. """ - head_node_ip, other_ips = [], [] - for instance in self.instances.values(): + head_node = self.get_head_instance() + if head_node is None: + head_node_ip = [] + else: + head_node_ip = [(head_node.internal_ip, head_node.external_ip)] + other_ips = [] + for instance in self.get_worker_instances(): pair = (instance.internal_ip, instance.external_ip) - if instance.instance_id == self.head_instance_id: - head_node_ip.append(pair) - else: - other_ips.append(pair) + other_ips.append(pair) return head_node_ip + other_ips def has_external_ips(self) -> bool: @@ -135,11 +165,3 @@ def _get_ips(self, use_internal_ips: bool) -> List[str]: def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: """Get external IPs if they exist, otherwise get internal ones.""" return self._get_ips(not self.has_external_ips() or force_internal_ips) - - def get_head_instance(self) -> Optional[InstanceInfo]: - """Get the instance metadata of the head node""" - if self.head_instance_id is None: - return None - if self.head_instance_id not in self.instances: - raise ValueError('Head instance ID not in the cluster metadata.') - return self.instances[self.head_instance_id] diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 9faaab37088..0d24a577690 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -1,6 +1,11 @@ """GCP provisioner for SkyPilot.""" +from sky.provision.gcp.config import bootstrap_instances from sky.provision.gcp.instance import cleanup_ports +from sky.provision.gcp.instance import get_cluster_info from sky.provision.gcp.instance import open_ports +from sky.provision.gcp.instance import query_instances +from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import stop_instances from sky.provision.gcp.instance import terminate_instances +from sky.provision.gcp.instance import wait_instances diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py new file mode 100644 index 00000000000..31a1e4b9031 --- /dev/null +++ b/sky/provision/gcp/config.py @@ -0,0 +1,797 @@ +"""GCP configuration bootstrapping.""" +import copy +import logging +import time +import typing +from typing import Any, Dict, List, Set, Tuple + +from sky.adaptors import gcp +from sky.provision import common +from sky.provision.gcp import constants +from sky.provision.gcp import instance_utils + +logger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + import google.cloud + + +def _skypilot_log_error_and_exit_for_failover(error_code: str, + error_msg: str) -> None: + """Logs an message then raises a specific RuntimeError to trigger failover. + Mainly used for handling VPC/subnet errors before nodes are launched. + """ + # NOTE: keep. The backend looks for this to know no nodes are launched. + prefix = 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' + error = common.ProvisionError(prefix + error_msg) + error.errors = [{ + 'code': error_code, + 'domain': 'bootstrap_instance', + 'message': error_msg, + }] + raise error + + +def wait_for_crm_operation(operation, crm): + """Poll for cloud resource manager operation until finished.""" + logger.info('wait_for_crm_operation: ' + 'Waiting for operation {} to finish...'.format(operation)) + + for _ in range(constants.MAX_POLLS): + result = crm.operations().get(name=operation['name']).execute() + if 'error' in result: + raise Exception(result['error']) + + if 'done' in result and result['done']: + logger.info('wait_for_crm_operation: Operation done.') + break + + time.sleep(constants.POLL_INTERVAL) + + return result + + +def wait_for_compute_global_operation(project_name, operation, compute): + """Poll for global compute operation until finished.""" + logger.info('wait_for_compute_global_operation: ' + 'Waiting for operation {} to finish...'.format( + operation['name'])) + + for _ in range(constants.MAX_POLLS): + result = (compute.globalOperations().get( + project=project_name, + operation=operation['name'], + ).execute()) + if 'error' in result: + raise Exception(result['error']) + + if result['status'] == 'DONE': + logger.info('wait_for_compute_global_operation: Operation done.') + break + + time.sleep(constants.POLL_INTERVAL) + + return result + + +def _create_crm(gcp_credentials=None): + return gcp.build('cloudresourcemanager', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) + + +def _create_iam(gcp_credentials=None): + return gcp.build('iam', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) + + +def _create_compute(gcp_credentials=None): + return gcp.build('compute', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) + + +def _create_tpu(gcp_credentials=None): + return gcp.build( + 'tpu', + constants.TPU_VERSION, + credentials=gcp_credentials, + cache_discovery=False, + discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest', + ) + + +def construct_clients_from_provider_config(provider_config): + """Attempt to fetch and parse the JSON GCP credentials. + + tpu resource (the last element of the tuple) will be None if + `_has_tpus` in provider config is not set or False. + """ + gcp_credentials = provider_config.get('gcp_credentials') + if gcp_credentials is None: + logger.debug('gcp_credentials not found in cluster yaml file. ' + 'Falling back to GOOGLE_APPLICATION_CREDENTIALS ' + 'environment variable.') + tpu_resource = (_create_tpu() if provider_config.get( + constants.HAS_TPU_PROVIDER_FIELD, False) else None) + # If gcp_credentials is None, then discovery.build will search for + # credentials in the local environment. + return _create_crm(), _create_iam(), _create_compute(), tpu_resource + + # Note: The following code has not been used yet, as we will never set + # `gcp_credentials` in provider_config. + # It will only be used when we allow users to specify their own credeitals. + assert ('type' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "type" field.' + assert ('credentials' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "credentials" field.' + + cred_type = gcp_credentials['type'] + credentials_field = gcp_credentials['credentials'] + credentials = gcp.get_credentials(cred_type, credentials_field) + + tpu_resource = (_create_tpu(credentials) if provider_config.get( + constants.HAS_TPU_PROVIDER_FIELD, False) else None) + + return ( + _create_crm(credentials), + _create_iam(credentials), + _create_compute(credentials), + tpu_resource, + ) + + +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + # Check if we have any TPUs defined, and if so, + # insert that information into the provider config + if instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU: + config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True + + crm, iam, compute, _ = construct_clients_from_provider_config( + config.provider_config) + + # Setup a Google Cloud Platform Project. + + # Google Compute Platform organizes all the resources, such as storage + # buckets, users, and instances under projects. This is different from + # aws ec2 where everything is global. + + _configure_project(config.provider_config, crm) + iam_role = _configure_iam_role(config, crm, iam) + config.provider_config['iam_role'] = iam_role # temporary store + config = _configure_subnet(region, cluster_name, config, compute) + + return config + + +def _configure_project(provider_config, crm): + """Setup a Google Cloud Platform Project. + + Google Compute Platform organizes all the resources, such as storage + buckets, users, and instances under projects. This is different from + aws ec2 where everything is global. + """ + project_id = provider_config.get('project_id') + assert project_id is not None, ( + '"project_id" must be set in the "provider" section of the autoscaler' + ' config. Notice that the project id must be globally unique.') + project = _get_project(project_id, crm) + + if project is None: + # Project not found, try creating it + _create_project(project_id, crm) + project = _get_project(project_id, crm) + + assert project is not None, 'Failed to create project' + assert (project['lifecycleState'] == 'ACTIVE' + ), 'Project status needs to be ACTIVE, got {}'.format( + project['lifecycleState']) + + provider_config['project_id'] = project['projectId'] + + +def _is_permission_satisfied(service_account, crm, iam, required_permissions, + required_roles): + """Check if either of the roles or permissions are satisfied.""" + if service_account is None: + return False, None + + project_id = service_account['projectId'] + email = service_account['email'] + + member_id = 'serviceAccount:' + email + + required_permissions = set(required_permissions) + policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute() + original_policy = copy.deepcopy(policy) + already_configured = True + + logger.info(f'_configure_iam_role: Checking permissions for {email}...') + + # Check the roles first, as checking the permission + # requires more API calls and permissions. + for role in required_roles: + role_exists = False + for binding in policy['bindings']: + if binding['role'] == role: + if member_id not in binding['members']: + logger.info(f'_configure_iam_role: role {role} is not ' + f'attached to {member_id}...') + binding['members'].append(member_id) + already_configured = False + role_exists = True + + if not role_exists: + logger.info(f'_configure_iam_role: role {role} does not exist.') + already_configured = False + policy['bindings'].append({ + 'members': [member_id], + 'role': role, + }) + + if already_configured: + # In some managed environments, an admin needs to grant the + # roles, so only call setIamPolicy if needed. + return True, policy + + for binding in original_policy['bindings']: + if member_id in binding['members']: + role = binding['role'] + try: + role_definition = iam.projects().roles().get( + name=role).execute() + except TypeError as e: + if 'does not match the pattern' in str(e): + logger.info('_configure_iam_role: fail to check permission ' + f'for built-in role {role}. skipped.') + permissions = [] + else: + raise + else: + permissions = role_definition['includedPermissions'] + required_permissions -= set(permissions) + if not required_permissions: + break + if not required_permissions: + # All required permissions are already granted. + return True, policy + logger.info( + f'_configure_iam_role: missing permisisons {required_permissions}') + + return False, policy + + +def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict: + """Setup a gcp service account with IAM roles. + + Creates a gcp service acconut and binds IAM roles which allow it to control + control storage/compute services. Specifically, the head node needs to have + an IAM role that allows it to create further gce instances and store items + in google cloud storage. + + TODO: Allow the name/id of the service account to be configured + """ + project_id = config.provider_config['project_id'] + email = constants.SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=constants.SKYPILOT_SERVICE_ACCOUNT_ID, + project_id=project_id, + ) + service_account = _get_service_account(email, project_id, iam) + + permissions = constants.VM_MINIMAL_PERMISSIONS + roles = constants.DEFAULT_SERVICE_ACCOUNT_ROLES + if config.provider_config.get(constants.HAS_TPU_PROVIDER_FIELD, False): + roles = (constants.DEFAULT_SERVICE_ACCOUNT_ROLES + + constants.TPU_SERVICE_ACCOUNT_ROLES) + permissions = (constants.VM_MINIMAL_PERMISSIONS + + constants.TPU_MINIMAL_PERMISSIONS) + + satisfied, policy = _is_permission_satisfied(service_account, crm, iam, + permissions, roles) + + if not satisfied: + # SkyPilot: Fallback to the old ray service account name for + # backwards compatibility. Users using GCP before #2112 have + # the old service account setup setup in their GCP project, + # and the user may not have the permissions to create the + # new service account. This is to ensure that the old service + # account is still usable. + email = constants.SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=constants.DEFAULT_SERVICE_ACCOUNT_ID, + project_id=project_id, + ) + logger.info(f'_configure_iam_role: Fallback to service account {email}') + + ray_service_account = _get_service_account(email, project_id, iam) + ray_satisfied, _ = _is_permission_satisfied(ray_service_account, crm, + iam, permissions, roles) + logger.info( + '_configure_iam_role: ' + f'Fallback to service account {email} succeeded? {ray_satisfied}') + + if ray_satisfied: + service_account = ray_service_account + satisfied = ray_satisfied + elif service_account is None: + logger.info('_configure_iam_role: ' + 'Creating new service account {}'.format( + constants.SKYPILOT_SERVICE_ACCOUNT_ID)) + # SkyPilot: a GCP user without the permission to create a service + # account will fail here. + service_account = _create_service_account( + constants.SKYPILOT_SERVICE_ACCOUNT_ID, + constants.SKYPILOT_SERVICE_ACCOUNT_CONFIG, + project_id, + iam, + ) + satisfied, policy = _is_permission_satisfied( + service_account, crm, iam, permissions, roles) + + assert service_account is not None, 'Failed to create service account' + + if not satisfied: + logger.info('_configure_iam_role: ' + f'Adding roles to service account {email}...') + _add_iam_policy_binding(service_account, policy, crm, iam) + + account_dict = { + 'email': service_account['email'], + # NOTE: The amount of access is determined by the scope + IAM + # role of the service account. Even if the cloud-platform scope + # gives (scope) access to the whole cloud-platform, the service + # account is limited by the IAM rights specified below. + 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], + } + iam_role: Dict[str, Any] + if (instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU): + # SKY: The API for TPU VM is slightly different from normal compute + # instances. + # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node # pylint: disable=line-too-long + account_dict['scope'] = account_dict['scopes'] + account_dict.pop('scopes') + iam_role = {'serviceAccount': account_dict} + else: + iam_role = {'serviceAccounts': [account_dict]} + + return iam_role + + +def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, + compute): + """Check if the firewall rules in the VPC are sufficient.""" + required_rules = constants.FIREWALL_RULES_REQUIRED.copy() + + operation = compute.networks().getEffectiveFirewalls(project=project_id, + network=vpc_name) + response = operation.execute() + if len(response) == 0: + return False + effective_rules = response['firewalls'] + + def _merge_and_refine_rule( + rules) -> Dict[Tuple[str, str], Dict[str, Set[int]]]: + """Returns the reformatted rules from the firewall rules + + The function translates firewall rules fetched from the cloud provider + to a format for simple comparison. + + Example of firewall rules from the cloud: + [ + { + ... + 'direction': 'INGRESS', + 'allowed': [ + {'IPProtocol': 'tcp', 'ports': ['80', '443']}, + {'IPProtocol': 'udp', 'ports': ['53']}, + ], + 'sourceRanges': ['10.128.0.0/9'], + }, + { + ... + 'direction': 'INGRESS', + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], + }], + 'sourceRanges': ['0.0.0.0/0'], + }, + ] + + Returns: + source2rules: Dict[(direction, sourceRanges) -> + Dict(protocol -> Set[ports])] + + Example { + ('INGRESS', '10.128.0.0/9'): {'tcp': {80, 443}, 'udp': {53}}, + ('INGRESS', '0.0.0.0/0'): {'tcp': {22}}, + } + """ + source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {} + source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {} + for rule in rules: + # Rules applied to specific VM (targetTags) may not work for the + # current VM, so should be skipped. + # Filter by targetTags == ['cluster_name'] + # See https://developers.google.com/resources/api-libraries/documentation/compute/alpha/python/latest/compute_alpha.networks.html#getEffectiveFirewalls # pylint: disable=line-too-long + tags = rule.get('targetTags', None) + if tags is not None: + if len(tags) != 1: + continue + if tags[0] != cluster_name: + continue + direction = rule.get('direction', '') + sources = rule.get('sourceRanges', []) + allowed = rule.get('allowed', []) + for source in sources: + key = (direction, source) + source2allowed_list[key] = source2allowed_list.get(key, + []) + allowed + for direction_source, allowed_list in source2allowed_list.items(): + source2rules[direction_source] = {} + for allowed in allowed_list: + # Example of port_list: ['20', '50-60'] + # If list is empty, it means all ports + port_list = allowed.get('ports', []) + port_set = set() + if port_list == []: + port_set.update(set(range(1, 65536))) + else: + for port_range in port_list: + parse_ports = port_range.split('-') + if len(parse_ports) == 1: + port_set.add(int(parse_ports[0])) + else: + assert ( + len(parse_ports) == 2 + ), f'Failed to parse the port range: {port_range}' + port_set.update( + set( + range(int(parse_ports[0]), + int(parse_ports[1]) + 1))) + if allowed['IPProtocol'] not in source2rules[direction_source]: + source2rules[direction_source][ + allowed['IPProtocol']] = set() + source2rules[direction_source][allowed['IPProtocol']].update( + port_set) + return source2rules + + effective_rules_map = _merge_and_refine_rule(effective_rules) + required_rules_map = _merge_and_refine_rule(required_rules) + + for direction_source, allowed_req in required_rules_map.items(): + if direction_source not in effective_rules_map: + return False + allowed_eff = effective_rules_map[direction_source] + # Special case: 'all' means allowing all traffic + if 'all' in allowed_eff: + continue + # Check if the required ports are a subset of the effective ports + for protocol, ports_req in allowed_req.items(): + ports_eff = allowed_eff.get(protocol, set()) + if not ports_req.issubset(ports_eff): + return False + return True + + +def _create_rules(project_id: str, compute, rules, vpc_name): + opertaions = [] + for rule in rules: + # Query firewall rule by its name (unique in a project). + # If the rule already exists, delete it first. + rule_name = rule['name'].format(VPC_NAME=vpc_name) + rule_list = _list_firewall_rules(project_id, + compute, + filter=f'(name={rule_name})') + if len(rule_list) > 0: + _delete_firewall_rule(project_id, compute, rule_name) + + body = rule.copy() + body['name'] = body['name'].format(VPC_NAME=vpc_name) + body['network'] = body['network'].format(PROJ_ID=project_id, + VPC_NAME=vpc_name) + body['selfLink'] = body['selfLink'].format(PROJ_ID=project_id, + VPC_NAME=vpc_name) + op = compute.firewalls().insert(project=project_id, body=body).execute() + opertaions.append(op) + for op in opertaions: + wait_for_compute_global_operation(project_id, op, compute) + + +def _network_interface_to_vpc_name(network_interface: Dict[str, str]) -> str: + """Returns the VPC name of a network interface.""" + return network_interface['network'].split('/')[-1] + + +def get_usable_vpc_and_subnet( + cluster_name: str, + region: str, + config: common.ProvisionConfig, + compute, +) -> Tuple[str, 'google.cloud.compute_v1.types.compute.Subnetwork']: + """Return a usable VPC and the subnet in it. + + If config.provider_config['vpc_name'] is set, return the VPC with the name + (errors out if not found). When this field is set, no firewall rules + checking or overrides will take place; it is the user's responsibility to + properly set up the VPC. + + If not found, create a new one with sufficient firewall rules. + + Returns: + vpc_name: The name of the VPC network. + subnet_name: The name of the subnet in the VPC network for the specific + region. + + Raises: + RuntimeError: if the user has specified a VPC name but the VPC is not + found. + """ + project_id = config.provider_config['project_id'] + + # For existing cluster, it is ok to return a VPC and subnet not used by + # the cluster, as AWS will ignore them. + # There is a corner case where the multi-node cluster was partially + # launched, launching the cluster again can cause the nodes located on + # different VPCs, if VPCs in the project have changed. It should be fine to + # not handle this special case as we don't want to sacrifice the performance + # for every launch just for this rare case. + + specific_vpc_to_use = config.provider_config.get('vpc_name', None) + if specific_vpc_to_use is not None: + vpcnets_all = _list_vpcnets(project_id, + compute, + filter=f'name={specific_vpc_to_use}') + # On GCP, VPC names are unique, so it'd be 0 or 1 VPC found. + assert (len(vpcnets_all) <= + 1), (f'{len(vpcnets_all)} VPCs found with the same name ' + f'{specific_vpc_to_use}') + if len(vpcnets_all) == 1: + # Skip checking any firewall rules if the user has specified a VPC. + logger.info(f'Using user-specified VPC {specific_vpc_to_use!r}.') + subnets = _list_subnets(project_id, + region, + compute, + network=specific_vpc_to_use) + if not subnets: + _skypilot_log_error_and_exit_for_failover( + 'SUBNET_NOT_FOUND_FOR_VPC', + f'No subnet for region {region} found for specified VPC ' + f'{specific_vpc_to_use!r}. ' + f'Check the subnets of VPC {specific_vpc_to_use!r} at ' + 'https://console.cloud.google.com/networking/networks') + return specific_vpc_to_use, subnets[0] + else: + # VPC with this name not found. Error out and let SkyPilot failover. + _skypilot_log_error_and_exit_for_failover( + 'VPC_NOT_FOUND', + f'No VPC with name {specific_vpc_to_use!r} is found. ' + 'To fix: specify a correct VPC name.') + # Should not reach here. + + subnets_all = _list_subnets(project_id, region, compute) + + # Check if VPC for subnet has sufficient firewall rules. + insufficient_vpcs = set() + for subnet in subnets_all: + vpc_name = _network_interface_to_vpc_name(subnet) + if vpc_name in insufficient_vpcs: + continue + if _check_firewall_rules(cluster_name, vpc_name, project_id, compute): + logger.info( + f'get_usable_vpc: Found a usable VPC network {vpc_name!r}.') + return vpc_name, subnet + else: + insufficient_vpcs.add(vpc_name) + + # No usable VPC found. Try to create one. + logger.info( + f'Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...') + + # Create a SkyPilot VPC network if it doesn't exist + vpc_list = _list_vpcnets(project_id, + compute, + filter=f'name={constants.SKYPILOT_VPC_NAME}') + if len(vpc_list) == 0: + body = constants.VPC_TEMPLATE.copy() + body['name'] = body['name'].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) + body['selfLink'] = body['selfLink'].format( + PROJ_ID=project_id, VPC_NAME=constants.SKYPILOT_VPC_NAME) + _create_vpcnet(project_id, compute, body) + + _create_rules(project_id, compute, constants.FIREWALL_RULES_TEMPLATE, + constants.SKYPILOT_VPC_NAME) + + usable_vpc_name = constants.SKYPILOT_VPC_NAME + subnets = _list_subnets(project_id, + region, + compute, + network=usable_vpc_name) + if not subnets: + _skypilot_log_error_and_exit_for_failover( + 'SUBNET_NOT_FOUND_FOR_VPC', + f'No subnet for region {region} found for generated VPC ' + f'{usable_vpc_name!r}. This is probably due to the region being ' + 'disabled in the account/project_id.') + usable_subnet = subnets[0] + logger.info(f'A VPC network {constants.SKYPILOT_VPC_NAME} created.') + return usable_vpc_name, usable_subnet + + +def _configure_subnet(region: str, cluster_name: str, + config: common.ProvisionConfig, compute): + """Pick a reasonable subnet if not specified by the config.""" + node_config = config.node_config + # Rationale: avoid subnet lookup if the network is already + # completely manually configured + + # networkInterfaces is compute, networkConfig is TPU + if 'networkInterfaces' in node_config or 'networkConfig' in node_config: + return config + + # SkyPilot: make sure there's a usable VPC + _, default_subnet = get_usable_vpc_and_subnet(cluster_name, region, config, + compute) + + default_interfaces = [{ + 'subnetwork': default_subnet['selfLink'], + 'accessConfigs': [{ + 'name': 'External NAT', + 'type': 'ONE_TO_ONE_NAT', + }], + }] + if config.provider_config.get('use_internal_ips', False): + # Removing this key means the VM will not be assigned an external IP. + default_interfaces[0].pop('accessConfigs') + + # The not applicable key will be removed during node creation + + # compute + if 'networkInterfaces' not in node_config: + node_config['networkInterfaces'] = copy.deepcopy(default_interfaces) + # TPU + if 'networkConfig' not in node_config: + node_config['networkConfig'] = copy.deepcopy(default_interfaces)[0] + # TPU doesn't have accessConfigs + node_config['networkConfig'].pop('accessConfigs', None) + if config.provider_config.get('use_internal_ips', False): + node_config['networkConfig']['enableExternalIps'] = False + else: + node_config['networkConfig']['enableExternalIps'] = True + + return config + + +def _delete_firewall_rule(project_id: str, compute, name): + operation = (compute.firewalls().delete(project=project_id, + firewall=name).execute()) + response = wait_for_compute_global_operation(project_id, operation, compute) + return response + + +# pylint: disable=redefined-builtin +def _list_firewall_rules(project_id, compute, filter=None): + response = (compute.firewalls().list( + project=project_id, + filter=filter, + ).execute()) + return response['items'] if 'items' in response else [] + + +def _create_vpcnet(project_id: str, compute, body): + operation = (compute.networks().insert(project=project_id, + body=body).execute()) + response = wait_for_compute_global_operation(project_id, operation, compute) + return response + + +def _list_vpcnets(project_id: str, compute, filter=None): # pylint: disable=redefined-builtin + response = (compute.networks().list( + project=project_id, + filter=filter, + ).execute()) + + return (list(sorted(response['items'], key=lambda x: x['name'])) + if 'items' in response else []) + + +def _list_subnets( + project_id: str, + region: str, + compute, + network=None +) -> List['google.cloud.compute_v1.types.compute.Subnetwork']: + response = (compute.subnetworks().list( + project=project_id, + region=region, + ).execute()) + + items = response['items'] if 'items' in response else [] + if network is None: + return items + + # Filter by network (VPC) name. + # + # Note we do not directly use the filter (network=<...>) arg of the list() + # call above, because it'd involve constructing a long URL of the following + # format and passing it as the filter value: + # 'https://www.googleapis.com/compute/v1/projects//global/networks/' # pylint: disable=line-too-long + matched_items = [] + for item in items: + if network == _network_interface_to_vpc_name(item): + matched_items.append(item) + return matched_items + + +def _get_project(project_id: str, crm): + try: + project = crm.projects().get(projectId=project_id).execute() + except gcp.http_error_exception() as e: + if e.resp.status != 403: + raise + project = None + + return project + + +def _create_project(project_id: str, crm): + operation = (crm.projects().create(body={ + 'projectId': project_id, + 'name': project_id + }).execute()) + + result = wait_for_crm_operation(operation, crm) + + return result + + +def _get_service_account(account: str, project_id: str, iam): + full_name = 'projects/{project_id}/serviceAccounts/{account}'.format( + project_id=project_id, account=account) + try: + service_account = iam.projects().serviceAccounts().get( + name=full_name).execute() + except gcp.http_error_exception() as e: + if e.resp.status not in [403, 404]: + # SkyPilot: added 403, which means the service account doesn't + # exist, or not accessible by the current account, which is fine, as + # we do the fallback in the caller. + raise + service_account = None + + return service_account + + +def _create_service_account(account_id: str, account_config, project_id: str, + iam): + service_account = (iam.projects().serviceAccounts().create( + name='projects/{project_id}'.format(project_id=project_id), + body={ + 'accountId': account_id, + 'serviceAccount': account_config, + }, + ).execute()) + + return service_account + + +def _add_iam_policy_binding(service_account, policy, crm, iam): + """Add new IAM roles for the service account.""" + del iam + project_id = service_account['projectId'] + + result = (crm.projects().setIamPolicy( + resource=project_id, + body={ + 'policy': policy, + }, + ).execute()) + + return result diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py new file mode 100644 index 00000000000..72c4373eb81 --- /dev/null +++ b/sky/provision/gcp/constants.py @@ -0,0 +1,197 @@ +"""Constants used by the GCP provisioner.""" + +VERSION = 'v1' +# Using v2 according to +# https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-curl # pylint: disable=line-too-long +TPU_VERSION = 'v2' + +RAY = 'ray-autoscaler' +DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION +SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + '{account_id}@{project_id}.iam.gserviceaccount.com') +DEFAULT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': f'Ray Autoscaler Service Account ({VERSION})', +} + +SKYPILOT = 'skypilot' +SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + '-' + VERSION +SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + '{account_id}@{project_id}.iam.gserviceaccount.com') +SKYPILOT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': f'SkyPilot Service Account ({VERSION})', +} + +# Those roles will be always added. +# NOTE: `serviceAccountUser` allows the head node to create workers with +# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. +DEFAULT_SERVICE_ACCOUNT_ROLES = [ + 'roles/storage.objectAdmin', + 'roles/compute.admin', + 'roles/iam.serviceAccountUser', + 'roles/iam.roleViewer', +] +# Those roles will only be added if there are TPU nodes defined in config. +TPU_SERVICE_ACCOUNT_ROLES = ['roles/tpu.admin'] + +# If there are TPU nodes in config, this field will be set +# to True in config['provider']. +HAS_TPU_PROVIDER_FIELD = '_has_tpus' + +# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes +# with ServiceAccounts. + +SKYPILOT_VPC_NAME = 'skypilot-vpc' + +# Below parameters are from the default VPC on GCP. +# https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc +VPC_TEMPLATE: dict = { + 'name': '{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'autoCreateSubnetworks': True, + 'mtu': 1460, + 'routingConfig': { + 'routingMode': 'GLOBAL' + }, +} +# Required firewall rules for SkyPilot to work. +FIREWALL_RULES_REQUIRED = [ + # Allow internal connections between GCP VMs for Ray multi-node cluster. + { + 'direction': 'INGRESS', + 'allowed': [ + { + 'IPProtocol': 'tcp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'udp', + 'ports': ['0-65535'] + }, + ], + 'sourceRanges': ['10.128.0.0/9'], + }, + # Allow ssh connection from anywhere. + { + 'direction': 'INGRESS', + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], + }], + # TODO(skypilot): some users reported that this should be relaxed (e.g., + # allowlisting only certain IPs to have ssh access). + 'sourceRanges': ['0.0.0.0/0'], + }, +] + +# Template when creating firewall rules for a new VPC. +FIREWALL_RULES_TEMPLATE = [ + { + 'name': '{VPC_NAME}-allow-custom', + 'description': ('Allows connection from any source to any instance on ' + 'the network using custom protocols.'), + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': + ('projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom'), + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [ + { + 'IPProtocol': 'tcp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'udp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'icmp' + }, + ], + 'sourceRanges': ['10.128.0.0/9'], + }, + { + 'name': '{VPC_NAME}-allow-ssh', + 'description': + ('Allows TCP connections from any source to any instance on the ' + 'network using port 22.'), + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh', + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], + }], + # TODO(skypilot): some users reported that this should be relaxed (e.g., + # allowlisting only certain IPs to have ssh access). + 'sourceRanges': ['0.0.0.0/0'], + }, + { + 'name': '{VPC_NAME}-allow-icmp', + 'description': ('Allows ICMP connections from any source to any ' + 'instance on the network.'), + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp', + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [{ + 'IPProtocol': 'icmp', + }], + 'sourceRanges': ['0.0.0.0/0'], + }, +] + +# A list of permissions required to run SkyPilot on GCP. +# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html # pylint: disable=line-too-long +VM_MINIMAL_PERMISSIONS = [ + 'compute.disks.create', + 'compute.disks.list', + # TODO(skypilot): some users reported that firewalls changes + # (create/delete/update) should be removed if VPC/firewalls are separately + # set up. It is undesirable for a normal account to have these permissions. + # Note that if these permissions are removed, opening ports (e.g., via + # `resources.ports`) would fail. + 'compute.firewalls.create', + 'compute.firewalls.delete', + 'compute.firewalls.get', + 'compute.instances.create', + 'compute.instances.delete', + 'compute.instances.get', + 'compute.instances.list', + 'compute.instances.setLabels', + 'compute.instances.setServiceAccount', + 'compute.instances.start', + 'compute.instances.stop', + 'compute.networks.get', + 'compute.networks.list', + 'compute.networks.getEffectiveFirewalls', + 'compute.globalOperations.get', + 'compute.subnetworks.use', + 'compute.subnetworks.list', + 'compute.subnetworks.useExternalIp', + 'compute.projects.get', + 'compute.zoneOperations.get', + 'iam.roles.get', + 'iam.serviceAccounts.actAs', + 'iam.serviceAccounts.get', + 'serviceusage.services.enable', + 'serviceusage.services.list', + 'serviceusage.services.use', + 'resourcemanager.projects.get', + 'resourcemanager.projects.getIamPolicy', +] + +TPU_MINIMAL_PERMISSIONS = [ + 'tpu.nodes.create', + 'tpu.nodes.delete', + 'tpu.nodes.list', + 'tpu.nodes.get', + 'tpu.nodes.update', + 'tpu.operations.get', +] + +# The maximum number of times to poll for the status of an operation. +POLL_INTERVAL = 1 +MAX_POLLS = 60 // POLL_INTERVAL +# Stopping instances can take several minutes, so we increase the timeout +MAX_POLLS_STOP = MAX_POLLS * 8 diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 99b80ab5074..0d2070ee8db 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -1,20 +1,22 @@ """GCP instance provisioning.""" import collections +import copy +from multiprocessing import pool import re import time from typing import Any, Callable, Dict, Iterable, List, Optional, Type from sky import sky_logging +from sky import status_lib from sky.adaptors import gcp +from sky.provision import common +from sky.provision.gcp import constants from sky.provision.gcp import instance_utils +from sky.utils import common_utils logger = sky_logging.init_logger(__name__) -MAX_POLLS = 12 -# Stopping instances can take several minutes, so we increase the timeout -MAX_POLLS_STOP = MAX_POLLS * 8 -POLL_INTERVAL = 5 - +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' # Tag uniquely identifying all nodes of a cluster TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' @@ -37,10 +39,11 @@ def _filter_instances( instances = set() logger.debug(f'handlers: {handlers}') for instance_handler in handlers: - instances |= set( - instance_handler.filter(project_id, zone, label_filters, - status_filters_fn(instance_handler), - included_instances, excluded_instances)) + instance_dict = instance_handler.filter( + project_id, zone, label_filters, + status_filters_fn(instance_handler), included_instances, + excluded_instances) + instances |= set(instance_dict.keys()) handler_to_instances = collections.defaultdict(list) for instance in instances: handler = instance_utils.instance_to_handler(instance) @@ -49,6 +52,63 @@ def _filter_instances( return handler_to_instances +# TODO(suquark): Does it make sense to not expose this and always assume +# non_terminated_only=True? +# Will there be callers who would want this to be False? +# stop() and terminate() for example already implicitly assume non-terminated. +@common_utils.retry +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + zone = provider_config['availability_zone'] + project_id = provider_config['project_id'] + label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + handler: Type[ + instance_utils.GCPInstance] = instance_utils.GCPComputeInstance + use_tpu_vms = provider_config.get('_has_tpus', False) + if use_tpu_vms: + handler = instance_utils.GCPTPUVMInstance + + instances = handler.filter( + project_id, + zone, + label_filters, + status_filters=None, + ) + + raw_statuses = {} + statuses = {} + for inst_id, instance in instances.items(): + raw_status = instance[handler.STATUS_FIELD] + raw_statuses[inst_id] = raw_status + if raw_status in handler.PENDING_STATES: + status = status_lib.ClusterStatus.INIT + elif raw_status in handler.STOPPING_STATES + handler.STOPPED_STATES: + status = status_lib.ClusterStatus.STOPPED + elif raw_status == handler.RUNNING_STATE: + status = status_lib.ClusterStatus.UP + else: + status = None + if non_terminated_only and status is None: + continue + statuses[inst_id] = status + + # GCP does not clean up preempted TPU VMs. We remove it ourselves. + if handler == instance_utils.GCPTPUVMInstance: + all_preempted = all(s == 'PREEMPTED' for s in raw_statuses.values()) + if all_preempted: + logger.info( + f'Terminating preempted TPU VM cluster {cluster_name_on_cloud}') + terminate_instances(cluster_name_on_cloud, provider_config) + # TODO(zhwu): TPU node should check the status of the attached TPU as well. + return statuses + + def _wait_for_operations( handlers_to_operations: Dict[Type[instance_utils.GCPInstance], List[dict]], project_id: str, @@ -65,13 +125,305 @@ def _wait_for_operations( logger.debug( f'wait_for_compute_{op_type}_operation: ' f'Waiting for operation {operation["name"]} to finish...') - while total_polls < MAX_POLLS: + while total_polls < constants.MAX_POLLS: if handler.wait_for_operation(operation, project_id, zone): break - time.sleep(POLL_INTERVAL) + time.sleep(constants.POLL_INTERVAL) total_polls += 1 +def _get_head_instance_id(instances: List) -> Optional[str]: + head_instance_id = None + for inst in instances: + labels = inst.get('labels', {}) + if (labels.get(TAG_RAY_NODE_KIND) == 'head' or + labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'): + head_instance_id = inst['name'] + break + return head_instance_id + + +def _run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """See sky/provision/__init__.py""" + # NOTE: although google cloud instances have IDs, but they are + # not used for indexing. Instead, we use the instance name. + labels = config.tags # gcp uses 'labels' instead of aws 'tags' + labels = dict(sorted(copy.deepcopy(labels).items())) + resumed_instance_ids: List[str] = [] + created_instance_ids: List[str] = [] + + node_type = instance_utils.get_node_type(config.node_config) + project_id = config.provider_config['project_id'] + availability_zone = config.provider_config['availability_zone'] + + # SKY: 'TERMINATED' for compute VM, 'STOPPED' for TPU VM + # 'STOPPING' means the VM is being stopped, which needs + # to be included to avoid creating a new VM. + resource: Type[instance_utils.GCPInstance] + if node_type == instance_utils.GCPNodeType.COMPUTE: + resource = instance_utils.GCPComputeInstance + elif node_type == instance_utils.GCPNodeType.TPU: + resource = instance_utils.GCPTPUVMInstance + else: + raise ValueError(f'Unknown node type {node_type}') + + filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + # wait until all stopping instances are stopped/terminated + while True: + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=resource.STOPPING_STATES, + ) + if not instances: + break + logger.info(f'run_instances: Waiting for {len(instances)} instances in ' + 'STOPPING status') + time.sleep(constants.POLL_INTERVAL) + + exist_instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=None, + ) + exist_instances = list(exist_instances.values()) + head_instance_id = _get_head_instance_id(exist_instances) + + # NOTE: We are not handling REPAIRING, SUSPENDING, SUSPENDED status. + pending_instances = [] + running_instances = [] + stopping_instances = [] + stopped_instances = [] + + # SkyPilot: We try to use the instances with the same matching launch_config + # first. If there is not enough instances with matching launch_config, we + # then use all the instances with the same matching launch_config plus some + # instances with wrong launch_config. + def get_order_key(node): + import datetime # pylint: disable=import-outside-toplevel + + timestamp = node.get('lastStartTimestamp') + if timestamp is not None: + return datetime.datetime.strptime(timestamp, + '%Y-%m-%dT%H:%M:%S.%f%z') + return node['id'] + + logger.info(str(exist_instances)) + for inst in exist_instances: + state = inst[resource.STATUS_FIELD] + if state in resource.PENDING_STATES: + pending_instances.append(inst) + elif state == resource.RUNNING_STATE: + running_instances.append(inst) + elif state in resource.STOPPING_STATES: + stopping_instances.append(inst) + elif state in resource.STOPPED_STATES: + stopped_instances.append(inst) + else: + raise RuntimeError(f'Unsupported state "{state}".') + + pending_instances.sort(key=get_order_key, reverse=True) + running_instances.sort(key=get_order_key, reverse=True) + stopping_instances.sort(key=get_order_key, reverse=True) + stopped_instances.sort(key=get_order_key, reverse=True) + + if stopping_instances: + raise RuntimeError( + 'Some instances are being stopped during provisioning. ' + 'Please wait a while and retry.') + + if head_instance_id is None: + if running_instances: + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + running_instances[0]['name'], + is_head=True, + ) + elif pending_instances: + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + pending_instances[0]['name'], + is_head=True, + ) + # TODO(suquark): Maybe in the future, users could adjust the number + # of instances dynamically. Then this case would not be an error. + if config.resume_stopped_nodes and len(exist_instances) > config.count: + raise RuntimeError( + 'The number of running/stopped/stopping ' + f'instances combined ({len(exist_instances)}) in ' + f'cluster "{cluster_name_on_cloud}" is greater than the ' + f'number requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') + + to_start_count = (config.count - len(running_instances) - + len(pending_instances)) + + # Try to reuse previously stopped nodes with compatible configs + if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances: + resumed_instance_ids = [n['name'] for n in stopped_instances] + if resumed_instance_ids: + for instance_id in resumed_instance_ids: + resource.start_instance(instance_id, project_id, + availability_zone) + resource.set_labels(project_id, availability_zone, instance_id, + labels) + to_start_count -= len(resumed_instance_ids) + + if head_instance_id is None: + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + resumed_instance_ids[0], + is_head=True, + ) + + if to_start_count > 0: + errors, created_instance_ids = resource.create_instances( + cluster_name_on_cloud, project_id, availability_zone, + config.node_config, labels, to_start_count, + head_instance_id is None) + if errors: + error = common.ProvisionError('Failed to launch instances.') + error.errors = errors + raise error + if head_instance_id is None: + head_instance_id = created_instance_ids[0] + + while True: + # wait until all instances are running + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=resource.PENDING_STATES, + ) + if not instances: + break + logger.debug(f'run_instances: Waiting for {len(instances)} instances ' + 'in PENDING status.') + time.sleep(constants.POLL_INTERVAL) + + # Check if the number of running instances is the same as the requested. + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=[resource.RUNNING_STATE], + ) + if len(instances) != config.count: + logger.warning('The number of running instances is different from ' + 'the requested number after provisioning ' + f'(requested: {config.count}, ' + f'observed: {len(instances)}). ' + 'This could be some instances failed to start ' + 'or some resource leak.') + + assert head_instance_id is not None, 'head_instance_id is None' + return common.ProvisionRecord(provider_name='gcp', + region=region, + zone=availability_zone, + cluster_name=cluster_name_on_cloud, + head_instance_id=head_instance_id, + resumed_instance_ids=resumed_instance_ids, + created_instance_ids=created_instance_ids) + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """See sky/provision/__init__.py""" + try: + return _run_instances(region, cluster_name_on_cloud, config) + except gcp.http_error_exception() as e: + error_details = getattr(e, 'error_details') + errors = [] + if isinstance(error_details, list): + for detail in error_details: + errors.append({ + 'code': detail.get('reason'), + 'domain': detail.get('domain'), + 'message': detail.get('message', str(e)), + }) + elif isinstance(error_details, str): + errors.append({ + 'code': None, + 'domain': 'run_instances', + 'message': error_details, + }) + else: + raise + error = common.ProvisionError('Failed to launch instances.') + error.errors = errors + raise error from e + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + """See sky/provision/__init__.py""" + del region, cluster_name_on_cloud, state + # We already wait for the instances to be running in run_instances. + # So we don't need to wait here. + + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: + """See sky/provision/__init__.py""" + del region + assert provider_config is not None, cluster_name_on_cloud + zone = provider_config['availability_zone'] + project_id = provider_config['project_id'] + label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + handlers: List[Type[instance_utils.GCPInstance]] = [ + instance_utils.GCPComputeInstance + ] + use_tpu_vms = provider_config.get('_has_tpus', False) + if use_tpu_vms: + handlers.append(instance_utils.GCPTPUVMInstance) + + handler_to_instances = _filter_instances( + handlers, + project_id, + zone, + label_filters, + lambda h: [h.RUNNING_STATE], + ) + instances: Dict[str, List[common.InstanceInfo]] = {} + for res, insts in handler_to_instances.items(): + with pool.ThreadPool() as p: + inst_info = p.starmap(res.get_instance_info, + [(project_id, zone, inst) for inst in insts]) + instances.update(zip(insts, inst_info)) + + head_instances = _filter_instances( + handlers, + project_id, + zone, + { + **label_filters, TAG_RAY_NODE_KIND: 'head' + }, + lambda h: [h.RUNNING_STATE], + ) + head_instance_id = None + for insts in head_instances.values(): + if insts and insts[0]: + head_instance_id = insts[0] + break + + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance_id, + ) + + def stop_instances( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None, @@ -110,7 +462,7 @@ def stop_instances( # Check if the instance is actually stopped. # GCP does not fully stop an instance even after # the stop operation is finished. - for _ in range(MAX_POLLS_STOP): + for _ in range(constants.MAX_POLLS_STOP): handler_to_instances = _filter_instances( handler_to_instances.keys(), project_id, @@ -121,10 +473,10 @@ def stop_instances( ) if not handler_to_instances: break - time.sleep(POLL_INTERVAL) + time.sleep(constants.POLL_INTERVAL) else: raise RuntimeError(f'Maximum number of polls: ' - f'{MAX_POLLS_STOP} reached. ' + f'{constants.MAX_POLLS_STOP} reached. ' f'Instance {all_instances} is still not in ' 'STOPPED status.') @@ -157,6 +509,7 @@ def terminate_instances( for handler, instances in handler_to_instances.items(): for instance in instances: try: + logger.debug(f'Terminating instance: {instance}.') operations[handler].append( handler.terminate(project_id, zone, instance)) except gcp.http_error_exception() as e: diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index cf46bef52fa..f823344fc4c 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1,21 +1,126 @@ """Utilities for GCP instances.""" +import copy +import enum +import functools +from multiprocessing import pool import re -from typing import Dict, List, Optional +import time +from typing import Any, Dict, List, Optional, Tuple +import uuid from sky import sky_logging from sky.adaptors import gcp +from sky.clouds import gcp as gcp_cloud +from sky.provision import common +from sky.provision.gcp import constants +from sky.utils import common_utils from sky.utils import ux_utils -logger = sky_logging.init_logger(__name__) +# Tag uniquely identifying all nodes of a cluster +TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' +# Tag for the name of the node +INSTANCE_NAME_MAX_LEN = 64 +INSTANCE_NAME_UUID_LEN = 8 +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' +TAG_RAY_NODE_KIND = 'ray-node-type' + +# This is the maximum number of times we will retry a GCP API call. +# The number is identical to those we use for AWS boto3. +GCP_MAX_RETRIES = 12 +GCP_CREATE_MAX_RETRIES = 5 +GCP_RETRY_INTERVAL_SECONDS = 5 +GCP_TIMEOUT = 300 -# Using v2 according to -# https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-curl # pylint: disable=line-too-long -TPU_VERSION = 'v2' +logger = sky_logging.init_logger(__name__) _FIREWALL_RESOURCE_NOT_FOUND_PATTERN = re.compile( r'The resource \'projects/.*/global/firewalls/.*\' was not found') +def _retry_on_http_exception( + regex: Optional[str] = None, + max_retries: int = GCP_MAX_RETRIES, + retry_interval_s: int = GCP_RETRY_INTERVAL_SECONDS, +): + """Retry a function call n-times for as long as it throws an exception.""" + + def dec(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + exception_type = gcp.http_error_exception() + + def try_catch_exc(): + try: + value = func(*args, **kwargs) + return value + except Exception as e: # pylint: disable=broad-except + if not isinstance(e, exception_type) or ( + regex and not re.search(regex, str(e))): + raise + return e + + for _ in range(max_retries): + ret = try_catch_exc() + if not isinstance(ret, Exception): + break + time.sleep(retry_interval_s) + if isinstance(ret, Exception): + raise ret + return ret + + return wrapper + + return dec + + +def _generate_node_name(cluster_name: str, node_suffix: str, + is_head: bool) -> str: + """Generate node name from labels and suffix. + + This is required so that the correct resource can be selected + when the only information autoscaler has is the name of the node. + + The suffix is expected to be one of 'compute' or 'tpu' + (as in ``GCPNodeType``). + """ + suffix_id = common_utils.base36_encode(uuid.uuid4().hex) + suffix = f'-{suffix_id[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}' + if is_head: + suffix = f'-head{suffix}' + else: + suffix = f'-worker{suffix}' + node_name = cluster_name + suffix + assert len(node_name) <= INSTANCE_NAME_MAX_LEN, cluster_name + return node_name + + +def _log_errors(errors: List[Dict[str, str]], e: Any, zone: str) -> None: + """Format errors into a string.""" + if errors: + plural = 's' if len(errors) > 1 else '' + codes = ', '.join(repr(e.get('code', 'N/A')) for e in errors) + messages = '; '.join( + repr(e.get('message', 'N/A').strip('.')) for e in errors) + logger.warning(f'create_instances: Got return code{plural} {codes} in ' + f'{zone}: {messages}') + else: + logger.warning(f'create_instances: Failed with reason: {e}') + + +def selflink_to_name(selflink: str) -> str: + """Converts a selflink to a name. + + Args: + selflink: The selflink to convert. + + Returns: + The name of the resource. + """ + return selflink.rsplit('/', 1)[-1] + + def instance_to_handler(instance: str): instance_type = instance.split('-')[-1] if instance_type == 'compute': @@ -28,9 +133,14 @@ def instance_to_handler(instance: str): class GCPInstance: """Base class for GCP instance handlers.""" + PENDING_STATES: List[str] = [] NEED_TO_STOP_STATES: List[str] = [] NON_STOPPED_STATES: List[str] = [] NEED_TO_TERMINATE_STATES: List[str] = [] + RUNNING_STATE: str = '' + STOPPING_STATES: List[str] = [] + STOPPED_STATES: List[str] = [] + STATUS_FIELD: str = '' @classmethod def load_resource(cls): @@ -73,7 +183,7 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: raise NotImplementedError @classmethod @@ -114,16 +224,80 @@ def add_network_tag_if_not_exist( ) -> None: raise NotImplementedError + @classmethod + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + include_head_node: bool, + ) -> Tuple[Optional[List], List[str]]: + """Creates multiple instances and returns result. + + Returns a tuple of (errors, list[instance_names]). + """ + raise NotImplementedError + + @classmethod + def start_instance(cls, node_id: str, project_id: str, zone: str) -> bool: + """Start a stopped instance.""" + raise NotImplementedError + + @classmethod + def set_labels(cls, project_id: str, availability_zone: str, node_id: str, + labels: dict) -> bool: + raise NotImplementedError + + @classmethod + def create_node_tag(cls, + project_id: str, + availability_zone: str, + target_instance_id: str, + is_head: bool = True) -> str: + if is_head: + node_tag = { + TAG_SKYPILOT_HEAD_NODE: '1', + TAG_RAY_NODE_KIND: 'head', + } + else: + node_tag = { + TAG_SKYPILOT_HEAD_NODE: '0', + TAG_RAY_NODE_KIND: 'worker', + } + cls.set_labels(project_id=project_id, + availability_zone=availability_zone, + node_id=target_instance_id, + labels=node_tag) + + return target_instance_id + + @classmethod + def get_instance_info(cls, project_id: str, availability_zone: str, + instance_id: str) -> List[common.InstanceInfo]: + raise NotImplementedError + + @classmethod + def resize_disk(cls, project_id: str, availability_zone: str, + node_config: dict, instance_name: str) -> bool: + """Resize a Google Cloud disk based on the provided configuration. + Returns the response of resize operation. + """ + raise NotImplementedError + class GCPComputeInstance(GCPInstance): """Instance handler for GCP compute instances.""" - NEED_TO_STOP_STATES = [ - 'PROVISIONING', - 'STAGING', - 'RUNNING', - ] + PENDING_STATES = ['PROVISIONING', 'STAGING', 'REPAIRING'] + STOPPING_STATES = ['STOPPING', 'SUSPENDING'] + STOPPED_STATES = ['TERMINATED', 'SUSPENDED'] + RUNNING_STATE = 'RUNNING' + STATUS_FIELD = 'status' + NEED_TO_STOP_STATES = PENDING_STATES + [RUNNING_STATE] - NON_STOPPED_STATES = NEED_TO_STOP_STATES + ['STOPPING'] + NON_STOPPED_STATES = NEED_TO_STOP_STATES + STOPPING_STATES @classmethod def load_resource(cls): @@ -174,7 +348,7 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: if label_filters: label_filter_expr = ('(' + ' AND '.join([ '(labels.{key} = {value})'.format(key=key, value=value) @@ -204,13 +378,19 @@ def filter( project=project_id, filter=filter_expr, zone=zone, - ).execute()) + ).execute(num_retries=GCP_MAX_RETRIES)) instances = response.get('items', []) - instances = [i['name'] for i in instances] + instances = {i['name']: i for i in instances} if included_instances: - instances = [i for i in instances if i in included_instances] + instances = { + k: v for k, v in instances.items() if k in included_instances + } if excluded_instances: - instances = [i for i in instances if i not in excluded_instances] + instances = { + k: v + for k, v in instances.items() + if k not in excluded_instances + } return instances @classmethod @@ -253,7 +433,7 @@ def get_vpc_name( ).execute() # Format: projects/PROJECT_ID/global/networks/VPC_NAME vpc_link = response['networkInterfaces'][0]['network'] - return vpc_link.split('/')[-1] + return selflink_to_name(vpc_link) except gcp.http_error_exception() as e: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -338,7 +518,8 @@ def create_or_update_firewall_rule( ) from e body = { 'name': firewall_rule_name, - 'description': f'Allow user-specified port {ports} for cluster {cluster_name_on_cloud}', + 'description': (f'Allow user-specified port {ports} for ' + f'cluster {cluster_name_on_cloud}'), 'network': f'projects/{project_id}/global/networks/{vpc_name}', 'selfLink': f'projects/{project_id}/global/firewalls/' + firewall_rule_name, @@ -357,23 +538,347 @@ def create_or_update_firewall_rule( ).execute() return operation + @classmethod + def set_labels(cls, project_id: str, availability_zone: str, node_id: str, + labels: dict) -> bool: + node = cls.load_resource().instances().get( + project=project_id, + instance=node_id, + zone=availability_zone, + ).execute(num_retries=GCP_CREATE_MAX_RETRIES) + body = { + 'labels': dict(node['labels'], **labels), + 'labelFingerprint': node['labelFingerprint'], + } + operation = (cls.load_resource().instances().setLabels( + project=project_id, + zone=availability_zone, + instance=node_id, + body=body, + ).execute(num_retries=GCP_CREATE_MAX_RETRIES)) + + result = cls.wait_for_operation(operation, project_id, + availability_zone) + return result + + @classmethod + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + include_head_node: bool, + ) -> Tuple[Optional[List], List[str]]: + # NOTE: The syntax for bulkInsert() is different from insert(). + # bulkInsert expects resource names without prefix. Otherwise + # it causes a 503 error. + config = copy.deepcopy(node_config) + + if 'scheduling' in config and isinstance(config['scheduling'], list): + # For backeward compatibility: converting the list of dictionaries + # to a dictionary due to the use of deprecated API. + # [{'preemptible': True}, {'onHostMaintenance': 'TERMINATE'}] + # to {'preemptible': True, 'onHostMaintenance': 'TERMINATE'} + config['scheduling'] = { + k: v for d in config['scheduling'] for k, v in d.items() + } + + for disk in config.get('disks', []): + disk_type = disk.get('initializeParams', {}).get('diskType') + if disk_type: + disk['initializeParams']['diskType'] = selflink_to_name( + disk_type) + config['machineType'] = selflink_to_name(config['machineType']) + for accelerator in config.get('guestAccelerators', []): + accelerator['acceleratorType'] = selflink_to_name( + accelerator['acceleratorType']) + + # removing TPU-specific default key set in config.py + config.pop('networkConfig', None) + + head_tag_needed = [False] * count + if include_head_node: + head_tag_needed[0] = True + + names = [] + for i in range(count): + names.append( + _generate_node_name(cluster_name, + GCPNodeType.COMPUTE.value, + is_head=head_tag_needed[i])) + + labels = dict(config.get('labels', {}), **labels) + + config.update({ + 'labels': dict( + labels, **{ + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_SKYPILOT_CLUSTER_NAME: cluster_name + }), + }) + + all_names = [] + if 'reservationAffinity' in config: + reservations = gcp_cloud.GCP().get_reservations_available_resources( + config['machineType'], + region=zone.rpartition('-')[0], + zone=zone, + specific_reservations=set( + config['reservationAffinity']['values'])) + # Sort the reservations by the number of available resources + reservation_list = sorted(reservations.items(), + key=lambda x: x[1], + reverse=True) + # TODO(zhwu): Convert this to parallel execution. + # TODO(zhwu): This is not atomic as the reservation count may change + # between the time we check and the time we create the instances, as + # other users may be creating instances at the same time. + # Our current implementation will skip the current region if the + # reservation count is not enough, which is suboptimal. + for reservation, reservation_count in reservation_list: + if reservation_count <= 0: + continue + reservation_count = min(reservation_count, count) + logger.debug(f'Creating {reservation_count} instances ' + f'with reservation {reservation}') + config['reservationAffinity']['values'] = [reservation] + errors, created_names = cls._create_instances( + names[:reservation_count], project_id, zone, config, + reservation_count, head_tag_needed[:reservation_count]) + all_names.extend(names) + if errors: + return errors, all_names + count -= reservation_count + if count <= 0: + return None, all_names + names = names[reservation_count:] + head_tag_needed = head_tag_needed[reservation_count:] + config.pop('reservationAffinity', None) + + errors, created_names = cls._create_instances(names, project_id, zone, + config, count, + head_tag_needed) + + all_names.extend(created_names) + return errors, all_names + + @classmethod + def _create_instances( + cls, + names: List[str], + project_id: str, + zone: str, + config: dict, + count: int, + head_tag_needed: List[bool], + ) -> Tuple[Optional[List], List[str]]: + source_instance_template = config.pop('sourceInstanceTemplate', None) + body = { + 'count': count, + 'instanceProperties': config, + 'sourceInstanceTemplate': source_instance_template, + 'perInstanceProperties': {n: {} for n in names} + } + + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + try: + logger.debug('Launching GCP instances with "bulkInsert" ...') + request = cls.load_resource().instances().bulkInsert( + project=project_id, + zone=zone, + body=body, + ) + operation = request.execute(num_retries=0) + except gcp.http_error_exception() as e: + # NOTE: Error example: + # { + # 'message': "Quota '...' exceeded. Limit: ... in region xx-xxxx.", # pylint: disable=line-too-long + # 'domain': 'usageLimits', + # 'reason': 'quotaExceeded' + # } + error_details = getattr(e, 'error_details', []) + errors = [] + for detail in error_details: + # To be consistent with error messages returned by operation wait. + errors.append({ + 'code': detail.get('reason'), + 'domain': detail.get('domain'), + 'message': detail.get('message', str(e)), + }) + logger.debug( + f'create_instances: googleapiclient.errors.HttpError: {e}') + _log_errors(errors, e, zone) + return errors, names + errors = operation.get('error', {}).get('errors') + if errors: + logger.debug('create_instances: Failed to create instances. ' + f'Reason: {errors}') + _log_errors(errors, operation, zone) + return errors, names + + logger.debug('Waiting GCP instances to be ready ...') + wait_start = time.time() + success = False + while time.time() - wait_start < GCP_TIMEOUT: + # Retry the wait() call until it succeeds or times out. + # This is because the wait() call is only best effort, and does not + # guarantee that the operation is done when it returns. + # Reference: https://cloud.google.com/workflows/docs/reference/googleapis/compute/v1/zoneOperations/wait # pylint: disable=line-too-long + request = cls.load_resource().zoneOperations().wait( + project=project_id, + operation=operation['name'], + zone=zone, + ) + request.http.timeout = GCP_TIMEOUT - (time.time() - wait_start) + result = request.execute(num_retries=GCP_CREATE_MAX_RETRIES) + success = result['status'] == 'DONE' + if success: + break + logger.debug(f'create_instances: Retry waiting for operation ' + f'{operation["name"]} to finish (result: {result})...') + else: + logger.warning('create_instances: Timeout waiting for creation ' + 'operation, cancelling the operation ...') + request = cls.load_resource().zoneOperations().delete( + project=project_id, + operation=operation['name'], + zone=zone, + ) + request.http.timeout = GCP_TIMEOUT - (time.time() - wait_start) + request.execute(num_retries=GCP_CREATE_MAX_RETRIES) + errors = [{ + 'code': 'TIMEOUT', + 'message': 'Timeout waiting for creation operation', + 'domain': 'create_instances' + }] + _log_errors(errors, None, zone) + return errors, names + + # NOTE: Error example: + # { + # 'code': 'VM_MIN_COUNT_NOT_REACHED', + # 'message': 'Requested minimum count of 4 VMs could not be created.' + # } + errors = result.get('error', {}).get('errors') + if errors: + logger.debug( + 'create_instances: Failed to create instances. Reason: ' + f'{errors}') + _log_errors(errors, result, zone) + return errors, names + assert success, ('Failed to create instances, but there is no error. ' + f'Instance status: {result}') + # assign labels for head node + with pool.ThreadPool() as p: + p.starmap(cls.create_node_tag, + [(project_id, zone, names[i], head_tag_needed[i]) + for i in range(count)]) + return None, names + + @classmethod + def start_instance(cls, node_id: str, project_id: str, zone: str) -> bool: + operation = (cls.load_resource().instances().start( + project=project_id, + zone=zone, + instance=node_id, + ).execute()) + + result = cls.wait_for_operation(operation, project_id, zone) + return result + + @classmethod + def get_instance_info(cls, project_id: str, availability_zone: str, + instance_id: str) -> List[common.InstanceInfo]: + result = cls.load_resource().instances().get( + project=project_id, + zone=availability_zone, + instance=instance_id, + ).execute() + external_ip = (result.get('networkInterfaces', + [{}])[0].get('accessConfigs', + [{}])[0].get('natIP', None)) + internal_ip = result.get('networkInterfaces', [{}])[0].get('networkIP') + + return [ + common.InstanceInfo( + instance_id=instance_id, + internal_ip=internal_ip, + external_ip=external_ip, + tags=result.get('labels', {}), + ) + ] + + @classmethod + def resize_disk(cls, project_id: str, availability_zone: str, + node_config: dict, instance_name: str) -> bool: + """Resize a Google Cloud disk based on the provided configuration.""" + + # Extract the specified disk size from the configuration + new_size_gb = node_config['disks'][0]['initializeParams']['diskSizeGb'] + + # Fetch the instance details to get the disk name and current disk size + response = (cls.load_resource().instances().get( + project=project_id, + zone=availability_zone, + instance=instance_name, + ).execute()) + disk_name = selflink_to_name(response['disks'][0]['source']) + + try: + # Execute the resize request and return the response + operation = (cls.load_resource().disks().resize( + project=project_id, + zone=availability_zone, + disk=disk_name, + body={ + 'sizeGb': str(new_size_gb), + }, + ).execute()) + except gcp.http_error_exception() as e: + # Catch HttpError when provided with invalid value for new disk + # size. Allowing users to create instances with the same size as the + # image. + logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') + return False + + result = cls.wait_for_operation(operation, project_id, + availability_zone) + + return result + class GCPTPUVMInstance(GCPInstance): - """Instance handler for GCP TPU node.""" - NEED_TO_STOP_STATES = [ - 'CREATING', - 'STARTING', - 'READY', - 'RESTARTING', - ] + """Instance handler for GCP TPU VM.""" + PENDING_STATES = ['CREATING', 'STARTING', 'RESTARTING', 'REPAIRING'] + RUNNING_STATE = 'READY' + STOPPING_STATES = ['STOPPING'] + STOPPED_STATES = ['STOPPED'] + STATUS_FIELD = 'state' + NEED_TO_STOP_STATES = PENDING_STATES + [RUNNING_STATE] - NON_STOPPED_STATES = NEED_TO_STOP_STATES + ['STOPPING'] + NON_STOPPED_STATES = NEED_TO_STOP_STATES + STOPPING_STATES @classmethod def load_resource(cls): return gcp.build( 'tpu', - TPU_VERSION, + constants.TPU_VERSION, credentials=None, cache_discovery=False, discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest') @@ -384,7 +889,7 @@ def wait_for_operation(cls, operation: dict, project_id: str, """Poll for TPU operation until finished.""" del project_id, zone # unused result = (cls.load_resource().projects().locations().operations().get( - name=str(operation['name'])).execute()) + name=str(operation['name'])).execute(num_retries=GCP_MAX_RETRIES)) if 'error' in result: raise Exception(result['error']) @@ -403,17 +908,18 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: path = f'projects/{project_id}/locations/{zone}' try: response = (cls.load_resource().projects().locations().nodes().list( - parent=path).execute()) + parent=path).execute(num_retries=GCP_MAX_RETRIES)) except gcp.http_error_exception() as e: # SKY: Catch HttpError when accessing unauthorized region. - # Return empty list instead of raising exception to not break - # ray down. - logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') - return [] + # Return empty dict instead of raising exception to not break. + if 'is not found or access is unauthorized.' in str(e): + return {} + logger.debug(f'filter: googleapiclient.errors.HttpError: {e}') + raise instances = response.get('nodes', []) @@ -439,18 +945,23 @@ def filter_instance(instance) -> bool: return True instances = list(filter(filter_instance, instances)) - instances = [i['name'] for i in instances] + instances = {i['name']: i for i in instances} if included_instances: - instances = [i for i in instances if i in included_instances] + instances = { + k: v for k, v in instances.items() if k in included_instances + } if excluded_instances: - instances = [i for i in instances if i not in excluded_instances] - + instances = { + k: v + for k, v in instances.items() + if k not in excluded_instances + } return instances @classmethod def stop(cls, project_id: str, zone: str, instance: str) -> dict: - """Stop a TPU node.""" + """Stop a TPU VM.""" del project_id, zone # unused operation = cls.load_resource().projects().locations().nodes().stop( name=instance).execute() @@ -458,7 +969,7 @@ def stop(cls, project_id: str, zone: str, instance: str) -> dict: @classmethod def terminate(cls, project_id: str, zone: str, instance: str) -> dict: - """Terminate a TPU node.""" + """Terminate a TPU VM.""" del project_id, zone # unused operation = cls.load_resource().projects().locations().nodes().delete( name=instance).execute() @@ -507,8 +1018,308 @@ def get_vpc_name( response = cls.load_resource().projects().locations().nodes().get( name=instance).execute() vpc_link = response['networkConfig']['network'] - return vpc_link.split('/')[-1] + return selflink_to_name(vpc_link) except gcp.http_error_exception() as e: with ux_utils.print_exception_no_traceback(): raise ValueError( f'Failed to get VPC name for instance {instance}') from e + + @classmethod + @_retry_on_http_exception('unable to queue the operation') + def set_labels(cls, project_id: str, availability_zone: str, node_id: str, + labels: dict) -> bool: + while True: + # wait until the instance become ready before setting labels + # as Cloud TPU API does not allow setting labels on pending + # instances + instances = cls.filter( + project_id=project_id, + zone=availability_zone, + label_filters=None, + status_filters=cls.PENDING_STATES, + included_instances=[node_id], + ) + if not instances: + break + logger.debug(f'set_labels: Waiting for instance {node_id} to be ' + 'ready...') + time.sleep(constants.POLL_INTERVAL) + + node = (cls.load_resource().projects().locations().nodes().get( + name=node_id).execute(num_retries=GCP_CREATE_MAX_RETRIES)) + body = { + 'labels': dict(node['labels'], **labels), + } + update_mask = 'labels' + + operation = (cls.load_resource().projects().locations().nodes().patch( + name=node_id, + updateMask=update_mask, + body=body, + ).execute(num_retries=GCP_CREATE_MAX_RETRIES)) + + result = cls.wait_for_operation(operation, project_id, + availability_zone) + + return result + + @classmethod + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + include_head_node: bool, + ) -> Tuple[Optional[List], List[str]]: + config = copy.deepcopy(node_config) + # removing Compute-specific default key set in config.py + config.pop('networkInterfaces', None) + + head_tag_needed = [False] * count + if include_head_node: + head_tag_needed[0] = True + + names = [] + for i in range(count): + names.append( + _generate_node_name(cluster_name, + GCPNodeType.TPU.value, + is_head=head_tag_needed[i])) + + labels = dict(config.get('labels', {}), **labels) + + config.update({ + 'labels': dict( + labels, **{ + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_SKYPILOT_CLUSTER_NAME: cluster_name + }), + }) + + if 'reservationAffinity' in config: + raise NotImplementedError( + 'TPU VMs do not support reservations yet.') + + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + operations = [] + for i, name in enumerate(names): + node_config = config.copy() + if i == 0: + node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '1' + node_config['labels'][TAG_RAY_NODE_KIND] = 'head' + else: + node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '0' + node_config['labels'][TAG_RAY_NODE_KIND] = 'worker' + try: + logger.debug('Launching GCP TPU VM ...') + request = ( + cls.load_resource().projects().locations().nodes().create( + parent=f'projects/{project_id}/locations/{zone}', + body=node_config, + nodeId=name, + )) + operation = request.execute(num_retries=0) + operations.append(operation) + except gcp.http_error_exception() as e: + # NOTE: Error example: + # { + # 'message': "Quota '...' exceeded. Limit: ... in region xx-xxxx.", # pylint: disable=line-too-long + # 'domain': 'usageLimits', + # 'reason': 'quotaExceeded' + # } + error_details = getattr(e, 'error_details', []) + logger.debug( + f'create_instances: googleapiclient.errors.HttpError: {e}') + errors = [] + if isinstance(error_details, str): + errors.append({ + 'code': 'CREATION_FAILED', + 'domain': 'create_instances', + 'message': error_details, + }) + _log_errors(errors, e, zone) + return errors, names + for detail in error_details: + # To be consistent with error messages returned by operation + # wait. + viloations = detail.get('violations', []) + if not viloations: + errors.append({ + 'code': detail.get('reason'), + 'domain': detail.get('domain'), + 'message': detail.get('message', str(e)), + }) + else: + for violation in viloations: + errors.append({ + 'code': detail.get('@type'), + 'domain': violation.get('subject'), + 'message': violation.get('description'), + }) + _log_errors(errors, e, zone) + return errors, names + errors = [] + logger.info(str(operations)) + for operation in operations: + error = operation.get('error', {}).get('details') + if error: + errors.extend(error) + if errors: + logger.debug('create_instances: Failed to create instances. ' + f'Reason: {errors}') + _log_errors(errors, operations, zone) + return errors, names + + logger.debug('Waiting GCP instances to be ready ...') + wait_start = time.time() + success = [False] * len(operations) + results: List[dict] = [{} for _ in range(len(operations))] + while time.time() - wait_start < GCP_TIMEOUT: + # Retry the wait() call until it succeeds or times out. + # This is because the wait() call is only best effort, and does not + # guarantee that the operation is done when it returns. + # Reference: https://cloud.google.com/workflows/docs/reference/googleapis/compute/v1/zoneOperations/wait # pylint: disable=line-too-long + for i, operation in enumerate(operations): + if success[i]: + continue + request = ( + cls.load_resource().projects().locations().operations().get( + name=operation['name'],)) + request.http.timeout = GCP_TIMEOUT - (time.time() - wait_start) + result = request.execute(num_retries=GCP_CREATE_MAX_RETRIES) + results[i] = result + success[i] = result['done'] + if all(success): + logger.debug(f'create_instances: Finished {results}') + break + logger.debug('create_instances: Retry waiting for TPU operations ' + f'to finish: {results}...') + else: + logger.warning('create_instances: Timeout waiting for TPU creation ' + 'operation, cancelling the operation ...') + for i, operation in enumerate(operations): + if success[i]: + continue + request = cls.load_resource().projects().locations().operations( + ).cancel(name=operation['name'],) + request.http.timeout = GCP_TIMEOUT - (time.time() - wait_start) + request.execute(num_retries=GCP_CREATE_MAX_RETRIES) + errors = [{ + 'code': 'TIMEOUT', + 'message': 'Timeout waiting for creation operation', + 'domain': 'create_instances' + }] + _log_errors(errors, None, zone) + return errors, names + + # NOTE: Error example: + # { + # 'code': 8, + # 'message': 'There is no more capacity in the zone ... + # } + errors = [] + for result in results: + error = result.get('error', {}) + if error: + errors.append(error) + if errors: + logger.debug( + 'create_instances: Failed to create instances. Reason: ' + f'{errors}') + _log_errors(errors, results, zone) + return errors, names + assert all(success), ( + 'Failed to create instances, but there is no error. ' + f'Instance status: {results}') + return None, names + + @classmethod + def start_instance(cls, node_id: str, project_id: str, zone: str) -> bool: + operation = (cls.load_resource().projects().locations().nodes().start( + name=node_id).execute()) + + # FIXME: original implementation has the 'max_polls=MAX_POLLS' option. + result = cls.wait_for_operation(operation, project_id, zone) + + return result + + @classmethod + def resize_disk(cls, project_id: str, availability_zone: str, + node_config: dict, instance_name: str) -> bool: + """Resize the disk a machine image with a different size is used. + + TODO: Implement the feature to attach persistent disks for TPU VMs. + The boot disk of TPU VMs is not resizable, and users need to add a + persistent disk to expand disk capacity. Related issue: #2387 + """ + return False + + @classmethod + def get_instance_info(cls, project_id: str, availability_zone: str, + instance_id: str) -> List[common.InstanceInfo]: + del project_id, availability_zone # unused + result = cls.load_resource().projects().locations().nodes().get( + name=instance_id).execute() + network_endpoints = result.get('networkEndpoints', [{}]) + external_ips = [] + internal_ips = [] + for endpoint in network_endpoints: + external_ips.append( + endpoint.get('accessConfig', {}).get('externalIp', None)) + internal_ips.append(endpoint.get('ipAddress', None)) + + return [ + common.InstanceInfo( + instance_id=instance_id, + internal_ip=internal_ip, + external_ip=external_ip, + tags=result.get('labels', {}), + ) for internal_ip, external_ip in zip(internal_ips, external_ips) + ] + + +class GCPNodeType(enum.Enum): + """Enum for GCP node types (compute & tpu)""" + + COMPUTE = 'compute' + TPU = 'tpu' + + +def get_node_type(node: dict) -> GCPNodeType: + """Returns node type based on the keys in ``node``. + + This is a very simple check. If we have a ``machineType`` key, + this is a Compute instance. If we don't have a ``machineType`` key, + but we have ``acceleratorType``, this is a TPU. Otherwise, it's + invalid and an exception is raised. + + This works for both node configs and API returned nodes. + """ + + if 'machineType' not in node and 'acceleratorType' not in node: + raise ValueError( + 'Invalid node. For a Compute instance, "machineType" is ' + 'required. ' + 'For a TPU instance, "acceleratorType" and no "machineType" ' + 'is required. ' + f'Got {list(node)}') + + if 'machineType' not in node and 'acceleratorType' in node: + return GCPNodeType.TPU + return GCPNodeType.COMPUTE diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 9a21f6ed9ca..68ed6665b35 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -37,7 +37,9 @@ _RAY_PORT_COMMAND = ( 'RAY_PORT=$(python -c "from sky.skylet import job_lib; ' - 'print(job_lib.get_ray_port())" 2> /dev/null || echo 6379)') + 'print(job_lib.get_ray_port())" 2> /dev/null || echo 6379);' + 'python -c "from sky.utils import common_utils; ' + 'print(common_utils.encode_payload({\'ray_port\': $RAY_PORT}))"') # Command that calls `ray status` with SkyPilot's Ray port set. RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND = ( @@ -65,7 +67,9 @@ def retry(*args, **kwargs): if retry_cnt >= _MAX_RETRY - 1: raise e sleep = backoff.current_backoff() - logger.info(f'Retrying in {sleep:.1f} seconds.') + logger.info( + f'{func.__name__}: Retrying in {sleep:.1f} seconds, ' + f'due to {e}') time.sleep(sleep) return retry @@ -86,7 +90,7 @@ def wrapper(*args, **kwargs): def _hint_worker_log_path(cluster_name: str, cluster_info: common.ClusterInfo, stage_name: str): - if len(cluster_info.instances) > 1: + if cluster_info.num_instances > 1: worker_log_path = metadata_utils.get_instance_log_dir( cluster_name, '*') / (stage_name + '.log') logger.info(f'Logs of worker nodes can be found at: {worker_log_path}') @@ -98,21 +102,22 @@ def _parallel_ssh_with_cache(func, cluster_name: str, stage_name: str, ssh_credentials: Dict[str, Any]) -> List[Any]: with futures.ThreadPoolExecutor(max_workers=32) as pool: results = [] - for instance_id, metadata in cluster_info.instances.items(): - runner = command_runner.SSHCommandRunner(metadata.get_feasible_ip(), - port=22, - **ssh_credentials) - wrapper = metadata_utils.cache_func(cluster_name, instance_id, - stage_name, digest) - if cluster_info.head_instance_id == instance_id: - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) - else: - log_dir_abs = metadata_utils.get_instance_log_dir( - cluster_name, instance_id) - log_path_abs = str(log_dir_abs / (stage_name + '.log')) - results.append( - pool.submit(wrapper(func), runner, metadata, log_path_abs)) + for instance_id, metadatas in cluster_info.instances.items(): + for i, metadata in enumerate(metadatas): + cache_id = f'{instance_id}-{i}' + runner = command_runner.SSHCommandRunner( + metadata.get_feasible_ip(), port=22, **ssh_credentials) + wrapper = metadata_utils.cache_func(cluster_name, cache_id, + stage_name, digest) + if (cluster_info.head_instance_id == instance_id and i == 0): + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + else: + log_dir_abs = metadata_utils.get_instance_log_dir( + cluster_name, cache_id) + log_path_abs = str(log_dir_abs / (stage_name + '.log')) + results.append( + pool.submit(wrapper(func), runner, metadata, log_path_abs)) return [future.result() for future in results] @@ -240,20 +245,27 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @_log_start_end @_auto_retry def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, - custom_resource: Optional[str], + custom_resource: Optional[str], ray_port: int, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the worker nodes.""" - if len(cluster_info.instances) <= 1: + if cluster_info.num_instances <= 1: return _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') ip_list = cluster_info.get_feasible_ips() ssh_runners = command_runner.SSHCommandRunner.make_runner_list( ip_list[1:], port_list=None, **ssh_credentials) - worker_ids = [ - instance_id for instance_id in cluster_info.instances - if instance_id != cluster_info.head_instance_id - ] + worker_instances = cluster_info.get_worker_instances() + cache_ids = [] + prev_instance_id = None + cnt = 0 + for instance in worker_instances: + if instance.instance_id != prev_instance_id: + cnt = 0 + prev_instance_id = instance.instance_id + cache_ids.append(f'{prev_instance_id}-{cnt}') + cnt += 1 + head_instance = cluster_info.get_head_instance() assert head_instance is not None, cluster_info head_private_ip = head_instance.internal_ip @@ -269,16 +281,16 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, cmd = (f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' f'ray start --disable-usage-stats {ray_options} || exit 1;' + - _RAY_PRLIMIT + _DUMP_RAY_PORTS) + _RAY_PRLIMIT) if no_restart: # We do not use ray status to check whether ray is running, because # on worker node, if the user started their own ray cluster, ray status # will return 0, i.e., we don't know skypilot's ray cluster is running. # Instead, we check whether the raylet process is running on gcs address # that is connected to the head with the correct port. - cmd = (f'{_RAY_PORT_COMMAND}; ps aux | grep "ray/raylet/raylet" | ' + cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' f'grep "gcs-address={head_private_ip}:${{RAY_PORT}}" || ' - f'{{ {cmd}; }}') + f'{{ {cmd} }}') else: cmd = 'ray stop; ' + cmd @@ -298,7 +310,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.SSHCommandRunner, log_path=log_path_abs) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(ssh_runners, worker_ids))) + _setup_ray_worker, list(zip(ssh_runners, cache_ids))) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index cbcf5c2bcee..a6e810b1386 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -22,6 +22,7 @@ from sky.provision import instance_setup from sky.provision import logging as provision_logging from sky.provision import metadata_utils +from sky.skylet import constants from sky.utils import command_runner from sky.utils import common_utils from sky.utils import rich_utils @@ -136,16 +137,17 @@ def bulk_provision( cluster_yaml: str, is_prev_cluster_healthy: bool, log_dir: str, -) -> Optional[provision_common.ProvisionRecord]: +) -> provision_common.ProvisionRecord: """Provisions a cluster and wait until fully provisioned.""" original_config = common_utils.read_yaml(cluster_yaml) + head_node_type = original_config['head_node_type'] bootstrap_config = provision_common.ProvisionConfig( provider_config=original_config['provider'], authentication_config=original_config['auth'], docker_config=original_config.get('docker', {}), # NOTE: (might be a legacy issue) we call it # 'ray_head_default' in 'gcp-ray.yaml' - node_config=original_config['available_node_types']['ray.head.default'] + node_config=original_config['available_node_types'][head_node_type] ['node_config'], count=num_nodes, tags={}, @@ -179,7 +181,7 @@ def bulk_provision( cluster_name, terminate=terminate, provider_config=original_config['provider']) - return None + raise def teardown_cluster(cloud_name: str, cluster_name: ClusterName, @@ -319,11 +321,14 @@ def _post_provision_setup( cloud_name: str, cluster_name: ClusterName, cluster_yaml: str, provision_record: provision_common.ProvisionRecord, custom_resource: Optional[str]) -> provision_common.ClusterInfo: + config_from_yaml = common_utils.read_yaml(cluster_yaml) + provider_config = config_from_yaml.get('provider') cluster_info = provision.get_cluster_info(cloud_name, provision_record.region, - cluster_name.name_on_cloud) + cluster_name.name_on_cloud, + provider_config=provider_config) - if len(cluster_info.instances) > 1: + if cluster_info.num_instances > 1: # Only worker nodes have logs in the per-instance log directory. Head # node's log will be redirected to the main log file. per_instance_log_dir = metadata_utils.get_instance_log_dir( @@ -343,17 +348,9 @@ def _post_provision_setup( 'Could not find any head instance.') # TODO(suquark): Move wheel build here in future PRs. - config_from_yaml = common_utils.read_yaml(cluster_yaml) ip_list = cluster_info.get_feasible_ips() ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml) - # TODO(suquark): Handle TPU VMs when dealing with GCP later. - # if tpu_utils.is_tpu_vm_pod(handle.launched_resources): - # logger.info(f'{style.BRIGHT}Setting up TPU VM Pod workers...' - # f'{style.RESET_ALL}') - # RetryingVmProvisioner._tpu_pod_setup( - # None, handle.cluster_yaml, handle) - with rich_utils.safe_status( '[bold cyan]Launching - Waiting for SSH access[/]') as status: @@ -416,16 +413,19 @@ def _post_provision_setup( status.update( runtime_preparation_str.format(step=3, step_name='runtime')) full_ray_setup = True + ray_port = constants.SKY_REMOTE_RAY_PORT if not provision_record.is_instance_just_booted( head_instance.instance_id): # Check if head node Ray is alive - returncode = head_runner.run( + returncode, stdout, _ = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, - stream_logs=False) + stream_logs=False, + require_outputs=True) if returncode: logger.info('Ray cluster on head is not up. Restarting...') else: logger.debug('Ray cluster on head is up.') + ray_port = common_utils.decode_payload(stdout)['ray_port'] full_ray_setup = bool(returncode) if full_ray_setup: @@ -450,6 +450,11 @@ def _post_provision_setup( cluster_name.name_on_cloud, no_restart=not full_ray_setup, custom_resource=custom_resource, + # Pass the ray_port to worker nodes for backward compatibility + # as in some existing clusters the ray_port is not dumped with + # instance_setup._DUMP_RAY_PORTS. We should use the ray_port + # from the head node for worker nodes. + ray_port=ray_port, cluster_info=cluster_info, ssh_credentials=ssh_credentials) diff --git a/sky/resources.py b/sky/resources.py index f83d31c5f1b..b44470869a3 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -19,7 +19,6 @@ from sky.utils import log_utils from sky.utils import resources_utils from sky.utils import schemas -from sky.utils import tpu_utils from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -515,8 +514,6 @@ def _set_accelerators( if accelerator_args is None: accelerator_args = {} use_tpu_vm = accelerator_args.get('tpu_vm', True) - if use_tpu_vm: - tpu_utils.check_gcp_cli_include_tpu_vm() if self.instance_type is not None and use_tpu_vm: if self.instance_type != 'TPU-VM': with ux_utils.print_exception_no_traceback(): diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 98f9ebf4925..136183b0973 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -44,7 +44,7 @@ # e.g., when we add new events to skylet, or we fix a bug in skylet. # # TODO(zongheng,zhanghao): make the upgrading of skylet automatic? -SKYLET_VERSION = '4' +SKYLET_VERSION = '5' SKYLET_VERSION_FILE = '~/.sky/skylet_version' # `sky spot dashboard`-related diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 4287acd394b..712ede57729 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -10,14 +10,16 @@ import psutil import yaml +from sky import clouds from sky import sky_logging from sky.backends import cloud_vm_ray_backend +from sky.clouds import cloud_registry from sky.serve import serve_utils from sky.skylet import autostop_lib from sky.skylet import job_lib from sky.spot import spot_utils +from sky.utils import cluster_yaml_utils from sky.utils import common_utils -from sky.utils import remote_cluster_yaml_utils from sky.utils import ux_utils # Seconds of sleep between the processing of skylet events. @@ -104,8 +106,6 @@ class AutostopEvent(SkyletEvent): def __init__(self): super().__init__() autostop_lib.set_last_active_time_to_now() - self._ray_yaml_path = ( - remote_cluster_yaml_utils.get_cluster_yaml_absolute_path()) def _run(self): autostop_config = autostop_lib.get_autostop_config() @@ -139,10 +139,16 @@ def _stop_cluster(self, autostop_config): cloud_vm_ray_backend.CloudVmRayBackend.NAME): autostop_lib.set_autostopping_started() - config = remote_cluster_yaml_utils.load_cluster_yaml() - provider_name = remote_cluster_yaml_utils.get_provider_name(config) + config_path = os.path.abspath( + os.path.expanduser( + cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH)) + config = common_utils.read_yaml(config_path) + provider_name = cluster_yaml_utils.get_provider_name(config) + cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name) + assert cloud is not None, f'Unknown cloud: {provider_name}' - if provider_name in ('aws', 'gcp'): + if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion. + RAY_PROVISIONER_SKYPILOT_TERMINATOR): logger.info('Using new provisioner to stop the cluster.') self._stop_cluster_with_new_provisioner(autostop_config, config, provider_name) @@ -154,8 +160,7 @@ def _stop_cluster(self, autostop_config): # Even for !is_cluster_multinode, we want to call this to replace # cache_stopped_nodes. - self._replace_yaml_for_stopping(self._ray_yaml_path, - autostop_config.down) + self._replace_yaml_for_stopping(config_path, autostop_config.down) # Use environment variables to disable the ray usage collection (to # avoid overheads and potential issues with the usage) as sdk does @@ -182,7 +187,7 @@ def _stop_cluster(self, autostop_config): logger.info('Running ray up.') script = (cloud_vm_ray_backend. write_ray_up_script_with_patched_launch_hash_fn( - self._ray_yaml_path, + config_path, ray_up_kwargs={'restart_only': True})) # Passing env inherited from os.environ is technically not # needed, because we call `python