Skip to content

Commit

Permalink
Refractored the test code and moved to respective folders
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Aug 8, 2024
1 parent c576110 commit 0ddca9d
Show file tree
Hide file tree
Showing 46 changed files with 588 additions and 633 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,6 @@ def execute(
parameters=prepared_params,
)

# print("Line 781")
# print(execute_response)
self.active_result_set = ResultSet(
self.connection,
execute_response,
Expand Down Expand Up @@ -1141,7 +1139,6 @@ def _fill_results_buffer(self):
def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
# print("Table\n",table)
result = []
for row_index in range(len(table[0])):
curr_row = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
else:
t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle)

# print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
if t_result_set_metadata_resp.resultFormat not in [
ttypes.TSparkRowSetType.ARROW_BASED_SET,
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
Expand Down Expand Up @@ -880,18 +879,8 @@ def execute_command(
# We want to receive proper Timestamp arrow types.
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
},
# useArrowNativeTypes=spark_arrow_types,
# canReadArrowResult=True,
# # canDecompressLZ4Result=lz4_compression,
# canDecompressLZ4Result=False,
# canDownloadResult=False,
# # confOverlay={
# # # We want to receive proper Timestamp arrow types.
# # "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
# # },
# resultDataFormat=TDBSqlResultFormat(None,None,True),
# # useArrowNativeTypes=spark_arrow_types,
parameters=parameters,
useArrowNativeTypes=spark_arrow_types,
parameters=parameters,
)
resp = self.make_request(self._client.ExecuteStatement, req)
return self._handle_execute_response(resp, cursor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,6 @@ def build_queue(
Returns:
ResultSetQueue
"""

# def trow_to_json(trow):
# # Step 1: Serialize TRow using Thrift's TJSONProtocol
# transport = TTransport.TMemoryBuffer()
# protocol = TJSONProtocol.TJSONProtocol(transport)
# trow.write(protocol)
#
# # Step 2: Extract JSON string from the transport
# json_str = transport.getvalue().decode('utf-8')
#
# return json_str

if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
Expand All @@ -95,30 +83,11 @@ def build_queue(
)
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
# print("Lin 79 ")
# print(type(t_row_set))
# print(t_row_set)
# json_str = json.loads(trow_to_json(t_row_set))
# pretty_json = json.dumps(json_str, indent=2)
# print(pretty_json)

converted_column_table, column_names = convert_column_based_set_to_column_table(
t_row_set.columns,
description)
# print(converted_column_table, column_names)

return ColumnQueue(converted_column_table, column_names)

# print(columnQueue.next_n_rows(2))
# print(columnQueue.next_n_rows(2))
# print(columnQueue.remaining_rows())
# arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
# t_row_set.columns, description
# )
# converted_arrow_table = convert_decimals_in_arrow_table(
# arrow_table, description
# )
# return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
return CloudFetchQueue(
schema_bytes=arrow_schema_bytes,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def pysql_supports_arrow():
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
from databricks.sql import Cursor
"""Import databricks_sql_connector_core.sql and test whether Cursor has fetchall_arrow."""
from databricks_sql_connector_core.sql.client import Cursor
return hasattr(Cursor, 'fetchall_arrow')


def pysql_has_version(compare, version):
"""Import databricks.sql, and return compare_module_version(...).
"""Import databricks_sql_connector_core.sql, and return compare_module_version(...).
Expected use:
from common.predicates import pysql_has_version
Expand Down Expand Up @@ -98,4 +98,4 @@ def validate_version(version):

mod_version = validate_version(module.__version__)
req_version = validate_version(version)
return compare_versions(compare, mod_version, req_version)
return compare_versions(compare, mod_version, req_version)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
from urllib3.exceptions import MaxRetryError

from databricks.sql import DatabricksRetryPolicy
from databricks.sql import (
from databricks_sql_connector_core.sql.auth.retry import DatabricksRetryPolicy
from databricks_sql_connector_core.sql.exc import (
MaxRetryDurationError,
NonRecoverableNetworkError,
RequestError,
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_retry_urllib3_settings_are_honored(self):
def test_oserror_retries(self):
"""If a network error occurs during make_request, the request is retried according to policy"""
with patch(
"urllib3.connectionpool.HTTPSConnectionPool._validate_conn",
"urllib3.connectionpool.HTTPSConnectionPool._validate_conn",
) as mock_validate_conn:
mock_validate_conn.side_effect = OSError("Some arbitrary network error")
with pytest.raises(MaxRetryError) as cm:
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_retry_safe_execute_statement_retry_condition(self):
]

with self.connection(
extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1}
extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1}
) as conn:
with conn.cursor() as cursor:
# Code 502 is a Bad Gateway, which we commonly see in production under heavy load
Expand Down Expand Up @@ -318,9 +318,9 @@ def test_retry_abort_close_operation_on_404(self, caplog):
with self.connection(extra_params={**self._retry_policy}) as conn:
with conn.cursor() as curs:
with patch(
"databricks.sql.utils.ExecuteResponse.has_been_closed_server_side",
new_callable=PropertyMock,
return_value=False,
"databricks_sql_connector_core.sql.utils.ExecuteResponse.has_been_closed_server_side",
new_callable=PropertyMock,
return_value=False,
):
# This call guarantees we have an open cursor at the server
curs.execute("SELECT 1")
Expand All @@ -340,10 +340,10 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self):
with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj:
with pytest.raises(MaxRetryError) as cm:
with self.connection(
extra_params={
**self._retry_policy,
"_retry_max_redirects": max_redirects,
}
extra_params={
**self._retry_policy,
"_retry_max_redirects": max_redirects,
}
):
pass
assert "too many redirects" == str(cm.value.reason)
Expand All @@ -362,9 +362,9 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self):
with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj:
with pytest.raises(MaxRetryError) as cm:
with self.connection(
extra_params={
**self._retry_policy,
}
extra_params={
**self._retry_policy,
}
):
pass

Expand Down Expand Up @@ -394,13 +394,13 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self):

def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog):
with self.connection(
extra_params={
**self._retry_policy,
**{
"_retry_max_redirects": 100,
"_retry_stop_after_attempts_count": 1,
},
}
extra_params={
**self._retry_policy,
**{
"_retry_max_redirects": 100,
"_retry_stop_after_attempts_count": 1,
},
}
):
assert "it will have no affect!" in caplog.text

Expand Down Expand Up @@ -433,4 +433,4 @@ def test_401_not_retried(self):
with pytest.raises(RequestError) as cm:
with self.connection(extra_params=self._retry_policy):
pass
assert isinstance(cm.value.args[1], NonRecoverableNetworkError)
assert isinstance(cm.value.args[1], NonRecoverableNetworkError)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import tempfile

import pytest
import databricks.sql as sql
from databricks.sql import Error
import databricks_sql_connector_core.sql as sql
from databricks_sql_connector_core.sql import Error


@pytest.fixture(scope="module", autouse=True)
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, in
cursor.execute(query)

def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path(
self, ingestion_user
self, ingestion_user
):

fh, temp_path = tempfile.mkstemp()
Expand All @@ -116,8 +116,8 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p
base_path = os.path.join(base_path, "temp")

with pytest.raises(
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
):
with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn:
cursor = conn.cursor()
Expand Down Expand Up @@ -158,7 +158,7 @@ def perform_remove():

# Try to put it again
with pytest.raises(
sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"
sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"
):
perform_put()

Expand Down Expand Up @@ -209,7 +209,7 @@ def perform_get():
perform_get()

def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path(
self, ingestion_user
self, ingestion_user
):
"""
This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths.
Expand All @@ -222,11 +222,11 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe
target_file = "/var/www/html/../html1/not_allowed.html"

with pytest.raises(
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
):
with self.connection(
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
) as conn:
cursor = conn.cursor()
query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
Expand All @@ -238,7 +238,7 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, inges

with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"):
with self.connection(
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
) as conn:
cursor = conn.cursor()
query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
Expand All @@ -250,14 +250,14 @@ def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_

with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"):
with self.connection(
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
) as conn:
cursor = conn.cursor()
query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
cursor.execute(query)

def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values(
self, ingestion_user
self, ingestion_user
):
"""staging_allowed_local_path may be either a path-like object or a list of path-like objects.
Expand Down Expand Up @@ -286,19 +286,19 @@ def generate_file_and_path_and_queries():
fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries()

with self.connection(
extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}
extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}
) as conn:
cursor = conn.cursor()

cursor.execute(put_query1)
cursor.execute(put_query2)

with pytest.raises(
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
Error,
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
):
cursor.execute(put_query3)

# Then clean up the files we made
cursor.execute(remove_query1)
cursor.execute(remove_query2)
cursor.execute(remove_query2)
File renamed without changes.
Loading

0 comments on commit 0ddca9d

Please sign in to comment.