Skip to content

Commit

Permalink
Setup working dynamic change from ColumnQueue to ArrowQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Aug 7, 2024
1 parent 1cfaae2 commit c576110
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,8 @@ def execute(
parameters=prepared_params,
)

print("Line 781")
print(execute_response)
# print("Line 781")
# print(execute_response)
self.active_result_set = ResultSet(
self.connection,
execute_response,
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def _fill_results_buffer(self):
def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
print("Table\n",table)
# print("Table\n",table)
result = []
for row_index in range(len(table[0])):
curr_row = []
Expand All @@ -1164,23 +1164,20 @@ def _convert_arrow_table(self, table):
# Need to use nullable types, as otherwise type can change when there are missing values.
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
try:
dtype_mapping = {
pyarrow.int8(): pandas.Int8Dtype(),
pyarrow.int16(): pandas.Int16Dtype(),
pyarrow.int32(): pandas.Int32Dtype(),
pyarrow.int64(): pandas.Int64Dtype(),
pyarrow.uint8(): pandas.UInt8Dtype(),
pyarrow.uint16(): pandas.UInt16Dtype(),
pyarrow.uint32(): pandas.UInt32Dtype(),
pyarrow.uint64(): pandas.UInt64Dtype(),
pyarrow.bool_(): pandas.BooleanDtype(),
pyarrow.float32(): pandas.Float32Dtype(),
pyarrow.float64(): pandas.Float64Dtype(),
pyarrow.string(): pandas.StringDtype(),
}
except AttributeError:
print("pyarrow is not present")
dtype_mapping = {
pyarrow.int8(): pandas.Int8Dtype(),
pyarrow.int16(): pandas.Int16Dtype(),
pyarrow.int32(): pandas.Int32Dtype(),
pyarrow.int64(): pandas.Int64Dtype(),
pyarrow.uint8(): pandas.UInt8Dtype(),
pyarrow.uint16(): pandas.UInt16Dtype(),
pyarrow.uint32(): pandas.UInt32Dtype(),
pyarrow.uint64(): pandas.UInt64Dtype(),
pyarrow.bool_(): pandas.BooleanDtype(),
pyarrow.float32(): pandas.Float32Dtype(),
pyarrow.float64(): pandas.Float64Dtype(),
pyarrow.string(): pandas.StringDtype(),
}

# Need to rename columns, as the to_pandas function cannot handle duplicate column names
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
Expand Down Expand Up @@ -1222,6 +1219,20 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":

return results

def fetchmany_columnar(self, size: int):
"""
Fetch the next set of rows of a query result, returning a Columnar Table.
An empty sequence is returned when no more rows are available.
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

results = self.results.next_n_rows(size)
self._next_row_index += results.num_rows

return results

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
Expand All @@ -1245,7 +1256,11 @@ def fetchone(self) -> Optional[Row]:
Fetch the next row of a query result set, returning a single sequence,
or None when no more data is available.
"""
res = self._convert_arrow_table(self.fetchmany_arrow(1))
if isinstance(self.results, ColumnQueue):
res = self._convert_columnar_table(self.fetchmany_columnar(1))
else:
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if len(res) > 0:
return res[0]
else:
Expand All @@ -1260,14 +1275,16 @@ def fetchall(self) -> List[Row]:
else:
return self._convert_arrow_table(self.fetchall_arrow())


def fetchmany(self, size: int) -> List[Row]:
"""
Fetch the next set of rows of a query result, returning a list of rows.
An empty sequence is returned when no more rows are available.
"""
return self._convert_arrow_table(self.fetchmany_arrow(size))
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchmany_columnar(size))
else:
return self._convert_arrow_table(self.fetchmany_arrow(size))

def close(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
else:
t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle)

print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
# print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
if t_result_set_metadata_resp.resultFormat not in [
ttypes.TSparkRowSetType.ARROW_BASED_SET,
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
Expand Down Expand Up @@ -873,7 +873,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=False,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,16 @@ def build_queue(
ResultSetQueue
"""

def trow_to_json(trow):
# Step 1: Serialize TRow using Thrift's TJSONProtocol
transport = TTransport.TMemoryBuffer()
protocol = TJSONProtocol.TJSONProtocol(transport)
trow.write(protocol)

# Step 2: Extract JSON string from the transport
json_str = transport.getvalue().decode('utf-8')

return json_str
# def trow_to_json(trow):
# # Step 1: Serialize TRow using Thrift's TJSONProtocol
# transport = TTransport.TMemoryBuffer()
# protocol = TJSONProtocol.TJSONProtocol(transport)
# trow.write(protocol)
#
# # Step 2: Extract JSON string from the transport
# json_str = transport.getvalue().decode('utf-8')
#
# return json_str

if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
Expand All @@ -95,30 +95,30 @@ def trow_to_json(trow):
)
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
print("Lin 79 ")
print(type(t_row_set))
print(t_row_set)
json_str = json.loads(trow_to_json(t_row_set))
pretty_json = json.dumps(json_str, indent=2)
print(pretty_json)
# print("Lin 79 ")
# print(type(t_row_set))
# print(t_row_set)
# json_str = json.loads(trow_to_json(t_row_set))
# pretty_json = json.dumps(json_str, indent=2)
# print(pretty_json)

converted_column_table, column_names = convert_column_based_set_to_column_table(
t_row_set.columns,
description)
print(converted_column_table, column_names)
# print(converted_column_table, column_names)

return ColumnQueue(converted_column_table, column_names)

print(columnQueue.next_n_rows(2))
print(columnQueue.next_n_rows(2))
print(columnQueue.remaining_rows())
arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
converted_arrow_table = convert_decimals_in_arrow_table(
arrow_table, description
)
return ArrowQueue(converted_arrow_table, n_valid_rows)
# print(columnQueue.next_n_rows(2))
# print(columnQueue.next_n_rows(2))
# print(columnQueue.remaining_rows())
# arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
# t_row_set.columns, description
# )
# converted_arrow_table = convert_decimals_in_arrow_table(
# arrow_table, description
# )
# return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
return CloudFetchQueue(
schema_bytes=arrow_schema_bytes,
Expand Down
2 changes: 1 addition & 1 deletion setup_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def build_and_install_library(directory_name):


if __name__ == "__main__":
build_and_install_library("databricks_sql_connector_core")
# build_and_install_library("databricks_sql_connector_core")
build_and_install_library("databricks_sql_connector")
# build_and_install_library("databricks_sqlalchemy")

0 comments on commit c576110

Please sign in to comment.