Skip to content

Commit

Permalink
Add support for table comments (#308)
Browse files Browse the repository at this point in the history
* Add support for table comments

Signed-off-by: Christophe Bornet <cbornet@hotmail.com>

* Use per-table DTE to get the table comments

Signed-off-by: Christophe Bornet <cbornet@hotmail.com>

* Revert pytest.ini change

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Fix typo in test name for columns. Move .engine and .compile into a base
class. Scaffold in the Table Comment unit tests.

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Add unit tests for table comments

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Revert overrides since these aren't needed after #328

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Stop skipping table comment portions of ComponentReflectionTest

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Move DTE parsing into _parse.py

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Add e2e test using inspector

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Add unit test for new method in _parse.py

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

* Fix assertion in column comment test

This was missed in #328

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>

---------

Signed-off-by: Christophe Bornet <cbornet@hotmail.com>
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
Co-authored-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
cbornet and Jesse Whitehouse authored Jan 23, 2024
1 parent a7f4773 commit f41a996
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 61 deletions.
10 changes: 8 additions & 2 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def __init__(self, dialect):

class DatabricksDDLCompiler(compiler.DDLCompiler):
def post_create_table(self, table):
return " USING DELTA"
post = " USING DELTA"
if table.comment:
comment = self.sql_compiler.render_literal_value(
table.comment, sqltypes.String()
)
post += " COMMENT " + comment
return post

def visit_unique_constraint(self, constraint, **kw):
logger.warning("Databricks does not support unique constraints")
Expand Down Expand Up @@ -61,7 +67,7 @@ def get_column_specification(self, column, **kwargs):
feature in the future, similar to the Microsoft SQL Server dialect.
"""
if column is column.table._autoincrement_column or column.autoincrement is True:
logger.warn(
logger.warning(
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
)

Expand Down
23 changes: 23 additions & 0 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,20 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
return output_rows


def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name`
value contains the match argument.
"""

output_rows = []

for row_dict in dte_output:
if match in row_dict["col_name"]:
output_rows.append(row_dict)

return output_rows


def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
one dictionary per defined constraint
Expand All @@ -275,6 +289,15 @@ def get_pk_strings_from_dte_output(
return output


def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]:
"""Returns the value of the first "Comment" col_name data in dte_output"""
output = match_dte_rows_by_key(dte_output, "Comment")
if not output:
return None
else:
return output[0]["data_type"]


# The keys of this dictionary are the values we expect to see in a
# TGetColumnsRequest's .TYPE_NAME attribute.
# These are enumerated in ttypes.py as class TTypeId.
Expand Down
36 changes: 30 additions & 6 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
Expand All @@ -11,19 +10,20 @@
build_pk_dict,
get_fk_strings_from_dte_output,
get_pk_strings_from_dte_output,
get_comment_from_dte_output,
parse_column_info_from_tgetcolumnsresponse,
)

import sqlalchemy
from sqlalchemy import DDL, event
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.reflection import ObjectKind
from sqlalchemy.engine.interfaces import (
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedColumn,
TableKey,
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

try:
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
views_result = self.get_view_names(connection=connection, schema=schema)

# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
# Potential optimisation: rewrite this to instead query informtation_schema
# Potential optimisation: rewrite this to instead query information_schema
tables_minus_views = [
row.tableName for row in tables_result if row.tableName not in views_result
]
Expand Down Expand Up @@ -328,7 +328,7 @@ def get_materialized_view_names(
def get_temp_view_names(
self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
"""A wrapper around get_view_names that fetches only the names of temporary views"""
return self.get_view_names(connection, schema, only_temp=True)

def do_rollback(self, dbapi_connection):
Expand Down Expand Up @@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw):
schema_list = [row[0] for row in result]
return schema_list

@reflection.cache
def get_table_comment(
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> ReflectedTableComment:
result = self._describe_table_extended(
connection=connection,
table_name=table_name,
schema_name=schema,
)

if result is None:
return ReflectionDefaults.table_comment()

comment = get_comment_from_dte_output(result)

if comment:
return dict(text=comment)
else:
return ReflectionDefaults.table_comment()


@event.listens_for(Engine, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/sqlalchemy/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def table_reflection(self):
"""target database has general support for table reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection(self):
"""Indicates if the database support table comment reflection"""
return sqlalchemy.testing.exclusions.open()

@property
def comment_reflection_full_unicode(self):
"""Indicates if the database support table comment reflection in the
full unicode range, including emoji etc.
"""
return sqlalchemy.testing.exclusions.open()

@property
def temp_table_reflection(self):
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.
Expand Down
48 changes: 0 additions & 48 deletions src/databricks/sqlalchemy/test/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
ComponentReflectionTest,
ComponentReflectionTestExtra,
CTETest,
FutureTableDDLTest,
InsertBehaviorTest,
TableDDLTest,
)
from sqlalchemy.testing.suite import (
ArrayTest,
Expand Down Expand Up @@ -53,7 +51,6 @@ class FutureFeature(Enum):
PROVISION = "event-driven engine configuration"
REGEXP = "_visit_regexp"
SANE_ROWCOUNT = "sane_rowcount support"
TBL_COMMENTS = "table comment reflection"
TBL_OPTS = "get_table_options method"
TEST_DESIGN = "required test-fixture overrides"
TUPLE_LITERAL = "tuple-like IN markers completely"
Expand Down Expand Up @@ -251,36 +248,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
pass


class FutureTableDDLTest(FutureTableDDLTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class TableDDLTest(TableDDLTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_add_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_drop_table_comment(self, connection):
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
pass


class ComponentReflectionTest(ComponentReflectionTest):
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_multi_table_comment(self):
"""There are 84 permutations of this test that are skipped."""
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
def test_multi_get_table_options_tables(self):
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
Expand All @@ -302,22 +270,6 @@ def test_get_multi_pk_constraint(self):
def test_get_multi_check_constraints(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_get_comments_with_schema(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode(self):
pass

@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
def test_comments_unicode_full(self):
pass


class ComponentReflectionTestExtra(ComponentReflectionTestExtra):
@pytest.mark.skip(render_future_feature(FutureFeature.CHECK))
Expand Down
48 changes: 45 additions & 3 deletions src/databricks/sqlalchemy/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import decimal
import os
from typing import Tuple, Union
from typing import Tuple, Union, List
from unittest import skipIf

import pytest
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_column_comment(db_engine, metadata_obj: MetaData):
connection=connection, table_name=table_name
)

assert columns[0].get("comment") == ""
assert columns[0].get("comment") == None

metadata_obj.drop_all(db_engine)

Expand Down Expand Up @@ -477,7 +477,7 @@ def sample_table(metadata_obj: MetaData, db_engine: Engine):

table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))

args = [
args: List[Column] = [
Column(colname, coltype) for colname, coltype in GET_COLUMNS_TYPE_MAP.items()
]

Expand All @@ -499,3 +499,45 @@ def test_get_columns(db_engine, sample_table: str):
columns = inspector.get_columns(sample_table)

assert True


class TestCommentReflection:
@pytest.fixture(scope="class")
def engine(self):
HOST = os.environ.get("host")
HTTP_PATH = os.environ.get("http_path")
ACCESS_TOKEN = os.environ.get("access_token")
CATALOG = os.environ.get("catalog")
SCHEMA = os.environ.get("schema")

connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}"
connect_args = {"_user_agent_entry": USER_AGENT_TOKEN}

engine = create_engine(connection_string, connect_args=connect_args)
return engine

@pytest.fixture
def inspector(self, engine: Engine) -> Inspector:
return Inspector.from_engine(engine)

@pytest.fixture
def table(self, engine):
md = MetaData()
tbl = Table(
"foo",
md,
Column("bar", String, comment="column comment"),
comment="table comment",
)
md.create_all(bind=engine)

yield tbl

md.drop_all(bind=engine)

def test_table_comment_reflection(self, inspector: Inspector, table: Table):
tbl_name = table.name

comment = inspector.get_table_comment(tbl_name)

assert comment == {"text": "table comment"}
52 changes: 50 additions & 2 deletions src/databricks/sqlalchemy/test_local/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import pytest
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy.schema import CreateTable, DropColumnComment, SetColumnComment
from sqlalchemy.schema import (
CreateTable,
DropColumnComment,
DropTableComment,
SetColumnComment,
SetTableComment,
)


class TestTableCommentDDL:
class DDLTestBase:
engine = create_engine(
"databricks://token:****@****?http_path=****&catalog=****&schema=****"
)

def compile(self, stmt):
return str(stmt.compile(bind=self.engine))


class TestColumnCommentDDL(DDLTestBase):
@pytest.fixture
def metadata(self) -> MetaData:
"""Assemble a metadata object with one table containing one column."""
Expand Down Expand Up @@ -45,3 +53,43 @@ def test_alter_table_drop_column_comment(self, column):
stmt = DropColumnComment(column)
output = self.compile(stmt)
assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''"


class TestTableCommentDDL(DDLTestBase):
@pytest.fixture
def metadata(self) -> MetaData:
"""Assemble a metadata object with one table containing one column."""
metadata = MetaData()

col1 = Column("foo", String)
col2 = Column("foo", String)
tbl_w_comment = Table("martin", metadata, col1, comment="foobar")
tbl_wo_comment = Table("prs", metadata, col2)

return metadata

@pytest.fixture
def table_with_comment(self, metadata) -> Table:
return metadata.tables.get("martin")

@pytest.fixture
def table_without_comment(self, metadata) -> Table:
return metadata.tables.get("prs")

def test_create_table_with_comment(self, table_with_comment):
stmt = CreateTable(table_with_comment)
output = self.compile(stmt)
assert "USING DELTA COMMENT 'foobar'" in output

def test_alter_table_add_comment(self, table_without_comment: Table):
table_without_comment.comment = "wireless mechanical keyboard"
stmt = SetTableComment(table_without_comment)
output = self.compile(stmt)

assert output == "COMMENT ON TABLE prs IS 'wireless mechanical keyboard'"

def test_alter_table_drop_comment(self, table_with_comment):
"""The syntax for COMMENT ON is here: https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-comment.html"""
stmt = DropTableComment(table_with_comment)
output = self.compile(stmt)
assert output == "COMMENT ON TABLE martin IS NULL"
Loading

0 comments on commit f41a996

Please sign in to comment.