From 78aa5bae44bc4ee3dfa06a785e28face79b5c4e5 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Thu, 24 Oct 2024 15:16:42 -0700 Subject: [PATCH 1/3] removing deprecated types --- dbt/adapters/databricks/api_client.py | 27 ++--- dbt/adapters/databricks/auth.py | 5 +- dbt/adapters/databricks/behaviors/columns.py | 13 +- dbt/adapters/databricks/column.py | 3 +- dbt/adapters/databricks/connections.py | 49 ++++---- dbt/adapters/databricks/credentials.py | 27 ++--- .../databricks/events/connection_events.py | 3 +- dbt/adapters/databricks/impl.py | 113 ++++++++---------- .../databricks/python_models/python_config.py | 17 +-- .../python_models/python_submissions.py | 53 ++++---- .../databricks/python_models/run_tracking.py | 5 +- dbt/adapters/databricks/relation.py | 13 +- .../databricks/relation_configs/base.py | 13 +- .../relation_configs/incremental.py | 3 +- .../relation_configs/materialized_view.py | 3 +- .../relation_configs/partitioning.py | 5 +- .../relation_configs/streaming_table.py | 3 +- .../databricks/relation_configs/tags.py | 8 +- .../relation_configs/tblproperties.py | 10 +- dbt/adapters/databricks/utils.py | 5 +- .../materialized_view_tests/test_basic.py | 3 +- .../test_persist_constraints.py | 4 +- .../adapter/streaming_tables/test_st_basic.py | 3 +- .../tblproperties/test_set_tblproperties.py | 4 +- tests/profiles.py | 5 +- tests/unit/fixtures.py | 10 +- tests/unit/macros/base.py | 5 +- tests/unit/test_adapter.py | 3 +- 28 files changed, 182 insertions(+), 233 deletions(-) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 893c09255..1b5b075ba 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -1,15 +1,12 @@ import base64 +from collections.abc import Callable import time from abc import ABC from abc import abstractmethod from dataclasses import dataclass import re from typing import Any -from typing import Callable -from typing import Dict -from typing import List from typing import Optional -from typing import Set from dbt.adapters.databricks import utils from dbt.adapters.databricks.__version__ import version @@ -34,17 +31,17 @@ def __init__(self, session: Session, host: str, api: str): self.session = session def get( - self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None + self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None ) -> Response: return self.session.get(f"{self.prefix}{suffix}", json=json, params=params) def post( - self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None + self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None ) -> Response: return self.session.post(f"{self.prefix}{suffix}", json=json, params=params) def put( - self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None + self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None ) -> Response: return self.session.put(f"{self.prefix}{suffix}", json=json, params=params) @@ -230,7 +227,7 @@ def _poll_api( url: str, params: dict, get_state_func: Callable[[Response], str], - terminal_states: Set[str], + terminal_states: set[str], expected_end_state: str, unexpected_end_state_func: Callable[[Response], None], ) -> Response: @@ -261,7 +258,7 @@ class CommandExecution(object): context_id: str cluster_id: str - def model_dump(self) -> Dict[str, Any]: + def model_dump(self) -> dict[str, Any]: return { "commandId": self.command_id, "contextId": self.context_id, @@ -328,7 +325,7 @@ def __init__(self, session: Session, host: str, polling_interval: int, timeout: super().__init__(session, host, "/api/2.1/jobs/runs", polling_interval, timeout) def submit( - self, run_name: str, job_spec: Dict[str, Any], **additional_job_settings: Dict[str, Any] + self, run_name: str, job_spec: dict[str, Any], **additional_job_settings: dict[str, Any] ) -> str: submit_response = self.session.post( "/submit", json={"run_name": run_name, "tasks": [job_spec], **additional_job_settings} @@ -388,7 +385,7 @@ class JobPermissionsApi(DatabricksApi): def __init__(self, session: Session, host: str): super().__init__(session, host, "/api/2.0/permissions/jobs") - def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None: + def put(self, job_id: str, access_control_list: list[dict[str, Any]]) -> None: request_body = {"access_control_list": access_control_list} response = self.session.put(f"/{job_id}", json=request_body) @@ -397,7 +394,7 @@ def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None: if response.status_code != 200: raise DbtRuntimeError(f"Error updating Databricks workflow.\n {response.content!r}") - def get(self, job_id: str) -> Dict[str, Any]: + def get(self, job_id: str) -> dict[str, Any]: response = self.session.get(f"/{job_id}") if response.status_code != 200: @@ -413,7 +410,7 @@ class WorkflowJobApi(DatabricksApi): def __init__(self, session: Session, host: str): super().__init__(session, host, "/api/2.1/jobs") - def search_by_name(self, job_name: str) -> List[Dict[str, Any]]: + def search_by_name(self, job_name: str) -> list[dict[str, Any]]: response = self.session.get("/list", json={"name": job_name}) if response.status_code != 200: @@ -421,7 +418,7 @@ def search_by_name(self, job_name: str) -> List[Dict[str, Any]]: return response.json().get("jobs", []) - def create(self, job_spec: Dict[str, Any]) -> str: + def create(self, job_spec: dict[str, Any]) -> str: """ :return: the job_id """ @@ -434,7 +431,7 @@ def create(self, job_spec: Dict[str, Any]) -> str: logger.info(f"New workflow created with job id {job_id}") return job_id - def update_job_settings(self, job_id: str, job_spec: Dict[str, Any]) -> None: + def update_job_settings(self, job_id: str, job_spec: dict[str, Any]) -> None: request_body = { "job_id": job_id, "new_settings": job_spec, diff --git a/dbt/adapters/databricks/auth.py b/dbt/adapters/databricks/auth.py index 51d894e05..8662f794d 100644 --- a/dbt/adapters/databricks/auth.py +++ b/dbt/adapters/databricks/auth.py @@ -1,5 +1,4 @@ from typing import Any -from typing import Dict from typing import Optional from databricks.sdk.core import Config @@ -34,7 +33,7 @@ def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]: def __call__(self, _: Optional[Config] = None) -> HeaderFactory: static_credentials = {"Authorization": f"Bearer {self._token}"} - def inner() -> Dict[str, str]: + def inner() -> dict[str, str]: return static_credentials return inner @@ -81,7 +80,7 @@ def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> Crede return c def __call__(self, _: Optional[Config] = None) -> HeaderFactory: - def inner() -> Dict[str, str]: + def inner() -> dict[str, str]: token = self._token_source.token() # type: ignore return {"Authorization": f"{token.token_type} {token.access_token}"} diff --git a/dbt/adapters/databricks/behaviors/columns.py b/dbt/adapters/databricks/behaviors/columns.py index 978823737..91f6c351a 100644 --- a/dbt/adapters/databricks/behaviors/columns.py +++ b/dbt/adapters/databricks/behaviors/columns.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List from dbt.adapters.sql import SQLAdapter from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.relation import DatabricksRelation @@ -14,13 +13,13 @@ class GetColumnsBehavior(ABC): @abstractmethod def get_columns_in_relation( cls, adapter: SQLAdapter, relation: DatabricksRelation - ) -> List[DatabricksColumn]: + ) -> list[DatabricksColumn]: pass @staticmethod def _get_columns_with_comments( adapter: SQLAdapter, relation: DatabricksRelation, macro_name: str - ) -> List[AttrDict]: + ) -> list[AttrDict]: return list( handle_missing_objects( lambda: adapter.execute_macro(macro_name, kwargs={"relation": relation}), @@ -33,12 +32,12 @@ class GetColumnsByDescribe(GetColumnsBehavior): @classmethod def get_columns_in_relation( cls, adapter: SQLAdapter, relation: DatabricksRelation - ) -> List[DatabricksColumn]: + ) -> list[DatabricksColumn]: rows = cls._get_columns_with_comments(adapter, relation, "get_columns_comments") return cls._parse_columns(rows) @classmethod - def _parse_columns(cls, rows: List[AttrDict]) -> List[DatabricksColumn]: + def _parse_columns(cls, rows: list[AttrDict]) -> list[DatabricksColumn]: columns = [] for row in rows: @@ -57,7 +56,7 @@ class GetColumnsByInformationSchema(GetColumnsByDescribe): @classmethod def get_columns_in_relation( cls, adapter: SQLAdapter, relation: DatabricksRelation - ) -> List[DatabricksColumn]: + ) -> list[DatabricksColumn]: if relation.is_hive_metastore() or relation.type == DatabricksRelation.View: return super().get_columns_in_relation(adapter, relation) @@ -67,5 +66,5 @@ def get_columns_in_relation( return cls._parse_columns(rows) @classmethod - def _parse_columns(cls, rows: List[AttrDict]) -> List[DatabricksColumn]: + def _parse_columns(cls, rows: list[AttrDict]) -> list[DatabricksColumn]: return [DatabricksColumn(column=row[0], dtype=row[1], comment=row[2]) for row in rows] diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index 74083b4bc..f9511201c 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from typing import ClassVar -from typing import Dict from typing import Optional from dbt.adapters.spark.column import SparkColumn @@ -11,7 +10,7 @@ class DatabricksColumn(SparkColumn): table_comment: Optional[str] = None comment: Optional[str] = None - TYPE_LABELS: ClassVar[Dict[str, str]] = { + TYPE_LABELS: ClassVar[dict[str, str]] = { "LONG": "BIGINT", } diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 4eb292eb8..83b55a999 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,3 +1,4 @@ +from collections.abc import Callable, Iterator, Sequence import decimal import os import re @@ -11,15 +12,9 @@ from numbers import Number from threading import get_ident from typing import Any -from typing import Callable from typing import cast -from typing import Dict from typing import Hashable -from typing import Iterator -from typing import List from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TYPE_CHECKING import databricks.sql as dbsql @@ -107,7 +102,7 @@ class DatabricksSQLConnectionWrapper: _conn: DatabricksSQLConnection _is_cluster: bool - _cursors: List[DatabricksSQLCursor] + _cursors: list[DatabricksSQLCursor] _creds: DatabricksCredentials _user_agent: str @@ -140,7 +135,7 @@ def cursor(self) -> "DatabricksSQLCursorWrapper": def cancel(self) -> None: logger.debug(ConnectionCancel(self._conn)) - cursors: List[DatabricksSQLCursor] = self._cursors + cursors: list[DatabricksSQLCursor] = self._cursors for cursor in cursors: try: @@ -159,10 +154,10 @@ def close(self) -> None: def rollback(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: rollback") - _dbr_version: Tuple[int, int] + _dbr_version: tuple[int, int] @property - def dbr_version(self) -> Tuple[int, int]: + def dbr_version(self) -> tuple[int, int]: if not hasattr(self, "_dbr_version"): if self._is_cluster: with self._conn.cursor() as cursor: @@ -214,13 +209,13 @@ def close(self) -> None: except Error as exc: logger.warning(CursorCloseError(self._cursor, exc)) - def fetchall(self) -> Sequence[Tuple]: + def fetchall(self) -> Sequence[tuple]: return self._cursor.fetchall() - def fetchone(self) -> Optional[Tuple]: + def fetchone(self) -> Optional[tuple]: return self._cursor.fetchone() - def fetchmany(self, size: int) -> Sequence[Tuple]: + def fetchmany(self, size: int) -> Sequence[tuple]: return self._cursor.fetchmany(size) def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: @@ -315,7 +310,7 @@ def poll_refresh_pipeline(self, pipeline_id: str) -> None: return @classmethod - def findUpdate(cls, updates: List, id: str) -> Optional[Dict]: + def findUpdate(cls, updates: list, id: str) -> Optional[dict]: matches = [x for x in updates if x.get("update_id") == id] if matches: return matches[0] @@ -343,7 +338,7 @@ def _fix_binding(cls, value: Any) -> Any: return value @property - def description(self) -> Optional[List[Tuple]]: + def description(self) -> Optional[list[tuple]]: return self._cursor.description def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None: @@ -406,7 +401,7 @@ class DatabricksDBTConnection(Connection): acquire_release_count: int = 0 compute_name: str = "" http_path: str = "" - thread_identifier: Tuple[int, int] = (0, 0) + thread_identifier: tuple[int, int] = (0, 0) max_idle_time: int = DEFAULT_MAX_IDLE_TIME # If the connection is being used for a model we want to track the model language. @@ -479,7 +474,7 @@ class DatabricksConnectionManager(SparkConnectionManager): credentials_provider: Optional[TCredentialProvider] = None _user_agent = f"dbt-databricks/{__version__}" - def cancel_open(self) -> List[str]: + def cancel_open(self) -> list[str]: cancelled = super().cancel_open() creds = cast(DatabricksCredentials, self.profile.credentials) api_client = DatabricksApiClient.create(creds, 15 * 60) @@ -494,7 +489,7 @@ def compare_dbr_version(self, major: int, minor: int) -> int: dbr_version = connection.dbr_version return (dbr_version > version) - (dbr_version < version) - def set_query_header(self, query_header_context: Dict[str, Any]) -> None: + def set_query_header(self, query_header_context: dict[str, Any]) -> None: self.query_header = DatabricksMacroQueryStringSetter(self.profile, query_header_context) @contextmanager @@ -572,7 +567,7 @@ def add_query( abridge_sql_log: bool = False, *, close_cursor: bool = False, - ) -> Tuple[Connection, Any]: + ) -> tuple[Connection, Any]: connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: self.begin() @@ -622,7 +617,7 @@ def execute( auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None, - ) -> Tuple[DatabricksAdapterResponse, "Table"]: + ) -> tuple[DatabricksAdapterResponse, "Table"]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) try: @@ -735,7 +730,7 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - http_headers: List[Tuple[str, str]] = list( + http_headers: list[tuple[str, str]] = list( creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) @@ -805,8 +800,8 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> USE_LONG_SESSIONS ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" super().__init__(profile, mp_context) - self.threads_compute_connections: Dict[ - Hashable, Dict[Hashable, DatabricksDBTConnection] + self.threads_compute_connections: dict[ + Hashable, dict[Hashable, DatabricksDBTConnection] ] = {} def set_connection_name( @@ -908,7 +903,7 @@ def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: def _get_compute_connections( self, - ) -> Dict[Hashable, DatabricksDBTConnection]: + ) -> dict[Hashable, DatabricksDBTConnection]: """Retrieve a map of compute name to connection for the current thread.""" thread_id = self.get_thread_identifier() @@ -973,7 +968,7 @@ def _create_compute_connection( conn.compute_name = compute_name creds = cast(DatabricksCredentials, self.profile.credentials) conn.http_path = _get_http_path(query_header_context, creds=creds) or "" - conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier()) + conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds) conn.handle = LazyHandle(self.open) @@ -1027,7 +1022,7 @@ def open(cls, connection: Connection) -> Connection: connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - http_headers: List[Tuple[str, str]] = list( + http_headers: list[tuple[str, str]] = list( creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) @@ -1094,7 +1089,7 @@ def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict: return response.json() -def _find_update(pipeline: dict, id: str = "") -> Optional[Dict]: +def _find_update(pipeline: dict, id: str = "") -> Optional[dict]: updates = pipeline.get("latest_updates", []) if not updates: raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}") diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 27b3a4ca7..8f1a1d89a 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import itertools import json import os @@ -6,11 +7,7 @@ from dataclasses import dataclass from typing import Any from typing import cast -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Tuple from typing import Union import keyring @@ -55,21 +52,21 @@ class DatabricksCredentials(Credentials): client_id: Optional[str] = None client_secret: Optional[str] = None oauth_redirect_url: Optional[str] = None - oauth_scopes: Optional[List[str]] = None - session_properties: Optional[Dict[str, Any]] = None - connection_parameters: Optional[Dict[str, Any]] = None + oauth_scopes: Optional[list[str]] = None + session_properties: Optional[dict[str, Any]] = None + connection_parameters: Optional[dict[str, Any]] = None auth_type: Optional[str] = None # Named compute resources specified in the profile. Used for # creating a connection when a model specifies a compute resource. - compute: Optional[Dict[str, Any]] = None + compute: Optional[dict[str, Any]] = None connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False connect_max_idle: Optional[int] = None - _credentials_provider: Optional[Dict[str, Any]] = None + _credentials_provider: Optional[dict[str, Any]] = None _lock = threading.Lock() # to avoid concurrent auth _ALIASES = { @@ -78,7 +75,7 @@ class DatabricksCredentials(Credentials): } @classmethod - def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]: + def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) if "database" not in data: data["database"] = None @@ -169,12 +166,12 @@ def get_invocation_env(cls) -> Optional[str]: return invocation_env @classmethod - def get_all_http_headers(cls, user_http_session_headers: Dict[str, str]) -> Dict[str, str]: + def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: http_session_headers_str: Optional[str] = os.environ.get( DBT_DATABRICKS_HTTP_SESSION_HEADERS ) - http_session_headers_dict: Dict[str, str] = ( + http_session_headers_dict: dict[str, str] = ( { k: v if isinstance(v, str) else json.dumps(v) for k, v in json.loads(http_session_headers_str).items() @@ -204,17 +201,17 @@ def type(self) -> str: def unique_field(self) -> str: return cast(str, self.host) - def connection_info(self, *, with_aliases: bool = False) -> Iterable[Tuple[str, Any]]: + def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: as_dict = self.to_dict(omit_none=False) connection_keys = set(self._connection_keys(with_aliases=with_aliases)) - aliases: List[str] = [] + aliases: list[str] = [] if with_aliases: aliases = [k for k, v in self._ALIASES.items() if v in connection_keys] for key in itertools.chain(self._connection_keys(with_aliases=with_aliases), aliases): if key in as_dict: yield key, as_dict[key] - def _connection_keys(self, *, with_aliases: bool = False) -> Tuple[str, ...]: + def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: # diff --git a/dbt/adapters/databricks/events/connection_events.py b/dbt/adapters/databricks/events/connection_events.py index c49d12278..f2b099940 100644 --- a/dbt/adapters/databricks/events/connection_events.py +++ b/dbt/adapters/databricks/events/connection_events.py @@ -1,7 +1,6 @@ from abc import ABC from typing import Any from typing import Optional -from typing import Tuple from databricks.sql.client import Connection from dbt.adapters.databricks.events.base import SQLErrorEvent @@ -64,7 +63,7 @@ def __init__( description: str, model: Optional[Any], compute_name: Optional[str], - thread_identifier: Tuple[int, int], + thread_identifier: tuple[int, int], ): message = f"Acquired connection on thread {thread_identifier}, using " if not compute_name: diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index b275aa4fc..130d45214 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable, Iterator from multiprocessing.context import SpawnContext import os import re @@ -10,16 +11,8 @@ from typing import Any from typing import cast from typing import ClassVar -from typing import Dict -from typing import FrozenSet from typing import Generic -from typing import Iterable -from typing import Iterator -from typing import List from typing import Optional -from typing import Set -from typing import Tuple -from typing import Type from typing import TYPE_CHECKING from typing import Union from uuid import uuid4 @@ -117,16 +110,16 @@ class DatabricksConfig(AdapterConfig): table_format: TableFormat = TableFormat.DEFAULT location_root: Optional[str] = None include_full_name_in_path: bool = False - partition_by: Optional[Union[List[str], str]] = None - clustered_by: Optional[Union[List[str], str]] = None - liquid_clustered_by: Optional[Union[List[str], str]] = None + partition_by: Optional[Union[list[str], str]] = None + clustered_by: Optional[Union[list[str], str]] = None + liquid_clustered_by: Optional[Union[list[str], str]] = None buckets: Optional[int] = None - options: Optional[Dict[str, str]] = None + options: Optional[dict[str, str]] = None merge_update_columns: Optional[str] = None merge_exclude_columns: Optional[str] = None - databricks_tags: Optional[Dict[str, str]] = None - tblproperties: Optional[Dict[str, str]] = None - zorder: Optional[Union[List[str], str]] = None + databricks_tags: Optional[dict[str, str]] = None + tblproperties: Optional[dict[str, str]] = None + zorder: Optional[Union[list[str], str]] = None unique_tmp_table_suffix: bool = False skip_non_matched_step: Optional[bool] = None skip_matched_step: Optional[bool] = None @@ -139,7 +132,7 @@ class DatabricksConfig(AdapterConfig): merge_with_schema_evolution: Optional[bool] = None -def get_identifier_list_string(table_names: Set[str]) -> str: +def get_identifier_list_string(table_names: set[str]) -> str: """Returns `"|".join(table_names)` by default. Returns `"*"` if `DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS` == `"true"` @@ -163,7 +156,7 @@ class DatabricksAdapter(SparkAdapter): Column = DatabricksColumn if USE_LONG_SESSIONS: - ConnectionManager: Type[DatabricksConnectionManager] = ExtendedSessionConnectionManager + ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager else: ConnectionManager = DatabricksConnectionManager @@ -188,13 +181,13 @@ def __init__(self, config: Any, mp_context: SpawnContext) -> None: self.get_column_behavior = GetColumnsByDescribe() @property - def _behavior_flags(self) -> List[BehaviorFlag]: + def _behavior_flags(self) -> list[BehaviorFlag]: return [USE_INFO_SCHEMA_FOR_COLUMNS] @available.parse(lambda *a, **k: 0) def update_tblproperties_for_iceberg( - self, config: BaseConfig, tblproperties: Optional[Dict[str, str]] = None - ) -> Dict[str, str]: + self, config: BaseConfig, tblproperties: Optional[dict[str, str]] = None + ) -> dict[str, str]: result = tblproperties or config.get("tblproperties", {}) if config.get("table_format") == TableFormat.ICEBERG: if self.compare_dbr_version(14, 3) < 0: @@ -266,7 +259,7 @@ def compare_dbr_version(self, major: int, minor: int) -> int: """ return self.connections.compare_dbr_version(major, minor) - def list_schemas(self, database: Optional[str]) -> List[str]: + def list_schemas(self, database: Optional[str]) -> list[str]: """ Get a list of existing schemas in database. @@ -291,7 +284,7 @@ def execute( limit: Optional[int] = None, *, staging_table: Optional[BaseRelation] = None, - ) -> Tuple[AdapterResponse, "Table"]: + ) -> tuple[AdapterResponse, "Table"]: try: return super().execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit) finally: @@ -300,8 +293,8 @@ def execute( def list_relations_without_caching( # type: ignore[override] self, schema_relation: DatabricksRelation - ) -> List[DatabricksRelation]: - empty: List[Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]] = [] + ) -> list[DatabricksRelation]: + empty: list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]] = [] results = handle_missing_objects( lambda: self.get_relations_without_caching(schema_relation), empty ) @@ -326,14 +319,14 @@ def list_relations_without_caching( # type: ignore[override] def get_relations_without_caching( self, relation: DatabricksRelation - ) -> List[Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: + ) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: if relation.is_hive_metastore(): return self._get_hive_relations(relation) return self._get_uc_relations(relation) def _get_uc_relations( self, relation: DatabricksRelation - ) -> List[Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: + ) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: kwargs = {"relation": relation} results = self.execute_macro("get_uc_tables", kwargs=kwargs) return [ @@ -343,10 +336,10 @@ def _get_uc_relations( def _get_hive_relations( self, relation: DatabricksRelation - ) -> List[Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: + ) -> list[tuple[Optional[str], Optional[str], Optional[str], Optional[str]]]: kwargs = {"relation": relation} - new_rows: List[Tuple[str, Optional[str]]] + new_rows: list[tuple[str, Optional[str]]] if all([relation.database, relation.schema]): tables = self.connections.list_tables( database=relation.database, schema=relation.schema # type: ignore[arg-type] @@ -377,11 +370,11 @@ def _get_hive_relations( return [(row[0], row[1], None, None) for row in new_rows] @available.parse(lambda *a, **k: []) - def get_column_schema_from_query(self, sql: str) -> List[DatabricksColumn]: + def get_column_schema_from_query(self, sql: str) -> list[DatabricksColumn]: """Get a list of the Columns with names and data types from the given sql.""" _, cursor = self.connections.add_select_query(sql) try: - columns: List[DatabricksColumn] = [ + columns: list[DatabricksColumn] = [ self.Column.create( column_name, self.connections.data_type_code_to_name(column_type_code) ) @@ -410,8 +403,8 @@ def get_relation( return self._set_relation_information(cached) if cached else None def parse_describe_extended( # type: ignore[override] - self, relation: DatabricksRelation, raw_rows: List["Row"] - ) -> Tuple[Dict[str, Any], List[DatabricksColumn]]: + self, relation: DatabricksRelation, raw_rows: list["Row"] + ) -> tuple[dict[str, Any], list[DatabricksColumn]]: # Convert the Row to a dict dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] # Find the separator between the rows and the metadata provided @@ -442,12 +435,12 @@ def parse_describe_extended( # type: ignore[override] def get_columns_in_relation( # type: ignore[override] self, relation: DatabricksRelation - ) -> List[DatabricksColumn]: + ) -> list[DatabricksColumn]: return self.get_column_behavior.get_columns_in_relation(self, relation) def _get_updated_relation( self, relation: DatabricksRelation - ) -> Tuple[DatabricksRelation, List[DatabricksColumn]]: + ) -> tuple[DatabricksRelation, list[DatabricksColumn]]: rows = list( handle_missing_objects( lambda: self.execute_macro( @@ -482,7 +475,7 @@ def _set_relation_information(self, relation: DatabricksRelation) -> DatabricksR def parse_columns_from_information( # type: ignore[override] self, relation: DatabricksRelation, information: str - ) -> List[DatabricksColumn]: + ) -> list[DatabricksColumn]: owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information) owner = owner_match[0] if owner_match else None matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information) @@ -511,9 +504,9 @@ def parse_columns_from_information( # type: ignore[override] return columns def get_catalog_by_relations( - self, used_schemas: FrozenSet[Tuple[str, str]], relations: Set[BaseRelation] - ) -> Tuple["Table", List[Exception]]: - relation_map: Dict[Tuple[str, str], Set[str]] = defaultdict(set) + self, used_schemas: frozenset[tuple[str, str]], relations: set[BaseRelation] + ) -> tuple["Table", list[Exception]]: + relation_map: dict[tuple[str, str], set[str]] = defaultdict(set) for relation in relations: if relation.identifier: relation_map[ @@ -525,9 +518,9 @@ def get_catalog_by_relations( def get_catalog( self, relation_configs: Iterable[RelationConfig], - used_schemas: FrozenSet[Tuple[str, str]], - ) -> Tuple["Table", List[Exception]]: - relation_map: Dict[Tuple[str, str], Set[str]] = defaultdict(set) + used_schemas: frozenset[tuple[str, str]], + ) -> tuple["Table", list[Exception]]: + relation_map: dict[tuple[str, str], set[str]] = defaultdict(set) for relation in relation_configs: relation_map[(relation.database or "hive_metastore", relation.schema or "default")].add( relation.identifier @@ -537,11 +530,11 @@ def get_catalog( def _get_catalog_for_relation_map( self, - relation_map: Dict[Tuple[str, str], Set[str]], - used_schemas: FrozenSet[Tuple[str, str]], - ) -> Tuple["Table", List[Exception]]: + relation_map: dict[tuple[str, str], set[str]], + used_schemas: frozenset[tuple[str, str]], + ) -> tuple["Table", list[Exception]]: with executor(self.config) as tpe: - futures: List[Future["Table"]] = [] + futures: list[Future["Table"]] = [] for schema, relations in relation_map.items(): if schema in used_schemas: identifier = get_identifier_list_string(relations) @@ -561,10 +554,10 @@ def _get_catalog_for_relation_map( def _list_relations_with_information( self, schema_relation: DatabricksRelation - ) -> List[Tuple[DatabricksRelation, str]]: + ) -> list[tuple[DatabricksRelation, str]]: results = self._show_table_extended(schema_relation) - relations: List[Tuple[DatabricksRelation, str]] = [] + relations: list[tuple[DatabricksRelation, str]] = [] if results: for name, information in results.select(["tableName", "information"]): rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table @@ -592,7 +585,7 @@ def _get_schema_for_catalog(self, catalog: str, schema: str, identifier: str) -> from agate import Table from dbt_common.clients.agate_helper import DEFAULT_TYPE_TESTER - columns: List[Dict[str, Any]] = [] + columns: list[dict[str, Any]] = [] if identifier: schema_relation = self.Relation.create( @@ -607,7 +600,7 @@ def _get_schema_for_catalog(self, catalog: str, schema: str, identifier: str) -> def _get_columns_for_catalog( # type: ignore[override] self, relation: DatabricksRelation, information: str - ) -> Iterable[Dict[str, Any]]: + ) -> Iterable[dict[str, Any]]: columns = self.parse_columns_from_information(relation, information) for column in columns: @@ -625,14 +618,14 @@ def add_query( abridge_sql_log: bool = False, *, close_cursor: bool = False, - ) -> Tuple[Connection, Any]: + ) -> tuple[Connection, Any]: return self.connections.add_query( sql, auto_begin, bindings, abridge_sql_log, close_cursor=close_cursor ) def run_sql_for_tests( self, sql: str, fetch: str, conn: Connection - ) -> Optional[Union[Optional[Tuple], List[Tuple]]]: + ) -> Optional[Union[Optional[tuple], list[tuple]]]: cursor = conn.handle.cursor() try: cursor.execute(sql) @@ -650,11 +643,11 @@ def run_sql_for_tests( cursor.close() conn.transaction_open = False - def valid_incremental_strategies(self) -> List[str]: + def valid_incremental_strategies(self) -> list[str]: return ["append", "merge", "insert_overwrite", "replace_where", "microbatch"] @property - def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: + def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: return { "job_cluster": JobClusterPythonJobHelper, "all_purpose_cluster": AllPurposeClusterPythonJobHelper, @@ -690,8 +683,8 @@ def _catalog(self, catalog: Optional[str]) -> Iterator[None]: @available.parse(lambda *a, **k: {}) def get_persist_doc_columns( - self, existing_columns: List[DatabricksColumn], columns: Dict[str, Any] - ) -> Dict[str, Any]: + self, existing_columns: list[DatabricksColumn], columns: dict[str, Any] + ) -> dict[str, Any]: """Returns a dictionary of columns that have updated comments.""" return_columns = {} @@ -755,7 +748,7 @@ class RelationAPIBase(ABC, Generic[DatabricksRelationConfig]): @classmethod @abstractmethod - def config_type(cls) -> Type[DatabricksRelationConfig]: + def config_type(cls) -> type[DatabricksRelationConfig]: """Get the config class for delegating calls.""" raise NotImplementedError("Must be implemented by subclass") @@ -814,7 +807,7 @@ class MaterializedViewAPI(DeltaLiveTableAPIBase[MaterializedViewConfig]): relation_type = DatabricksRelationType.MaterializedView @classmethod - def config_type(cls) -> Type[MaterializedViewConfig]: + def config_type(cls) -> type[MaterializedViewConfig]: return MaterializedViewConfig @classmethod @@ -833,7 +826,7 @@ def _describe_relation( return results @staticmethod - def _get_information_schema_views(adapter: DatabricksAdapter, kwargs: Dict[str, Any]) -> "Row": + def _get_information_schema_views(adapter: DatabricksAdapter, kwargs: dict[str, Any]) -> "Row": return get_first_row(adapter.execute_macro("get_view_description", kwargs=kwargs)) @@ -841,7 +834,7 @@ class StreamingTableAPI(DeltaLiveTableAPIBase[StreamingTableConfig]): relation_type = DatabricksRelationType.StreamingTable @classmethod - def config_type(cls) -> Type[StreamingTableConfig]: + def config_type(cls) -> type[StreamingTableConfig]: return StreamingTableConfig @classmethod @@ -864,7 +857,7 @@ class IncrementalTableAPI(RelationAPIBase[IncrementalTableConfig]): relation_type = DatabricksRelationType.Table @classmethod - def config_type(cls) -> Type[IncrementalTableConfig]: + def config_type(cls) -> type[IncrementalTableConfig]: return IncrementalTableConfig @classmethod diff --git a/dbt/adapters/databricks/python_models/python_config.py b/dbt/adapters/databricks/python_models/python_config.py index 6398397d9..08d06d425 100644 --- a/dbt/adapters/databricks/python_models/python_config.py +++ b/dbt/adapters/databricks/python_models/python_config.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +from typing import Any +from typing import Optional import uuid from pydantic import BaseModel, Field @@ -10,10 +11,10 @@ class PythonJobConfig(BaseModel): """Pydantic model for config found in python_job_config.""" name: Optional[str] = None - grants: Dict[str, List[Dict[str, str]]] = Field(exclude=True, default_factory=dict) + grants: dict[str, list[dict[str, str]]] = Field(exclude=True, default_factory=dict) existing_job_id: str = Field("", exclude=True) - post_hook_tasks: List[Dict[str, Any]] = Field(exclude=True, default_factory=list) - additional_task_settings: Dict[str, Any] = Field(exclude=True, default_factory=dict) + post_hook_tasks: list[dict[str, Any]] = Field(exclude=True, default_factory=list) + additional_task_settings: dict[str, Any] = Field(exclude=True, default_factory=dict) class Config: extra = "allow" @@ -27,11 +28,11 @@ class PythonModelConfig(BaseModel): user_folder_for_python: bool = False timeout: int = Field(DEFAULT_TIMEOUT, gt=0) - job_cluster_config: Dict[str, Any] = Field(default_factory=dict) - access_control_list: List[Dict[str, str]] = Field(default_factory=list) - packages: List[str] = Field(default_factory=list) + job_cluster_config: dict[str, Any] = Field(default_factory=dict) + access_control_list: list[dict[str, str]] = Field(default_factory=list) + packages: list[str] = Field(default_factory=list) index_url: Optional[str] = None - additional_libs: List[Dict[str, Any]] = Field(default_factory=list) + additional_libs: list[dict[str, Any]] = Field(default_factory=list) python_job_config: PythonJobConfig = Field(default_factory=lambda: PythonJobConfig(**{})) cluster_id: Optional[str] = None http_path: Optional[str] = None diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index 4b564f1c9..a28a15619 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,9 +1,6 @@ from abc import ABC, abstractmethod from typing import Any -from typing import Dict -from typing import List from typing import Optional -from typing import Tuple from attr import dataclass from typing_extensions import override @@ -33,7 +30,7 @@ class BaseDatabricksHelper(PythonJobHelper): tracker = PythonRunTracker() - def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -114,8 +111,8 @@ class PythonJobDetails: """Details required to submit a Python job run to Databricks.""" run_name: str - job_spec: Dict[str, Any] - additional_job_config: Dict[str, Any] + job_spec: dict[str, Any] + additional_job_config: dict[str, Any] class PythonPermissionBuilder: @@ -127,7 +124,7 @@ def __init__( ) -> None: self.api_client = api_client - def _get_job_owner_for_config(self) -> Tuple[str, str]: + def _get_job_owner_for_config(self) -> tuple[str, str]: """Get the owner of the job (and type) for the access control list.""" curr_user = self.api_client.curr_user.get_username() is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) @@ -137,15 +134,15 @@ def _get_job_owner_for_config(self) -> Tuple[str, str]: @staticmethod def _build_job_permission( - job_grants: List[Dict[str, Any]], permission: str - ) -> List[Dict[str, Any]]: + job_grants: list[dict[str, Any]], permission: str + ) -> list[dict[str, Any]]: return [{**grant, **{"permission_level": permission}} for grant in job_grants] def build_job_permissions( self, - job_grants: Dict[str, List[Dict[str, Any]]], - acls: List[Dict[str, str]], - ) -> List[Dict[str, Any]]: + job_grants: dict[str, list[dict[str, Any]]], + acls: list[dict[str, str]], + ) -> list[dict[str, Any]]: """Build the access control list for the job.""" access_control_list = [] @@ -171,10 +168,10 @@ def build_job_permissions( def get_library_config( - packages: List[str], + packages: list[str], index_url: Optional[str], - additional_libraries: List[Dict[str, Any]], -) -> Dict[str, Any]: + additional_libraries: list[dict[str, Any]], +) -> dict[str, Any]: """Update the job configuration with the required libraries.""" libraries = [] @@ -199,7 +196,7 @@ def __init__( api_client: DatabricksApiClient, permission_builder: PythonPermissionBuilder, parsed_model: ParsedPythonModel, - cluster_spec: Dict[str, Any], + cluster_spec: dict[str, Any], ) -> None: self.api_client = api_client self.permission_builder = permission_builder @@ -215,7 +212,7 @@ def __init__( def compile(self, path: str) -> PythonJobDetails: - job_spec: Dict[str, Any] = { + job_spec: dict[str, Any] = { "task_key": "inner_notebook", "notebook_task": { "notebook_path": path, @@ -253,7 +250,7 @@ def create( api_client: DatabricksApiClient, tracker: PythonRunTracker, parsed_model: ParsedPythonModel, - cluster_spec: Dict[str, Any], + cluster_spec: dict[str, Any], ) -> "PythonNotebookSubmitter": notebook_uploader = PythonNotebookUploader(api_client, parsed_model) permission_builder = PythonPermissionBuilder(api_client) @@ -306,7 +303,7 @@ class AllPurposeClusterPythonJobHelper(BaseDatabricksHelper): Top level helper for Python models using job runs or Command API on an all-purpose cluster. """ - def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.credentials.validate_creds() self.parsed_model = ParsedPythonModel(**parsed_model) @@ -359,10 +356,10 @@ class PythonWorkflowConfigCompiler: def __init__( self, - task_settings: Dict[str, Any], - workflow_spec: Dict[str, Any], + task_settings: dict[str, Any], + workflow_spec: dict[str, Any], existing_job_id: str, - post_hook_tasks: List[Dict[str, Any]], + post_hook_tasks: list[dict[str, Any]], ) -> None: self.task_settings = task_settings self.existing_job_id = existing_job_id @@ -395,11 +392,11 @@ def workflow_name(parsed_model: ParsedPythonModel) -> str: ) @staticmethod - def cluster_settings(parsed_model: ParsedPythonModel) -> Dict[str, Any]: + def cluster_settings(parsed_model: ParsedPythonModel) -> dict[str, Any]: config = parsed_model.config job_cluster_config = config.job_cluster_config - cluster_settings: Dict[str, Any] = {} + cluster_settings: dict[str, Any] = {} if job_cluster_config: cluster_settings["new_cluster"] = job_cluster_config elif config.cluster_id: @@ -407,7 +404,7 @@ def cluster_settings(parsed_model: ParsedPythonModel) -> Dict[str, Any]: return cluster_settings - def compile(self, path: str) -> Tuple[Dict[str, Any], str]: + def compile(self, path: str) -> tuple[dict[str, Any], str]: notebook_task = { "task_key": "inner_notebook", "notebook_task": { @@ -429,7 +426,7 @@ def __init__(self, workflows: WorkflowJobApi) -> None: def create_or_update( self, - workflow_spec: Dict[str, Any], + workflow_spec: dict[str, Any], existing_job_id: Optional[str], ) -> str: """ @@ -465,8 +462,8 @@ def __init__( config_compiler: PythonWorkflowConfigCompiler, permission_builder: PythonPermissionBuilder, workflow_creater: PythonWorkflowCreator, - job_grants: Dict[str, List[Dict[str, str]]], - acls: List[Dict[str, str]], + job_grants: dict[str, list[dict[str, str]]], + acls: list[dict[str, str]], ) -> None: self.api_client = api_client self.tracker = tracker diff --git a/dbt/adapters/databricks/python_models/run_tracking.py b/dbt/adapters/databricks/python_models/run_tracking.py index 01f8ea1ec..e8f95a522 100644 --- a/dbt/adapters/databricks/python_models/run_tracking.py +++ b/dbt/adapters/databricks/python_models/run_tracking.py @@ -1,5 +1,4 @@ import threading -from typing import Set from dbt.adapters.databricks.api_client import CommandExecution from dbt.adapters.databricks.api_client import DatabricksApiClient @@ -8,8 +7,8 @@ class PythonRunTracker(object): - _run_ids: Set[str] = set() - _commands: Set[CommandExecution] = set() + _run_ids: set[str] = set() + _commands: set[CommandExecution] = set() _lock = threading.Lock() @classmethod diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 12c47cdea..61961e002 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,11 +1,8 @@ +from collections.abc import Iterable from dataclasses import dataclass from dataclasses import field from typing import Any -from typing import Dict -from typing import Iterable from typing import Optional -from typing import Set -from typing import Type from dbt.adapters.base.relation import BaseRelation from dbt.adapters.base.relation import InformationSchema @@ -66,10 +63,10 @@ class DatabricksRelation(BaseRelation): include_policy: Policy = field(default_factory=lambda: DatabricksIncludePolicy()) quote_character: str = "`" - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None @classmethod - def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]: + def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) if "database" not in data["path"]: data["path"]["database"] = None @@ -138,7 +135,7 @@ def matches( return match @classproperty - def get_relation_type(cls) -> Type[DatabricksRelationType]: + def get_relation_type(cls) -> type[DatabricksRelationType]: # type: ignore return DatabricksRelationType def information_schema(self, view_name: Optional[str] = None) -> InformationSchema: @@ -160,5 +157,5 @@ def is_hive_metastore(database: Optional[str]) -> bool: return database is None or database.lower() == "hive_metastore" -def extract_identifiers(relations: Iterable[BaseRelation]) -> Set[str]: +def extract_identifiers(relations: Iterable[BaseRelation]) -> set[str]: return {r.identifier for r in relations if r.identifier is not None} diff --git a/dbt/adapters/databricks/relation_configs/base.py b/dbt/adapters/databricks/relation_configs/base.py index 6fa36ee9a..9366226c3 100644 --- a/dbt/adapters/databricks/relation_configs/base.py +++ b/dbt/adapters/databricks/relation_configs/base.py @@ -2,16 +2,13 @@ from abc import abstractmethod from typing import Any from typing import ClassVar -from typing import Dict from typing import Generic -from typing import List from typing import Optional from typing import TypeVar from pydantic import BaseModel from pydantic import ConfigDict from typing_extensions import Self -from typing_extensions import Type from dbt.adapters.contracts.relation import RelationConfig from dbt.adapters.relation_configs.config_base import RelationResults @@ -49,7 +46,7 @@ class DatabricksRelationChangeSet(BaseModel): """Class for encapsulating the changes that need to be applied to a Databricks relation.""" model_config = ConfigDict(frozen=True) - changes: Dict[str, DatabricksComponentConfig] + changes: dict[str, DatabricksComponentConfig] requires_full_refresh: bool = False @property @@ -101,14 +98,14 @@ class DatabricksRelationConfigBase(BaseModel, ABC): # The list of components that make up the relation config. In the base implemenation, these # components are applied sequentially to either the existing relation, or the model node, to # build up the config. - config_components: ClassVar[List[Type[DatabricksComponentProcessor]]] - config: Dict[str, DatabricksComponentConfig] + config_components: ClassVar[list[type[DatabricksComponentProcessor]]] + config: dict[str, DatabricksComponentConfig] @classmethod def from_relation_config(cls, relation_config: RelationConfig) -> Self: """Build the relation config from a model node.""" - config_dict: Dict[str, DatabricksComponentConfig] = {} + config_dict: dict[str, DatabricksComponentConfig] = {} for component in cls.config_components: relation_component = component.from_relation_config(relation_config) if relation_component: @@ -120,7 +117,7 @@ def from_relation_config(cls, relation_config: RelationConfig) -> Self: def from_results(cls, results: RelationResults) -> Self: """Build the relation config from the results of a query against the existing relation.""" - config_dict: Dict[str, DatabricksComponentConfig] = {} + config_dict: dict[str, DatabricksComponentConfig] = {} for component in cls.config_components: result_component = component.from_relation_results(results) if result_component: diff --git a/dbt/adapters/databricks/relation_configs/incremental.py b/dbt/adapters/databricks/relation_configs/incremental.py index 3af6c3e0d..f31baf581 100644 --- a/dbt/adapters/databricks/relation_configs/incremental.py +++ b/dbt/adapters/databricks/relation_configs/incremental.py @@ -1,4 +1,3 @@ -from typing import Dict from typing import Optional from dbt.adapters.databricks.relation_configs.base import DatabricksComponentConfig @@ -14,7 +13,7 @@ class IncrementalTableConfig(DatabricksRelationConfigBase): def get_changeset( self, existing: "IncrementalTableConfig" ) -> Optional[DatabricksRelationChangeSet]: - changes: Dict[str, DatabricksComponentConfig] = {} + changes: dict[str, DatabricksComponentConfig] = {} for component in self.config_components: key = component.name diff --git a/dbt/adapters/databricks/relation_configs/materialized_view.py b/dbt/adapters/databricks/relation_configs/materialized_view.py index b95b61890..9a57edd9f 100644 --- a/dbt/adapters/databricks/relation_configs/materialized_view.py +++ b/dbt/adapters/databricks/relation_configs/materialized_view.py @@ -1,4 +1,3 @@ -from typing import Dict from typing import Optional from dbt.adapters.databricks.relation_configs.base import DatabricksComponentConfig @@ -33,7 +32,7 @@ class MaterializedViewConfig(DatabricksRelationConfigBase): def get_changeset( self, existing: "MaterializedViewConfig" ) -> Optional[DatabricksRelationChangeSet]: - changes: Dict[str, DatabricksComponentConfig] = {} + changes: dict[str, DatabricksComponentConfig] = {} requires_refresh = False for component in self.config_components: diff --git a/dbt/adapters/databricks/relation_configs/partitioning.py b/dbt/adapters/databricks/relation_configs/partitioning.py index ea8e58b38..89120b8ac 100644 --- a/dbt/adapters/databricks/relation_configs/partitioning.py +++ b/dbt/adapters/databricks/relation_configs/partitioning.py @@ -1,6 +1,5 @@ import itertools from typing import ClassVar -from typing import List from typing import Union from dbt.adapters.contracts.relation import RelationConfig @@ -13,7 +12,7 @@ class PartitionedByConfig(DatabricksComponentConfig): """Component encapsulating the partitioning of relations.""" - partition_by: List[str] + partition_by: list[str] class PartitionedByProcessor(DatabricksComponentProcessor): @@ -35,7 +34,7 @@ def from_relation_results(cls, results: RelationResults) -> PartitionedByConfig: @classmethod def from_relation_config(cls, relation_config: RelationConfig) -> PartitionedByConfig: - partition_by: Union[str, List[str], None] = base.get_config_value( + partition_by: Union[str, list[str], None] = base.get_config_value( relation_config, "partition_by" ) if not partition_by: diff --git a/dbt/adapters/databricks/relation_configs/streaming_table.py b/dbt/adapters/databricks/relation_configs/streaming_table.py index 7c00d549c..d891da13d 100644 --- a/dbt/adapters/databricks/relation_configs/streaming_table.py +++ b/dbt/adapters/databricks/relation_configs/streaming_table.py @@ -1,4 +1,3 @@ -from typing import Dict from typing import Optional from dbt.adapters.databricks.relation_configs.base import DatabricksComponentConfig @@ -31,7 +30,7 @@ def get_changeset( """Get the changeset that must be applied to the existing relation to make it match the current state of the dbt project. """ - changes: Dict[str, DatabricksComponentConfig] = {} + changes: dict[str, DatabricksComponentConfig] = {} requires_refresh = False for component in self.config_components: diff --git a/dbt/adapters/databricks/relation_configs/tags.py b/dbt/adapters/databricks/relation_configs/tags.py index ddd356908..a739abe19 100644 --- a/dbt/adapters/databricks/relation_configs/tags.py +++ b/dbt/adapters/databricks/relation_configs/tags.py @@ -1,6 +1,4 @@ from typing import ClassVar -from typing import Dict -from typing import List from typing import Optional from dbt.adapters.contracts.relation import RelationConfig @@ -14,8 +12,8 @@ class TagsConfig(DatabricksComponentConfig): """Component encapsulating the tblproperties of a relation.""" - set_tags: Dict[str, str] - unset_tags: List[str] = [] + set_tags: dict[str, str] + unset_tags: list[str] = [] def get_diff(self, other: "TagsConfig") -> Optional["TagsConfig"]: to_unset = [] @@ -46,7 +44,7 @@ def from_relation_config(cls, relation_config: RelationConfig) -> TagsConfig: tags = base.get_config_value(relation_config, "databricks_tags") if not tags: return TagsConfig(set_tags=dict()) - if isinstance(tags, Dict): + if isinstance(tags, dict): tags = {str(k): str(v) for k, v in tags.items()} return TagsConfig(set_tags=tags) else: diff --git a/dbt/adapters/databricks/relation_configs/tblproperties.py b/dbt/adapters/databricks/relation_configs/tblproperties.py index 060cf2a47..dad00e537 100644 --- a/dbt/adapters/databricks/relation_configs/tblproperties.py +++ b/dbt/adapters/databricks/relation_configs/tblproperties.py @@ -1,7 +1,5 @@ from typing import Any from typing import ClassVar -from typing import Dict -from typing import List from typing import Optional from dbt.adapters.contracts.relation import RelationConfig @@ -15,12 +13,12 @@ class TblPropertiesConfig(DatabricksComponentConfig): """Component encapsulating the tblproperties of a relation.""" - tblproperties: Dict[str, str] + tblproperties: dict[str, str] pipeline_id: Optional[str] = None # List of tblproperties that should be ignored when comparing configs. These are generally # set by Databricks and are not user-configurable. - ignore_list: List[str] = [ + ignore_list: list[str] = [ "pipelines.pipelineId", "delta.enableChangeDataFeed", "delta.minReaderVersion", @@ -47,7 +45,7 @@ def __eq__(self, __value: Any) -> bool: if not isinstance(__value, TblPropertiesConfig): return False - def _without_ignore_list(d: Dict[str, str]) -> Dict[str, str]: + def _without_ignore_list(d: dict[str, str]) -> dict[str, str]: return {k: v for k, v in d.items() if k not in self.ignore_list} return _without_ignore_list(self.tblproperties) == _without_ignore_list( @@ -77,7 +75,7 @@ def from_relation_config(cls, relation_config: RelationConfig) -> TblPropertiesC tblproperties = base.get_config_value(relation_config, "tblproperties") or {} is_iceberg = base.get_config_value(relation_config, "table_format") == "iceberg" - if not isinstance(tblproperties, Dict): + if not isinstance(tblproperties, dict): raise DbtRuntimeError("tblproperties must be a dictionary") # If the table format is Iceberg, we need to set the iceberg-specific tblproperties diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 552458f03..2fbf73115 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -1,9 +1,8 @@ +from collections.abc import Callable import functools import inspect import re from typing import Any -from typing import Callable -from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -46,7 +45,7 @@ def remove_undefined(v: Any) -> Any: return None if isinstance(v, Undefined) else v -def undefined_proof(cls: Type[A]) -> Type[A]: +def undefined_proof(cls: type[A]) -> type[A]: for name in cls._available_: func = getattr(cls, name) if not callable(func): diff --git a/tests/functional/adapter/materialized_view_tests/test_basic.py b/tests/functional/adapter/materialized_view_tests/test_basic.py index 08c49bba8..1cb17d8e1 100644 --- a/tests/functional/adapter/materialized_view_tests/test_basic.py +++ b/tests/functional/adapter/materialized_view_tests/test_basic.py @@ -1,5 +1,4 @@ from typing import Optional -from typing import Tuple import pytest @@ -11,7 +10,7 @@ class TestMaterializedViewsMixin: @staticmethod - def insert_record(project, table: BaseRelation, record: Tuple[int, int]) -> None: + def insert_record(project, table: BaseRelation, record: tuple[int, int]) -> None: project.run_sql(f"insert into {table} values {record}") @staticmethod diff --git a/tests/functional/adapter/persist_constraints/test_persist_constraints.py b/tests/functional/adapter/persist_constraints/test_persist_constraints.py index 324cb1fa6..42c9862fa 100644 --- a/tests/functional/adapter/persist_constraints/test_persist_constraints.py +++ b/tests/functional/adapter/persist_constraints/test_persist_constraints.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest from dbt.contracts.results import RunResult @@ -44,7 +42,7 @@ def project_config_update(self): "snapshots": {"+persist_constraints": True}, } - def check_constraints(self, project, model_name: str, expected: Dict[str, str]): + def check_constraints(self, project, model_name: str, expected: dict[str, str]): rows = project.run_sql("show tblproperties {database}.{schema}." + model_name, fetch="all") constraints = { row.key: row.value for row in rows if row.key.startswith("delta.constraints") diff --git a/tests/functional/adapter/streaming_tables/test_st_basic.py b/tests/functional/adapter/streaming_tables/test_st_basic.py index ad5d8a253..ba95143b2 100644 --- a/tests/functional/adapter/streaming_tables/test_st_basic.py +++ b/tests/functional/adapter/streaming_tables/test_st_basic.py @@ -1,5 +1,4 @@ from typing import Optional -from typing import Tuple import pytest @@ -16,7 +15,7 @@ @pytest.mark.skip_profile("databricks_cluster", "databricks_uc_cluster") class TestStreamingTablesBasic: @staticmethod - def insert_record(project, table: BaseRelation, record: Tuple[int, int]): + def insert_record(project, table: BaseRelation, record: tuple[int, int]): project.run_sql(f"insert into {table} values {record}") @staticmethod diff --git a/tests/functional/adapter/tblproperties/test_set_tblproperties.py b/tests/functional/adapter/tblproperties/test_set_tblproperties.py index 104db59e4..e6e5ca78f 100644 --- a/tests/functional/adapter/tblproperties/test_set_tblproperties.py +++ b/tests/functional/adapter/tblproperties/test_set_tblproperties.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from dbt.tests import util @@ -26,7 +24,7 @@ def seeds(self): "expected.csv": fixtures.seed_csv, } - def check_tblproperties(self, project, model_name: str, properties: List[str]): + def check_tblproperties(self, project, model_name: str, properties: list[str]): results = util.run_sql_with_adapter( project.adapter, f"show tblproperties {project.test_schema}.{model_name}", diff --git a/tests/profiles.py b/tests/profiles.py index f21d728c1..e34c5073f 100644 --- a/tests/profiles.py +++ b/tests/profiles.py @@ -1,6 +1,5 @@ import os from typing import Any -from typing import Dict from typing import Optional @@ -19,9 +18,9 @@ def _build_databricks_cluster_target( http_path: str, catalog: Optional[str] = None, schema: Optional[str] = None, - session_properties: Optional[Dict[str, str]] = None, + session_properties: Optional[dict[str, str]] = None, ): - profile: Dict[str, Any] = { + profile: dict[str, Any] = { "type": "databricks", "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), "http_path": http_path, diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 89a206ef7..b6d478ccb 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -1,12 +1,10 @@ -from typing import List - from agate import Table def gen_describe_extended( - columns: List[List[str]] = [["col_a", "int", "This is a comment"]], - partition_info: List[List[str]] = [], - detailed_table_info: List[List[str]] = [], + columns: list[list[str]] = [["col_a", "int", "This is a comment"]], + partition_info: list[list[str]] = [], + detailed_table_info: list[list[str]] = [], ) -> Table: return Table( rows=[ @@ -24,5 +22,5 @@ def gen_describe_extended( ) -def gen_tblproperties(rows: List[List[str]] = [["prop", "1"], ["other", "other"]]) -> Table: +def gen_tblproperties(rows: list[list[str]] = [["prop", "1"], ["other", "other"]]) -> Table: return Table(rows=rows, column_names=["key", "value"]) diff --git a/tests/unit/macros/base.py b/tests/unit/macros/base.py index e17f2aec4..80912ba0a 100644 --- a/tests/unit/macros/base.py +++ b/tests/unit/macros/base.py @@ -1,6 +1,5 @@ import re from typing import Any -from typing import Dict import pytest from jinja2 import Environment @@ -25,7 +24,7 @@ def config(self, context) -> dict: """ Anything you put in this dict will be returned by config in the rendered template """ - local_config: Dict[str, Any] = {} + local_config: dict[str, Any] = {} context["config"].get = lambda key, default=None, **kwargs: local_config.get(key, default) return local_config @@ -34,7 +33,7 @@ def var(self, context) -> dict: """ Anything you put in this dict will be returned by var in the rendered template """ - local_var: Dict[str, Any] = {} + local_var: dict[str, Any] = {} context["var"] = lambda key, default=None, **kwargs: local_var.get(key, default) return local_var diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 5d7afb34f..3dd929bd3 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,6 +1,5 @@ from multiprocessing import get_context from typing import Any -from typing import Dict from typing import Optional import dbt.flags as flags @@ -57,7 +56,7 @@ def setUp(self): def _get_config( self, token: Optional[str] = "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", - session_properties: Optional[Dict[str, str]] = {"spark.sql.ansi.enabled": "true"}, + session_properties: Optional[dict[str, str]] = {"spark.sql.ansi.enabled": "true"}, **kwargs: Any, ) -> RuntimeConfig: if token: From 48c82c340c2b93c6149b5cba2a76ccd614c75a21 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Thu, 24 Oct 2024 15:24:09 -0700 Subject: [PATCH 2/3] one of the type usages doesn't work --- dbt/adapters/databricks/relation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 61961e002..efb36a01d 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from dataclasses import dataclass from dataclasses import field -from typing import Any +from typing import Any, Type from typing import Optional from dbt.adapters.base.relation import BaseRelation @@ -135,7 +135,7 @@ def matches( return match @classproperty - def get_relation_type(cls) -> type[DatabricksRelationType]: # type: ignore + def get_relation_type(cls) -> Type[DatabricksRelationType]: return DatabricksRelationType def information_schema(self, view_name: Optional[str] = None) -> InformationSchema: From 11ea77eab732f471f7b68e2f4d18ee9b42448efc Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Thu, 24 Oct 2024 15:52:23 -0700 Subject: [PATCH 3/3] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb759c855..ac2f67949 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ - Fix places where we were not properly closing cursors, and other test warnings ([713](https://github.com/databricks/dbt-databricks/pull/713)) - Drop support for Python 3.8 ([713](https://github.com/databricks/dbt-databricks/pull/713)) - Upgrade databricks-sql-connector dependency to 3.5.0 ([833](https://github.com/databricks/dbt-databricks/pull/833)) +- Prepare for python typing deprecations ([837](https://github.com/databricks/dbt-databricks/pull/837)) ## dbt-databricks 1.8.7 (October 10, 2024)