From e6a3b830fb2a12871815773af6171d42e0416e89 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Sat, 28 Sep 2024 23:04:52 -0700 Subject: [PATCH] [k8s] Fix incluster auth after multi-context support (#4014) * Make incluster auth work * lint * rename * rename * pop allowed_contexts from config * lint * comments * comments * lint --- sky/authentication.py | 5 ++ sky/clouds/kubernetes.py | 42 +++++++++++--- sky/provision/kubernetes/config.py | 9 +-- sky/provision/kubernetes/instance.py | 13 +++-- sky/provision/kubernetes/network_utils.py | 15 ++--- sky/provision/kubernetes/utils.py | 69 +++++++++++++++++------ sky/utils/command_runner.py | 2 +- sky/utils/command_runner.pyi | 2 +- sky/utils/controller_utils.py | 8 +++ 9 files changed, 122 insertions(+), 43 deletions(-) diff --git a/sky/authentication.py b/sky/authentication.py index 67b4bcd576f..eb51aad02ad 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -380,6 +380,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name context = config['provider'].get( 'context', kubernetes_utils.get_current_kube_config_context_name()) + if context == kubernetes_utils.IN_CLUSTER_REGION: + # If the context is set to IN_CLUSTER_REGION, we are running in a pod + # with in-cluster configuration. We need to set the context to None + # to use the mounted service account. + context = None namespace = config['provider'].get( 'namespace', kubernetes_utils.get_kube_config_context_namespace(context)) diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 2c1e753bccf..da85246e9ea 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -129,11 +129,24 @@ def _log_skipped_contexts_once(cls, skipped_contexts: Tuple[str, 'Ignoring these contexts.') @classmethod - def _existing_allowed_contexts(cls) -> List[str]: - """Get existing allowed contexts.""" + def _existing_allowed_contexts(cls) -> List[Optional[str]]: + """Get existing allowed contexts. + + If None is returned in the list, it means that we are running in a pod + with in-cluster auth. In this case, we specify None context, which will + use the service account mounted in the pod. + """ all_contexts = kubernetes_utils.get_all_kube_config_context_names() - if all_contexts is None: + if len(all_contexts) == 0: return [] + if all_contexts == [None]: + # If only one context is found and it is None, we are running in a + # pod with in-cluster auth. In this case, we allow it to be used + # without checking against allowed_contexts. + # TODO(romilb): We may want check in-cluster auth against + # allowed_contexts in the future by adding a special context name + # for in-cluster auth. + return [None] all_contexts = set(all_contexts) allowed_contexts = skypilot_config.get_nested( @@ -164,7 +177,15 @@ def regions_with_offering(cls, instance_type: Optional[str], del accelerators, zone, use_spot # unused existing_contexts = cls._existing_allowed_contexts() - regions = [clouds.Region(context) for context in existing_contexts] + regions = [] + for context in existing_contexts: + if context is None: + # If running in-cluster, we allow the region to be set to the + # singleton region since there is no context name available. + regions.append(clouds.Region( + kubernetes_utils.IN_CLUSTER_REGION)) + else: + regions.append(clouds.Region(context)) if region is not None: regions = [r for r in regions if r.name == region] @@ -541,13 +562,20 @@ def instance_type_exists(self, instance_type: str) -> bool: def validate_region_zone(self, region: Optional[str], zone: Optional[str]): if region == self._LEGACY_SINGLETON_REGION: # For backward compatibility, we allow the region to be set to the - # legacy singletonton region. + # legacy singleton region. # TODO: Remove this after 0.9.0. return region, zone + if region == kubernetes_utils.IN_CLUSTER_REGION: + # If running incluster, we set region to IN_CLUSTER_REGION + # since there is no context name available. + return region, zone + all_contexts = kubernetes_utils.get_all_kube_config_context_names() - if all_contexts is None: - all_contexts = [] + if all_contexts == [None]: + # If [None] context is returned, use the singleton region since we + # are running in a pod with in-cluster auth. + all_contexts = [kubernetes_utils.IN_CLUSTER_REGION] if region not in all_contexts: raise ValueError( f'Context {region} not found in kubeconfig. Kubernetes only ' diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index e377f3029b8..370430720f0 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -247,7 +247,8 @@ def _get_resource(container_resources: Dict[str, Any], resource_name: str, def _configure_autoscaler_service_account( - namespace: str, context: str, provider_config: Dict[str, Any]) -> None: + namespace: str, context: Optional[str], + provider_config: Dict[str, Any]) -> None: account_field = 'autoscaler_service_account' if account_field not in provider_config: logger.info('_configure_autoscaler_service_account: ' @@ -281,7 +282,7 @@ def _configure_autoscaler_service_account( f'{created_msg(account_field, name)}') -def _configure_autoscaler_role(namespace: str, context: str, +def _configure_autoscaler_role(namespace: str, context: Optional[str], provider_config: Dict[str, Any], role_field: str) -> None: """ Reads the role from the provider config, creates if it does not exist. @@ -330,7 +331,7 @@ def _configure_autoscaler_role(namespace: str, context: str, def _configure_autoscaler_role_binding( namespace: str, - context: str, + context: Optional[str], provider_config: Dict[str, Any], binding_field: str, override_name: Optional[str] = None, @@ -620,7 +621,7 @@ def _configure_fuse_mounting(provider_config: Dict[str, Any]) -> None: f'in namespace {fuse_device_manager_namespace!r}') -def _configure_services(namespace: str, context: str, +def _configure_services(namespace: str, context: Optional[str], provider_config: Dict[str, Any]) -> None: service_field = 'services' if service_field not in provider_config: diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index f9ee75e466b..8da13d5ad0f 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -302,7 +302,8 @@ def _check_init_containers(pod): time.sleep(1) -def _set_env_vars_in_pods(namespace: str, context: str, new_pods: List): +def _set_env_vars_in_pods(namespace: str, context: Optional[str], + new_pods: List): """Setting environment variables in pods. Once all containers are ready, we can exec into them and set env vars. @@ -330,7 +331,7 @@ def _set_env_vars_in_pods(namespace: str, context: str, new_pods: List): new_pod.metadata.name, rc, stdout) -def _check_user_privilege(namespace: str, context: str, +def _check_user_privilege(namespace: str, context: Optional[str], new_nodes: List) -> None: # Checks if the default user has sufficient privilege to set up # the kubernetes instance pod. @@ -366,7 +367,8 @@ def _check_user_privilege(namespace: str, context: str, 'from the image.') -def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None: +def _setup_ssh_in_pods(namespace: str, context: Optional[str], + new_nodes: List) -> None: # Setting up ssh for the pod instance. This is already setup for # the jump pod so it does not need to be run for it. set_k8s_ssh_cmd = ( @@ -410,7 +412,7 @@ def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None: logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}') -def _label_pod(namespace: str, context: str, pod_name: str, +def _label_pod(namespace: str, context: Optional[str], pod_name: str, label: Dict[str, str]) -> None: """Label a pod.""" kubernetes.core_api(context).patch_namespaced_pod( @@ -647,7 +649,8 @@ def stop_instances( raise NotImplementedError() -def _terminate_node(namespace: str, context: str, pod_name: str) -> None: +def _terminate_node(namespace: str, context: Optional[str], + pod_name: str) -> None: """Terminate a pod.""" logger.debug('terminate_instances: calling delete_namespaced_pod') try: diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index a1d919a6766..b16482e5072 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -132,7 +132,7 @@ def fill_ingress_template(namespace: str, service_details: List[Tuple[str, int, def create_or_replace_namespaced_ingress( - namespace: str, context: str, ingress_name: str, + namespace: str, context: Optional[str], ingress_name: str, ingress_spec: Dict[str, Union[str, int]]) -> None: """Creates an ingress resource for the specified service.""" networking_api = kubernetes.networking_api(context) @@ -156,7 +156,7 @@ def create_or_replace_namespaced_ingress( _request_timeout=kubernetes.API_TIMEOUT) -def delete_namespaced_ingress(namespace: str, context: str, +def delete_namespaced_ingress(namespace: str, context: Optional[str], ingress_name: str) -> None: """Deletes an ingress resource.""" networking_api = kubernetes.networking_api(context) @@ -171,7 +171,7 @@ def delete_namespaced_ingress(namespace: str, context: str, def create_or_replace_namespaced_service( - namespace: str, context: str, service_name: str, + namespace: str, context: Optional[str], service_name: str, service_spec: Dict[str, Union[str, int]]) -> None: """Creates a service resource for the specified service.""" core_api = kubernetes.core_api(context) @@ -208,7 +208,7 @@ def delete_namespaced_service(namespace: str, service_name: str) -> None: raise e -def ingress_controller_exists(context: str, +def ingress_controller_exists(context: Optional[str], ingress_class_name: str = 'nginx') -> bool: """Checks if an ingress controller exists in the cluster.""" networking_api = kubernetes.networking_api(context) @@ -220,7 +220,7 @@ def ingress_controller_exists(context: str, def get_ingress_external_ip_and_ports( - context: str, + context: Optional[str], namespace: str = 'ingress-nginx' ) -> Tuple[Optional[str], Optional[Tuple[int, int]]]: """Returns external ip and ports for the ingress controller.""" @@ -258,7 +258,7 @@ def get_ingress_external_ip_and_ports( return external_ip, None -def get_loadbalancer_ip(context: str, +def get_loadbalancer_ip(context: Optional[str], namespace: str, service_name: str, timeout: int = 0) -> Optional[str]: @@ -284,7 +284,8 @@ def get_loadbalancer_ip(context: str, return ip -def get_pod_ip(context: str, namespace: str, pod_name: str) -> Optional[str]: +def get_pod_ip(context: Optional[str], namespace: str, + pod_name: str) -> Optional[str]: """Returns the IP address of the pod.""" core_api = kubernetes.core_api(context) pod = core_api.read_namespaced_pod(pod_name, diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index f31652030a5..0498cc7f59f 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -33,6 +33,7 @@ # TODO(romilb): Move constants to constants.py DEFAULT_NAMESPACE = 'default' +IN_CLUSTER_REGION = 'in-cluster' DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account' @@ -310,7 +311,7 @@ class KarpenterLabelFormatter(SkyPilotLabelFormatter): @functools.lru_cache() def detect_gpu_label_formatter( - context: str + context: Optional[str] ) -> Tuple[Optional[GPULabelFormatter], Dict[str, List[Tuple[str, str]]]]: """Detects the GPU label formatter for the Kubernetes cluster @@ -342,7 +343,7 @@ def detect_gpu_label_formatter( @functools.lru_cache(maxsize=10) -def detect_gpu_resource(context: str) -> Tuple[bool, Set[str]]: +def detect_gpu_resource(context: Optional[str]) -> Tuple[bool, Set[str]]: """Checks if the Kubernetes cluster has nvidia.com/gpu resource. If nvidia.com/gpu resource is missing, that typically means that the @@ -402,7 +403,7 @@ def get_all_pods_in_kubernetes_cluster( return pods -def check_instance_fits(context: str, +def check_instance_fits(context: Optional[str], instance: str) -> Tuple[bool, Optional[str]]: """Checks if the instance fits on the Kubernetes cluster. @@ -488,7 +489,7 @@ def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType', return fits, reason -def get_gpu_label_key_value(context: str, +def get_gpu_label_key_value(context: Optional[str], acc_type: str, check_mode=False) -> Tuple[str, str]: """Returns the label key and value for the given GPU type. @@ -651,11 +652,14 @@ def get_external_ip(network_mode: Optional[ return parsed_url.hostname -def check_credentials(context: str, timeout: int = kubernetes.API_TIMEOUT) -> \ +def check_credentials(context: Optional[str], + timeout: int = kubernetes.API_TIMEOUT) -> \ Tuple[bool, Optional[str]]: """Check if the credentials in kubeconfig file are valid Args: + context (Optional[str]): The Kubernetes context to use. If none, uses + in-cluster auth to check credentials, if available. timeout (int): Timeout in seconds for the test API call Returns: @@ -817,22 +821,42 @@ def get_current_kube_config_context_name() -> Optional[str]: return None -def get_all_kube_config_context_names() -> Optional[List[str]]: +def is_incluster_config_available() -> bool: + """Check if in-cluster auth is available. + + Note: We cannot use load_incluster_config() to check if in-cluster config + is available because it will load the in-cluster config (if available) + and modify the current global kubernetes config. We simply check if the + service account token file exists to determine if in-cluster config may + be available. + """ + return os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token') + + +def get_all_kube_config_context_names() -> List[Optional[str]]: """Get all kubernetes context names from the kubeconfig file. + If running in-cluster, returns [None] to indicate in-cluster config. + We should not cache the result of this function as the admin policy may update the contexts. Returns: - List[str] | None: The list of kubernetes context names if it exists, - None otherwise + List[Optional[str]]: The list of kubernetes context names if + available, an empty list otherwise. If running in-cluster, + returns [None] to indicate in-cluster config. """ k8s = kubernetes.kubernetes try: all_contexts, _ = k8s.config.list_kube_config_contexts() + # all_contexts will always have at least one context. If kubeconfig + # does not have any contexts defined, it will raise ConfigException. return [context['name'] for context in all_contexts] except k8s.config.config_exception.ConfigException: - return None + # If running in cluster, return [None] to indicate in-cluster config + if is_incluster_config_available(): + return [None] + return [] @functools.lru_cache() @@ -1046,7 +1070,7 @@ def get_ssh_proxy_command( k8s_ssh_target: str, network_mode: kubernetes_enums.KubernetesNetworkingMode, private_key_path: str, - context: str, + context: Optional[str], namespace: str, ) -> str: """Generates the SSH proxy command to connect to the pod. @@ -1144,7 +1168,8 @@ def create_proxy_command_script() -> str: return port_fwd_proxy_cmd_path -def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str, context: str, +def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str, + context: Optional[str], service_type: kubernetes_enums.KubernetesServiceType): """Sets up Kubernetes service resource to access for SSH jump pod. @@ -1216,7 +1241,8 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str, context: str, def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str, - ssh_key_secret: str, namespace: str, context: str): + ssh_key_secret: str, namespace: str, + context: Optional[str]): """Sets up Kubernetes RBAC and pod for SSH jump host. Our Kubernetes implementation uses a SSH jump pod to reach SkyPilot clusters @@ -1296,7 +1322,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str, logger.info(f'Created SSH Jump Host {ssh_jump_name}.') -def clean_zombie_ssh_jump_pod(namespace: str, context: str, node_id: str): +def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str], + node_id: str): """Analyzes SSH jump pod and removes if it is in a bad state Prevents the existence of a dangling SSH jump pod. This could happen @@ -1618,7 +1645,8 @@ def check_nvidia_runtime_class(context: Optional[str] = None) -> bool: return nvidia_exists -def check_secret_exists(secret_name: str, namespace: str, context: str) -> bool: +def check_secret_exists(secret_name: str, namespace: str, + context: Optional[str]) -> bool: """Checks if a secret exists in a namespace Args: @@ -1836,7 +1864,7 @@ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str: def filter_pods(namespace: str, - context: str, + context: Optional[str], tag_filters: Dict[str, str], status_filters: Optional[List[str]] = None) -> Dict[str, Any]: """Filters pods by tags and status.""" @@ -1962,6 +1990,11 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle', context=context) -def get_context_from_config(provider_config: Dict[str, Any]) -> str: - return provider_config.get('context', - get_current_kube_config_context_name()) +def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]: + context = provider_config.get('context', + get_current_kube_config_context_name()) + if context == IN_CLUSTER_REGION: + # If the context (also used as the region) is set to IN_CLUSTER_REGION + # we need to use in-cluster auth. + context = None + return context diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 1cb1dfc88e6..3d4bcb0af9a 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -653,7 +653,7 @@ class KubernetesCommandRunner(CommandRunner): def __init__( self, - node: Tuple[Tuple[str, str], str], + node: Tuple[Tuple[str, Optional[str]], str], **kwargs, ): """Initialize KubernetesCommandRunner. diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index e2bf2e5031c..51b22a259ea 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -204,7 +204,7 @@ class KubernetesCommandRunner(CommandRunner): def __init__( self, - node: Tuple[Tuple[str, str], str], + node: Tuple[Tuple[str, Optional[str]], str], ) -> None: ... diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 118f9a2b718..39045962a78 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -363,6 +363,14 @@ def shared_controller_vars_to_fill( # again on the controller. This is required since admin_policy is not # installed on the controller. local_user_config.pop('admin_policy', None) + # Remove allowed_contexts from local_user_config since the controller + # may be running in a Kubernetes cluster with in-cluster auth and may + # not have kubeconfig available to it. This is the typical case since + # remote_identity default for Kubernetes is SERVICE_ACCOUNT. + # TODO(romilb): We should check the cloud the controller is running on + # before popping allowed_contexts. If it is not on Kubernetes, + # we may be able to use allowed_contexts. + local_user_config.pop('allowed_contexts', None) with tempfile.NamedTemporaryFile( delete=False, suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX) as temp_file: