From 61a8ed5fa41303821623184e92034ada298ff4ae Mon Sep 17 00:00:00 2001 From: Jakub Grzegorzewski Date: Fri, 12 Jan 2024 20:01:50 +0000 Subject: [PATCH] pipe proxy override through connection and clients --- src/databricks/sql/auth/thrift_http_client.py | 8 ++++++-- src/databricks/sql/client.py | 8 +++++--- src/databricks/sql/thrift_backend.py | 5 ++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 11589258..8c0c5837 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,7 +1,7 @@ import base64 import logging import urllib.parse -from typing import Dict, Union +from typing import Dict, Optional, Union import six import thrift @@ -31,6 +31,7 @@ def __init__( ssl_context=None, max_connections: int = 1, retry_policy: Union[DatabricksRetryPolicy, int] = 0, + proxies: Optional[Dict[str, str]] = None, ): if port is not None: warnings.warn( @@ -60,8 +61,11 @@ def __init__( self.path = parsed.path if parsed.query: self.path += "?%s" % parsed.query + + if proxies is None: + proxies = urllib.request.getproxies() try: - proxy = urllib.request.getproxies()[self.scheme] + proxy = proxies[self.scheme] except KeyError: proxy = None else: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 7417161f..4876c6cd 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -160,6 +160,7 @@ def read(self) -> Optional[OAuthToken]: STRUCT is returned as Dict[str, Any] ARRAY is returned as numpy.ndarray When False, complex types are returned as a strings. These are generally deserializable as JSON. + :param proxies: An optional dictionary mapping protocol to the URL of the proxy. """ # Internal arguments in **kwargs: @@ -208,6 +209,7 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.proxies = kwargs.get("proxies") auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs @@ -648,7 +650,7 @@ def _handle_staging_put( raise Error("Cannot perform PUT without specifying a local_file") with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = requests.put(url=presigned_url, data=fh, headers=headers, proxies=self.connection.proxies) # fmt: off # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 @@ -682,7 +684,7 @@ def _handle_staging_get( if local_file is None: raise Error("Cannot perform GET without specifying a local_file") - r = requests.get(url=presigned_url, headers=headers) + r = requests.get(url=presigned_url, headers=headers, proxies=self.connection.proxies) # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -697,7 +699,7 @@ def _handle_staging_get( def _handle_staging_remove(self, presigned_url: str, headers: dict = None): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = requests.delete(url=presigned_url, headers=headers, proxies=self.connection.proxies) if not r.ok: raise Error( diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 288c3e10..b6b164ab 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -6,7 +6,7 @@ import uuid import threading from ssl import CERT_NONE, CERT_REQUIRED, create_default_context -from typing import List, Union +from typing import List, Union, Optional, Dict import pyarrow import thrift.transport.THttpClient @@ -218,6 +218,9 @@ def __init__( additional_transport_args["retry_policy"] = self.retry_policy + if "proxies" in kwargs: + additional_transport_args["proxies"] = kwargs["proxies"] + self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri,