Skip to content

Commit

Permalink
[PECO-1414] Support Databricks InHouse OAuth in Azure
Browse files Browse the repository at this point in the history
Signed-off-by: Jacky Hu <jacky.hu@databricks.com>
  • Loading branch information
jackyhu-db committed Feb 11, 2024
1 parent 1b469c0 commit 0b46be1
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 98 deletions.
18 changes: 10 additions & 8 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
ExternalAuthProvider,
DatabricksOAuthProvider,
)
from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType
from databricks.sql.experimental.oauth_persistence import OAuthPersistence


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# other supported types (access_token, user/pass) can be inferred
# we can add more types as needed later

Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(
def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
Expand All @@ -62,6 +61,7 @@ def get_auth_provider(cfg: ClientContext):
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
cfg.auth_type,
)
elif cfg.access_token is not None:
return AccessTokenAuthProvider(cfg.access_token)
Expand All @@ -87,20 +87,22 @@ def normalize_host_name(hostname: str):
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"


def get_client_id_and_redirect_port(hostname: str):
cloud_type = infer_cloud_from_host(hostname)
def get_client_id_and_redirect_port(use_azure_auth: bool):
return (
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
if not use_azure_auth
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
)


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
(client_id, redirect_port_range) = get_client_id_and_redirect_port(hostname)
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
cfg = ClientContext(
hostname=normalize_host_name(hostname),
auth_type=kwargs.get("auth_type"),
auth_type=auth_type,
access_token=kwargs.get("access_token"),
username=kwargs.get("_username"),
password=kwargs.get("_password"),
Expand Down
10 changes: 4 additions & 6 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def add_headers(self, request_headers: Dict[str, str]):

HeaderFactory = Callable[[], Dict[str, str]]


# In order to keep compatibility with SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
Expand Down Expand Up @@ -69,16 +70,13 @@ def __init__(
redirect_port_range: List[int],
client_id: str,
scopes: List[str],
auth_type: str = "databricks-oauth",
):
try:
cloud_type = infer_cloud_from_host(hostname)
if not cloud_type:
raise NotImplementedError("Cannot infer the cloud type from hostname")

idp_endpoint = get_oauth_endpoints(cloud_type)
idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth")
if not idp_endpoint:
raise NotImplementedError(
f"OAuth is not supported for cloud ${cloud_type.value}"
f"OAuth is not supported for host ${hostname}"
)

# Convert to the corresponding scopes in the corresponding IdP
Expand Down
29 changes: 24 additions & 5 deletions src/databricks/sql/auth/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#
# It implements all the cloud specific OAuth configuration/metadata
#
# Azure: It uses AAD
# Azure: It uses Databricks internal IdP or Azure AD
# AWS: It uses Databricks internal IdP
# GCP: Not support yet
# GCP: It uses Databricks internal IdP
#
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -37,6 +37,9 @@ class CloudType(Enum):
]
DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"]

# Domain supported by Databricks InHouse OAuth
DATABRICKS_OAUTH_AZURE_DOMAINS = [".azuredatabricks.net"]


# Infer cloud type from Databricks SQL instance hostname
def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
Expand All @@ -53,6 +56,14 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
return None


def is_supported_databricks_oauth_host(hostname: str) -> bool:
host = hostname.lower().replace("https://", "").split("/")[0]
domains = (
DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS
)
return any(e for e in domains if host.endswith(e))


def get_databricks_oidc_url(hostname: str):
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
Expand Down Expand Up @@ -112,10 +123,18 @@ def get_openid_config_url(self, hostname: str):
return f"{idp_url}/.well-known/oauth-authorization-server"


def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
if cloud == CloudType.AWS or cloud == CloudType.GCP:
def get_oauth_endpoints(
hostname: str, use_azure_auth: bool
) -> Optional[OAuthEndpointCollection]:
cloud = infer_cloud_from_host(hostname)

if cloud in [CloudType.AWS, CloudType.GCP]:
return InHouseOAuthEndpointCollection()
elif cloud == CloudType.AZURE:
return AzureOAuthEndpointCollection()
return (
InHouseOAuthEndpointCollection()
if is_supported_databricks_oauth_host(hostname) and not use_azure_auth
else AzureOAuthEndpointCollection()
)
else:
return None
9 changes: 5 additions & 4 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(
legacy purposes and will be deprecated in a future release. When this parameter is `True` you will see
a warning log message. To suppress this log message, set `use_inline_params="silent"`.
auth_type: `str`, optional
`databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`.
`databricks-oauth` : to use Databricks OAuth with fine-grained permission scopes, set to `databricks-oauth`.
`azure-auth` : to use Microsoft Entra ID OAuth flow, set to `azure-auth`.
oauth_client_id: `str`, optional
custom oauth client_id. If not specified, it will use the built-in client_id of databricks-sql-python.
Expand All @@ -107,9 +108,9 @@ def __init__(
experimental_oauth_persistence: configures preferred storage for persisting oauth tokens.
This has to be a class implementing `OAuthPersistence`.
When `auth_type` is set to `databricks-oauth` without persisting the oauth token in a persistence storage
the oauth tokens will only be maintained in memory and if the python process restarts the end user
will have to login again.
When `auth_type` is set to `databricks-oauth` or `azure-auth` without persisting the oauth token in a
persistence storage the oauth tokens will only be maintained in memory and if the python process
restarts the end user will have to login again.
Note this is beta (private preview)
For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence
Expand Down
Loading

0 comments on commit 0b46be1

Please sign in to comment.