diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 54f160afcee..e24184ceb69 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2226,8 +2226,8 @@ def update_cluster_ips( max_attempts: int = 1, internal_ips: Optional[List[Optional[str]]] = None, external_ips: Optional[List[Optional[str]]] = None, - cluster_info: Optional[provision_common.ClusterInfo] = None, - vpn_config: Optional[vpn_utils.VPNConfig] = None) -> None: + cluster_info: Optional[provision_common.ClusterInfo] = None + ) -> None: """Updates the cluster IPs cached in the handle. We cache the cluster IPs in the handle to avoid having to retrieve @@ -2259,12 +2259,10 @@ def update_cluster_ips( if cluster_info is not None: self.cached_cluster_info = cluster_info # Update cluster config by private IPs (if available). - if vpn_config is not None: - self._setup_vpn(vpn_config) + if self.vpn_config is not None: vpn_utils.rewrite_cluster_info_by_vpn(self.cached_cluster_info, self.cluster_name, - vpn_config) - self.vpn_config = vpn_config.to_backend_config() + self.vpn_config) cluster_feasible_ips = self.cached_cluster_info.get_feasible_ips() cluster_internal_ips = self.cached_cluster_info.get_feasible_ips( force_internal_ips=True) @@ -2462,15 +2460,14 @@ def setup_docker_user(self, cluster_config_file: str): cluster_config_file) self.docker_user = docker_user - def _setup_vpn(self, vpn_config: vpn_utils.VPNConfig) -> None: - assert self.cached_cluster_info is not None, ( - 'cluster_info should be set before setting up VPN.') + def setup_vpn(self, vpn_config: vpn_utils.VPNConfig) -> None: + self.vpn_config = vpn_config.to_backend_config() runners = self.get_command_runners() def _run_setup_commands(id_runner): node_id, runner = id_runner command = vpn_config.get_setup_command(self.cluster_name, node_id) - returncode, stdout, stderr = runner.run([command], + returncode, stdout, stderr = runner.run(command, require_outputs=True, stream_logs=False) subprocess_utils.handle_returncode( @@ -2915,10 +2912,14 @@ def _provision( # from cluster_info. handle.docker_user = cluster_info.docker_user handle.update_cluster_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS, - cluster_info=cluster_info, - vpn_config=task.vpn_config) + cluster_info=cluster_info) handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS) + # If VPN is used, we need to reconfigure cluster IPs. + if task.vpn_config is not None: + handle.setup_vpn(task.vpn_config) + handle.update_cluster_ips(cluster_info=cluster_info) + # Update launched resources. handle.launched_resources = handle.launched_resources.copy( region=provision_record.region, zone=provision_record.zone) @@ -4155,11 +4156,10 @@ def post_teardown_cleanup(self, raise if terminate and handle.vpn_config is not None: # Delete the VPN records when terminating the cluster. - vpn_config = vpn_utils.VPNConfig.from_backend_config( - handle.vpn_config) if handle.cached_cluster_info is not None: - for node_id in range(handle.cached_cluster_info.num_instances): - vpn_config.remove_node(handle.cluster_name, node_id) + vpn_utils.remove_nodes_from_vpn(handle.cached_cluster_info, + handle.cluster_name, + handle.vpn_config) # The cluster file must exist because the cluster_yaml will only # be removed after the cluster entry in the database is removed. diff --git a/sky/utils/vpn_utils.py b/sky/utils/vpn_utils.py index 669cf692a01..baac13cbce1 100644 --- a/sky/utils/vpn_utils.py +++ b/sky/utils/vpn_utils.py @@ -74,11 +74,11 @@ def use_vpn_ip(self) -> bool: return self._use_vpn_ip @staticmethod - def from_backend_config(config: Dict[str, Any]) -> 'VPNConfig': + def from_backend_config(**kwargs) -> 'VPNConfig': """Create a VPN configuration from the backend configuration.""" - vpn_type = config.pop('type') + vpn_type = kwargs.pop('vpn_type') if vpn_type == 'tailscale': - return TailscaleConfig.from_backend_config(config) + return TailscaleConfig.from_backend_config(**kwargs) with ux_utils.print_exception_no_traceback(): raise ValueError('Unsupported VPN type. Please check the backend ' 'configuration.') @@ -213,17 +213,16 @@ def to_yaml_config(self) -> Dict[str, Any]: return {'tailscale': True} @staticmethod - def from_backend_config(config: Dict[str, Any]) -> 'TailscaleConfig': - assert config.get('auth_key') is not None and config.get( - 'api_key') is not None and config.get('tailnet') is not None, ( + def from_backend_config(**kwargs) -> 'TailscaleConfig': + assert kwargs.get('auth_key') is not None and kwargs.get( + 'api_key') is not None and kwargs.get('tailnet') is not None, ( 'Tailscale VPN configuration is missing required ' 'fields. Please check the backend configuration.') - - return TailscaleConfig(**config) + return TailscaleConfig(**kwargs) def to_backend_config(self) -> Dict[str, Any]: return { - 'type': self._TYPE, + 'vpn_type': self._TYPE, 'auth_key': self._auth_key, 'api_key': self._api_key, 'tailnet': self._tailnet, @@ -234,9 +233,10 @@ def to_backend_config(self) -> Dict[str, Any]: def rewrite_cluster_info_by_vpn( cluster_info: common.ClusterInfo, cluster_name: str, - vpn_config: VPNConfig, + vpn_config_dict: Dict[str, Any], ) -> None: """Rewrite the cluster info with the VPN configuration.""" + vpn_config = VPNConfig.from_backend_config(**vpn_config_dict) if not vpn_config.use_vpn_ip: return instance_list = cluster_info.get_instances() @@ -244,6 +244,16 @@ def rewrite_cluster_info_by_vpn( instance.external_ip = vpn_config.get_private_ip(cluster_name, node_id) +def remove_nodes_from_vpn( + cluster_info: common.ClusterInfo, + cluster_name: str, + vpn_config_dict: Dict[str, Any], +) -> None: + vpn_config = VPNConfig.from_backend_config(**vpn_config_dict) + for node_id in range(cluster_info.num_instances): + vpn_config.remove_node(cluster_name, node_id) + + def check_vpn_unchanged(new_config: Optional[VPNConfig], old_config_dict: Optional[Dict[str, Any]], expose_service: Optional[bool] = None) -> Optional[str]: