Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing deprecated types #837

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import base64
from collections.abc import Callable
import time
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Set

from dbt.adapters.databricks import utils
from dbt.adapters.databricks.__version__ import version
Expand All @@ -34,17 +31,17 @@ def __init__(self, session: Session, host: str, api: str):
self.session = session

def get(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.get(f"{self.prefix}{suffix}", json=json, params=params)

def post(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.post(f"{self.prefix}{suffix}", json=json, params=params)

def put(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.put(f"{self.prefix}{suffix}", json=json, params=params)

Expand Down Expand Up @@ -230,7 +227,7 @@ def _poll_api(
url: str,
params: dict,
get_state_func: Callable[[Response], str],
terminal_states: Set[str],
terminal_states: set[str],
expected_end_state: str,
unexpected_end_state_func: Callable[[Response], None],
) -> Response:
Expand Down Expand Up @@ -261,7 +258,7 @@ class CommandExecution(object):
context_id: str
cluster_id: str

def model_dump(self) -> Dict[str, Any]:
def model_dump(self) -> dict[str, Any]:
return {
"commandId": self.command_id,
"contextId": self.context_id,
Expand Down Expand Up @@ -328,7 +325,7 @@ def __init__(self, session: Session, host: str, polling_interval: int, timeout:
super().__init__(session, host, "/api/2.1/jobs/runs", polling_interval, timeout)

def submit(
self, run_name: str, job_spec: Dict[str, Any], **additional_job_settings: Dict[str, Any]
self, run_name: str, job_spec: dict[str, Any], **additional_job_settings: dict[str, Any]
) -> str:
submit_response = self.session.post(
"/submit", json={"run_name": run_name, "tasks": [job_spec], **additional_job_settings}
Expand Down Expand Up @@ -388,7 +385,7 @@ class JobPermissionsApi(DatabricksApi):
def __init__(self, session: Session, host: str):
super().__init__(session, host, "/api/2.0/permissions/jobs")

def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None:
def put(self, job_id: str, access_control_list: list[dict[str, Any]]) -> None:
request_body = {"access_control_list": access_control_list}

response = self.session.put(f"/{job_id}", json=request_body)
Expand All @@ -397,7 +394,7 @@ def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None:
if response.status_code != 200:
raise DbtRuntimeError(f"Error updating Databricks workflow.\n {response.content!r}")

def get(self, job_id: str) -> Dict[str, Any]:
def get(self, job_id: str) -> dict[str, Any]:
response = self.session.get(f"/{job_id}")

if response.status_code != 200:
Expand All @@ -413,15 +410,15 @@ class WorkflowJobApi(DatabricksApi):
def __init__(self, session: Session, host: str):
super().__init__(session, host, "/api/2.1/jobs")

def search_by_name(self, job_name: str) -> List[Dict[str, Any]]:
def search_by_name(self, job_name: str) -> list[dict[str, Any]]:
response = self.session.get("/list", json={"name": job_name})

if response.status_code != 200:
raise DbtRuntimeError(f"Error fetching job by name.\n {response.content!r}")

return response.json().get("jobs", [])

def create(self, job_spec: Dict[str, Any]) -> str:
def create(self, job_spec: dict[str, Any]) -> str:
"""
:return: the job_id
"""
Expand All @@ -434,7 +431,7 @@ def create(self, job_spec: Dict[str, Any]) -> str:
logger.info(f"New workflow created with job id {job_id}")
return job_id

def update_job_settings(self, job_id: str, job_spec: Dict[str, Any]) -> None:
def update_job_settings(self, job_id: str, job_spec: dict[str, Any]) -> None:
request_body = {
"job_id": job_id,
"new_settings": job_spec,
Expand Down
5 changes: 2 additions & 3 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any
from typing import Dict
from typing import Optional

from databricks.sdk.core import Config
Expand Down Expand Up @@ -34,7 +33,7 @@ def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]:
def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}

def inner() -> Dict[str, str]:
def inner() -> dict[str, str]:
return static_credentials

return inner
Expand Down Expand Up @@ -81,7 +80,7 @@ def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> Crede
return c

def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
def inner() -> Dict[str, str]:
def inner() -> dict[str, str]:
token = self._token_source.token() # type: ignore
return {"Authorization": f"{token.token_type} {token.access_token}"}

Expand Down
13 changes: 6 additions & 7 deletions dbt/adapters/databricks/behaviors/columns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import List
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.relation import DatabricksRelation
Expand All @@ -14,13 +13,13 @@ class GetColumnsBehavior(ABC):
@abstractmethod
def get_columns_in_relation(
cls, adapter: SQLAdapter, relation: DatabricksRelation
) -> List[DatabricksColumn]:
) -> list[DatabricksColumn]:
pass

@staticmethod
def _get_columns_with_comments(
adapter: SQLAdapter, relation: DatabricksRelation, macro_name: str
) -> List[AttrDict]:
) -> list[AttrDict]:
return list(
handle_missing_objects(
lambda: adapter.execute_macro(macro_name, kwargs={"relation": relation}),
Expand All @@ -33,12 +32,12 @@ class GetColumnsByDescribe(GetColumnsBehavior):
@classmethod
def get_columns_in_relation(
cls, adapter: SQLAdapter, relation: DatabricksRelation
) -> List[DatabricksColumn]:
) -> list[DatabricksColumn]:
rows = cls._get_columns_with_comments(adapter, relation, "get_columns_comments")
return cls._parse_columns(rows)

@classmethod
def _parse_columns(cls, rows: List[AttrDict]) -> List[DatabricksColumn]:
def _parse_columns(cls, rows: list[AttrDict]) -> list[DatabricksColumn]:
columns = []

for row in rows:
Expand All @@ -57,7 +56,7 @@ class GetColumnsByInformationSchema(GetColumnsByDescribe):
@classmethod
def get_columns_in_relation(
cls, adapter: SQLAdapter, relation: DatabricksRelation
) -> List[DatabricksColumn]:
) -> list[DatabricksColumn]:
if relation.is_hive_metastore() or relation.type == DatabricksRelation.View:
return super().get_columns_in_relation(adapter, relation)

Expand All @@ -67,5 +66,5 @@ def get_columns_in_relation(
return cls._parse_columns(rows)

@classmethod
def _parse_columns(cls, rows: List[AttrDict]) -> List[DatabricksColumn]:
def _parse_columns(cls, rows: list[AttrDict]) -> list[DatabricksColumn]:
return [DatabricksColumn(column=row[0], dtype=row[1], comment=row[2]) for row in rows]
3 changes: 1 addition & 2 deletions dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from typing import ClassVar
from typing import Dict
from typing import Optional

from dbt.adapters.spark.column import SparkColumn
Expand All @@ -11,7 +10,7 @@ class DatabricksColumn(SparkColumn):
table_comment: Optional[str] = None
comment: Optional[str] = None

TYPE_LABELS: ClassVar[Dict[str, str]] = {
TYPE_LABELS: ClassVar[dict[str, str]] = {
"LONG": "BIGINT",
}

Expand Down
49 changes: 22 additions & 27 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable, Iterator, Sequence
import decimal
import os
import re
Expand All @@ -11,15 +12,9 @@
from numbers import Number
from threading import get_ident
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Hashable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING

import databricks.sql as dbsql
Expand Down Expand Up @@ -107,7 +102,7 @@ class DatabricksSQLConnectionWrapper:

_conn: DatabricksSQLConnection
_is_cluster: bool
_cursors: List[DatabricksSQLCursor]
_cursors: list[DatabricksSQLCursor]
_creds: DatabricksCredentials
_user_agent: str

Expand Down Expand Up @@ -140,7 +135,7 @@ def cursor(self) -> "DatabricksSQLCursorWrapper":
def cancel(self) -> None:
logger.debug(ConnectionCancel(self._conn))

cursors: List[DatabricksSQLCursor] = self._cursors
cursors: list[DatabricksSQLCursor] = self._cursors

for cursor in cursors:
try:
Expand All @@ -159,10 +154,10 @@ def close(self) -> None:
def rollback(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: rollback")

_dbr_version: Tuple[int, int]
_dbr_version: tuple[int, int]

@property
def dbr_version(self) -> Tuple[int, int]:
def dbr_version(self) -> tuple[int, int]:
if not hasattr(self, "_dbr_version"):
if self._is_cluster:
with self._conn.cursor() as cursor:
Expand Down Expand Up @@ -214,13 +209,13 @@ def close(self) -> None:
except Error as exc:
logger.warning(CursorCloseError(self._cursor, exc))

def fetchall(self) -> Sequence[Tuple]:
def fetchall(self) -> Sequence[tuple]:
return self._cursor.fetchall()

def fetchone(self) -> Optional[Tuple]:
def fetchone(self) -> Optional[tuple]:
return self._cursor.fetchone()

def fetchmany(self, size: int) -> Sequence[Tuple]:
def fetchmany(self, size: int) -> Sequence[tuple]:
return self._cursor.fetchmany(size)

def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
Expand Down Expand Up @@ -315,7 +310,7 @@ def poll_refresh_pipeline(self, pipeline_id: str) -> None:
return

@classmethod
def findUpdate(cls, updates: List, id: str) -> Optional[Dict]:
def findUpdate(cls, updates: list, id: str) -> Optional[dict]:
matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]
Expand Down Expand Up @@ -343,7 +338,7 @@ def _fix_binding(cls, value: Any) -> Any:
return value

@property
def description(self) -> Optional[List[Tuple]]:
def description(self) -> Optional[list[tuple]]:
return self._cursor.description

def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None:
Expand Down Expand Up @@ -406,7 +401,7 @@ class DatabricksDBTConnection(Connection):
acquire_release_count: int = 0
compute_name: str = ""
http_path: str = ""
thread_identifier: Tuple[int, int] = (0, 0)
thread_identifier: tuple[int, int] = (0, 0)
max_idle_time: int = DEFAULT_MAX_IDLE_TIME

# If the connection is being used for a model we want to track the model language.
Expand Down Expand Up @@ -479,7 +474,7 @@ class DatabricksConnectionManager(SparkConnectionManager):
credentials_provider: Optional[TCredentialProvider] = None
_user_agent = f"dbt-databricks/{__version__}"

def cancel_open(self) -> List[str]:
def cancel_open(self) -> list[str]:
cancelled = super().cancel_open()
creds = cast(DatabricksCredentials, self.profile.credentials)
api_client = DatabricksApiClient.create(creds, 15 * 60)
Expand All @@ -494,7 +489,7 @@ def compare_dbr_version(self, major: int, minor: int) -> int:
dbr_version = connection.dbr_version
return (dbr_version > version) - (dbr_version < version)

def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
def set_query_header(self, query_header_context: dict[str, Any]) -> None:
self.query_header = DatabricksMacroQueryStringSetter(self.profile, query_header_context)

@contextmanager
Expand Down Expand Up @@ -572,7 +567,7 @@ def add_query(
abridge_sql_log: bool = False,
*,
close_cursor: bool = False,
) -> Tuple[Connection, Any]:
) -> tuple[Connection, Any]:
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()
Expand Down Expand Up @@ -622,7 +617,7 @@ def execute(
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[DatabricksAdapterResponse, "Table"]:
) -> tuple[DatabricksAdapterResponse, "Table"]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
try:
Expand Down Expand Up @@ -735,7 +730,7 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = list(
http_headers: list[tuple[str, str]] = list(
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

Expand Down Expand Up @@ -805,8 +800,8 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) ->
USE_LONG_SESSIONS
), "This connection manager should only be used when USE_LONG_SESSIONS is enabled"
super().__init__(profile, mp_context)
self.threads_compute_connections: Dict[
Hashable, Dict[Hashable, DatabricksDBTConnection]
self.threads_compute_connections: dict[
Hashable, dict[Hashable, DatabricksDBTConnection]
] = {}

def set_connection_name(
Expand Down Expand Up @@ -908,7 +903,7 @@ def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None:

def _get_compute_connections(
self,
) -> Dict[Hashable, DatabricksDBTConnection]:
) -> dict[Hashable, DatabricksDBTConnection]:
"""Retrieve a map of compute name to connection for the current thread."""

thread_id = self.get_thread_identifier()
Expand Down Expand Up @@ -973,7 +968,7 @@ def _create_compute_connection(
conn.compute_name = compute_name
creds = cast(DatabricksCredentials, self.profile.credentials)
conn.http_path = _get_http_path(query_header_context, creds=creds) or ""
conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier())
conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier())
conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds)

conn.handle = LazyHandle(self.open)
Expand Down Expand Up @@ -1027,7 +1022,7 @@ def open(cls, connection: Connection) -> Connection:

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = list(
http_headers: list[tuple[str, str]] = list(
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

Expand Down Expand Up @@ -1094,7 +1089,7 @@ def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
return response.json()


def _find_update(pipeline: dict, id: str = "") -> Optional[Dict]:
def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")
Expand Down
Loading
Loading