Skip to content

Commit

Permalink
pipe proxy override through connection and clients
Browse files Browse the repository at this point in the history
  • Loading branch information
jakewski committed Jan 12, 2024
1 parent 3f6834c commit 61a8ed5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
8 changes: 6 additions & 2 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import logging
import urllib.parse
from typing import Dict, Union
from typing import Dict, Optional, Union

import six
import thrift
Expand Down Expand Up @@ -31,6 +31,7 @@ def __init__(
ssl_context=None,
max_connections: int = 1,
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
proxies: Optional[Dict[str, str]] = None,
):
if port is not None:
warnings.warn(
Expand Down Expand Up @@ -60,8 +61,11 @@ def __init__(
self.path = parsed.path
if parsed.query:
self.path += "?%s" % parsed.query

if proxies is None:
proxies = urllib.request.getproxies()
try:
proxy = urllib.request.getproxies()[self.scheme]
proxy = proxies[self.scheme]
except KeyError:
proxy = None
else:
Expand Down
8 changes: 5 additions & 3 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def read(self) -> Optional[OAuthToken]:
STRUCT is returned as Dict[str, Any]
ARRAY is returned as numpy.ndarray
When False, complex types are returned as a strings. These are generally deserializable as JSON.
:param proxies: An optional dictionary mapping protocol to the URL of the proxy.
"""

# Internal arguments in **kwargs:
Expand Down Expand Up @@ -208,6 +209,7 @@ def read(self) -> Optional[OAuthToken]:
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.proxies = kwargs.get("proxies")

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
Expand Down Expand Up @@ -648,7 +650,7 @@ def _handle_staging_put(
raise Error("Cannot perform PUT without specifying a local_file")

with open(local_file, "rb") as fh:
r = requests.put(url=presigned_url, data=fh, headers=headers)
r = requests.put(url=presigned_url, data=fh, headers=headers, proxies=self.connection.proxies)

# fmt: off
# Design borrowed from: https://stackoverflow.com/a/2342589/5093960
Expand Down Expand Up @@ -682,7 +684,7 @@ def _handle_staging_get(
if local_file is None:
raise Error("Cannot perform GET without specifying a local_file")

r = requests.get(url=presigned_url, headers=headers)
r = requests.get(url=presigned_url, headers=headers, proxies=self.connection.proxies)

# response.ok verifies the status code is not between 400-600.
# Any 2xx or 3xx will evaluate r.ok == True
Expand All @@ -697,7 +699,7 @@ def _handle_staging_get(
def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
"""Make an HTTP DELETE request to the presigned_url"""

r = requests.delete(url=presigned_url, headers=headers)
r = requests.delete(url=presigned_url, headers=headers, proxies=self.connection.proxies)

if not r.ok:
raise Error(
Expand Down
5 changes: 4 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
import threading
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union
from typing import List, Union, Optional, Dict

import pyarrow
import thrift.transport.THttpClient
Expand Down Expand Up @@ -218,6 +218,9 @@ def __init__(

additional_transport_args["retry_policy"] = self.retry_policy

if "proxies" in kwargs:
additional_transport_args["proxies"] = kwargs["proxies"]

self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
Expand Down

0 comments on commit 61a8ed5

Please sign in to comment.