diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index e9fd9f2b..667d46da 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -16,7 +16,13 @@ def __init__(self, dialect): class DatabricksDDLCompiler(compiler.DDLCompiler): def post_create_table(self, table): - return " USING DELTA" + post = " USING DELTA" + if table.comment: + comment = self.sql_compiler.render_literal_value( + table.comment, sqltypes.String() + ) + post += " COMMENT " + comment + return post def visit_unique_constraint(self, constraint, **kw): logger.warning("Databricks does not support unique constraints") @@ -61,7 +67,7 @@ def get_column_specification(self, column, **kwargs): feature in the future, similar to the Microsoft SQL Server dialect. """ if column is column.table._autoincrement_column or column.autoincrement is True: - logger.warn( + logger.warning( "Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead." ) diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 8b9e8337..a80f37bb 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -251,6 +251,20 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis return output_rows +def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]: + """Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name` + value contains the match argument. + """ + + output_rows = [] + + for row_dict in dte_output: + if match in row_dict["col_name"]: + output_rows.append(row_dict) + + return output_rows + + def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]: """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries, one dictionary per defined constraint @@ -275,6 +289,15 @@ def get_pk_strings_from_dte_output( return output +def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]: + """Returns the value of the first "Comment" col_name data in dte_output""" + output = match_dte_rows_by_key(dte_output, "Comment") + if not output: + return None + else: + return output[0]["data_type"] + + # The keys of this dictionary are the values we expect to see in a # TGetColumnsRequest's .TYPE_NAME attribute. # These are enumerated in ttypes.py as class TTypeId. diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index e770f3e3..40af61fe 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -1,5 +1,4 @@ -import re -from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple +from typing import Any, List, Optional, Dict, Union import databricks.sqlalchemy._ddl as dialect_ddl_impl import databricks.sqlalchemy._types as dialect_type_impl @@ -11,19 +10,20 @@ build_pk_dict, get_fk_strings_from_dte_output, get_pk_strings_from_dte_output, + get_comment_from_dte_output, parse_column_info_from_tgetcolumnsresponse, ) import sqlalchemy from sqlalchemy import DDL, event from sqlalchemy.engine import Connection, Engine, default, reflection -from sqlalchemy.engine.reflection import ObjectKind from sqlalchemy.engine.interfaces import ( ReflectedForeignKeyConstraint, ReflectedPrimaryKeyConstraint, ReflectedColumn, - TableKey, + ReflectedTableComment, ) +from sqlalchemy.engine.reflection import ReflectionDefaults from sqlalchemy.exc import DatabaseError, SQLAlchemyError try: @@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs): views_result = self.get_view_names(connection=connection, schema=schema) # In Databricks, SHOW TABLES FROM returns both tables and views. - # Potential optimisation: rewrite this to instead query informtation_schema + # Potential optimisation: rewrite this to instead query information_schema tables_minus_views = [ row.tableName for row in tables_result if row.tableName not in views_result ] @@ -328,7 +328,7 @@ def get_materialized_view_names( def get_temp_view_names( self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: - """A wrapper around get_view_names taht fetches only the names of temporary views""" + """A wrapper around get_view_names that fetches only the names of temporary views""" return self.get_view_names(connection, schema, only_temp=True) def do_rollback(self, dbapi_connection): @@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw): schema_list = [row[0] for row in result] return schema_list + @reflection.cache + def get_table_comment( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedTableComment: + result = self._describe_table_extended( + connection=connection, + table_name=table_name, + schema_name=schema, + ) + + if result is None: + return ReflectionDefaults.table_comment() + + comment = get_comment_from_dte_output(result) + + if comment: + return dict(text=comment) + else: + return ReflectionDefaults.table_comment() + @event.listens_for(Engine, "do_connect") def receive_do_connect(dialect, conn_rec, cargs, cparams): diff --git a/src/databricks/sqlalchemy/requirements.py b/src/databricks/sqlalchemy/requirements.py index b68f6344..75227efb 100644 --- a/src/databricks/sqlalchemy/requirements.py +++ b/src/databricks/sqlalchemy/requirements.py @@ -159,6 +159,18 @@ def table_reflection(self): """target database has general support for table reflection""" return sqlalchemy.testing.exclusions.open() + @property + def comment_reflection(self): + """Indicates if the database support table comment reflection""" + return sqlalchemy.testing.exclusions.open() + + @property + def comment_reflection_full_unicode(self): + """Indicates if the database support table comment reflection in the + full unicode range, including emoji etc. + """ + return sqlalchemy.testing.exclusions.open() + @property def temp_table_reflection(self): """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. diff --git a/src/databricks/sqlalchemy/test/_future.py b/src/databricks/sqlalchemy/test/_future.py index cbd28575..6e470f60 100644 --- a/src/databricks/sqlalchemy/test/_future.py +++ b/src/databricks/sqlalchemy/test/_future.py @@ -13,9 +13,7 @@ ComponentReflectionTest, ComponentReflectionTestExtra, CTETest, - FutureTableDDLTest, InsertBehaviorTest, - TableDDLTest, ) from sqlalchemy.testing.suite import ( ArrayTest, @@ -53,7 +51,6 @@ class FutureFeature(Enum): PROVISION = "event-driven engine configuration" REGEXP = "_visit_regexp" SANE_ROWCOUNT = "sane_rowcount support" - TBL_COMMENTS = "table comment reflection" TBL_OPTS = "get_table_options method" TEST_DESIGN = "required test-fixture overrides" TUPLE_LITERAL = "tuple-like IN markers completely" @@ -251,36 +248,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT pass -class FutureTableDDLTest(FutureTableDDLTest): - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_add_table_comment(self): - """We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message""" - pass - - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_drop_table_comment(self): - """We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message""" - pass - - -class TableDDLTest(TableDDLTest): - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_add_table_comment(self, connection): - """We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message""" - pass - - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_drop_table_comment(self, connection): - """We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message""" - pass - - class ComponentReflectionTest(ComponentReflectionTest): - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_get_multi_table_comment(self): - """There are 84 permutations of this test that are skipped.""" - pass - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True)) def test_multi_get_table_options_tables(self): """It's not clear what the expected ouput from this method would even _be_. Requires research.""" @@ -302,22 +270,6 @@ def test_get_multi_pk_constraint(self): def test_get_multi_check_constraints(self): pass - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_get_comments(self): - pass - - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_get_comments_with_schema(self): - pass - - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_comments_unicode(self): - pass - - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS)) - def test_comments_unicode_full(self): - pass - class ComponentReflectionTestExtra(ComponentReflectionTestExtra): @pytest.mark.skip(render_future_feature(FutureFeature.CHECK)) diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index 0c47f3e7..eb490532 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -1,7 +1,7 @@ import datetime import decimal import os -from typing import Tuple, Union +from typing import Tuple, Union, List from unittest import skipIf import pytest @@ -219,7 +219,7 @@ def test_column_comment(db_engine, metadata_obj: MetaData): connection=connection, table_name=table_name ) - assert columns[0].get("comment") == "" + assert columns[0].get("comment") == None metadata_obj.drop_all(db_engine) @@ -477,7 +477,7 @@ def sample_table(metadata_obj: MetaData, db_engine: Engine): table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - args = [ + args: List[Column] = [ Column(colname, coltype) for colname, coltype in GET_COLUMNS_TYPE_MAP.items() ] @@ -499,3 +499,45 @@ def test_get_columns(db_engine, sample_table: str): columns = inspector.get_columns(sample_table) assert True + + +class TestCommentReflection: + @pytest.fixture(scope="class") + def engine(self): + HOST = os.environ.get("host") + HTTP_PATH = os.environ.get("http_path") + ACCESS_TOKEN = os.environ.get("access_token") + CATALOG = os.environ.get("catalog") + SCHEMA = os.environ.get("schema") + + connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" + connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} + + engine = create_engine(connection_string, connect_args=connect_args) + return engine + + @pytest.fixture + def inspector(self, engine: Engine) -> Inspector: + return Inspector.from_engine(engine) + + @pytest.fixture + def table(self, engine): + md = MetaData() + tbl = Table( + "foo", + md, + Column("bar", String, comment="column comment"), + comment="table comment", + ) + md.create_all(bind=engine) + + yield tbl + + md.drop_all(bind=engine) + + def test_table_comment_reflection(self, inspector: Inspector, table: Table): + tbl_name = table.name + + comment = inspector.get_table_comment(tbl_name) + + assert comment == {"text": "table comment"} diff --git a/src/databricks/sqlalchemy/test_local/test_ddl.py b/src/databricks/sqlalchemy/test_local/test_ddl.py index eb8e7083..a83ff244 100644 --- a/src/databricks/sqlalchemy/test_local/test_ddl.py +++ b/src/databricks/sqlalchemy/test_local/test_ddl.py @@ -1,9 +1,15 @@ import pytest from sqlalchemy import Column, MetaData, String, Table, create_engine -from sqlalchemy.schema import CreateTable, DropColumnComment, SetColumnComment +from sqlalchemy.schema import ( + CreateTable, + DropColumnComment, + DropTableComment, + SetColumnComment, + SetTableComment, +) -class TestTableCommentDDL: +class DDLTestBase: engine = create_engine( "databricks://token:****@****?http_path=****&catalog=****&schema=****" ) @@ -11,6 +17,8 @@ class TestTableCommentDDL: def compile(self, stmt): return str(stmt.compile(bind=self.engine)) + +class TestColumnCommentDDL(DDLTestBase): @pytest.fixture def metadata(self) -> MetaData: """Assemble a metadata object with one table containing one column.""" @@ -45,3 +53,43 @@ def test_alter_table_drop_column_comment(self, column): stmt = DropColumnComment(column) output = self.compile(stmt) assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''" + + +class TestTableCommentDDL(DDLTestBase): + @pytest.fixture + def metadata(self) -> MetaData: + """Assemble a metadata object with one table containing one column.""" + metadata = MetaData() + + col1 = Column("foo", String) + col2 = Column("foo", String) + tbl_w_comment = Table("martin", metadata, col1, comment="foobar") + tbl_wo_comment = Table("prs", metadata, col2) + + return metadata + + @pytest.fixture + def table_with_comment(self, metadata) -> Table: + return metadata.tables.get("martin") + + @pytest.fixture + def table_without_comment(self, metadata) -> Table: + return metadata.tables.get("prs") + + def test_create_table_with_comment(self, table_with_comment): + stmt = CreateTable(table_with_comment) + output = self.compile(stmt) + assert "USING DELTA COMMENT 'foobar'" in output + + def test_alter_table_add_comment(self, table_without_comment: Table): + table_without_comment.comment = "wireless mechanical keyboard" + stmt = SetTableComment(table_without_comment) + output = self.compile(stmt) + + assert output == "COMMENT ON TABLE prs IS 'wireless mechanical keyboard'" + + def test_alter_table_drop_comment(self, table_with_comment): + """The syntax for COMMENT ON is here: https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-comment.html""" + stmt = DropTableComment(table_with_comment) + output = self.compile(stmt) + assert output == "COMMENT ON TABLE martin IS NULL" diff --git a/src/databricks/sqlalchemy/test_local/test_parsing.py b/src/databricks/sqlalchemy/test_local/test_parsing.py index f17814f9..70e6337a 100644 --- a/src/databricks/sqlalchemy/test_local/test_parsing.py +++ b/src/databricks/sqlalchemy/test_local/test_parsing.py @@ -6,6 +6,7 @@ build_fk_dict, build_pk_dict, match_dte_rows_by_value, + get_comment_from_dte_output, DatabricksSqlAlchemyParseException, ) @@ -105,6 +106,7 @@ def test_build_pk_dict(): ["Type", "MANAGED"], ["Location", "s3://us-west-2-****-/19a85dee-****/tables/ccb7***"], ["Provider", "delta"], + ["Comment", "some comment"], ["Owner", "some.user@example.com"], ["Is_managed_location", "true"], ["Predictive Optimization", "ENABLE (inherited from CATALOG main)"], @@ -152,3 +154,7 @@ def test_build_pk_dict(): def test_filter_dict_by_value(match, output): result = match_dte_rows_by_value(FMT_SAMPLE_DT_OUTPUT, match) assert result == output + + +def test_get_comment_from_dte_output(): + assert get_comment_from_dte_output(FMT_SAMPLE_DT_OUTPUT) == "some comment"