From 0b46be1236a9a08b063732b3c5674154f49cc66f Mon Sep 17 00:00:00 2001 From: Jacky Hu Date: Thu, 1 Feb 2024 13:43:16 -0800 Subject: [PATCH] [PECO-1414] Support Databricks InHouse OAuth in Azure Signed-off-by: Jacky Hu --- src/databricks/sql/auth/auth.py | 18 +-- src/databricks/sql/auth/authenticators.py | 10 +- src/databricks/sql/auth/endpoint.py | 29 ++++- src/databricks/sql/client.py | 9 +- tests/unit/test_auth.py | 148 +++++++++++++++------- tests/unit/test_endpoint.py | 127 ++++++++++++++----- 6 files changed, 243 insertions(+), 98 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 928898cd..fcb2219e 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -8,12 +8,11 @@ ExternalAuthProvider, DatabricksOAuthProvider, ) -from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType -from databricks.sql.experimental.oauth_persistence import OAuthPersistence class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" + AZURE_OAUTH = "azure-oauth" # other supported types (access_token, user/pass) can be inferred # we can add more types as needed later @@ -51,7 +50,7 @@ def __init__( def get_auth_provider(cfg: ClientContext): if cfg.credentials_provider: return ExternalAuthProvider(cfg.credentials_provider) - if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: + if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None @@ -62,6 +61,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + cfg.auth_type, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -87,20 +87,22 @@ def normalize_host_name(hostname: str): return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" -def get_client_id_and_redirect_port(hostname: str): - cloud_type = infer_cloud_from_host(hostname) +def get_client_id_and_redirect_port(use_azure_auth: bool): return ( (PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE) - if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP + if not use_azure_auth else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE) ) def get_python_sql_connector_auth_provider(hostname: str, **kwargs): - (client_id, redirect_port_range) = get_client_id_and_redirect_port(hostname) + auth_type = kwargs.get("auth_type") + (client_id, redirect_port_range) = get_client_id_and_redirect_port( + auth_type == AuthType.AZURE_OAUTH.value + ) cfg = ClientContext( hostname=normalize_host_name(hostname), - auth_type=kwargs.get("auth_type"), + auth_type=auth_type, access_token=kwargs.get("access_token"), username=kwargs.get("_username"), password=kwargs.get("_password"), diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 1cd68f90..e89c2bd5 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -18,6 +18,7 @@ def add_headers(self, request_headers: Dict[str, str]): HeaderFactory = Callable[[], Dict[str, str]] + # In order to keep compatibility with SDK class CredentialsProvider(abc.ABC): """CredentialsProvider is the protocol (call-side interface) @@ -69,16 +70,13 @@ def __init__( redirect_port_range: List[int], client_id: str, scopes: List[str], + auth_type: str = "databricks-oauth", ): try: - cloud_type = infer_cloud_from_host(hostname) - if not cloud_type: - raise NotImplementedError("Cannot infer the cloud type from hostname") - - idp_endpoint = get_oauth_endpoints(cloud_type) + idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth") if not idp_endpoint: raise NotImplementedError( - f"OAuth is not supported for cloud ${cloud_type.value}" + f"OAuth is not supported for host ${hostname}" ) # Convert to the corresponding scopes in the corresponding IdP diff --git a/src/databricks/sql/auth/endpoint.py b/src/databricks/sql/auth/endpoint.py index bfcc15f7..5cb26ae3 100644 --- a/src/databricks/sql/auth/endpoint.py +++ b/src/databricks/sql/auth/endpoint.py @@ -1,9 +1,9 @@ # # It implements all the cloud specific OAuth configuration/metadata # -# Azure: It uses AAD +# Azure: It uses Databricks internal IdP or Azure AD # AWS: It uses Databricks internal IdP -# GCP: Not support yet +# GCP: It uses Databricks internal IdP # from abc import ABC, abstractmethod from enum import Enum @@ -37,6 +37,9 @@ class CloudType(Enum): ] DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"] +# Domain supported by Databricks InHouse OAuth +DATABRICKS_OAUTH_AZURE_DOMAINS = [".azuredatabricks.net"] + # Infer cloud type from Databricks SQL instance hostname def infer_cloud_from_host(hostname: str) -> Optional[CloudType]: @@ -53,6 +56,14 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]: return None +def is_supported_databricks_oauth_host(hostname: str) -> bool: + host = hostname.lower().replace("https://", "").split("/")[0] + domains = ( + DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS + ) + return any(e for e in domains if host.endswith(e)) + + def get_databricks_oidc_url(hostname: str): maybe_scheme = "https://" if not hostname.startswith("https://") else "" maybe_trailing_slash = "/" if not hostname.endswith("/") else "" @@ -112,10 +123,18 @@ def get_openid_config_url(self, hostname: str): return f"{idp_url}/.well-known/oauth-authorization-server" -def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]: - if cloud == CloudType.AWS or cloud == CloudType.GCP: +def get_oauth_endpoints( + hostname: str, use_azure_auth: bool +) -> Optional[OAuthEndpointCollection]: + cloud = infer_cloud_from_host(hostname) + + if cloud in [CloudType.AWS, CloudType.GCP]: return InHouseOAuthEndpointCollection() elif cloud == CloudType.AZURE: - return AzureOAuthEndpointCollection() + return ( + InHouseOAuthEndpointCollection() + if is_supported_databricks_oauth_host(hostname) and not use_azure_auth + else AzureOAuthEndpointCollection() + ) else: return None diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 45f116f0..b4484fb7 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -96,7 +96,8 @@ def __init__( legacy purposes and will be deprecated in a future release. When this parameter is `True` you will see a warning log message. To suppress this log message, set `use_inline_params="silent"`. auth_type: `str`, optional - `databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`. + `databricks-oauth` : to use Databricks OAuth with fine-grained permission scopes, set to `databricks-oauth`. + `azure-auth` : to use Microsoft Entra ID OAuth flow, set to `azure-auth`. oauth_client_id: `str`, optional custom oauth client_id. If not specified, it will use the built-in client_id of databricks-sql-python. @@ -107,9 +108,9 @@ def __init__( experimental_oauth_persistence: configures preferred storage for persisting oauth tokens. This has to be a class implementing `OAuthPersistence`. - When `auth_type` is set to `databricks-oauth` without persisting the oauth token in a persistence storage - the oauth tokens will only be maintained in memory and if the python process restarts the end user - will have to login again. + When `auth_type` is set to `databricks-oauth` or `azure-auth` without persisting the oauth token in a + persistence storage the oauth tokens will only be maintained in memory and if the python process + restarts the end user will have to login again. Note this is beta (private preview) For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 1ed45445..5b81f2b7 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -3,47 +3,59 @@ from typing import Optional from unittest.mock import patch -from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider +from databricks.sql.auth.auth import ( + AccessTokenAuthProvider, + BasicAuthProvider, + AuthProvider, + ExternalAuthProvider, + AuthType, +) from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.auth.oauth import OAuthManager from databricks.sql.auth.authenticators import DatabricksOAuthProvider -from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection +from databricks.sql.auth.endpoint import ( + CloudType, + InHouseOAuthEndpointCollection, + AzureOAuthEndpointCollection, +) from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache class Auth(unittest.TestCase): - def test_access_token_provider(self): access_token = "aBc2" auth = AccessTokenAuthProvider(access_token=access_token) - http_request = {'myKey': 'myVal'} + http_request = {"myKey": "myVal"} auth.add_headers(http_request) - self.assertEqual(http_request['Authorization'], 'Bearer aBc2') + self.assertEqual(http_request["Authorization"], "Bearer aBc2") self.assertEqual(len(http_request.keys()), 2) - self.assertEqual(http_request['myKey'], 'myVal') + self.assertEqual(http_request["myKey"], "myVal") def test_basic_auth_provider(self): username = "moderakh" password = "Elevate Databricks 123!!!" auth = BasicAuthProvider(username=username, password=password) - http_request = {'myKey': 'myVal'} + http_request = {"myKey": "myVal"} auth.add_headers(http_request) - self.assertEqual(http_request['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') + self.assertEqual( + http_request["Authorization"], + "Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==", + ) self.assertEqual(len(http_request.keys()), 2) - self.assertEqual(http_request['myKey'], 'myVal') + self.assertEqual(http_request["myKey"], "myVal") def test_noop_auth_provider(self): auth = AuthProvider() - http_request = {'myKey': 'myVal'} + http_request = {"myKey": "myVal"} auth.add_headers(http_request) self.assertEqual(len(http_request.keys()), 1) - self.assertEqual(http_request['myKey'], 'myVal') + self.assertEqual(http_request["myKey"], "myVal") @patch.object(OAuthManager, "check_and_refresh_access_token") @patch.object(OAuthManager, "get_tokens") @@ -55,90 +67,136 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): mock_get_tokens.return_value = (access_token, refresh_token) mock_check_and_refresh.return_value = (access_token, refresh_token, False) - params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"), - (CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection, - f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"), - (CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")] - - for cloud_type, host, expected_endpoint_type, expected_scopes in params: + params = [ + ( + CloudType.AWS, + "foo.cloud.databricks.com", + False, + InHouseOAuthEndpointCollection, + "offline_access sql", + ), + ( + CloudType.AZURE, + "foo.1.azuredatabricks.net", + True, + AzureOAuthEndpointCollection, + f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access", + ), + ( + CloudType.AZURE, + "foo.1.azuredatabricks.net", + False, + InHouseOAuthEndpointCollection, + "offline_access sql", + ), + ( + CloudType.GCP, + "foo.gcp.databricks.com", + False, + InHouseOAuthEndpointCollection, + "offline_access sql", + ), + ] + + for ( + cloud_type, + host, + use_azure_auth, + expected_endpoint_type, + expected_scopes, + ) in params: with self.subTest(cloud_type.value): oauth_persistence = OAuthPersistenceCache() - auth_provider = DatabricksOAuthProvider(hostname=host, - oauth_persistence=oauth_persistence, - redirect_port_range=[8020], - client_id=client_id, - scopes=scopes) - - self.assertIsInstance(auth_provider.oauth_manager.idp_endpoint, expected_endpoint_type) + auth_provider = DatabricksOAuthProvider( + hostname=host, + oauth_persistence=oauth_persistence, + redirect_port_range=[8020], + client_id=client_id, + scopes=scopes, + auth_type=AuthType.AZURE_OAUTH.value + if use_azure_auth + else AuthType.DATABRICKS_OAUTH.value, + ) + + self.assertIsInstance( + auth_provider.oauth_manager.idp_endpoint, expected_endpoint_type + ) self.assertEqual(auth_provider.oauth_manager.port_range, [8020]) self.assertEqual(auth_provider.oauth_manager.client_id, client_id) - self.assertEqual(oauth_persistence.read(host).refresh_token, refresh_token) + self.assertEqual( + oauth_persistence.read(host).refresh_token, refresh_token + ) mock_get_tokens.assert_called_with(hostname=host, scope=expected_scopes) headers = {} auth_provider.add_headers(headers) - self.assertEqual(headers['Authorization'], f"Bearer {access_token}") + self.assertEqual(headers["Authorization"], f"Bearer {access_token}") def test_external_provider(self): class MyProvider(CredentialsProvider): - def auth_type(self) -> str: - return "mine" + def auth_type(self) -> str: + return "mine" - def __call__(self, *args, **kwargs) -> HeaderFactory: - return lambda: {"foo": "bar"} + def __call__(self, *args, **kwargs) -> HeaderFactory: + return lambda: {"foo": "bar"} auth = ExternalAuthProvider(MyProvider()) - http_request = {'myKey': 'myVal'} + http_request = {"myKey": "myVal"} auth.add_headers(http_request) - self.assertEqual(http_request['foo'], 'bar') + self.assertEqual(http_request["foo"], "bar") self.assertEqual(len(http_request.keys()), 2) - self.assertEqual(http_request['myKey'], 'myVal') + self.assertEqual(http_request["myKey"], "myVal") def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" - kwargs = {'access_token': 'dpi123'} + kwargs = {"access_token": "dpi123"} auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} auth_provider.add_headers(headers) - self.assertEqual(headers['Authorization'], 'Bearer dpi123') + self.assertEqual(headers["Authorization"], "Bearer dpi123") def test_get_python_sql_connector_auth_provider_external(self): - class MyProvider(CredentialsProvider): - def auth_type(self) -> str: - return "mine" + def auth_type(self) -> str: + return "mine" - def __call__(self, *args, **kwargs) -> HeaderFactory: - return lambda: {"foo": "bar"} + def __call__(self, *args, **kwargs) -> HeaderFactory: + return lambda: {"foo": "bar"} hostname = "moderakh-test.cloud.databricks.com" - kwargs = {'credentials_provider': MyProvider()} + kwargs = {"credentials_provider": MyProvider()} auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} auth_provider.add_headers(headers) - self.assertEqual(headers['foo'], 'bar') + self.assertEqual(headers["foo"], "bar") def test_get_python_sql_connector_auth_provider_username_password(self): username = "moderakh" password = "Elevate Databricks 123!!!" hostname = "moderakh-test.cloud.databricks.com" - kwargs = {'_username': username, '_password': password} + kwargs = {"_username": username, "_password": password} auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) self.assertTrue(type(auth_provider).__name__, "BasicAuthProvider") headers = {} auth_provider.add_headers(headers) - self.assertEqual(headers['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') + self.assertEqual( + headers["Authorization"], + "Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==", + ) def test_get_python_sql_connector_auth_provider_noop(self): tls_client_cert_file = "fake.cert" use_cert_as_auth = "abc" hostname = "moderakh-test.cloud.databricks.com" - kwargs = {'_tls_client_cert_file': tls_client_cert_file, '_use_cert_as_auth': use_cert_as_auth} + kwargs = { + "_tls_client_cert_file": tls_client_cert_file, + "_use_cert_as_auth": use_cert_as_auth, + } auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index 63393039..1f7d7cdd 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -4,54 +4,121 @@ from unittest.mock import patch -from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType, get_oauth_endpoints, \ - AzureOAuthEndpointCollection +from databricks.sql.auth.auth import AuthType +from databricks.sql.auth.endpoint import ( + infer_cloud_from_host, + CloudType, + get_oauth_endpoints, + AzureOAuthEndpointCollection, +) aws_host = "foo-bar.cloud.databricks.com" azure_host = "foo-bar.1.azuredatabricks.net" +azure_cn_host = "foo-bar2.databricks.azure.cn" +gcp_host = "foo.1.gcp.databricks.com" class EndpointTest(unittest.TestCase): def test_infer_cloud_from_host(self): - param_list = [(CloudType.AWS, aws_host), (CloudType.AZURE, azure_host), (None, "foo.example.com")] + param_list = [ + (CloudType.AWS, aws_host), + (CloudType.AZURE, azure_host), + (None, "foo.example.com"), + ] for expected_type, host in param_list: with self.subTest(expected_type or "None", expected_type=expected_type): self.assertEqual(infer_cloud_from_host(host), expected_type) - self.assertEqual(infer_cloud_from_host(f"https://{host}/to/path"), expected_type) + self.assertEqual( + infer_cloud_from_host(f"https://{host}/to/path"), expected_type + ) def test_oauth_endpoint(self): scopes = ["offline_access", "sql", "admin"] scopes2 = ["sql", "admin"] - azure_scope = f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation" - - param_list = [(CloudType.AWS, - aws_host, - f"https://{aws_host}/oidc/oauth2/v2.0/authorize", - f"https://{aws_host}/oidc/.well-known/oauth-authorization-server", - scopes, - scopes2 - ), - ( - CloudType.AZURE, - azure_host, - f"https://{azure_host}/oidc/oauth2/v2.0/authorize", - "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration", - [azure_scope, "offline_access"], - [azure_scope] - )] - - for cloud_type, host, expected_auth_url, expected_config_url, expected_scopes, expected_scope2 in param_list: + azure_scope = ( + f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation" + ) + + param_list = [ + ( + CloudType.AWS, + aws_host, + False, + f"https://{aws_host}/oidc/oauth2/v2.0/authorize", + f"https://{aws_host}/oidc/.well-known/oauth-authorization-server", + scopes, + scopes2, + ), + ( + CloudType.AZURE, + azure_cn_host, + False, + f"https://{azure_cn_host}/oidc/oauth2/v2.0/authorize", + "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration", + [azure_scope, "offline_access"], + [azure_scope], + ), + ( + CloudType.AZURE, + azure_host, + True, + f"https://{azure_host}/oidc/oauth2/v2.0/authorize", + "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration", + [azure_scope, "offline_access"], + [azure_scope], + ), + ( + CloudType.AZURE, + azure_host, + False, + f"https://{azure_host}/oidc/oauth2/v2.0/authorize", + f"https://{azure_host}/oidc/.well-known/oauth-authorization-server", + scopes, + scopes2, + ), + ( + CloudType.GCP, + gcp_host, + False, + f"https://{gcp_host}/oidc/oauth2/v2.0/authorize", + f"https://{gcp_host}/oidc/.well-known/oauth-authorization-server", + scopes, + scopes2, + ), + ] + + for ( + cloud_type, + host, + use_azure_auth, + expected_auth_url, + expected_config_url, + expected_scopes, + expected_scope2, + ) in param_list: with self.subTest(cloud_type): - endpoint = get_oauth_endpoints(cloud_type) - self.assertEqual(endpoint.get_authorization_url(host), expected_auth_url) - self.assertEqual(endpoint.get_openid_config_url(host), expected_config_url) + endpoint = get_oauth_endpoints(host, use_azure_auth) + self.assertEqual( + endpoint.get_authorization_url(host), expected_auth_url + ) + self.assertEqual( + endpoint.get_openid_config_url(host), expected_config_url + ) self.assertEqual(endpoint.get_scopes_mapping(scopes), expected_scopes) self.assertEqual(endpoint.get_scopes_mapping(scopes2), expected_scope2) - @patch.dict(os.environ, {'DATABRICKS_AZURE_TENANT_ID': '052ee82f-b79d-443c-8682-3ec1749e56b0'}) + @patch.dict( + os.environ, + {"DATABRICKS_AZURE_TENANT_ID": "052ee82f-b79d-443c-8682-3ec1749e56b0"}, + ) def test_azure_oauth_scope_mappings_from_different_tenant_id(self): scopes = ["offline_access", "sql", "all"] - endpoint = get_oauth_endpoints(CloudType.AZURE) - self.assertEqual(endpoint.get_scopes_mapping(scopes), - ['052ee82f-b79d-443c-8682-3ec1749e56b0/user_impersonation', "offline_access"]) + endpoint = get_oauth_endpoints(azure_host, True) + self.assertEqual( + endpoint.get_scopes_mapping(scopes), + [ + "052ee82f-b79d-443c-8682-3ec1749e56b0/user_impersonation", + "offline_access", + ], + )