Skip to content

Commit

Permalink
Use per-table DTE to get the table comments
Browse files Browse the repository at this point in the history
Signed-off-by: Christophe Bornet <cbornet@hotmail.com>
  • Loading branch information
cbornet committed Jan 2, 2024
1 parent cfdf3f3 commit e49eed3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 100 deletions.
23 changes: 0 additions & 23 deletions src/databricks/sqlalchemy/_information_schema.py

This file was deleted.

95 changes: 18 additions & 77 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
from databricks import sql
from databricks.sqlalchemy._information_schema import tables
from databricks.sqlalchemy._parse import (
_describe_table_extended_result_to_dict_list,
_match_table_not_found_string,
Expand All @@ -15,16 +14,15 @@
)

import sqlalchemy
from sqlalchemy import DDL, event, select, bindparam, exc
from sqlalchemy import DDL, event, exc
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.interfaces import (
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedColumn,
ReflectedTableComment,
TableKey,
)
from sqlalchemy.engine.reflection import ReflectionDefaults, ObjectKind, ObjectScope
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

try:
Expand Down Expand Up @@ -376,86 +374,29 @@ def get_schema_names(self, connection, **kw):
schema_list = [row[0] for row in result]
return schema_list

def get_multi_table_comment(
self,
connection,
schema=None,
filter_names=None,
scope=ObjectScope.ANY,
kind=ObjectKind.ANY,
**kw,
) -> Iterable[Tuple[TableKey, ReflectedTableComment]]:
result = []
_schema = schema or self.schema
if ObjectScope.DEFAULT in scope:
query = (
select(tables.c.table_name, tables.c.comment)
.select_from(tables)
.where(
tables.c.table_catalog == self.catalog,
tables.c.table_schema == _schema,
)
)

if ObjectKind.ANY not in kind:
where_in = set()
if ObjectKind.TABLE in kind:
where_in.update(
["BASE TABLE", "MANAGED", "EXTERNAL", "STREAMING_TABLE"]
)
if ObjectKind.VIEW in kind:
where_in.update(["VIEW"])
if ObjectKind.MATERIALIZED_VIEW in kind:
where_in.update(["MATERIALIZED_VIEW"])
query = query.where(tables.c.table_type.in_(where_in))

if filter_names:
query = query.where(tables.c.table_name.in_(bindparam("filter_names")))
result = connection.execute(
query, {"filter_names": [f.lower() for f in filter_names]}
)
else:
result = connection.execute(query)

if ObjectScope.TEMPORARY in scope and ObjectKind.VIEW in kind:
result = list(result)
temp_views = self.get_view_names(connection, schema, only_temp=True)
if filter_names:
temp_views = set(temp_views).intersection(
[f.lower() for f in filter_names]
)
result.extend(zip(temp_views, [None] * len(temp_views)))

# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
return (
(
(schema, table),
{"text": comment}
if comment is not None
else ReflectionDefaults.table_comment(),
)
for table, comment in result
) # type: ignore

@reflection.cache
def get_table_comment(
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> ReflectedTableComment:
data = self.get_multi_table_comment(
connection,
schema,
[table_name],
**kw,
result = self._describe_table_extended(
connection=connection,
table_name=table_name,
schema_name=schema,
)
# Type ignore is because mypy knows that self._describe_table_extended *can*
# return None (even though it never will since expect_result defaults to True)
comment_row: Dict[str, str] = next(
filter(lambda r: r["col_name"] == "Comment", result), None
) # type: ignore
return (
{"text": comment_row["data_type"]}
if comment_row
else ReflectionDefaults.table_comment()
)
try:
return dict(data)[(schema, table_name.lower())]
except KeyError:
raise exc.NoSuchTableError(
f"{schema}.{table_name}" if schema else table_name
) from None


@event.listens_for(Engine, "do_connect")
Expand Down

0 comments on commit e49eed3

Please sign in to comment.