Skip to content

Commit

Permalink
Rewrite the logic of updating VPN IPs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Conless committed Nov 4, 2024
1 parent 3178cc0 commit 3f875d5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
32 changes: 16 additions & 16 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,8 +2226,8 @@ def update_cluster_ips(
max_attempts: int = 1,
internal_ips: Optional[List[Optional[str]]] = None,
external_ips: Optional[List[Optional[str]]] = None,
cluster_info: Optional[provision_common.ClusterInfo] = None,
vpn_config: Optional[vpn_utils.VPNConfig] = None) -> None:
cluster_info: Optional[provision_common.ClusterInfo] = None
) -> None:
"""Updates the cluster IPs cached in the handle.
We cache the cluster IPs in the handle to avoid having to retrieve
Expand Down Expand Up @@ -2259,12 +2259,10 @@ def update_cluster_ips(
if cluster_info is not None:
self.cached_cluster_info = cluster_info
# Update cluster config by private IPs (if available).
if vpn_config is not None:
self._setup_vpn(vpn_config)
if self.vpn_config is not None:
vpn_utils.rewrite_cluster_info_by_vpn(self.cached_cluster_info,
self.cluster_name,
vpn_config)
self.vpn_config = vpn_config.to_backend_config()
self.vpn_config)
cluster_feasible_ips = self.cached_cluster_info.get_feasible_ips()
cluster_internal_ips = self.cached_cluster_info.get_feasible_ips(
force_internal_ips=True)
Expand Down Expand Up @@ -2462,15 +2460,14 @@ def setup_docker_user(self, cluster_config_file: str):
cluster_config_file)
self.docker_user = docker_user

def _setup_vpn(self, vpn_config: vpn_utils.VPNConfig) -> None:
assert self.cached_cluster_info is not None, (
'cluster_info should be set before setting up VPN.')
def setup_vpn(self, vpn_config: vpn_utils.VPNConfig) -> None:
self.vpn_config = vpn_config.to_backend_config()
runners = self.get_command_runners()

def _run_setup_commands(id_runner):
node_id, runner = id_runner
command = vpn_config.get_setup_command(self.cluster_name, node_id)
returncode, stdout, stderr = runner.run([command],
returncode, stdout, stderr = runner.run(command,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
Expand Down Expand Up @@ -2915,10 +2912,14 @@ def _provision(
# from cluster_info.
handle.docker_user = cluster_info.docker_user
handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS,
cluster_info=cluster_info,
vpn_config=task.vpn_config)
cluster_info=cluster_info)
handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS)

# If VPN is used, we need to reconfigure cluster IPs.
if task.vpn_config is not None:
handle.setup_vpn(task.vpn_config)
handle.update_cluster_ips(cluster_info=cluster_info)

# Update launched resources.
handle.launched_resources = handle.launched_resources.copy(
region=provision_record.region, zone=provision_record.zone)
Expand Down Expand Up @@ -4155,11 +4156,10 @@ def post_teardown_cleanup(self,
raise
if terminate and handle.vpn_config is not None:
# Delete the VPN records when terminating the cluster.
vpn_config = vpn_utils.VPNConfig.from_backend_config(
handle.vpn_config)
if handle.cached_cluster_info is not None:
for node_id in range(handle.cached_cluster_info.num_instances):
vpn_config.remove_node(handle.cluster_name, node_id)
vpn_utils.remove_nodes_from_vpn(handle.cached_cluster_info,
handle.cluster_name,
handle.vpn_config)

# The cluster file must exist because the cluster_yaml will only
# be removed after the cluster entry in the database is removed.
Expand Down
30 changes: 20 additions & 10 deletions sky/utils/vpn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def use_vpn_ip(self) -> bool:
return self._use_vpn_ip

@staticmethod
def from_backend_config(config: Dict[str, Any]) -> 'VPNConfig':
def from_backend_config(**kwargs) -> 'VPNConfig':
"""Create a VPN configuration from the backend configuration."""
vpn_type = config.pop('type')
vpn_type = kwargs.pop('vpn_type')
if vpn_type == 'tailscale':
return TailscaleConfig.from_backend_config(config)
return TailscaleConfig.from_backend_config(**kwargs)
with ux_utils.print_exception_no_traceback():
raise ValueError('Unsupported VPN type. Please check the backend '
'configuration.')
Expand Down Expand Up @@ -213,17 +213,16 @@ def to_yaml_config(self) -> Dict[str, Any]:
return {'tailscale': True}

@staticmethod
def from_backend_config(config: Dict[str, Any]) -> 'TailscaleConfig':
assert config.get('auth_key') is not None and config.get(
'api_key') is not None and config.get('tailnet') is not None, (
def from_backend_config(**kwargs) -> 'TailscaleConfig':
assert kwargs.get('auth_key') is not None and kwargs.get(
'api_key') is not None and kwargs.get('tailnet') is not None, (
'Tailscale VPN configuration is missing required '
'fields. Please check the backend configuration.')

return TailscaleConfig(**config)
return TailscaleConfig(**kwargs)

def to_backend_config(self) -> Dict[str, Any]:
return {
'type': self._TYPE,
'vpn_type': self._TYPE,
'auth_key': self._auth_key,
'api_key': self._api_key,
'tailnet': self._tailnet,
Expand All @@ -234,16 +233,27 @@ def to_backend_config(self) -> Dict[str, Any]:
def rewrite_cluster_info_by_vpn(
cluster_info: common.ClusterInfo,
cluster_name: str,
vpn_config: VPNConfig,
vpn_config_dict: Dict[str, Any],
) -> None:
"""Rewrite the cluster info with the VPN configuration."""
vpn_config = VPNConfig.from_backend_config(**vpn_config_dict)
if not vpn_config.use_vpn_ip:
return
instance_list = cluster_info.get_instances()
for (node_id, instance) in enumerate(instance_list):
instance.external_ip = vpn_config.get_private_ip(cluster_name, node_id)


def remove_nodes_from_vpn(
cluster_info: common.ClusterInfo,
cluster_name: str,
vpn_config_dict: Dict[str, Any],
) -> None:
vpn_config = VPNConfig.from_backend_config(**vpn_config_dict)
for node_id in range(cluster_info.num_instances):
vpn_config.remove_node(cluster_name, node_id)


def check_vpn_unchanged(new_config: Optional[VPNConfig],
old_config_dict: Optional[Dict[str, Any]],
expose_service: Optional[bool] = None) -> Optional[str]:
Expand Down

0 comments on commit 3f875d5

Please sign in to comment.