diff --git a/CHANGELOG.md b/CHANGELOG.md index ba686648..7930ba81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,12 @@ ## 2.7.x (Unreleased) - Add support for Cloud Fetch (#146, #151, #154) +- SQLAlchemy has_table function now honours schema= argument and adds catalog= argument (#174) - Fix: Revised SQLAlchemy dialect and examples for compatibility with SQLAlchemy==1.3.x (#173) - Fix: oauth would fail if expired credentials appeared in ~/.netrc (#122) - Fix: Python HTTP proxies were broken after switch to urllib3 (#158) - Other: Connector now logs operation handle guids as hexadecimal instead of bytes (#170) +- Add support for Cloud Fetch ## 2.7.0 (2023-06-26) diff --git a/src/databricks/sqlalchemy/dialect/__init__.py b/src/databricks/sqlalchemy/dialect/__init__.py index 0f96c2bc..cfb7d857 100644 --- a/src/databricks/sqlalchemy/dialect/__init__.py +++ b/src/databricks/sqlalchemy/dialect/__init__.py @@ -267,17 +267,22 @@ def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions pass - def has_table(self, connection, table_name, schema=None, **kwargs) -> bool: + def has_table( + self, connection, table_name, schema=None, catalog=None, **kwargs + ) -> bool: """SQLAlchemy docstrings say dialect providers must implement this method""" - schema = schema or "default" + _schema = schema or self.schema + _catalog = catalog or self.catalog # DBR >12.x uses underscores in error messages DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found" DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND" try: - res = connection.execute(f"DESCRIBE TABLE {table_name}") + res = connection.execute( + f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}" + ) return True except DatabaseError as e: if DBR_GT_12_NOT_FOUND_STRING in str( diff --git a/tests/e2e/sqlalchemy/test_basic.py b/tests/e2e/sqlalchemy/test_basic.py index 89ceb07e..1d3125f2 100644 --- a/tests/e2e/sqlalchemy/test_basic.py +++ b/tests/e2e/sqlalchemy/test_basic.py @@ -340,3 +340,44 @@ def test_get_table_names_smoke_test(samples_engine: Engine): with samples_engine.connect() as conn: _names = samples_engine.table_names(schema="nyctaxi", connection=conn) _names is not None, "get_table_names did not succeed" + + +def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): + """For this test to pass these conditions must be met: + - Table samples.nyctaxi.trips must exist + - Table samples.tpch.customer must exist + - The `catalog` and `schema` environment variables must be set and valid + """ + + with samples_engine.connect() as conn: + + # 1) Check for table within schema declared at engine creation time + assert samples_engine.dialect.has_table(connection=conn, table_name="trips") + + # 2) Check for table within another schema in the same catalog + assert samples_engine.dialect.has_table( + connection=conn, table_name="customer", schema="tpch" + ) + + # 3) Check for a table within a different catalog + other_catalog = os.environ.get("catalog") + other_schema = os.environ.get("schema") + + # Create a table in a different catalog + with db_engine.connect() as conn: + conn.execute("CREATE TABLE test_has_table (numbers_are_cool INT);") + + try: + # Verify that this table is not found in the samples catalog + assert not samples_engine.dialect.has_table( + connection=conn, table_name="test_has_table" + ) + # Verify that this table is found in a separate catalog + assert samples_engine.dialect.has_table( + connection=conn, + table_name="test_has_table", + schema=other_schema, + catalog=other_catalog, + ) + finally: + conn.execute("DROP TABLE test_has_table;")