From 7fade51097d943c7f0fbe2b43f762bf837656f49 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:16:12 -0700 Subject: [PATCH] Fix cookie setting (#379) * fix cookie setting Signed-off-by: Ben Cassell * Removing cookie code Signed-off-by: Ben Cassell --------- Signed-off-by: Ben Cassell --- src/databricks/sql/auth/thrift_http_client.py | 4 - tests/unit/test_thrift_backend.py | 684 +++++++++++------- 2 files changed, 424 insertions(+), 264 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 11589258..0a862aaf 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -189,10 +189,6 @@ def flush(self): self.message = self.__resp.reason self.headers = self.__resp.headers - # Saves the cookie sent by the server response - if "Set-Cookie" in self.headers: - self.setCustomHeaders(dict("Cookie", self.headers["Set-Cookie"])) - @staticmethod def basic_proxy_auth_header(proxy): if proxy is None or not proxy.username: diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 92c664a0..ee4fc4b7 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -16,12 +16,12 @@ def retry_policy_factory(): - return { # (type, default, min, max) - "_retry_delay_min": (float, 1, None, None), - "_retry_delay_max": (float, 60, None, None), - "_retry_stop_after_attempts_count": (int, 30, None, None), - "_retry_stop_after_attempts_duration": (float, 900, None, None), - "_retry_delay_default": (float, 5, 1, 60) + return { # (type, default, min, max) + "_retry_delay_min": (float, 1, None, None), + "_retry_delay_max": (float, 60, None, None), + "_retry_stop_after_attempts_count": (int, 30, None, None), + "_retry_stop_after_attempts_duration": (float, 900, None, None), + "_retry_delay_default": (float, 5, 1, 60), } @@ -35,14 +35,17 @@ class ThriftBackendTestSuite(unittest.TestCase): operation_handle = ttypes.TOperationHandle( operationId=ttypes.THandleIdentifier(guid=0x33, secret=0x35), - operationType=ttypes.TOperationType.EXECUTE_STATEMENT) + operationType=ttypes.TOperationType.EXECUTE_STATEMENT, + ) session_handle = ttypes.TSessionHandle( - sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37)) + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37) + ) open_session_resp = ttypes.TOpenSessionResp( status=okay_status, - serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4) + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, + ) metadata_resp = ttypes.TGetResultSetMetadataResp( status=okay_status, @@ -51,8 +54,11 @@ class ThriftBackendTestSuite(unittest.TestCase): ) execute_response_types = [ - ttypes.TExecuteStatementResp, ttypes.TGetCatalogsResp, ttypes.TGetSchemasResp, - ttypes.TGetTablesResp, ttypes.TGetColumnsResp + ttypes.TExecuteStatementResp, + ttypes.TGetCatalogsResp, + ttypes.TGetSchemasResp, + ttypes.TGetTablesResp, + ttypes.TGetColumnsResp, ] def test_make_request_checks_thrift_status_code(self): @@ -66,7 +72,9 @@ def test_make_request_checks_thrift_status_code(self): thrift_backend.make_request(mock_method, Mock()) def _make_type_desc(self, type): - return ttypes.TTypeDesc(types=[ttypes.TTypeEntry(ttypes.TTAllowedParameterValueEntry(type=type))]) + return ttypes.TTypeDesc( + types=[ttypes.TTypeEntry(ttypes.TTAllowedParameterValueEntry(type=type))] + ) def _make_fake_thrift_backend(self): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -79,13 +87,17 @@ def _make_fake_thrift_backend(self): def test_hive_schema_to_arrow_schema_preserves_column_names(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)), + columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)), + columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)), + columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ttypes.TColumnDesc( - columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)) + columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ] t_table_schema = ttypes.TTableSchema(columns) @@ -115,7 +127,8 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): for protocol_version in bad_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version) + status=self.okay_status, serverProtocolVersion=protocol_version + ) with self.assertRaises(OperationalError) as cm: thrift_backend = self._make_fake_thrift_backend() @@ -129,12 +142,13 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): good_protocol_versions = [ ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V2, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, - ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4 + ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, ] for protocol_version in good_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version) + status=self.okay_status, serverProtocolVersion=protocol_version + ) thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) @@ -151,7 +165,7 @@ def test_proxy_headers_are_set(self): fake_proxy_spec = "https://someuser:somepassword@8.8.8.8:12340" parsed_proxy = urlparse(fake_proxy_spec) - + try: result = THttpClient.basic_proxy_auth_header(parsed_proxy) except TypeError as e: @@ -170,17 +184,20 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ ThriftBackend( "foo", 123, - "bar", [], + "bar", + [], auth_provider=AuthProvider(), _tls_client_cert_file=mock_cert_file, _tls_client_cert_key_file=mock_cert_key_file, _tls_client_cert_key_password=mock_cert_key_password, - _tls_trusted_ca_file=mock_trusted_ca_file) + _tls_trusted_ca_file=mock_trusted_ca_file, + ) mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) mock_ssl_context = mock_create_default_context.return_value mock_ssl_context.load_cert_chain.assert_called_once_with( - certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password) + certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password + ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) @@ -197,9 +214,12 @@ def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_cl @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") - def test_tls_verify_hostname_is_respected(self, mock_create_default_context, - t_http_client_class): - ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False) + def test_tls_verify_hostname_is_respected( + self, mock_create_default_context, t_http_client_class + ): + ThriftBackend( + "foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False + ) mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) @@ -209,41 +229,54 @@ def test_tls_verify_hostname_is_respected(self, mock_create_default_context, @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) - self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], - "https://hostname:123/path_value") + self.assertEqual( + t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider()) - self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], - "https://hostname:123/path_value") - + self.assertEqual( + t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + ) + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider()) - self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], - "https://hostname:123/path_value") + self.assertEqual( + t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129) + ThriftBackend( + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129 + ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0) + ThriftBackend( + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0 + ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None) + ThriftBackend( + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None + ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)), + columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ttypes.TColumnDesc( columnName="column 2", - typeDesc=ttypes.TTypeDesc(types=[ - ttypes.TTypeEntry(userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo")) - ])) + typeDesc=ttypes.TTypeDesc( + types=[ + ttypes.TTypeEntry(userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo")) + ] + ), + ), ] t_table_schema = ttypes.TTableSchema(columns) @@ -257,46 +290,62 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): # canary test columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE)), + columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE)), + columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE) + ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE)), + columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE) + ), ttypes.TColumnDesc( - columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.STRUCT_TYPE)) + columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.STRUCT_TYPE) + ), ] t_table_schema = ttypes.TTableSchema(columns) description = ThriftBackend._hive_schema_to_description(t_table_schema) - self.assertEqual(description, [ - ("column 1", "int", None, None, None, None, None), - ("column 2", "boolean", None, None, None, None, None), - ("column 2", "map", None, None, None, None, None), - ("", "struct", None, None, None, None, None), - ]) + self.assertEqual( + description, + [ + ("column 1", "int", None, None, None, None, None), + ("column 2", "boolean", None, None, None, None, None), + ("column 2", "map", None, None, None, None, None), + ("", "struct", None, None, None, None, None), + ], + ) def test_hive_schema_to_description_preserves_scale_and_precision(self): columns = [ ttypes.TColumnDesc( columnName="column 1", - typeDesc=ttypes.TTypeDesc(types=[ - ttypes.TTypeEntry( - ttypes.TTAllowedParameterValueEntry( - type=ttypes.TTypeId.DECIMAL_TYPE, - typeQualifiers=ttypes.TTypeQualifiers( - qualifiers={ - "precision": ttypes.TTypeQualifierValue(i32Value=10), - "scale": ttypes.TTypeQualifierValue(i32Value=100), - }))) - ])), + typeDesc=ttypes.TTypeDesc( + types=[ + ttypes.TTypeEntry( + ttypes.TTAllowedParameterValueEntry( + type=ttypes.TTypeId.DECIMAL_TYPE, + typeQualifiers=ttypes.TTypeQualifiers( + qualifiers={ + "precision": ttypes.TTypeQualifierValue(i32Value=10), + "scale": ttypes.TTypeQualifierValue(i32Value=100), + } + ), + ) + ) + ] + ), + ), ] t_table_schema = ttypes.TTableSchema(columns) description = ThriftBackend._hive_schema_to_description(t_table_schema) - self.assertEqual(description, [ - ("column 1", "decimal", None, None, 10, 100, None), - ]) + self.assertEqual( + description, + [ + ("column 1", "decimal", None, None, 10, 100, None), + ], + ) def test_make_request_checks_status_code(self): error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] @@ -311,8 +360,9 @@ def test_make_request_checks_status_code(self): self.assertIn("a detailed error message", str(cm.exception)) success_codes = [ - ttypes.TStatusCode.SUCCESS_STATUS, ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS, - ttypes.TStatusCode.STILL_EXECUTING_STATUS + ttypes.TStatusCode.SUCCESS_STATUS, + ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS, + ttypes.TStatusCode.STILL_EXECUTING_STATUS, ] for code in success_codes: @@ -329,11 +379,16 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): operationStatus=ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.ERROR_STATE, - errorMessage="some information about the error"), + errorMessage="some information about the error", + ), resultSetMetadata=None, resultSet=None, - closeOperation=None)) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + closeOperation=None, + ), + ) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider() + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -342,22 +397,25 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): for resp_type in self.execute_response_types: - lz4Compressed=Mock() - resultSet=MagicMock() + lz4Compressed = Mock() + resultSet = MagicMock() resultSet.results.startRowOffset = 0 t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus= Mock(), + operationStatus=Mock(), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, schema=MagicMock(), arrowSchema=MagicMock(), - lz4Compressed=lz4Compressed), + lz4Compressed=lz4Compressed, + ), resultSet=resultSet, - closeOperation=None)) + closeOperation=None, + ), + ) thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -370,22 +428,28 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv error_resp = ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.ERROR_STATE, - errorMessage="some information about the error") + errorMessage="some information about the error", + ) closed_resp = ttypes.TGetOperationStatusResp( - status=self.okay_status, operationState=ttypes.TOperationState.CLOSED_STATE) + status=self.okay_status, operationState=ttypes.TOperationState.CLOSED_STATE + ) - for op_state_resp, exec_resp_type in itertools.product([error_resp, closed_resp], - self.execute_response_types): + for op_state_resp, exec_resp_type in itertools.product( + [error_resp, closed_resp], self.execute_response_types + ): with self.subTest(op_state_resp=op_state_resp, exec_resp_type=exec_resp_type): tcli_service_instance = tcli_service_class.return_value t_execute_resp = exec_resp_type( status=self.okay_status, directResults=None, - operationHandle=self.operation_handle) + operationHandle=self.operation_handle, + ) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider() + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -403,10 +467,12 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): operationState=ttypes.TOperationState.ERROR_STATE, errorMessage="foo", displayMessage=display_message, - diagnosticInfo=diagnostic_info) + diagnosticInfo=diagnostic_info, + ) t_execute_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle) + status=self.okay_status, directResults=None, operationHandle=self.operation_handle + ) tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp @@ -428,7 +494,8 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla operationState=ttypes.TOperationState.ERROR_STATE, errorMessage="foo", displayMessage=display_message, - diagnosticInfo=diagnostic_info) + diagnosticInfo=diagnostic_info, + ) t_execute_resp = ttypes.TExecuteStatementResp( status=self.okay_status, @@ -436,7 +503,9 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla operationStatus=t_get_operation_status_resp, resultSetMetadata=None, resultSet=None, - closeOperation=None)) + closeOperation=None, + ), + ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp @@ -455,7 +524,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): operationStatus=ttypes.TGetOperationStatusResp(status=self.bad_status), resultSetMetadata=None, resultSet=None, - closeOperation=None)) + closeOperation=None, + ), + ) resp_2 = resp_type( status=self.okay_status, @@ -463,7 +534,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): operationStatus=None, resultSetMetadata=ttypes.TGetResultSetMetadataResp(status=self.bad_status), resultSet=None, - closeOperation=None)) + closeOperation=None, + ), + ) resp_3 = resp_type( status=self.okay_status, @@ -471,7 +544,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): operationStatus=None, resultSetMetadata=None, resultSet=ttypes.TFetchResultsResp(status=self.bad_status), - closeOperation=None)) + closeOperation=None, + ), + ) resp_4 = resp_type( status=self.okay_status, @@ -479,11 +554,15 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): operationStatus=None, resultSetMetadata=None, resultSet=None, - closeOperation=ttypes.TCloseOperationResp(status=self.bad_status))) + closeOperation=ttypes.TCloseOperationResp(status=self.bad_status), + ), + ) for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider() + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -513,17 +592,24 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se ) op_state_3 = ttypes.TGetOperationStatusResp( - status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE) + status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE + ) tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp tcli_service_instance.GetOperationStatus.side_effect = [ - op_state_1, op_state_2, op_state_3 + op_state_1, + op_state_2, + op_state_3, ] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider() + ) results_message_response = thrift_backend._handle_execute_response( - execute_resp, Mock()) - self.assertEqual(results_message_response.status, - ttypes.TOperationState.FINISHED_STATE) + execute_resp, Mock() + ) + self.assertEqual( + results_message_response.status, ttypes.TOperationState.FINISHED_STATE + ) def test_handle_execute_response_can_handle_with_direct_results(self): result_set_metadata_mock = Mock() @@ -535,16 +621,20 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ), resultSetMetadata=result_set_metadata_mock, resultSet=Mock(), - closeOperation=Mock()) + closeOperation=Mock(), + ) for resp_type in self.execute_response_types: with self.subTest(resp_type=resp_type): execute_resp = resp_type( status=self.okay_status, directResults=direct_results_message, - operationHandle=self.operation_handle) + operationHandle=self.operation_handle, + ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider() + ) thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -564,7 +654,8 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, schema=hive_schema_mock, - arrowSchema=arrow_schema_mock) + arrowSchema=arrow_schema_mock, + ) t_execute_resp = ttypes.TExecuteStatementResp( status=self.okay_status, @@ -587,7 +678,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, arrowSchema=None, - schema=hive_schema_mock) + schema=hive_schema_mock, + ) t_execute_resp = ttypes.TExecuteStatementResp( status=self.okay_status, @@ -599,15 +691,18 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) - self.assertEqual(hive_schema_mock, - thrift_backend._hive_schema_to_arrow_schema.call_args[0][0]) + self.assertEqual( + hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0] + ) @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( - self, tcli_service_class, build_queue): - for has_more_rows, resp_type in itertools.product([True, False], - self.execute_response_types): + self, tcli_service_class, build_queue + ): + for has_more_rows, resp_type in itertools.product( + [True, False], self.execute_response_types + ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() @@ -623,11 +718,13 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( hasMoreRows=has_more_rows, results=results_mock, ), - closeOperation=Mock()) + closeOperation=Mock(), + ) execute_resp = resp_type( status=self.okay_status, directResults=direct_results_message, - operationHandle=self.operation_handle) + operationHandle=self.operation_handle, + ) tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp thrift_backend = self._make_fake_thrift_backend() @@ -639,9 +736,11 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( - self, tcli_service_class, build_queue): - for has_more_rows, resp_type in itertools.product([True, False], - self.execute_response_types): + self, tcli_service_class, build_queue + ): + for has_more_rows, resp_type in itertools.product( + [True, False], self.execute_response_types + ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() @@ -650,7 +749,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( execute_resp = resp_type( status=self.okay_status, directResults=None, - operationHandle=self.operation_handle) + operationHandle=self.operation_handle, + ) fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, @@ -658,13 +758,14 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET - ) + ), ) operation_status_resp = ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE, - errorMessage="some information about the error") + errorMessage="some information about the error", + ) tcli_service_instance.FetchResults.return_value = fetch_results_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp @@ -679,7 +780,8 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( expected_row_start_offset=0, lz4_compressed=False, arrow_schema_bytes=Mock(), - description=Mock()) + description=Mock(), + ) self.assertEqual(has_more_rows, has_more_rows_resp) @@ -695,19 +797,25 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): rows=[], arrowBatches=[ ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) - ] + ], ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET - ) + ), ) tcli_service_instance.FetchResults.return_value = t_fetch_results_resp - schema = pyarrow.schema([ - pyarrow.field("column1", pyarrow.int32()), - pyarrow.field("column2", pyarrow.string()), - pyarrow.field("column3", pyarrow.float64()), - pyarrow.field("column3", pyarrow.binary()) - ]).serialize().to_pybytes() + schema = ( + pyarrow.schema( + [ + pyarrow.field("column1", pyarrow.int32()), + pyarrow.field("column2", pyarrow.string()), + pyarrow.field("column3", pyarrow.float64()), + pyarrow.field("column3", pyarrow.binary()), + ] + ) + .serialize() + .to_pybytes() + ) thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) arrow_queue, has_more_results = thrift_backend.fetch_results( @@ -717,7 +825,8 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): expected_row_start_offset=0, lz4_compressed=False, arrow_schema_bytes=schema, - description=MagicMock()) + description=MagicMock(), + ) self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @@ -771,7 +880,8 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service 200, cursor_mock, catalog_name="catalog_pattern", - schema_name="schema_pattern") + schema_name="schema_pattern", + ) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -798,7 +908,8 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - table_types=["type1", "type2"]) + table_types=["type1", "type2"], + ) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -827,7 +938,8 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - column_name="column_pattern") + column_name="column_pattern", + ) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) @@ -853,16 +965,19 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend.close_command(self.operation_handle) - self.assertEqual(tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, - self.operation_handle) + self.assertEqual( + tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, + self.operation_handle, + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend.close_session(self.session_handle) - self.assertEqual(tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, - self.session_handle) + self.assertEqual( + tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_class): @@ -871,7 +986,8 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl results_mock.startRowOffset = 0 execute_statement_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle) + status=self.okay_status, directResults=None, operationHandle=self.operation_handle + ) metadata_resp = ttypes.TGetResultSetMetadataResp( status=self.okay_status, @@ -881,7 +997,8 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl operation_status_resp = ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE, - errorMessage="some information about the error") + errorMessage="some information about the error", + ) tcli_service_instance.ExecuteStatement.return_value = execute_statement_resp tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp @@ -900,8 +1017,9 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") - def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock, - convert_arrow_mock): + def test_create_arrow_table_calls_correct_conversion_method( + self, convert_col_mock, convert_arrow_mock + ): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -925,14 +1043,23 @@ def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mo @patch("pyarrow.ipc.open_stream") def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) - - lz4_decompress_mock.return_value = bytearray('Testing','utf-8') - - schema = pyarrow.schema([ - pyarrow.field("column1", pyarrow.int32()), - ]).serialize().to_pybytes() - - arrow_batches = [ttypes.TSparkArrowBatch(batch=bytearray('Testing','utf-8'), rowCount=1) for _ in range(10)] + + lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") + + schema = ( + pyarrow.schema( + [ + pyarrow.field("column1", pyarrow.int32()), + ] + ) + .serialize() + .to_pybytes() + ) + + arrow_batches = [ + ttypes.TSparkArrowBatch(batch=bytearray("Testing", "utf-8"), rowCount=1) + for _ in range(10) + ] utils.convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) lz4_decompress_mock.assert_not_called() @@ -942,19 +1069,20 @@ def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_deco def test_convert_column_based_set_to_arrow_table_without_nulls(self): # Deliberately duplicate the column name to check that dups work field_names = ["column1", "column2", "column3", "column3"] - description = [(name, ) for name in field_names] + description = [(name,) for name in field_names] t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1))), + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + ), ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -972,48 +1100,50 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): self.assertEqual(arrow_table.column(0).to_pylist(), [1, 2, 3]) self.assertEqual(arrow_table.column(1).to_pylist(), ["s1", "s2", "s3"]) self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) - self.assertEqual(arrow_table.column(3).to_pylist(), [b'\x11', b'\x22', b'\x33']) + self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) def test_convert_column_based_set_to_arrow_table_with_nulls(self): field_names = ["column1", "column2", "column3", "column3"] - description = [(name, ) for name in field_names] + description = [(name,) for name in field_names] t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1]))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes([2]))), + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes([2])) + ), ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes([4]))), + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes([4])) + ), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn( - values=[b'\x11', b'\x22', b'\x33'], nulls=bytes([3]))) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3])) + ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check data self.assertEqual(arrow_table.column(0).to_pylist(), [None, 2, 3]) self.assertEqual(arrow_table.column(1).to_pylist(), ["s1", None, "s3"]) self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, None]) - self.assertEqual(arrow_table.column(3).to_pylist(), [None, None, b'\x33']) + self.assertEqual(arrow_table.column(3).to_pylist(), [None, None, b"\x33"]) def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): field_names = ["column1", "column2", "column3", "column3"] - description = [(name, ) for name in field_names] + description = [(name,) for name in field_names] t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1))), + stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + ), ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) + binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( - t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1031,7 +1161,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(0).to_pylist(), [1, 2, 3]) self.assertEqual(arrow_table.column(1).to_pylist(), ["s1", "s2", "s3"]) self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) - self.assertEqual(arrow_table.column(3).to_pylist(), [b'\x11', b'\x22', b'\x33']) + self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): @@ -1041,8 +1171,10 @@ def test_cancel_command_uses_active_op_handle(self, tcli_service_class): active_op_handle_mock = Mock() thrift_backend.cancel_command(active_op_handle_mock) - self.assertEqual(tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, - active_op_handle_mock) + self.assertEqual( + tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, + active_op_handle_mock, + ) def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() @@ -1060,7 +1192,8 @@ def test_handle_execute_response_sets_active_op_handle(self): @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus( - self, mock_retry_policy, mock_GetOperationStatus, t_transport_class): + self, mock_retry_policy, mock_GetOperationStatus, t_transport_class + ): import thrift, errno from databricks.sql.thrift_api.TCLIService.TCLIService import Client @@ -1084,17 +1217,20 @@ def test_make_request_will_retry_GetOperationStatus( thrift_backend = ThriftBackend( "foobar", 443, - "path", [], + "path", + [], auth_provider=AuthProvider(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, - _retry_delay_default=1) - + _retry_delay_default=1, + ) with self.assertRaises(RequestError) as cm: thrift_backend.make_request(client.GetOperationStatus, req) - self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) - self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) + self.assertEqual( + NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] + ) + self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) # Unusual OSError code mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") @@ -1110,14 +1246,19 @@ def test_make_request_will_retry_GetOperationStatus( self.assertEqual(cm.output[1], cm.output[0]) # The warnings should include this text - self.assertIn(f"{this_gos_name} failed with code {errno.EEXIST} and will attempt to retry", cm.output[0]) + self.assertIn( + f"{this_gos_name} failed with code {errno.EEXIST} and will attempt to retry", + cm.output[0], + ) @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus_for_http_error( - self, mock_retry_policy, mock_gos): + self, mock_retry_policy, mock_gos + ): import urllib3.exceptions + mock_gos.side_effect = urllib3.exceptions.HTTPError("Read timed out") import thrift, errno @@ -1142,37 +1283,20 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( thrift_backend = ThriftBackend( "foobar", 443, - "path", [], + "path", + [], auth_provider=AuthProvider(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, - _retry_delay_default=1) - + _retry_delay_default=1, + ) with self.assertRaises(RequestError) as cm: thrift_backend.make_request(client.GetOperationStatus, req) - - self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) - self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) - - - - - @patch("thrift.transport.THttpClient.THttpClient") - def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class): - t_transport_instance = t_transport_class.return_value - t_transport_instance.code = 429 - t_transport_instance.headers = {"foo": "bar"} - mock_method = Mock() - mock_method.__name__ = "method name" - mock_method.side_effect = Exception("This method fails") - - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) - - with self.assertRaises(OperationalError) as cm: - thrift_backend.make_request(mock_method, Mock()) - - self.assertIn("This method fails", str(cm.exception.message_with_context())) + self.assertEqual( + NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] + ) + self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) @patch("thrift.transport.THttpClient.THttpClient") def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_class): @@ -1193,7 +1317,8 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( - self, mock_retry_policy, t_transport_class): + self, mock_retry_policy, t_transport_class + ): t_transport_instance = t_transport_class.return_value t_transport_instance.code = 429 t_transport_instance.headers = {"Retry-After": "0"} @@ -1204,11 +1329,13 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( thrift_backend = ThriftBackend( "foobar", 443, - "path", [], + "path", + [], auth_provider=AuthProvider(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, - _retry_delay_min=0) + _retry_delay_min=0, + ) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1227,15 +1354,23 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) - error_headers = [[("x-thriftserver-error-message", "thrift server error message")], - [("x-databricks-error-or-redirect-message", "databricks error message")], - [("x-databricks-error-or-redirect-message", "databricks error message"), - ("x-databricks-reason-phrase", "databricks error reason")], - [("x-thriftserver-error-message", "thrift server error message"), - ("x-databricks-error-or-redirect-message", "databricks error message"), - ("x-databricks-reason-phrase", "databricks error reason")], - [("x-thriftserver-error-message", "thrift server error message"), - ("x-databricks-error-or-redirect-message", "databricks error message")]] + error_headers = [ + [("x-thriftserver-error-message", "thrift server error message")], + [("x-databricks-error-or-redirect-message", "databricks error message")], + [ + ("x-databricks-error-or-redirect-message", "databricks error message"), + ("x-databricks-reason-phrase", "databricks error reason"), + ], + [ + ("x-thriftserver-error-message", "thrift server error message"), + ("x-databricks-error-or-redirect-message", "databricks error message"), + ("x-databricks-reason-phrase", "databricks error reason"), + ], + [ + ("x-thriftserver-error-message", "thrift server error message"), + ("x-databricks-error-or-redirect-message", "databricks error message"), + ], + ] for headers in error_headers: t_transport_instance.headers = dict(headers) @@ -1246,16 +1381,17 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c self.assertIn(header[1], str(cm.exception)) @staticmethod - def make_table_and_desc(height, n_decimal_cols, width, precision, scale, int_constant, - decimal_constant): + def make_table_and_desc( + height, n_decimal_cols, width, precision, scale, int_constant, decimal_constant + ): int_col = [int_constant for _ in range(height)] decimal_col = [decimal_constant for _ in range(height)] data = OrderedDict({"col{}".format(i): int_col for i in range(width - n_decimal_cols)}) decimals = OrderedDict({"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)}) data.update(decimals) - int_desc = ([("", "int")] * (width - n_decimal_cols)) - decimal_desc = ([("", "decimal", None, None, precision, scale, None)] * n_decimal_cols) + int_desc = [("", "int")] * (width - n_decimal_cols) + decimal_desc = [("", "decimal", None, None, precision, scale, None)] * n_decimal_cols description = int_desc + decimal_desc table = pyarrow.Table.from_pydict(data) @@ -1271,30 +1407,39 @@ def test_arrow_decimal_conversion(self): for n_decimal_cols in [0, 1, 10]: for height in [0, 1, 10]: with self.subTest(n_decimal_cols=n_decimal_cols, height=height): - table, description = self.make_table_and_desc(height, n_decimal_cols, width, - precision, scale, int_constant, - decimal_constant) + table, description = self.make_table_and_desc( + height, + n_decimal_cols, + width, + precision, + scale, + int_constant, + decimal_constant, + ) decimal_converted_table = utils.convert_decimals_in_arrow_table( - table, description) + table, description + ) for i in range(width): if height > 0: if i < width - n_decimal_cols: self.assertEqual( - decimal_converted_table.field(i).type, pyarrow.int64()) + decimal_converted_table.field(i).type, pyarrow.int64() + ) else: self.assertEqual( decimal_converted_table.field(i).type, - pyarrow.decimal128(precision=precision, scale=scale)) + pyarrow.decimal128(precision=precision, scale=scale), + ) int_col = [int_constant for _ in range(height)] decimal_col = [Decimal(decimal_constant) for _ in range(height)] expected_result = OrderedDict( - {"col{}".format(i): int_col - for i in range(width - n_decimal_cols)}) + {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} + ) decimals = OrderedDict( - {"col_dec{}".format(i): decimal_col - for i in range(n_decimal_cols)}) + {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} + ) expected_result.update(decimals) self.assertEqual(decimal_converted_table.to_pydict(), expected_result) @@ -1305,29 +1450,31 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_delay_min": 6, "_retry_delay_max": 10, "_retry_stop_after_attempts_count": 1, - "_retry_stop_after_attempts_duration": 100 + "_retry_stop_after_attempts_duration": 100, } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args) - for (arg, val) in retry_delay_args.items(): + backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + ) + for arg, val in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for (k, (_, _, min, max)) in databricks.sql.thrift_backend._retry_policy.items(): + for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ((min - 1, min), (max + 1, max)) for i in range(2): retry_delay_args = { - k: v[i][0] - for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args) + backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + ) retry_delay_expected_vals = { - k: v[i][1] - for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() } - for (arg, val) in retry_delay_expected_vals.items(): + for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) @@ -1339,7 +1486,7 @@ def test_configuration_passthrough(self, tcli_client_class): "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false", "foo": "bar", "baz": "True", - "42": "42" + "42": "42", } backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -1365,7 +1512,8 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch status=self.okay_status, serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, canUseMultipleCatalogs=can_use_multiple_cats, - initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem)) + initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): @@ -1376,8 +1524,9 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): for cat, schem in initial_cat_schem_args: with self.subTest(cat=cat, schem=schem): - tcli_service_instance.OpenSession.return_value = \ + tcli_service_instance.OpenSession.return_value = ( self._construct_open_session_with_namespace(True, cat, schem) + ) backend.open_session({}, cat, schem) @@ -1408,48 +1557,55 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcl passing_ns_args = [(None, None), (None, "schem")] for cat, schem in failing_ns_args: - tcli_service_instance.OpenSession.return_value = \ + tcli_service_instance.OpenSession.return_value = ( self._construct_open_session_with_namespace(False, cat, schem) + ) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, cat, schem) - self.assertIn("server does not support multiple catalogs", str(cm.exception), - "incorrect error thrown for initial namespace {}".format((cat, schem))) + self.assertIn( + "server does not support multiple catalogs", + str(cm.exception), + "incorrect error thrown for initial namespace {}".format((cat, schem)), + ) for cat, schem in passing_ns_args: - tcli_service_instance.OpenSession.return_value = \ + tcli_service_instance.OpenSession.return_value = ( self._construct_open_session_with_namespace(False, cat, schem) + ) backend.open_session({}, cat, schem) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - tcli_service_instance.OpenSession.return_value = \ - ttypes.TOpenSessionResp( - status=self.okay_status, - serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, - canUseMultipleCatalogs=True, - initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem") - ) + tcli_service_instance.OpenSession.return_value = ttypes.TOpenSessionResp( + status=self.okay_status, + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, + canUseMultipleCatalogs=True, + initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), + ) backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") - self.assertIn("Setting initial namespace not supported by the DBR version", - str(cm.exception)) + self.assertIn( + "Setting initial namespace not supported by the DBR version", str(cm.exception) + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") - def test_execute_command_sets_complex_type_fields_correctly(self, mock_handle_execute_response, - tcli_service_class): + def test_execute_command_sets_complex_type_fields_correctly( + self, mock_handle_execute_response, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value # Iterate through each possible combination of native types (True, False and unset) - for (complex, timestamp, decimals) in itertools.product( - [True, False, None], [True, False, None], [True, False, None]): + for complex, timestamp, decimals in itertools.product( + [True, False, None], [True, False, None], [True, False, None] + ): complex_arg_types = {} if complex is not None: complex_arg_types["_use_arrow_native_complex_types"] = complex @@ -1458,18 +1614,26 @@ def test_execute_command_sets_complex_type_fields_correctly(self, mock_handle_ex if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **complex_arg_types) + thrift_backend = ThriftBackend( + "foobar", 443, "path", [], auth_provider=AuthProvider(), **complex_arg_types + ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] # If the value is unset, the native type should default to True - self.assertEqual(t_execute_statement_req.useArrowNativeTypes.timestampAsArrow, - complex_arg_types.get("_use_arrow_native_timestamps", True)) - self.assertEqual(t_execute_statement_req.useArrowNativeTypes.decimalAsArrow, - complex_arg_types.get("_use_arrow_native_decimals", True)) - self.assertEqual(t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow, - complex_arg_types.get("_use_arrow_native_complex_types", True)) + self.assertEqual( + t_execute_statement_req.useArrowNativeTypes.timestampAsArrow, + complex_arg_types.get("_use_arrow_native_timestamps", True), + ) + self.assertEqual( + t_execute_statement_req.useArrowNativeTypes.decimalAsArrow, + complex_arg_types.get("_use_arrow_native_decimals", True), + ) + self.assertEqual( + t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow, + complex_arg_types.get("_use_arrow_native_complex_types", True), + ) self.assertFalse(t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()