From 0ddca9df58fdb7ce547bd343289d2eb7179a4b14 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 8 Aug 2024 14:58:19 +0530 Subject: [PATCH] Refractored the test code and moved to respective folders --- .../sql/client.py | 3 - .../sql/thrift_backend.py | 15 +- .../sql/utils.py | 31 --- .../tests}/__init__.py | 0 .../tests}/e2e/__init__.py | 0 .../tests}/e2e/common/__init__.py | 0 .../tests}/e2e/common/core_tests.py | 0 .../tests}/e2e/common/decimal_tests.py | 0 .../tests}/e2e/common/large_queries_mixin.py | 0 .../tests}/e2e/common/predicates.py | 8 +- .../tests}/e2e/common/retry_test_mixins.py | 44 ++-- .../e2e/common/staging_ingestion_tests.py | 34 +-- .../tests}/e2e/common/timestamp_tests.py | 0 .../tests}/e2e/common/uc_volume_tests.py | 32 +-- .../tests}/e2e/test_complex_types.py | 0 .../tests}/e2e/test_driver.py | 36 ++-- .../tests}/e2e/test_parameterized_queries.py | 62 +++--- .../tests}/unit/__init__.py | 0 .../tests}/unit/test_arrow_queue.py | 2 +- .../tests}/unit/test_auth.py | 14 +- .../tests}/unit/test_client.py | 64 +++--- .../tests}/unit/test_cloud_fetch_queue.py | 34 +-- .../tests}/unit/test_download_manager.py | 6 +- .../tests}/unit/test_downloader.py | 6 +- .../tests}/unit/test_endpoint.py | 20 +- .../tests}/unit/test_fetches.py | 6 +- .../tests}/unit/test_fetches_bench.py | 6 +- .../tests}/unit/test_init_file.py | 0 .../tests}/unit/test_oauth_persistence.py | 4 +- .../tests}/unit/test_param_escaper.py | 112 +++++----- .../tests/unit/test_parameters.py | 204 ++++++++++++++++++ .../tests}/unit/test_retry.py | 4 +- .../tests}/unit/test_thrift_backend.py | 168 +++++++-------- examples/custom_cred_provider.py | 3 +- examples/insert_data.py | 4 +- examples/interactive_oauth.py | 7 +- examples/m2m_oauth.py | 4 +- examples/parameters.py | 8 +- examples/persistent_oauth.py | 34 +-- examples/query_cancel.py | 16 +- examples/query_execute.py | 2 +- examples/set_user_agent.py | 2 +- examples/staging_ingestion.py | 12 +- examples/v3_retries_query_execute.py | 6 +- setup_script.py | 4 +- tests/unit/test_parameters.py | 204 ------------------ 46 files changed, 588 insertions(+), 633 deletions(-) rename {tests => databricks_sql_connector_core/tests}/__init__.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/__init__.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/__init__.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/core_tests.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/decimal_tests.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/large_queries_mixin.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/predicates.py (91%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/retry_test_mixins.py (94%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/staging_ingestion_tests.py (91%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/timestamp_tests.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/common/uc_volume_tests.py (90%) rename {tests => databricks_sql_connector_core/tests}/e2e/test_complex_types.py (100%) rename {tests => databricks_sql_connector_core/tests}/e2e/test_driver.py (96%) rename {tests => databricks_sql_connector_core/tests}/e2e/test_parameterized_queries.py (92%) rename {tests => databricks_sql_connector_core/tests}/unit/__init__.py (100%) rename {tests => databricks_sql_connector_core/tests}/unit/test_arrow_queue.py (94%) rename {tests => databricks_sql_connector_core/tests}/unit/test_auth.py (91%) rename {tests => databricks_sql_connector_core/tests}/unit/test_client.py (90%) rename {tests => databricks_sql_connector_core/tests}/unit/test_cloud_fetch_queue.py (87%) rename {tests => databricks_sql_connector_core/tests}/unit/test_download_manager.py (92%) rename {tests => databricks_sql_connector_core/tests}/unit/test_downloader.py (97%) rename {tests => databricks_sql_connector_core/tests}/unit/test_endpoint.py (91%) rename {tests => databricks_sql_connector_core/tests}/unit/test_fetches.py (98%) rename {tests => databricks_sql_connector_core/tests}/unit/test_fetches_bench.py (93%) rename {tests => databricks_sql_connector_core/tests}/unit/test_init_file.py (100%) rename {tests => databricks_sql_connector_core/tests}/unit/test_oauth_persistence.py (91%) rename {tests => databricks_sql_connector_core/tests}/unit/test_param_escaper.py (62%) create mode 100644 databricks_sql_connector_core/tests/unit/test_parameters.py rename {tests => databricks_sql_connector_core/tests}/unit/test_retry.py (93%) rename {tests => databricks_sql_connector_core/tests}/unit/test_thrift_backend.py (90%) delete mode 100644 tests/unit/test_parameters.py diff --git a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py index 93675578..1cee29d0 100755 --- a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py +++ b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py @@ -784,8 +784,6 @@ def execute( parameters=prepared_params, ) - # print("Line 781") - # print(execute_response) self.active_result_set = ResultSet( self.connection, execute_response, @@ -1141,7 +1139,6 @@ def _fill_results_buffer(self): def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) - # print("Table\n",table) result = [] for row_index in range(len(table[0])): curr_row = [] diff --git a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py index c879de58..9c888a80 100644 --- a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py +++ b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py @@ -743,7 +743,6 @@ def _results_message_to_execute_response(self, resp, operation_state): else: t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle) - # print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}") if t_result_set_metadata_resp.resultFormat not in [ ttypes.TSparkRowSetType.ARROW_BASED_SET, ttypes.TSparkRowSetType.COLUMN_BASED_SET, @@ -880,18 +879,8 @@ def execute_command( # We want to receive proper Timestamp arrow types. "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" }, - # useArrowNativeTypes=spark_arrow_types, - # canReadArrowResult=True, - # # canDecompressLZ4Result=lz4_compression, - # canDecompressLZ4Result=False, - # canDownloadResult=False, - # # confOverlay={ - # # # We want to receive proper Timestamp arrow types. - # # "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" - # # }, - # resultDataFormat=TDBSqlResultFormat(None,None,True), - # # useArrowNativeTypes=spark_arrow_types, - parameters=parameters, + useArrowNativeTypes=spark_arrow_types, + parameters=parameters, ) resp = self.make_request(self._client.ExecuteStatement, req) return self._handle_execute_response(resp, cursor) diff --git a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py index c259fc37..5c490c73 100644 --- a/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py +++ b/databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py @@ -74,18 +74,6 @@ def build_queue( Returns: ResultSetQueue """ - - # def trow_to_json(trow): - # # Step 1: Serialize TRow using Thrift's TJSONProtocol - # transport = TTransport.TMemoryBuffer() - # protocol = TJSONProtocol.TJSONProtocol(transport) - # trow.write(protocol) - # - # # Step 2: Extract JSON string from the transport - # json_str = transport.getvalue().decode('utf-8') - # - # return json_str - if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes @@ -95,30 +83,11 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: - # print("Lin 79 ") - # print(type(t_row_set)) - # print(t_row_set) - # json_str = json.loads(trow_to_json(t_row_set)) - # pretty_json = json.dumps(json_str, indent=2) - # print(pretty_json) - converted_column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description) - # print(converted_column_table, column_names) return ColumnQueue(converted_column_table, column_names) - - # print(columnQueue.next_n_rows(2)) - # print(columnQueue.next_n_rows(2)) - # print(columnQueue.remaining_rows()) - # arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( - # t_row_set.columns, description - # ) - # converted_arrow_table = convert_decimals_in_arrow_table( - # arrow_table, description - # ) - # return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( schema_bytes=arrow_schema_bytes, diff --git a/tests/__init__.py b/databricks_sql_connector_core/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to databricks_sql_connector_core/tests/__init__.py diff --git a/tests/e2e/__init__.py b/databricks_sql_connector_core/tests/e2e/__init__.py similarity index 100% rename from tests/e2e/__init__.py rename to databricks_sql_connector_core/tests/e2e/__init__.py diff --git a/tests/e2e/common/__init__.py b/databricks_sql_connector_core/tests/e2e/common/__init__.py similarity index 100% rename from tests/e2e/common/__init__.py rename to databricks_sql_connector_core/tests/e2e/common/__init__.py diff --git a/tests/e2e/common/core_tests.py b/databricks_sql_connector_core/tests/e2e/common/core_tests.py similarity index 100% rename from tests/e2e/common/core_tests.py rename to databricks_sql_connector_core/tests/e2e/common/core_tests.py diff --git a/tests/e2e/common/decimal_tests.py b/databricks_sql_connector_core/tests/e2e/common/decimal_tests.py similarity index 100% rename from tests/e2e/common/decimal_tests.py rename to databricks_sql_connector_core/tests/e2e/common/decimal_tests.py diff --git a/tests/e2e/common/large_queries_mixin.py b/databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py similarity index 100% rename from tests/e2e/common/large_queries_mixin.py rename to databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py diff --git a/tests/e2e/common/predicates.py b/databricks_sql_connector_core/tests/e2e/common/predicates.py similarity index 91% rename from tests/e2e/common/predicates.py rename to databricks_sql_connector_core/tests/e2e/common/predicates.py index 4d01c5fe..10dbde9e 100644 --- a/tests/e2e/common/predicates.py +++ b/databricks_sql_connector_core/tests/e2e/common/predicates.py @@ -8,13 +8,13 @@ def pysql_supports_arrow(): - """Import databricks.sql and test whether Cursor has fetchall_arrow.""" - from databricks.sql import Cursor + """Import databricks_sql_connector_core.sql and test whether Cursor has fetchall_arrow.""" + from databricks_sql_connector_core.sql.client import Cursor return hasattr(Cursor, 'fetchall_arrow') def pysql_has_version(compare, version): - """Import databricks.sql, and return compare_module_version(...). + """Import databricks_sql_connector_core.sql, and return compare_module_version(...). Expected use: from common.predicates import pysql_has_version @@ -98,4 +98,4 @@ def validate_version(version): mod_version = validate_version(module.__version__) req_version = validate_version(version) - return compare_versions(compare, mod_version, req_version) + return compare_versions(compare, mod_version, req_version) \ No newline at end of file diff --git a/tests/e2e/common/retry_test_mixins.py b/databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py similarity index 94% rename from tests/e2e/common/retry_test_mixins.py rename to databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py index e7761397..6053e314 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py @@ -6,8 +6,8 @@ import pytest from urllib3.exceptions import MaxRetryError -from databricks.sql import DatabricksRetryPolicy -from databricks.sql import ( +from databricks_sql_connector_core.sql.auth.retry import DatabricksRetryPolicy +from databricks_sql_connector_core.sql.exc import ( MaxRetryDurationError, NonRecoverableNetworkError, RequestError, @@ -146,7 +146,7 @@ def test_retry_urllib3_settings_are_honored(self): def test_oserror_retries(self): """If a network error occurs during make_request, the request is retried according to policy""" with patch( - "urllib3.connectionpool.HTTPSConnectionPool._validate_conn", + "urllib3.connectionpool.HTTPSConnectionPool._validate_conn", ) as mock_validate_conn: mock_validate_conn.side_effect = OSError("Some arbitrary network error") with pytest.raises(MaxRetryError) as cm: @@ -275,7 +275,7 @@ def test_retry_safe_execute_statement_retry_condition(self): ] with self.connection( - extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1} + extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1} ) as conn: with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load @@ -318,9 +318,9 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", - new_callable=PropertyMock, - return_value=False, + "databricks_sql_connector_core.sql.utils.ExecuteResponse.has_been_closed_server_side", + new_callable=PropertyMock, + return_value=False, ): # This call guarantees we have an open cursor at the server curs.execute("SELECT 1") @@ -340,10 +340,10 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( - extra_params={ - **self._retry_policy, - "_retry_max_redirects": max_redirects, - } + extra_params={ + **self._retry_policy, + "_retry_max_redirects": max_redirects, + } ): pass assert "too many redirects" == str(cm.value.reason) @@ -362,9 +362,9 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( - extra_params={ - **self._retry_policy, - } + extra_params={ + **self._retry_policy, + } ): pass @@ -394,13 +394,13 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog): with self.connection( - extra_params={ - **self._retry_policy, - **{ - "_retry_max_redirects": 100, - "_retry_stop_after_attempts_count": 1, - }, - } + extra_params={ + **self._retry_policy, + **{ + "_retry_max_redirects": 100, + "_retry_stop_after_attempts_count": 1, + }, + } ): assert "it will have no affect!" in caplog.text @@ -433,4 +433,4 @@ def test_401_not_retried(self): with pytest.raises(RequestError) as cm: with self.connection(extra_params=self._retry_policy): pass - assert isinstance(cm.value.args[1], NonRecoverableNetworkError) + assert isinstance(cm.value.args[1], NonRecoverableNetworkError) \ No newline at end of file diff --git a/tests/e2e/common/staging_ingestion_tests.py b/databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py similarity index 91% rename from tests/e2e/common/staging_ingestion_tests.py rename to databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py index d8d0429f..f1034faa 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py @@ -2,8 +2,8 @@ import tempfile import pytest -import databricks.sql as sql -from databricks.sql import Error +import databricks_sql_connector_core.sql as sql +from databricks_sql_connector_core.sql import Error @pytest.fixture(scope="module", autouse=True) @@ -100,7 +100,7 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, in cursor.execute(query) def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path( - self, ingestion_user + self, ingestion_user ): fh, temp_path = tempfile.mkstemp() @@ -116,8 +116,8 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p base_path = os.path.join(base_path, "temp") with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() @@ -158,7 +158,7 @@ def perform_remove(): # Try to put it again with pytest.raises( - sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" + sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" ): perform_put() @@ -209,7 +209,7 @@ def perform_get(): perform_get() def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( - self, ingestion_user + self, ingestion_user ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. @@ -222,11 +222,11 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe target_file = "/var/www/html/../html1/not_allowed.html" with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -238,7 +238,7 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, inges with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -250,14 +250,14 @@ def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_ with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values( - self, ingestion_user + self, ingestion_user ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. @@ -286,7 +286,7 @@ def generate_file_and_path_and_queries(): fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() with self.connection( - extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} ) as conn: cursor = conn.cursor() @@ -294,11 +294,11 @@ def generate_file_and_path_and_queries(): cursor.execute(put_query2) with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): cursor.execute(put_query3) # Then clean up the files we made cursor.execute(remove_query1) - cursor.execute(remove_query2) + cursor.execute(remove_query2) \ No newline at end of file diff --git a/tests/e2e/common/timestamp_tests.py b/databricks_sql_connector_core/tests/e2e/common/timestamp_tests.py similarity index 100% rename from tests/e2e/common/timestamp_tests.py rename to databricks_sql_connector_core/tests/e2e/common/timestamp_tests.py diff --git a/tests/e2e/common/uc_volume_tests.py b/databricks_sql_connector_core/tests/e2e/common/uc_volume_tests.py similarity index 90% rename from tests/e2e/common/uc_volume_tests.py rename to databricks_sql_connector_core/tests/e2e/common/uc_volume_tests.py index 21e43036..88bc2972 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/databricks_sql_connector_core/tests/e2e/common/uc_volume_tests.py @@ -2,8 +2,8 @@ import tempfile import pytest -import databricks.sql as sql -from databricks.sql import Error +import databricks_sql_connector_core.sql as sql +from databricks_sql_connector_core.sql import Error @pytest.fixture(scope="module", autouse=True) @@ -99,7 +99,7 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, s cursor.execute(query) def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( - self, catalog, schema + self, catalog, schema ): fh, temp_path = tempfile.mkstemp() @@ -115,8 +115,8 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( base_path = os.path.join(base_path, "temp") with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() @@ -157,7 +157,7 @@ def perform_remove(): # Try to put it again with pytest.raises( - sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" + sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" ): perform_put() @@ -165,7 +165,7 @@ def perform_remove(): perform_remove() def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( - self, catalog, schema + self, catalog, schema ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. @@ -178,11 +178,11 @@ def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_ target_file = "/var/www/html/../html1/not_allowed.html" with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -194,7 +194,7 @@ def test_uc_volume_empty_local_path_fails_to_parse_at_server(self, catalog, sche with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -206,7 +206,7 @@ def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): with pytest.raises(Error, match="NOT_FOUND: Catalog"): with self.connection( - extra_params={"staging_allowed_local_path": staging_allowed_local_path} + extra_params={"staging_allowed_local_path": staging_allowed_local_path} ) as conn: cursor = conn.cursor() query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -240,7 +240,7 @@ def generate_file_and_path_and_queries(): fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() with self.connection( - extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} ) as conn: cursor = conn.cursor() @@ -248,11 +248,11 @@ def generate_file_and_path_and_queries(): cursor.execute(put_query2) with pytest.raises( - Error, - match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): cursor.execute(put_query3) # Then clean up the files we made cursor.execute(remove_query1) - cursor.execute(remove_query2) + cursor.execute(remove_query2) \ No newline at end of file diff --git a/tests/e2e/test_complex_types.py b/databricks_sql_connector_core/tests/e2e/test_complex_types.py similarity index 100% rename from tests/e2e/test_complex_types.py rename to databricks_sql_connector_core/tests/e2e/test_complex_types.py diff --git a/tests/e2e/test_driver.py b/databricks_sql_connector_core/tests/e2e/test_driver.py similarity index 96% rename from tests/e2e/test_driver.py rename to databricks_sql_connector_core/tests/e2e/test_driver.py index b49dfade..52764ad3 100644 --- a/tests/e2e/test_driver.py +++ b/databricks_sql_connector_core/tests/e2e/test_driver.py @@ -18,8 +18,8 @@ import pytest from urllib3.connectionpool import ReadTimeoutError -import databricks.sql as sql -from databricks.sql import ( +import databricks_sql_connector_core.sql as sql +from databricks_sql_connector_core.sql import ( STRING, BINARY, NUMBER, @@ -46,11 +46,11 @@ from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin -from databricks.sql import SessionAlreadyClosedError +from databricks_sql_connector_core.sql.exc import SessionAlreadyClosedError log = logging.getLogger(__name__) -unsafe_logger = logging.getLogger("databricks.sql.unsafe") +unsafe_logger = logging.getLogger("databricks_sql_connector_core.sql.unsafe") unsafe_logger.setLevel(logging.DEBUG) unsafe_logger.addHandler(logging.FileHandler("./tests-unsafe.log")) @@ -141,19 +141,19 @@ def test_cloud_fetch(self): # If this table is deleted or this test is run on a different host, a different table may need to be used. base_query = "SELECT * FROM store_sales WHERE ss_sold_date_sk = 2452234 " for num_limit, num_threads, lz4_compression in itertools.product( - limits, threads, [True, False] + limits, threads, [True, False] ): with self.subTest( - num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression + num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression ): cf_result, noop_result = None, None query = base_query + "LIMIT " + str(num_limit) with self.cursor( - { - "use_cloud_fetch": True, - "max_download_threads": num_threads, - "catalog": "hive_metastore", - }, + { + "use_cloud_fetch": True, + "max_download_threads": num_threads, + "catalog": "hive_metastore", + }, ) as cursor: cursor.execute(query) cf_result = cursor.fetchall() @@ -333,7 +333,7 @@ def test_get_columns(self): "col_3", 2002, "STRUCT", - ], + ], ["default", table_name + "_1", "col_4", 2000, "MAP"], ["default", table_name + "_1", "col_5", 2003, "ARRAY"], ["default", table_name + "_2", "col_1", 4, "INT"], @@ -344,7 +344,7 @@ def test_get_columns(self): "col_3", 2002, "STRUCT", - ], + ], ["default", table_name + "_2", "col_4", 2000, "MAP"], [ "default", @@ -352,7 +352,7 @@ def test_get_columns(self): "col_5", 2003, "ARRAY", - ], + ], ] assert cleaned_response == expected expected = [ @@ -608,7 +608,7 @@ def test_timestamps_arrow(self): # be UTC (what it should be by default on the server) aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) assert result_value == ( - aware_timestamp and aware_timestamp.timestamp() * 1000000 + aware_timestamp and aware_timestamp.timestamp() * 1000000 ), "timestamp {} did not match {}".format(timestamp, expected) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") @@ -696,7 +696,7 @@ def test_decimal_not_returned_as_strings_arrow(self): def test_close_connection_closes_cursors(self): - from databricks.sql import ttypes + from databricks_sql_connector_core.sql.thrift_api.TCLIService import ttypes with self.connection() as conn: cursor = conn.cursor() @@ -750,10 +750,10 @@ def test_initial_namespace(self): cursor.execute("USE CATALOG {}".format(self.arguments["catalog"])) cursor.execute("CREATE TABLE table_{} (col1 int)".format(table_name)) with self.connection( - {"catalog": self.arguments["catalog"], "schema": table_name} + {"catalog": self.arguments["catalog"], "schema": table_name} ) as connection: cursor = connection.cursor() cursor.execute("select current_catalog()") assert cursor.fetchone()[0] == self.arguments["catalog"] cursor.execute("select current_database()") - assert cursor.fetchone()[0] == table_name + assert cursor.fetchone()[0] == table_name \ No newline at end of file diff --git a/tests/e2e/test_parameterized_queries.py b/databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py similarity index 92% rename from tests/e2e/test_parameterized_queries.py rename to databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py index d63dc133..8bde1b95 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py @@ -8,7 +8,7 @@ import pytest import pytz -from databricks.sql import ( +from databricks_sql_connector_core.sql.parameters.native import ( BigIntegerParameter, BooleanParameter, DateParameter, @@ -147,8 +147,8 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru """Applies a patch so we can test the connector's behaviour under different SPARK_CLI_SERVICE_PROTOCOL_VERSION conditions.""" with patch( - "databricks.sql.client.Connection.server_parameterized_queries_enabled", - return_value=supports_native_params, + "databricks_sql_connector_core.sql.client.Connection.server_parameterized_queries_enabled", + return_value=supports_native_params, ) as mock_parameterized_queries_enabled: try: yield mock_parameterized_queries_enabled @@ -186,10 +186,10 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): return to_return def _native_roundtrip( - self, - parameters: Union[Dict, List[Dict]], - paramstyle: ParamStyle, - parameter_structure: ParameterStructure, + self, + parameters: Union[Dict, List[Dict]], + paramstyle: ParamStyle, + parameter_structure: ParameterStructure, ): if parameter_structure == ParameterStructure.POSITIONAL: _query = self.POSITIONAL_PARAMSTYLE_QUERY @@ -203,11 +203,11 @@ def _native_roundtrip( return cursor.fetchone() def _get_one_result( - self, - params, - approach: ParameterApproach = ParameterApproach.NONE, - paramstyle: ParamStyle = ParamStyle.NONE, - parameter_structure: ParameterStructure = ParameterStructure.NONE, + self, + params, + approach: ParameterApproach = ParameterApproach.NONE, + paramstyle: ParamStyle = ParamStyle.NONE, + parameter_structure: ParameterStructure = ParameterStructure.NONE, ): """When approach is INLINE then we use %(param)s paramstyle and a connection with use_inline_params=True When approach is NATIVE then we use :param paramstyle and a connection with use_inline_params=False @@ -243,12 +243,12 @@ def _eq(self, actual, expected: Primitive): "approach,paramstyle,parameter_structure", approach_paramstyle_combinations ) def test_primitive_single( - self, - approach, - paramstyle, - parameter_structure, - primitive: Primitive, - inline_table, + self, + approach, + paramstyle, + parameter_structure, + primitive: Primitive, + inline_table, ): """When ParameterApproach.INLINE is passed, inferrence will not be used. When ParameterApproach.NATIVE is passed, primitive inputs will be inferred. @@ -285,10 +285,10 @@ def test_primitive_single( ], ) def test_dbsqlparameter_single( - self, - primitive: Primitive, - dbsql_parameter_cls: Type[TDbsqlParameter], - parameter_structure: ParameterStructure, + self, + primitive: Primitive, + dbsql_parameter_cls: Type[TDbsqlParameter], + parameter_structure: ParameterStructure, ): dbsql_param = dbsql_parameter_cls( value=primitive.value, # type: ignore @@ -316,11 +316,11 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: assert ( - "Consider using native parameters." in caplog.text + "Consider using native parameters." in caplog.text ), "Log message should be suppressed" elif use_inline_params == "silent": assert ( - "Consider using native parameters." not in caplog.text + "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" def test_positional_native_params_with_defaults(self): @@ -333,12 +333,12 @@ def test_positional_native_params_with_defaults(self): @pytest.mark.parametrize( "params", ( - [ - StringParameter(value="foo"), - StringParameter(value="bar"), - StringParameter(value="baz"), - ], - ["foo", "bar", "baz"], + [ + StringParameter(value="foo"), + StringParameter(value="bar"), + StringParameter(value="baz"), + ], + ["foo", "bar", "baz"], ), ) def test_positional_native_multiple(self, params): @@ -450,4 +450,4 @@ def test_native_like_wildcard_works(self): with self.cursor(extra_params={"use_inline_params": False}) as cursor: result = cursor.execute(query, parameters=params).fetchone() - assert result.col == 1 + assert result.col == 1 \ No newline at end of file diff --git a/tests/unit/__init__.py b/databricks_sql_connector_core/tests/unit/__init__.py similarity index 100% rename from tests/unit/__init__.py rename to databricks_sql_connector_core/tests/unit/__init__.py diff --git a/tests/unit/test_arrow_queue.py b/databricks_sql_connector_core/tests/unit/test_arrow_queue.py similarity index 94% rename from tests/unit/test_arrow_queue.py rename to databricks_sql_connector_core/tests/unit/test_arrow_queue.py index a6ef417a..93e2ed2b 100644 --- a/tests/unit/test_arrow_queue.py +++ b/databricks_sql_connector_core/tests/unit/test_arrow_queue.py @@ -2,7 +2,7 @@ import pyarrow as pa -from databricks.sql import ArrowQueue +from databricks_sql_connector_core.sql.utils import ArrowQueue class ArrowQueueSuite(unittest.TestCase): diff --git a/tests/unit/test_auth.py b/databricks_sql_connector_core/tests/unit/test_auth.py similarity index 91% rename from tests/unit/test_auth.py rename to databricks_sql_connector_core/tests/unit/test_auth.py index d88fe286..9e2b60b7 100644 --- a/tests/unit/test_auth.py +++ b/databricks_sql_connector_core/tests/unit/test_auth.py @@ -3,22 +3,22 @@ from typing import Optional from unittest.mock import patch -from databricks.sql import ( +from databricks_sql_connector_core.sql.auth.auth import ( AccessTokenAuthProvider, AuthProvider, ExternalAuthProvider, AuthType, ) -from databricks.sql import get_python_sql_connector_auth_provider -from databricks.sql import OAuthManager -from databricks.sql import DatabricksOAuthProvider -from databricks.sql import ( +from databricks_sql_connector_core.sql.auth.auth import get_python_sql_connector_auth_provider, PYSQL_OAUTH_CLIENT_ID +from databricks_sql_connector_core.sql.auth.oauth import OAuthManager +from databricks_sql_connector_core.sql.auth.authenticators import DatabricksOAuthProvider +from databricks_sql_connector_core.sql.auth.endpoint import ( CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection, ) -from databricks.sql import CredentialsProvider, HeaderFactory -from databricks.sql import OAuthPersistenceCache +from databricks_sql_connector_core.sql.auth.authenticators import CredentialsProvider, HeaderFactory +from databricks_sql_connector_core.sql.experimental.oauth_persistence import OAuthPersistenceCache class Auth(unittest.TestCase): diff --git a/tests/unit/test_client.py b/databricks_sql_connector_core/tests/unit/test_client.py similarity index 90% rename from tests/unit/test_client.py rename to databricks_sql_connector_core/tests/unit/test_client.py index 5d1a2b04..a3cdc7b7 100644 --- a/tests/unit/test_client.py +++ b/databricks_sql_connector_core/tests/unit/test_client.py @@ -8,19 +8,19 @@ from datetime import datetime, date from uuid import UUID -from databricks.sql import ( +from databricks_sql_connector_core.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TExecuteStatementResp, TOperationHandle, THandleIdentifier, TOperationType ) -from databricks.sql import ThriftBackend +from databricks_sql_connector_core.sql.thrift_backend import ThriftBackend -import databricks.sql -import databricks.sql.client as client -from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError -from databricks.sql import Row +import databricks_sql_connector_core.sql +import databricks_sql_connector_core.sql.client as client +from databricks_sql_connector_core.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks_sql_connector_core.sql.types import Row from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -77,7 +77,7 @@ class ClientTestSuite(unittest.TestCase): Unit tests for isolated client behaviour. """ - PACKAGE_NAME = "databricks.sql" + PACKAGE_NAME = "databricks_sql_connector_core.sql" DUMMY_CONNECTION_ARGS = { "server_hostname": "foo", "http_path": "dummy_path", @@ -92,7 +92,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): mock_open_session_resp.sessionHandle.sessionId = b'\x22' instance.open_session.return_value = mock_open_session_resp - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.close() # Check the close session request has an id of x22 @@ -120,7 +120,7 @@ def test_auth_args(self, mock_client_class): ] for args in connection_args: - connection = databricks.sql.connect(**args) + connection = databricks_sql_connector_core.sql.connect(**args) host, port, http_path, *_ = mock_client_class.call_args[0] self.assertEqual(args["server_hostname"], host) self.assertEqual(args["http_path"], http_path) @@ -129,14 +129,14 @@ def test_auth_args(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( + databricks_sql_connector_core.sql.connect( **self.DUMMY_CONNECTION_ARGS, _tls_verify_hostname="hostname", _tls_trusted_ca_file="trusted ca file", @@ -152,16 +152,16 @@ def test_tls_arg_passthrough(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) http_headers = mock_client_class.call_args[0][3] - user_agent_header = ("User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, - databricks.sql.__version__)) + user_agent_header = ("User-Agent", "{}/{}".format(databricks_sql_connector_core.sql.USER_AGENT_NAME, + databricks_sql_connector_core.sql.__version__)) self.assertIn(user_agent_header, http_headers) - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, _user_agent_entry="foobar") + databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS, _user_agent_entry="foobar") user_agent_header_with_entry = ("User-Agent", "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar")) + databricks_sql_connector_core.sql.USER_AGENT_NAME, databricks_sql_connector_core.sql.__version__, "foobar")) http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) @@ -172,7 +172,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): for closed in (True, False): with self.subTest(closed=closed): mock_result_set_class.return_value = Mock() - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() cursor.execute("SELECT 1;") connection.close() @@ -182,7 +182,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) connection.close() self.assertFalse(connection.open) @@ -193,7 +193,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough(self, mock_cursor_class, mock_client_class): - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.cursor(arraysize=999, buffer_size_bytes=1234) kwargs = mock_cursor_class.call_args[1] @@ -275,7 +275,7 @@ def test_context_manager_closes_connection(self, mock_client_class): mock_open_session_resp.sessionHandle.sessionId = b'\x22' instance.open_session.return_value = mock_open_session_resp - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + with databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: pass # Check the close session request has an id of x22 @@ -363,7 +363,7 @@ def test_cancel_command_calls_the_backend(self): cursor.cancel() mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) - @patch("databricks.sql.client.logger") + @patch("databricks_sql_connector_core.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self, logger_instance): mock_thrift_backend = Mock() @@ -375,17 +375,17 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect(_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS) + databricks_sql_connector_core.sql.connect(_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + databricks_sql_connector_core.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) def test_version_is_canonical(self): - version = databricks.sql.__version__ + version = databricks_sql_connector_core.sql.__version__ canonical_version_re = r'^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)' \ r'(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$' self.assertIsNotNone(re.match(canonical_version_re, version)) @@ -393,7 +393,7 @@ def test_version_is_canonical(self): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() - databricks.sql.connect( + databricks_sql_connector_core.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.return_value.open_session.call_args[0][0], @@ -404,7 +404,7 @@ def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem) + databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem) self.assertEqual(mock_client_class.return_value.open_session.call_args[0][1], mock_cat) self.assertEqual(mock_client_class.return_value.open_session.call_args[0][2], mock_schem) @@ -463,7 +463,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + c = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() def test_setinputsizes_a_noop(self): @@ -476,7 +476,7 @@ def test_setoutputsizes_a_noop(self): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + c = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): c.rollback() @@ -562,7 +562,7 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): mock_open_session_resp.sessionHandle.sessionId = b'\x22' instance.open_session.return_value = mock_open_session_resp - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) # not strictly necessary as the refcount is 0, but just to be sure gc.collect() @@ -579,7 +579,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): mock_open_session_resp.sessionHandle.sessionId = b'\x22' instance.open_session.return_value = mock_open_session_resp - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() del connection @@ -599,7 +599,7 @@ def test_staging_operation_response_is_handled(self, mock_client_class, mock_han mock_client_class.execute_command.return_value = mock_execute_response mock_client_class.return_value = mock_client_class - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() cursor.execute("Text of some staging operation command;") connection.close() @@ -610,7 +610,7 @@ def test_staging_operation_response_is_handled(self, mock_client_class, mock_han def test_access_current_query_id(self): operation_id = 'EE6A8778-21FC-438B-92D8-96AC51EE3821' - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection = databricks_sql_connector_core.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() self.assertIsNone(cursor.query_id) diff --git a/tests/unit/test_cloud_fetch_queue.py b/databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py similarity index 87% rename from tests/unit/test_cloud_fetch_queue.py rename to databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py index 6cddff35..abc687d2 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock, patch from ssl import create_default_context -from databricks.sql import TSparkArrowResultLink -import databricks.sql.utils as utils +from databricks_sql_connector_core.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +import databricks_sql_connector_core.sql.utils as utils class CloudFetchQueueSuite(unittest.TestCase): @@ -43,7 +43,7 @@ def get_schema_bytes(): return sink.getvalue().to_pybytes() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) @@ -72,7 +72,7 @@ def test_initializer_no_links_to_add(self): assert len(queue.download_manager._download_tasks) == 0 assert queue.table is None - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) + @patch("databricks_sql_connector_core.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): queue = utils.CloudFetchQueue( MagicMock(), @@ -84,8 +84,8 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) - @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + @patch("databricks_sql_connector_core.sql.utils.create_arrow_table_from_arrow_file") + @patch("databricks_sql_connector_core.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): mock_create_arrow_table.return_value = self.make_arrow_table() @@ -111,7 +111,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi assert table.num_rows == 4 assert queue.start_row_index == 8 - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() @@ -131,7 +131,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): assert queue.table_row_index == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() @@ -151,7 +151,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == self.make_arrow_table()[:3] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() @@ -171,7 +171,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() @@ -190,7 +190,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() @@ -207,7 +207,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): mock_create_next_table.assert_called() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() @@ -226,7 +226,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) assert result.num_rows == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() @@ -245,7 +245,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl assert result.num_rows == 2 assert result == self.make_arrow_table()[2:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() @@ -264,7 +264,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() @@ -284,7 +284,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta assert result.num_rows == 5 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + @patch("databricks_sql_connector_core.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() @@ -298,4 +298,4 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): assert queue.table is None result = queue.remaining_rows() - assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() + assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() \ No newline at end of file diff --git a/tests/unit/test_download_manager.py b/databricks_sql_connector_core/tests/unit/test_download_manager.py similarity index 92% rename from tests/unit/test_download_manager.py rename to databricks_sql_connector_core/tests/unit/test_download_manager.py index 4d5d0ea9..c44cf126 100644 --- a/tests/unit/test_download_manager.py +++ b/databricks_sql_connector_core/tests/unit/test_download_manager.py @@ -3,8 +3,8 @@ from ssl import create_default_context -import databricks.sql.cloudfetch.download_manager as download_manager -from databricks.sql import TSparkArrowResultLink +import databricks_sql_connector_core.sql.cloudfetch.download_manager as download_manager +from databricks_sql_connector_core.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink class DownloadManagerTests(unittest.TestCase): @@ -61,4 +61,4 @@ def test_schedule_downloads(self, mock_submit): manager._schedule_downloads() assert mock_submit.call_count == max_download_threads assert len(manager._pending_links) == len(links) - max_download_threads - assert len(manager._download_tasks) == max_download_threads + assert len(manager._download_tasks) == max_download_threads \ No newline at end of file diff --git a/tests/unit/test_downloader.py b/databricks_sql_connector_core/tests/unit/test_downloader.py similarity index 97% rename from tests/unit/test_downloader.py rename to databricks_sql_connector_core/tests/unit/test_downloader.py index f2974b7c..93f834b3 100644 --- a/tests/unit/test_downloader.py +++ b/databricks_sql_connector_core/tests/unit/test_downloader.py @@ -4,8 +4,8 @@ import requests from ssl import create_default_context -import databricks.sql.cloudfetch.downloader as downloader -from databricks.sql import Error +import databricks_sql_connector_core.sql.cloudfetch.downloader as downloader +from databricks_sql_connector_core.sql.exc import Error def create_response(**kwargs) -> requests.Response: @@ -116,4 +116,4 @@ def test_download_timeout(self, mock_time, mock_session): d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(TimeoutError): - d.run() + d.run() \ No newline at end of file diff --git a/tests/unit/test_endpoint.py b/databricks_sql_connector_core/tests/unit/test_endpoint.py similarity index 91% rename from tests/unit/test_endpoint.py rename to databricks_sql_connector_core/tests/unit/test_endpoint.py index 199318d7..68fab22c 100644 --- a/tests/unit/test_endpoint.py +++ b/databricks_sql_connector_core/tests/unit/test_endpoint.py @@ -4,8 +4,8 @@ from unittest.mock import patch -from databricks.sql import AuthType -from databricks.sql import ( +from databricks_sql_connector_core.sql.auth.auth import AuthType +from databricks_sql_connector_core.sql.auth.endpoint import ( infer_cloud_from_host, CloudType, get_oauth_endpoints, @@ -89,13 +89,13 @@ def test_oauth_endpoint(self): ] for ( - cloud_type, - host, - use_azure_auth, - expected_auth_url, - expected_config_url, - expected_scopes, - expected_scope2, + cloud_type, + host, + use_azure_auth, + expected_auth_url, + expected_config_url, + expected_scopes, + expected_scope2, ) in param_list: with self.subTest(cloud_type): endpoint = get_oauth_endpoints(host, use_azure_auth) @@ -121,4 +121,4 @@ def test_azure_oauth_scope_mappings_from_different_tenant_id(self): "052ee82f-b79d-443c-8682-3ec1749e56b0/user_impersonation", "offline_access", ], - ) + ) \ No newline at end of file diff --git a/tests/unit/test_fetches.py b/databricks_sql_connector_core/tests/unit/test_fetches.py similarity index 98% rename from tests/unit/test_fetches.py rename to databricks_sql_connector_core/tests/unit/test_fetches.py index 35e13900..81ab6868 100644 --- a/tests/unit/test_fetches.py +++ b/databricks_sql_connector_core/tests/unit/test_fetches.py @@ -3,8 +3,8 @@ import pyarrow as pa -import databricks.sql.client as client -from databricks.sql import ExecuteResponse, ArrowQueue +import databricks_sql_connector_core.sql.client as client +from databricks_sql_connector_core.sql.utils import ExecuteResponse, ArrowQueue class FetchTests(unittest.TestCase): @@ -211,4 +211,4 @@ def test_fetchone_without_initial_results(self): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/unit/test_fetches_bench.py b/databricks_sql_connector_core/tests/unit/test_fetches_bench.py similarity index 93% rename from tests/unit/test_fetches_bench.py rename to databricks_sql_connector_core/tests/unit/test_fetches_bench.py index 2335df3f..8072df76 100644 --- a/tests/unit/test_fetches_bench.py +++ b/databricks_sql_connector_core/tests/unit/test_fetches_bench.py @@ -6,8 +6,8 @@ import time import pytest -import databricks.sql.client as client -from databricks.sql import ExecuteResponse, ArrowQueue +import databricks_sql_connector_core.sql.client as client +from databricks_sql_connector_core.sql.utils import ExecuteResponse, ArrowQueue class FetchBenchmarkTests(unittest.TestCase): @@ -60,4 +60,4 @@ def test_benchmark_fetchall(self): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/unit/test_init_file.py b/databricks_sql_connector_core/tests/unit/test_init_file.py similarity index 100% rename from tests/unit/test_init_file.py rename to databricks_sql_connector_core/tests/unit/test_init_file.py diff --git a/tests/unit/test_oauth_persistence.py b/databricks_sql_connector_core/tests/unit/test_oauth_persistence.py similarity index 91% rename from tests/unit/test_oauth_persistence.py rename to databricks_sql_connector_core/tests/unit/test_oauth_persistence.py index 812c918a..bcfb8499 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/databricks_sql_connector_core/tests/unit/test_oauth_persistence.py @@ -1,7 +1,7 @@ import unittest -from databricks.sql import DevOnlyFilePersistence, OAuthToken +from databricks_sql_connector_core.sql.experimental.oauth_persistence import DevOnlyFilePersistence, OAuthToken import tempfile import os @@ -29,4 +29,4 @@ def test_DevOnlyFilePersistence_file_does_not_exist(self): self.assertEqual(new_token, None) - # TODO moderakh add test for file with invalid format (should return None) + # TODO moderakh add test for file with invalid format (should return None) \ No newline at end of file diff --git a/tests/unit/test_param_escaper.py b/databricks_sql_connector_core/tests/unit/test_param_escaper.py similarity index 62% rename from tests/unit/test_param_escaper.py rename to databricks_sql_connector_core/tests/unit/test_param_escaper.py index 0f37bf92..817c4ea1 100644 --- a/tests/unit/test_param_escaper.py +++ b/databricks_sql_connector_core/tests/unit/test_param_escaper.py @@ -1,9 +1,9 @@ from datetime import date, datetime import unittest, pytest, decimal from typing import Any, Dict -from databricks.sql import dbsql_parameter_from_primitive +from databricks_sql_connector_core.sql.parameters.native import dbsql_parameter_from_primitive -from databricks.sql import ParamEscaper, inject_parameters, transform_paramstyle, ParameterStructure +from databricks_sql_connector_core.sql.utils import ParamEscaper, inject_parameters, transform_paramstyle, ParameterStructure pe = ParamEscaper() @@ -42,48 +42,48 @@ def test_escape_string_that_includes_special_characters(self): # Testing for the presence of these characters: '"/\😂 assert ( - pe.escape_string("his name was 'robert palmer'") - == r"'his name was \'robert palmer\''" + pe.escape_string("his name was 'robert palmer'") + == r"'his name was \'robert palmer\''" ) # These tests represent the same user input in the several ways it can be written in Python # Each argument to `escape_string` evaluates to the same bytes. But Python lets us write it differently. assert ( - pe.escape_string('his name was "robert palmer"') - == "'his name was \"robert palmer\"'" + pe.escape_string('his name was "robert palmer"') + == "'his name was \"robert palmer\"'" ) assert ( - pe.escape_string('his name was "robert palmer"') - == "'his name was \"robert palmer\"'" + pe.escape_string('his name was "robert palmer"') + == "'his name was \"robert palmer\"'" ) assert ( - pe.escape_string("his name was {}".format('"robert palmer"')) - == "'his name was \"robert palmer\"'" + pe.escape_string("his name was {}".format('"robert palmer"')) + == "'his name was \"robert palmer\"'" ) assert ( - pe.escape_string("his name was robert / palmer") - == r"'his name was robert / palmer'" + pe.escape_string("his name was robert / palmer") + == r"'his name was robert / palmer'" ) # If you need to include a single backslash, use an r-string to prevent Python from raising a # DeprecationWarning for an invalid escape sequence assert ( - pe.escape_string("his name was robert \\/ palmer") - == r"'his name was robert \\/ palmer'" + pe.escape_string("his name was robert \\/ palmer") + == r"'his name was robert \\/ palmer'" ) assert ( - pe.escape_string("his name was robert \\ palmer") - == r"'his name was robert \\ palmer'" + pe.escape_string("his name was robert \\ palmer") + == r"'his name was robert \\ palmer'" ) assert ( - pe.escape_string("his name was robert \\\\ palmer") - == r"'his name was robert \\\\ palmer'" + pe.escape_string("his name was robert \\\\ palmer") + == r"'his name was robert \\\\ palmer'" ) assert ( - pe.escape_string("his name was robert palmer 😂") - == r"'his name was robert palmer 😂'" + pe.escape_string("his name was robert palmer 😂") + == r"'his name was robert palmer 😂'" ) # Adding the test from PR #56 to prove escape behaviour @@ -122,8 +122,8 @@ def test_escape_sequence_float(self): def test_escape_sequence_string(self): assert ( - pe.escape_sequence(["his", "name", "was", "robert", "palmer"]) - == "('his','name','was','robert','palmer')" + pe.escape_sequence(["his", "name", "was", "robert", "palmer"]) + == "('his','name','was','robert','palmer')" ) def test_escape_sequence_sequence_of_strings(self): @@ -182,44 +182,44 @@ class TestInlineToNativeTransformer(object): @pytest.mark.parametrize( ("label", "query", "params", "expected"), ( - ("no effect", "SELECT 1", {}, "SELECT 1"), - ("one marker", "%(param)s", {"param": ""}, ":param"), - ( - "multiple markers", - "%(foo)s %(bar)s %(baz)s", - {"foo": None, "bar": None, "baz": None}, - ":foo :bar :baz", - ), - ( - "sql query", - "SELECT * FROM table WHERE field = %(param)s AND other_field IN (%(list)s)", - {"param": None, "list": None}, - "SELECT * FROM table WHERE field = :param AND other_field IN (:list)", - ), - ( - "query with like wildcard", - 'select * from table where field like "%"', - {}, - 'select * from table where field like "%"' - ), - ( - "query with named param and like wildcard", - 'select :param from table where field like "%"', - {"param": None}, - 'select :param from table where field like "%"' - ), - ( - "query with doubled wildcards", - 'select 1 where '' like "%%"', - {"param": None}, - 'select 1 where '' like "%%"', - ) + ("no effect", "SELECT 1", {}, "SELECT 1"), + ("one marker", "%(param)s", {"param": ""}, ":param"), + ( + "multiple markers", + "%(foo)s %(bar)s %(baz)s", + {"foo": None, "bar": None, "baz": None}, + ":foo :bar :baz", + ), + ( + "sql query", + "SELECT * FROM table WHERE field = %(param)s AND other_field IN (%(list)s)", + {"param": None, "list": None}, + "SELECT * FROM table WHERE field = :param AND other_field IN (:list)", + ), + ( + "query with like wildcard", + 'select * from table where field like "%"', + {}, + 'select * from table where field like "%"' + ), + ( + "query with named param and like wildcard", + 'select :param from table where field like "%"', + {"param": None}, + 'select :param from table where field like "%"' + ), + ( + "query with doubled wildcards", + 'select 1 where '' like "%%"', + {"param": None}, + 'select 1 where '' like "%%"', + ) ), ) def test_transformer( - self, label: str, query: str, params: Dict[str, Any], expected: str + self, label: str, query: str, params: Dict[str, Any], expected: str ): _params = [dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items()] output = transform_paramstyle(query, _params, param_structure=ParameterStructure.NAMED) - assert output == expected + assert output == expected \ No newline at end of file diff --git a/databricks_sql_connector_core/tests/unit/test_parameters.py b/databricks_sql_connector_core/tests/unit/test_parameters.py new file mode 100644 index 00000000..6b108695 --- /dev/null +++ b/databricks_sql_connector_core/tests/unit/test_parameters.py @@ -0,0 +1,204 @@ +import datetime +from decimal import Decimal +from enum import Enum +from typing import Type + +import pytest +import pytz + +from databricks_sql_connector_core.sql.client import Connection +from databricks_sql_connector_core.sql.parameters import ( + BigIntegerParameter, + BooleanParameter, + DateParameter, + DecimalParameter, + DoubleParameter, + FloatParameter, + IntegerParameter, + SmallIntParameter, + StringParameter, + TimestampNTZParameter, + TimestampParameter, + TinyIntParameter, + VoidParameter, +) +from databricks_sql_connector_core.sql.parameters.native import ( + TDbsqlParameter, + TSparkParameterValue, + dbsql_parameter_from_primitive, +) +from databricks_sql_connector_core.sql.thrift_api.TCLIService import ttypes +from databricks_sql_connector_core.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + TSparkParameterValue, +) + + +class TestSessionHandleChecks(object): + @pytest.mark.parametrize( + "test_input,expected", + [ + ( + TOpenSessionResp( + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + sessionHandle=TSessionHandle(1, None), + ), + ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + ), + # Ensure that protocol version inside sessionhandle takes precedence. + ( + TOpenSessionResp( + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + sessionHandle=TSessionHandle( + 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ), + ), + ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, + ), + ], + ) + def test_get_protocol_version_fallback_behavior(self, test_input, expected): + assert Connection.get_protocol_version(test_input) == expected + + @pytest.mark.parametrize( + "test_input,expected", + [ + ( + None, + False, + ), + ( + ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, + False, + ), + ( + ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, + True, + ), + ], + ) + def test_parameters_enabled(self, test_input, expected): + assert Connection.server_parameterized_queries_enabled(test_input) == expected + + +@pytest.mark.parametrize( + "value,expected", + ( + (Decimal("10.00"), "DECIMAL(4,2)"), + (Decimal("123456789123456789.123456789123456789"), "DECIMAL(36,18)"), + (Decimal(".12345678912345678912345678912345678912"), "DECIMAL(38,38)"), + (Decimal("123456789.123456789"), "DECIMAL(18,9)"), + (Decimal("12345678912345678912345678912345678912"), "DECIMAL(38,0)"), + (Decimal("1234.56"), "DECIMAL(6,2)"), + ), +) +def test_calculate_decimal_cast_string(value, expected): + p = DecimalParameter(value) + assert p._cast_expr() == expected + + +class Primitive(Enum): + """These are the inferrable types. This Enum is used for parametrized tests.""" + + NONE = None + BOOL = True + INT = 50 + BIGINT = 2147483648 + STRING = "Hello" + DECIMAL = Decimal("1234.56") + DATE = datetime.date(2023, 9, 6) + TIMESTAMP = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC) + DOUBLE = 3.14 + FLOAT = 3.15 + SMALLINT = 51 + + +class TestDbsqlParameter: + @pytest.mark.parametrize( + "_type, prim, expect_cast_expr", + ( + (DecimalParameter, Primitive.DECIMAL, "DECIMAL(6,2)"), + (IntegerParameter, Primitive.INT, "INT"), + (StringParameter, Primitive.STRING, "STRING"), + (BigIntegerParameter, Primitive.BIGINT, "BIGINT"), + (BooleanParameter, Primitive.BOOL, "BOOLEAN"), + (DateParameter, Primitive.DATE, "DATE"), + (DoubleParameter, Primitive.DOUBLE, "DOUBLE"), + (FloatParameter, Primitive.FLOAT, "FLOAT"), + (VoidParameter, Primitive.NONE, "VOID"), + (SmallIntParameter, Primitive.INT, "SMALLINT"), + (TimestampParameter, Primitive.TIMESTAMP, "TIMESTAMP"), + (TimestampNTZParameter, Primitive.TIMESTAMP, "TIMESTAMP_NTZ"), + (TinyIntParameter, Primitive.INT, "TINYINT"), + ), + ) + def test_cast_expression( + self, _type: TDbsqlParameter, prim: Primitive, expect_cast_expr: str + ): + p = _type(prim.value) + assert p._cast_expr() == expect_cast_expr + + @pytest.mark.parametrize( + "t, prim", + ( + (DecimalParameter, Primitive.DECIMAL), + (IntegerParameter, Primitive.INT), + (StringParameter, Primitive.STRING), + (BigIntegerParameter, Primitive.BIGINT), + (BooleanParameter, Primitive.BOOL), + (DateParameter, Primitive.DATE), + (DoubleParameter, Primitive.DOUBLE), + (FloatParameter, Primitive.FLOAT), + (VoidParameter, Primitive.NONE), + (SmallIntParameter, Primitive.INT), + (TimestampParameter, Primitive.TIMESTAMP), + (TimestampNTZParameter, Primitive.TIMESTAMP), + (TinyIntParameter, Primitive.INT), + ), + ) + def test_tspark_param_value(self, t: TDbsqlParameter, prim): + p: TDbsqlParameter = t(prim.value) + output = p._tspark_param_value() + + if prim == Primitive.NONE: + assert output == None + else: + assert output == TSparkParameterValue(stringValue=str(prim.value)) + + def test_tspark_param_named(self): + p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p") + tsp = p.as_tspark_param(named=True) + + assert tsp.name == "p" + assert tsp.ordinal is False + + def test_tspark_param_ordinal(self): + p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p") + tsp = p.as_tspark_param(named=False) + + assert tsp.name is None + assert tsp.ordinal is True + + @pytest.mark.parametrize( + "_type, prim", + ( + (DecimalParameter, Primitive.DECIMAL), + (IntegerParameter, Primitive.INT), + (StringParameter, Primitive.STRING), + (BigIntegerParameter, Primitive.BIGINT), + (BooleanParameter, Primitive.BOOL), + (DateParameter, Primitive.DATE), + (FloatParameter, Primitive.FLOAT), + (VoidParameter, Primitive.NONE), + (TimestampParameter, Primitive.TIMESTAMP), + ), + ) + def test_inference(self, _type: TDbsqlParameter, prim: Primitive): + """This method only tests inferrable types. + + Not tested are TinyIntParameter, SmallIntParameter DoubleParameter and TimestampNTZParameter + """ + + inferred_type = dbsql_parameter_from_primitive(prim.value) + assert isinstance(inferred_type, _type) \ No newline at end of file diff --git a/tests/unit/test_retry.py b/databricks_sql_connector_core/tests/unit/test_retry.py similarity index 93% rename from tests/unit/test_retry.py rename to databricks_sql_connector_core/tests/unit/test_retry.py index 7cbc957a..b80c2270 100644 --- a/tests/unit/test_retry.py +++ b/databricks_sql_connector_core/tests/unit/test_retry.py @@ -4,7 +4,7 @@ import pytest from requests import Request from urllib3 import HTTPResponse -from databricks.sql import DatabricksRetryPolicy, RequestHistory +from databricks_sql_connector_core.sql.auth.retry import DatabricksRetryPolicy, RequestHistory class TestRetry: @@ -52,4 +52,4 @@ def test_sleep__retry_after_surpassed(self, t_mock, retry_policy, error_history) retry_policy._retry_start_time = time.time() retry_policy.history = [error_history, error_history, error_history] retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "3"})) - t_mock.assert_called_with(4) + t_mock.assert_called_with(4) \ No newline at end of file diff --git a/tests/unit/test_thrift_backend.py b/databricks_sql_connector_core/tests/unit/test_thrift_backend.py similarity index 90% rename from tests/unit/test_thrift_backend.py rename to databricks_sql_connector_core/tests/unit/test_thrift_backend.py index 3c2b0953..aa72c4ad 100644 --- a/tests/unit/test_thrift_backend.py +++ b/databricks_sql_connector_core/tests/unit/test_thrift_backend.py @@ -7,12 +7,12 @@ import pyarrow -import databricks.sql -from databricks.sql import utils -from databricks.sql import ttypes -from databricks.sql import * -from databricks.sql import AuthProvider -from databricks.sql import ThriftBackend +import databricks_sql_connector_core.sql +from databricks_sql_connector_core.sql import utils +from databricks_sql_connector_core.sql.thrift_api.TCLIService import ttypes +from databricks_sql_connector_core.sql import * +from databricks_sql_connector_core.sql.auth.authenticators import AuthProvider +from databricks_sql_connector_core.sql.thrift_backend import ThriftBackend def retry_policy_factory(): @@ -108,7 +108,7 @@ def test_hive_schema_to_arrow_schema_preserves_column_names(self): self.assertEqual(arrow_schema.field(2).name, "column 2") self.assertEqual(arrow_schema.field(3).name, "") - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value bad_protocol_versions = [ @@ -136,7 +136,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): self.assertIn("expected server to use a protocol version", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value good_protocol_versions = [ @@ -153,14 +153,14 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) def test_proxy_headers_are_set(self): - from databricks.sql import THttpClient + from databricks_sql_connector_core.sql.auth.thrift_http_client import THttpClient from urllib.parse import urlparse fake_proxy_spec = "https://someuser:somepassword@8.8.8.8:12340" @@ -174,8 +174,8 @@ def test_proxy_headers_are_set(self): assert isinstance(result, type(dict())) assert isinstance(result.get('proxy-authorization'), type(str())) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.thrift_backend.create_default_context") def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): mock_cert_key_file = Mock() mock_cert_key_password = Mock() @@ -203,8 +203,8 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.thrift_backend.create_default_context") def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_no_verify=True) @@ -213,10 +213,10 @@ def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_cl self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.thrift_backend.create_default_context") def test_tls_verify_hostname_is_respected( - self, mock_create_default_context, t_http_client_class + self, mock_create_default_context, t_http_client_class ): ThriftBackend( "foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False @@ -227,28 +227,28 @@ def test_tls_verify_hostname_is_respected( self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend( "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129 @@ -395,7 +395,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch("databricks_sql_connector_core.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): for resp_type in self.execute_response_types: lz4Compressed = Mock() @@ -422,7 +422,7 @@ def test_handle_execute_response_sets_compression_in_direct_results(self, build_ execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -437,7 +437,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv ) for op_state_resp, exec_resp_type in itertools.product( - [error_resp, closed_resp], self.execute_response_types + [error_resp, closed_resp], self.execute_response_types ): with self.subTest(op_state_resp=op_state_resp, exec_resp_type=exec_resp_type): tcli_service_instance = tcli_service_class.return_value @@ -457,7 +457,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv if op_state_resp.errorMessage: self.assertIn(op_state_resp.errorMessage, str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -484,7 +484,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_direct_results_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -569,7 +569,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -645,7 +645,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ttypes.TOperationState.FINISHED_STATE, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value arrow_schema_mock = MagicMock(name="Arrow schema mock") @@ -670,7 +670,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value hive_schema_mock = MagicMock(name="Hive schema mock") @@ -696,13 +696,13 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0] ) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( - self, tcli_service_class, build_queue + self, tcli_service_class, build_queue ): for has_more_rows, resp_type in itertools.product( - [True, False], self.execute_response_types + [True, False], self.execute_response_types ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value @@ -734,13 +734,13 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, execute_response.has_more_rows) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( - self, tcli_service_class, build_queue + self, tcli_service_class, build_queue ): for has_more_rows, resp_type in itertools.product( - [True, False], self.execute_response_types + [True, False], self.execute_response_types ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value @@ -786,7 +786,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( self.assertEqual(has_more_rows, has_more_rows_resp) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue tcli_service_instance = tcli_service_class.return_value @@ -831,7 +831,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -849,7 +849,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s # Check response handling thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -866,7 +866,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic # Check response handling thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -892,7 +892,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service # Check response handling thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -922,7 +922,7 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ # Check response handling thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -952,7 +952,7 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service # Check response handling thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -961,7 +961,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -971,7 +971,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): self.operation_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -980,7 +980,7 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() @@ -1016,10 +1016,10 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") - @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") + @patch("databricks_sql_connector_core.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") + @patch("databricks_sql_connector_core.sql.thrift_backend.convert_column_based_set_to_arrow_table") def test_create_arrow_table_calls_correct_conversion_method( - self, convert_col_mock, convert_arrow_mock + self, convert_col_mock, convert_arrow_mock ): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) convert_arrow_mock.return_value = (MagicMock(), Mock()) @@ -1164,7 +1164,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -1189,17 +1189,17 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks_sql_connector_core.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus( - self, mock_retry_policy, mock_GetOperationStatus, t_transport_class + self, mock_retry_policy, mock_GetOperationStatus, t_transport_class ): import thrift, errno - from databricks.sql import Client - from databricks.sql import RequestError - from databricks.sql import NoRetryReason + from databricks_sql_connector_core.sql.thrift_api.TCLIService.TCLIService import Client + from databricks_sql_connector_core.sql.exc import RequestError + from databricks_sql_connector_core.sql.utils import NoRetryReason this_gos_name = "GetOperationStatus" mock_GetOperationStatus.__name__ = this_gos_name @@ -1236,7 +1236,7 @@ def test_make_request_will_retry_GetOperationStatus( # Unusual OSError code mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") - with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: + with self.assertLogs("databricks_sql_connector_core.sql.thrift_backend", level=logging.WARNING) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1252,10 +1252,10 @@ def test_make_request_will_retry_GetOperationStatus( cm.output[0], ) - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch("databricks_sql_connector_core.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks_sql_connector_core.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus_for_http_error( - self, mock_retry_policy, mock_gos + self, mock_retry_policy, mock_gos ): import urllib3.exceptions @@ -1263,10 +1263,10 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( mock_gos.side_effect = urllib3.exceptions.HTTPError("Read timed out") import thrift, errno - from databricks.sql import Client - from databricks.sql import RequestError - from databricks.sql import NoRetryReason - from databricks.sql import THttpClient + from databricks_sql_connector_core.sql.thrift_api.TCLIService.TCLIService import Client + from databricks_sql_connector_core.sql.exc import RequestError + from databricks_sql_connector_core.sql.utils import NoRetryReason + from databricks_sql_connector_core.sql.auth.thrift_http_client import THttpClient this_gos_name = "GetOperationStatus" mock_gos.__name__ = this_gos_name @@ -1315,10 +1315,10 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ self.assertIn("This method fails", str(cm.exception.message_with_context())) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( - self, mock_retry_policy, t_transport_class + self, mock_retry_policy, t_transport_class ): t_transport_instance = t_transport_class.return_value t_transport_instance.code = 429 @@ -1346,7 +1346,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self.assertEqual(mock_method.call_count, 14) - @patch("databricks.sql.auth.thrift_http_client.THttpClient") + @patch("databricks_sql_connector_core.sql.auth.thrift_http_client.THttpClient") def test_make_request_will_read_error_message_headers_if_set(self, t_transport_class): t_transport_instance = t_transport_class.return_value mock_method = Mock() @@ -1383,7 +1383,7 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c @staticmethod def make_table_and_desc( - height, n_decimal_cols, width, precision, scale, int_constant, decimal_constant + height, n_decimal_cols, width, precision, scale, int_constant, decimal_constant ): int_col = [int_constant for _ in range(height)] decimal_col = [decimal_constant for _ in range(height)] @@ -1462,7 +1462,7 @@ def test_retry_args_passthrough(self, mock_http_client): @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): + for k, (_, _, min, max) in databricks_sql_connector_core.sql.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ((min - 1, min), (max + 1, max)) for i in range(2): @@ -1478,7 +1478,7 @@ def test_retry_args_bounding(self, mock_http_client): for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_configuration_passthrough(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1496,14 +1496,14 @@ def test_configuration_passthrough(self, tcli_client_class): open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.configuration, expected_config) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) - with self.assertRaises(databricks.sql.Error) as cm: + with self.assertRaises(databricks_sql_connector_core.sql.Error) as cm: backend.open_session(mock_config, None, None) self.assertIn("timestampAsString cannot be changed", str(cm.exception)) @@ -1516,7 +1516,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -1535,7 +1535,7 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1546,7 +1546,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -1577,7 +1577,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcl ) backend.open_session({}, cat, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -1597,15 +1597,15 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): "Setting initial namespace not supported by the DBR version", str(cm.exception) ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") + @patch("databricks_sql_connector_core.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks_sql_connector_core.sql.thrift_backend.ThriftBackend._handle_execute_response") def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( - [True, False, None], [True, False, None], [True, False, None] + [True, False, None], [True, False, None], [True, False, None] ): complex_arg_types = {} if complex is not None: @@ -1637,4 +1637,4 @@ def test_execute_command_sets_complex_type_fields_correctly( if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/examples/custom_cred_provider.py b/examples/custom_cred_provider.py index 4c43280f..e78eb8b6 100644 --- a/examples/custom_cred_provider.py +++ b/examples/custom_cred_provider.py @@ -1,6 +1,7 @@ + # please pip install databricks-sdk prior to running this example. -from databricks import sql +from databricks_sql_connector_core import sql from databricks.sdk.oauth import OAuthClient import os diff --git a/examples/insert_data.py b/examples/insert_data.py index b304a0e9..8777a846 100644 --- a/examples/insert_data.py +++ b/examples/insert_data.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), @@ -18,4 +18,4 @@ result = cursor.fetchall() for row in result: - print(row) + print(row) \ No newline at end of file diff --git a/examples/interactive_oauth.py b/examples/interactive_oauth.py index d7c59597..33b4740d 100644 --- a/examples/interactive_oauth.py +++ b/examples/interactive_oauth.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os """databricks-sql-connector supports user to machine OAuth login which means the @@ -14,8 +14,7 @@ """ with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - auth_type="databricks-oauth") as connection: + http_path = os.getenv("DATABRICKS_HTTP_PATH")) as connection: for x in range(1, 100): cursor = connection.cursor() @@ -25,4 +24,4 @@ print(row) cursor.close() - connection.close() + connection.close() \ No newline at end of file diff --git a/examples/m2m_oauth.py b/examples/m2m_oauth.py index eba2095c..723e2261 100644 --- a/examples/m2m_oauth.py +++ b/examples/m2m_oauth.py @@ -1,7 +1,7 @@ import os from databricks.sdk.core import oauth_service_principal, Config -from databricks import sql +from databricks_sql_connector_core import sql """ This example shows how to use OAuth M2M (machine-to-machine) for service principal @@ -38,4 +38,4 @@ def credential_provider(): print(row) cursor.close() - connection.close() + connection.close() \ No newline at end of file diff --git a/examples/parameters.py b/examples/parameters.py index c0367e1b..b4384b76 100644 --- a/examples/parameters.py +++ b/examples/parameters.py @@ -3,11 +3,11 @@ """ from decimal import Decimal -from databricks import sql -from databricks.sql import * +from databricks_sql_connector_core import sql +from databricks_sql_connector_core.sql.parameters import * import os -from databricks import sql +from databricks_sql_connector_core import sql from datetime import datetime import pytz @@ -118,4 +118,4 @@ print("\nEXAMPLE 4") print("Example 4 inferred result\t→\t {}\t{}".format(result.p1, result.p3)) -print("Example 4 explicit result\t→\t {}\t\t{}".format(result.p2, result.p4)) +print("Example 4 explicit result\t→\t {}\t\t{}".format(result.p2, result.p4)) \ No newline at end of file diff --git a/examples/persistent_oauth.py b/examples/persistent_oauth.py index 22a10def..5af419c7 100644 --- a/examples/persistent_oauth.py +++ b/examples/persistent_oauth.py @@ -16,29 +16,29 @@ import os from typing import Optional -from databricks import sql -from databricks.sql import OAuthPersistence, OAuthToken, DevOnlyFilePersistence +from databricks_sql_connector_core import sql +from databricks_sql_connector_core.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken, DevOnlyFilePersistence class SampleOAuthPersistence(OAuthPersistence): - def persist(self, hostname: str, oauth_token: OAuthToken): - """To be implemented by the end user to persist in the preferred storage medium. + def persist(self, hostname: str, oauth_token: OAuthToken): + """To be implemented by the end user to persist in the preferred storage medium. - OAuthToken has two properties: - 1. OAuthToken.access_token - 2. OAuthToken.refresh_token + OAuthToken has two properties: + 1. OAuthToken.access_token + 2. OAuthToken.refresh_token - Both should be persisted. - """ - pass + Both should be persisted. + """ + pass - def read(self, hostname: str) -> Optional[OAuthToken]: - """To be implemented by the end user to fetch token from the preferred storage + def read(self, hostname: str) -> Optional[OAuthToken]: + """To be implemented by the end user to fetch token from the preferred storage - Fetch the access_token and refresh_token for the given hostname. - Return OAuthToken(access_token, refresh_token) - """ - pass + Fetch the access_token and refresh_token for the given hostname. + Return OAuthToken(access_token, refresh_token) + """ + pass with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path = os.getenv("DATABRICKS_HTTP_PATH"), @@ -53,4 +53,4 @@ def read(self, hostname: str) -> Optional[OAuthToken]: print(row) cursor.close() - connection.close() + connection.close() \ No newline at end of file diff --git a/examples/query_cancel.py b/examples/query_cancel.py index 4e0b74a5..c954cec0 100644 --- a/examples/query_cancel.py +++ b/examples/query_cancel.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os, threading, time """ @@ -11,12 +11,12 @@ with connection.cursor() as cursor: def execute_really_long_query(): - try: - cursor.execute("SELECT SUM(A.id - B.id) " + - "FROM range(1000000000) A CROSS JOIN range(100000000) B " + - "GROUP BY (A.id - B.id)") - except sql.exc.RequestError: - print("It looks like this query was cancelled.") + try: + cursor.execute("SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)") + except sql.exc.RequestError: + print("It looks like this query was cancelled.") exec_thread = threading.Thread(target=execute_really_long_query) @@ -48,4 +48,4 @@ def execute_really_long_query(): print("\n Execution was successful. Results appear below:") - print(cursor.fetchall()) + print(cursor.fetchall()) \ No newline at end of file diff --git a/examples/query_execute.py b/examples/query_execute.py index a851ab50..8ca5b7d0 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), diff --git a/examples/set_user_agent.py b/examples/set_user_agent.py index 449692cf..c25de4e9 100644 --- a/examples/set_user_agent.py +++ b/examples/set_user_agent.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), diff --git a/examples/staging_ingestion.py b/examples/staging_ingestion.py index a55be477..6af657a1 100644 --- a/examples/staging_ingestion.py +++ b/examples/staging_ingestion.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os """ @@ -51,10 +51,10 @@ staging_allowed_local_path = os.path.split(_complete_path)[0] with sql.connect( - server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path=os.getenv("DATABRICKS_HTTP_PATH"), - access_token=os.getenv("DATABRICKS_TOKEN"), - staging_allowed_local_path=staging_allowed_local_path, + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + staging_allowed_local_path=staging_allowed_local_path, ) as connection: with connection.cursor() as cursor: @@ -84,4 +84,4 @@ print("Removing demo.csv from staging location") cursor.execute(query) - print("Remove was successful") + print("Remove was successful") \ No newline at end of file diff --git a/examples/v3_retries_query_execute.py b/examples/v3_retries_query_execute.py index 4b6772fe..e742583b 100644 --- a/examples/v3_retries_query_execute.py +++ b/examples/v3_retries_query_execute.py @@ -1,4 +1,4 @@ -from databricks import sql +from databricks_sql_connector_core import sql import os # Users of connector versions >= 2.9.0 and <= 3.0.0 can use the v3 retry behaviour by setting _enable_v3_retries=True @@ -26,7 +26,7 @@ # which means all redirects will be followed. In this case, a redirect will count toward the # _retry_stop_after_attempts_count which means that by default the connector will not enter an endless retry loop. # -# For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend +# For complete information about configuring retries, see the docstring for databricks_sql_connector_core.sql.thrift_backend.ThriftBackend with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path = os.getenv("DATABRICKS_HTTP_PATH"), @@ -40,4 +40,4 @@ result = cursor.fetchall() for row in result: - print(row) + print(row) \ No newline at end of file diff --git a/setup_script.py b/setup_script.py index 798a5c50..14929888 100644 --- a/setup_script.py +++ b/setup_script.py @@ -25,6 +25,6 @@ def build_and_install_library(directory_name): if __name__ == "__main__": - # build_and_install_library("databricks_sql_connector_core") + build_and_install_library("databricks_sql_connector_core") build_and_install_library("databricks_sql_connector") - # build_and_install_library("databricks_sqlalchemy") \ No newline at end of file + build_and_install_library("databricks_sqlalchemy") \ No newline at end of file diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py deleted file mode 100644 index 9050f3ee..00000000 --- a/tests/unit/test_parameters.py +++ /dev/null @@ -1,204 +0,0 @@ -import datetime -from decimal import Decimal -from enum import Enum -from typing import Type - -import pytest -import pytz - -from databricks.sql import Connection -from databricks.sql import ( - BigIntegerParameter, - BooleanParameter, - DateParameter, - DecimalParameter, - DoubleParameter, - FloatParameter, - IntegerParameter, - SmallIntParameter, - StringParameter, - TimestampNTZParameter, - TimestampParameter, - TinyIntParameter, - VoidParameter, -) -from databricks.sql import ( - TDbsqlParameter, - TSparkParameterValue, - dbsql_parameter_from_primitive, -) -from databricks.sql import ttypes -from databricks.sql import ( - TOpenSessionResp, - TSessionHandle, - TSparkParameterValue, -) - - -class TestSessionHandleChecks(object): - @pytest.mark.parametrize( - "test_input,expected", - [ - ( - TOpenSessionResp( - serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle(1, None), - ), - ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - ), - # Ensure that protocol version inside sessionhandle takes precedence. - ( - TOpenSessionResp( - serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle( - 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ), - ), - ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, - ), - ], - ) - def test_get_protocol_version_fallback_behavior(self, test_input, expected): - assert Connection.get_protocol_version(test_input) == expected - - @pytest.mark.parametrize( - "test_input,expected", - [ - ( - None, - False, - ), - ( - ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - False, - ), - ( - ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, - True, - ), - ], - ) - def test_parameters_enabled(self, test_input, expected): - assert Connection.server_parameterized_queries_enabled(test_input) == expected - - -@pytest.mark.parametrize( - "value,expected", - ( - (Decimal("10.00"), "DECIMAL(4,2)"), - (Decimal("123456789123456789.123456789123456789"), "DECIMAL(36,18)"), - (Decimal(".12345678912345678912345678912345678912"), "DECIMAL(38,38)"), - (Decimal("123456789.123456789"), "DECIMAL(18,9)"), - (Decimal("12345678912345678912345678912345678912"), "DECIMAL(38,0)"), - (Decimal("1234.56"), "DECIMAL(6,2)"), - ), -) -def test_calculate_decimal_cast_string(value, expected): - p = DecimalParameter(value) - assert p._cast_expr() == expected - - -class Primitive(Enum): - """These are the inferrable types. This Enum is used for parametrized tests.""" - - NONE = None - BOOL = True - INT = 50 - BIGINT = 2147483648 - STRING = "Hello" - DECIMAL = Decimal("1234.56") - DATE = datetime.date(2023, 9, 6) - TIMESTAMP = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC) - DOUBLE = 3.14 - FLOAT = 3.15 - SMALLINT = 51 - - -class TestDbsqlParameter: - @pytest.mark.parametrize( - "_type, prim, expect_cast_expr", - ( - (DecimalParameter, Primitive.DECIMAL, "DECIMAL(6,2)"), - (IntegerParameter, Primitive.INT, "INT"), - (StringParameter, Primitive.STRING, "STRING"), - (BigIntegerParameter, Primitive.BIGINT, "BIGINT"), - (BooleanParameter, Primitive.BOOL, "BOOLEAN"), - (DateParameter, Primitive.DATE, "DATE"), - (DoubleParameter, Primitive.DOUBLE, "DOUBLE"), - (FloatParameter, Primitive.FLOAT, "FLOAT"), - (VoidParameter, Primitive.NONE, "VOID"), - (SmallIntParameter, Primitive.INT, "SMALLINT"), - (TimestampParameter, Primitive.TIMESTAMP, "TIMESTAMP"), - (TimestampNTZParameter, Primitive.TIMESTAMP, "TIMESTAMP_NTZ"), - (TinyIntParameter, Primitive.INT, "TINYINT"), - ), - ) - def test_cast_expression( - self, _type: TDbsqlParameter, prim: Primitive, expect_cast_expr: str - ): - p = _type(prim.value) - assert p._cast_expr() == expect_cast_expr - - @pytest.mark.parametrize( - "t, prim", - ( - (DecimalParameter, Primitive.DECIMAL), - (IntegerParameter, Primitive.INT), - (StringParameter, Primitive.STRING), - (BigIntegerParameter, Primitive.BIGINT), - (BooleanParameter, Primitive.BOOL), - (DateParameter, Primitive.DATE), - (DoubleParameter, Primitive.DOUBLE), - (FloatParameter, Primitive.FLOAT), - (VoidParameter, Primitive.NONE), - (SmallIntParameter, Primitive.INT), - (TimestampParameter, Primitive.TIMESTAMP), - (TimestampNTZParameter, Primitive.TIMESTAMP), - (TinyIntParameter, Primitive.INT), - ), - ) - def test_tspark_param_value(self, t: TDbsqlParameter, prim): - p: TDbsqlParameter = t(prim.value) - output = p._tspark_param_value() - - if prim == Primitive.NONE: - assert output == None - else: - assert output == TSparkParameterValue(stringValue=str(prim.value)) - - def test_tspark_param_named(self): - p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p") - tsp = p.as_tspark_param(named=True) - - assert tsp.name == "p" - assert tsp.ordinal is False - - def test_tspark_param_ordinal(self): - p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p") - tsp = p.as_tspark_param(named=False) - - assert tsp.name is None - assert tsp.ordinal is True - - @pytest.mark.parametrize( - "_type, prim", - ( - (DecimalParameter, Primitive.DECIMAL), - (IntegerParameter, Primitive.INT), - (StringParameter, Primitive.STRING), - (BigIntegerParameter, Primitive.BIGINT), - (BooleanParameter, Primitive.BOOL), - (DateParameter, Primitive.DATE), - (FloatParameter, Primitive.FLOAT), - (VoidParameter, Primitive.NONE), - (TimestampParameter, Primitive.TIMESTAMP), - ), - ) - def test_inference(self, _type: TDbsqlParameter, prim: Primitive): - """This method only tests inferrable types. - - Not tested are TinyIntParameter, SmallIntParameter DoubleParameter and TimestampNTZParameter - """ - - inferred_type = dbsql_parameter_from_primitive(prim.value) - assert isinstance(inferred_type, _type)