Skip to content

Commit

Permalink
New provisioner GCP (#2681)
Browse files Browse the repository at this point in the history
* init

* remove ray

* update config

* update

* update

* update

* complete bootstrapping

* add start instance

* fix

* fix

* fix

* update

* wait stopping instances

* support normal gcp tpus first

* fix gcp

* support get cluster info

* fix

* update

* wait for instance starting

* rename

* hide gcp package import

* fix

* fix

* update constants

* fix comments

* remove unused methods

* fix comments

* sync 'config' & 'constants' with upstream, Nov 16

* sync 'instace_utils' with the upstream, Nov 16

* fix typing

* parallelize provisioning

* Fix TPU node

* Fix TPU NAME env for tpu node

* implement bulk provision

* refactor selflink

* format

* reduce the sleep time for autostop

* provisioner version refactoring

* refactor

* Add logging

* avoid saving the provisioner version

* format

* format

* Fix scheduling field in config

* format

* fix public key content

* Fix provisioner version for azure

* Use ray port from head node for workers

* format

* fix ray_port

* fix smoke tests

* shorter sleep time

* refactor status refresh version

* [Provisioner] Support reserved instances in GCP (#2824)

* Support reserved instances

* remove min max count

* remove unecessary fields

* Add todo

* Add todo

* remove unused reseravation config

* Fix config.yaml tests

* format

* sync with the upstream (Dec 05, 23)

* set timeout and retries

* handle GCP creation errors

* Fix provisioning errors and improve error handling

* update blocklist for GCP

* refactor code for linting issues

* fix

* show instance status during assertion error

* Refactor error handling for failover

* adopt changes in #2854

* format

* retry for wait operation

* format

* fix typo

* fix interface

* more robust zone to region

* Fix tpu vm external IP setup

* Fix get node

* format

* revert for TPU VM pod

* Fix get_cluster_info call

* fix tab

* Fix timeout case

* remvoe \t

* GCP query statuses with new provisioner

* format

* fix import

* refactor query status

* fix stopped status

* Fix stopped status

* Add head ray start command

* Add back keys

* add workers

* Fix non stopped states

* Add more logs for autostop

* format

* increase job_docker job time

* better logging

* shorter time for recovering

* fix conflicting var

* change to V1

* fix comments

* refactor constants

* refactoring

* typo

* Fix max retry

* longer sleep time for job

* add detach setup

* revert --detach-setup

* shorter time for recovering

* more retries

* Update sky/provision/instance_setup.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* Update sky/backends/cloud_vm_ray_backend.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* format

* [Provisioner] New provisioner for GCP TPU VM (#2898)

* init

* test

* test ins_type

* fix

* format..

* wip

* remove TPU config

* fix node ips

* Fix TPU VM pod

* format

* use TPU VM as default

* Fix example for TPU VM

* format

* fix optimizer random dag

* set TPU-VM

* accelerator_args False

* backward compatibility

* add tpu filter for tests

* fix

* Fix

* fix status refresh for tpu VM pod

* Support autodown for TPU VM pod

* Allow multi-node TPU VM pod

* Allow multi-node TPU VM pod

* fix

* add execute for operation

* avoid from

* Wait for pending before set_labels

* format

* refactor constants

* Fix for API changes

* remove GCP failover handler v1

* format

* remove TPU VM pod specific codes as they have been moved to new provisioner

* Add error handling for TPU pod case

* fix

* fix multiple node calculation

* refactor tpu_utils to gcp_utils

* shorter time for recovering

* format

---------

Co-authored-by: Wei-Lin Chiang <weichiang@berkeley.edu>

* better error logging

* Fix logging for TPU VM

* Fix logging

* Add insufficientCapacity to error handler

* Avoid adding duplicated resources to blocked_resources

* Fix blocked resources

* address comment

* add comment

* Add comments

* format

* Fix num_node_ips

* format

* fix smoke test for preinstalled package

* shorter wait time for recovering

* Fix TPU VM pod stop

* format

* Update sky/provision/gcp/instance_utils.py

Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>

* update

* format

* Add debug message

* revert version for handle

* disable tpu name set

---------

Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>
Co-authored-by: Zongheng Yang <zongheng.y@gmail.com>
Co-authored-by: Wei-Lin Chiang <weichiang@berkeley.edu>
  • Loading branch information
4 people authored Jan 1, 2024
1 parent 394ec4a commit 318553b
Show file tree
Hide file tree
Showing 33 changed files with 2,937 additions and 1,163 deletions.
5 changes: 2 additions & 3 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ Available fields and semantics:
# zero cost) if the requested resources matches the reservation.
# Ref: https://cloud.google.com/compute/docs/instances/reservations-overview#consumption-type
specific_reservations:
# Only one element is allowed in this list, as GCP disallows multiple
# specific_reservations in a single request.
- projects/my-project/reservations/my-reservation
- projects/my-project/reservations/my-reservation1
- projects/my-project/reservations/my-reservation2
# Advanced Kubernetes configurations (optional).
kubernetes:
Expand Down
2 changes: 1 addition & 1 deletion examples/job_queue/job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ setup: |
run: |
timestamp=$(date +%s)
conda env list
for i in {1..140}; do
for i in {1..180}; do
echo "$timestamp $i"
sleep 1
done
2 changes: 1 addition & 1 deletion examples/job_queue/job_docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ setup: |
run: |
timestamp=$(date +%s)
conda env list
for i in {1..120}; do
for i in {1..180}; do
echo "$timestamp $i"
sleep 1
done
23 changes: 23 additions & 0 deletions sky/adaptors/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# pylint: disable=import-outside-toplevel
import functools
import json

googleapiclient = None
google = None
Expand Down Expand Up @@ -82,3 +83,25 @@ def credential_error_exception():
"""CredentialError exception."""
from google.auth import exceptions
return exceptions.DefaultCredentialsError


@import_package
def get_credentials(cred_type: str, credentials_field: str):
"""Get GCP credentials."""
from google.oauth2 import service_account
from google.oauth2.credentials import Credentials as OAuthCredentials

if cred_type == 'service_account':
# If parsing the gcp_credentials failed, then the user likely made a
# mistake in copying the credentials into the config yaml.
try:
service_account_info = json.loads(credentials_field)
except json.decoder.JSONDecodeError as e:
raise RuntimeError('gcp_credentials found in cluster yaml file but '
'formatted improperly.') from e
credentials = service_account.Credentials.from_service_account_info(
service_account_info)
elif cred_type == 'credentials_token':
# Otherwise the credentials type must be credentials_token.
credentials = OAuthCredentials(credentials_field)
return credentials
180 changes: 36 additions & 144 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Util constants/functions for the backends."""
from datetime import datetime
import enum
import json
import os
import pathlib
import pprint
Expand Down Expand Up @@ -39,19 +38,19 @@
from sky import skypilot_config
from sky import status_lib
from sky.backends import onprem_utils
from sky.clouds import cloud_registry
from sky.clouds.utils import gcp_utils
from sky.provision import instance_setup
from sky.skylet import constants
from sky.skylet import log_lib
from sky.usage import usage_lib
from sky.utils import cluster_yaml_utils
from sky.utils import command_runner
from sky.utils import common_utils
from sky.utils import controller_utils
from sky.utils import env_options
from sky.utils import remote_cluster_yaml_utils
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -938,11 +937,6 @@ def write_cluster_config(
r for r, available_resources in reservations.items()
if r in specific_reservations and available_resources > 0
]
available_specific_reservations = sum(
available_resources for r, available_resources in reservations.items()
if r in specific_reservations)
num_specific_reserved_workers = max(
min(available_specific_reservations - 1, num_nodes - 1), 0)

assert cluster_name is not None
credentials = sky_check.get_cloud_credential_file_mounts()
Expand Down Expand Up @@ -1033,7 +1027,6 @@ def write_cluster_config(
# GCP only:
'gcp_project_id': gcp_project_id,
'specific_reservations': filtered_specific_reservations,
'num_specific_reserved_workers': num_specific_reserved_workers,

# Conda setup
'conda_installation_commands':
Expand All @@ -1054,7 +1047,7 @@ def write_cluster_config(
'sky_local_path': str(local_wheel_path),
# Add yaml file path to the template variables.
'sky_ray_yaml_remote_path':
remote_cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH,
cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH,
'sky_ray_yaml_local_path':
tmp_yaml_path
if not isinstance(cloud, clouds.Local) else yaml_path,
Expand Down Expand Up @@ -1112,7 +1105,7 @@ def write_cluster_config(
usage_lib.messages.usage.update_ray_yaml(yaml_path)

# For TPU nodes. TPU VMs do not need TPU_NAME.
if tpu_utils.is_tpu(to_provision) and not tpu_utils.is_tpu_vm(to_provision):
if gcp_utils.is_tpu(to_provision) and not gcp_utils.is_tpu_vm(to_provision):
tpu_name = resources_vars.get('tpu_name')
if tpu_name is None:
tpu_name = cluster_name
Expand Down Expand Up @@ -1533,20 +1526,22 @@ def _query_head_ip_with_retries(cluster_yaml: str,


@timeline.event
def get_node_ips(cluster_yaml: str,
expected_num_nodes: int,
handle: Optional[
'cloud_vm_ray_backend.CloudVmRayResourceHandle'] = None,
head_ip_max_attempts: int = 1,
worker_ip_max_attempts: int = 1,
get_internal_ips: bool = False) -> List[str]:
def get_node_ips(
cluster_yaml: str,
expected_num_nodes: int,
# TODO: remove this argument once we remove the legacy on-prem
# support.
handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle',
head_ip_max_attempts: int = 1,
worker_ip_max_attempts: int = 1,
get_internal_ips: bool = False) -> List[str]:
"""Returns the IPs of all nodes in the cluster, with head node at front.
Args:
cluster_yaml: Path to the cluster yaml.
expected_num_nodes: Expected number of nodes in the cluster.
handle: Cloud VM Ray resource handle. It is only required for TPU VM or
on-prem clusters.
handle: Cloud VM Ray resource handle. It is only required for on-prem
clusters.
head_ip_max_attempts: Max attempts to get head ip.
worker_ip_max_attempts: Max attempts to get worker ips.
get_internal_ips: Whether to get internal IPs. When False, it is still
Expand All @@ -1557,34 +1552,21 @@ def get_node_ips(cluster_yaml: str,
exceptions.FetchIPError: if we failed to get the IPs. e.reason is
HEAD or WORKER.
"""
# When ray up launches TPU VM Pod, Pod workers (except for the head)
# won't be connected to Ray cluster. Thus "ray get-worker-ips"
# won't work and we need to query the node IPs with gcloud as
# implmented in _get_tpu_vm_pod_ips.
ray_config = common_utils.read_yaml(cluster_yaml)
# Use the new provisioner for AWS.
if '.aws' in ray_config['provider']['module']:
provider_name = cluster_yaml_utils.get_provider_name(ray_config)
cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name)
assert cloud is not None, provider_name

if cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT:
metadata = provision_lib.get_cluster_info(
'aws', ray_config['provider']['region'], ray_config['cluster_name'])
provider_name, ray_config['provider']['region'],
ray_config['cluster_name'], ray_config['provider'])
if len(metadata.instances) < expected_num_nodes:
# Simulate the exception when Ray head node is not up.
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
return metadata.get_feasible_ips(get_internal_ips)

use_tpu_vm = ray_config['provider'].get('_has_tpus', False)
if use_tpu_vm:
assert expected_num_nodes == 1, (
'TPU VM only supports single node for now.')
assert handle is not None, 'handle is required for TPU VM.'
try:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
except exceptions.CommandError as e:
raise exceptions.FetchIPError(
exceptions.FetchIPError.Reason.HEAD) from e
if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources):
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
return ips

if get_internal_ips:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
ray_config['provider']['use_internal_ips'] = True
Expand Down Expand Up @@ -1658,74 +1640,6 @@ def get_node_ips(cluster_yaml: str,
return head_ip_list + worker_ips


@timeline.event
def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
get_internal_ips: bool = False) -> List[str]:
"""Returns the IPs of all TPU VM Pod workers using gcloud."""

cluster_name = ray_config['cluster_name']
zone = ray_config['provider']['availability_zone']
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
f'"(labels.ray-cluster-name={cluster_name})" '
f'--zone={zone} --format="value(name)"')
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
query_cmd,
'Failed to run gcloud to get TPU VM IDs.',
stderr=stdout + stderr)
if len(stdout) == 0:
logger.debug('No TPU VMs found with cluster name '
f'{cluster_name} in zone {zone}.')
if len(stdout.splitlines()) > 1:
# Rare case, this could mean resource leakage. Hint user.
logger.warning('Found more than one TPU VM/Pod with the same cluster '
f'name {cluster_name} in zone {zone}.')

all_ips = []
for tpu_id in stdout.splitlines():
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
f' --zone {zone} --format=json')
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
os.devnull,
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
tpuvm_cmd,
'Failed to run gcloud tpu-vm describe.',
stderr=stdout + stderr)

tpuvm_json = json.loads(stdout)
if tpuvm_json['state'] != 'READY':
# May be a leaked preempted resource, or terminated by user in the
# console, or still in the process of being created.
ux_utils.console_newline()
logger.debug(f'TPU VM {tpu_id} is in {tpuvm_json["state"]} '
'state. Skipping IP query... '
'Hint: make sure it is not leaked.')
continue

ips = []
for endpoint in tpuvm_json['networkEndpoints']:
# Note: if TPU VM is being preempted, its IP field may not exist.
# We use get() to avoid KeyError.
if get_internal_ips:
ip = endpoint.get('ipAddress', None)
else:
ip = endpoint['accessConfig'].get('externalIp', None)
if ip is not None:
ips.append(ip)
all_ips.extend(ips)

return all_ips


def check_network_connection():
# Tolerate 3 retries as it is observed that connections can fail.
adapter = adapters.HTTPAdapter(max_retries=retry_lib.Retry(total=3))
Expand Down Expand Up @@ -1839,15 +1753,12 @@ def _query_cluster_status_via_cloud_api(
# correctly yet.
ray_config = common_utils.read_yaml(handle.cluster_yaml)
provider_config = ray_config['provider']
region = provider_config.get('region') or provider_config.get('location')
zone = ray_config['provider'].get('availability_zone')
kwargs = {}
if isinstance(handle.launched_resources.cloud, clouds.GCP):
kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False)

# Query the cloud provider.
# TODO(suquark): move implementations of more clouds here
if isinstance(handle.launched_resources.cloud, clouds.AWS):
cloud = handle.launched_resources.cloud
assert cloud is not None, handle
if cloud.STATUS_VERSION >= clouds.StatusVersion.SKYPILOT:
cloud_name = repr(handle.launched_resources.cloud)
try:
node_status_dict = provision_lib.query_instances(
Expand All @@ -1864,30 +1775,12 @@ def _query_cluster_status_via_cloud_api(
f'status: {common_utils.format_exception(e, use_bracket=True)}'
)
else:
node_statuses = handle.launched_resources.cloud.query_status(
region = provider_config.get('region') or provider_config.get(
'location')
zone = ray_config['provider'].get('availability_zone')
node_statuses = cloud.query_status(
cluster_name_on_cloud,
tag_filter_for_cluster(cluster_name_on_cloud), region, zone,
**kwargs)
# GCP does not clean up preempted TPU VMs. We remove it ourselves.
# TODO(wei-lin): handle multi-node cases.
# TODO(zhwu): this should be moved into the GCP class, after we refactor
# the cluster termination, as the preempted TPU VM should always be
# removed.
if kwargs.get('use_tpu_vm', False) and len(node_statuses) == 0:
logger.debug(
f'Terminating preempted TPU VM cluster {cluster_name_in_hint}')
backend = backends.CloudVmRayBackend()
# Do not use refresh cluster status during teardown, as that will
# cause infinite recursion by calling cluster status refresh
# again.

# Post teardown cleanup be done later in this function, which will
# remove the cluster entry from the status table & the ssh config file.
backend.teardown_no_lock(handle,
terminate=True,
purge=False,
post_teardown_cleanup=False,
refresh_cluster_status=False)
tag_filter_for_cluster(cluster_name_on_cloud), region, zone)
return node_statuses


Expand Down Expand Up @@ -2036,9 +1929,8 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
# in the worst case we time out in the `ray status` SSH command
# below.
external_ips = handle.cached_external_ips
# This happens to a stopped TPU VM as we use gcloud to query the IP.
# Or user interrupt the `sky launch` process before the first time
# resources handle is written back to local database.
# This happens when user interrupt the `sky launch` process before
# the first time resources handle is written back to local database.
# This is helpful when user interrupt after the provision is done
# and before the skylet is restarted. After #2304 is merged, this
# helps keep the cluster status to INIT after `sky status -r`, so
Expand Down Expand Up @@ -2076,13 +1968,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
f'-- stdout --\n{output}\n-- stderr --\n{stderr}')

ready_head, ready_workers = _count_healthy_nodes_from_ray(output)

if ready_head + ready_workers == handle.launched_nodes:
total_nodes = handle.launched_nodes * handle.num_ips_per_node
if ready_head + ready_workers == total_nodes:
return True
raise RuntimeError(
f'Refreshing status ({cluster_name!r}): ray status not showing '
f'all nodes ({ready_head + ready_workers}/'
f'{handle.launched_nodes}); output: {output}; stderr: {stderr}')
f'{total_nodes}); output: {output}; stderr: {stderr}')
except exceptions.FetchIPError:
logger.debug(
f'Refreshing status ({cluster_name!r}) failed to get IPs.')
Expand Down
Loading

0 comments on commit 318553b

Please sign in to comment.