Skip to content

Commit

Permalink
[PECO-1926] Create a non pyarrow flow to handle small results for the…
Browse files Browse the repository at this point in the history
… column set (#440)

* Implemented the columnar flow for non arrow users

* Minor fixes

* Introduced the Column Table structure

* Added test for the new column table

* Minor fix

* Removed unnecessory fikes
  • Loading branch information
jprakash-db authored Oct 3, 2024
1 parent d31063c commit a151df2
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 32 deletions.
96 changes: 88 additions & 8 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import requests
import json
import os
Expand All @@ -22,6 +25,8 @@
ParamEscaper,
inject_parameters,
transform_paramstyle,
ColumnTable,
ColumnQueue
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
Expand Down Expand Up @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self):
self.results = results
self.has_more_rows = has_more_rows

def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
result = []
for row_index in range(table.num_rows):
curr_row = []
for col_index in range(table.num_columns):
curr_row.append(table.get_item(col_index, row_index))
result.append(ResultRow(*curr_row))

return result

def _convert_arrow_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
Expand Down Expand Up @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -1210,7 +1227,46 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def merge_columnar(self, result1, result2):
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""

if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
return ColumnTable(merged_result, result1.column_names)

def fetchmany_columnar(self, size: int):
"""
Fetch the next set of rows of a query result, returning a Columnar Table.
An empty sequence is returned when no more rows are available.
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand All @@ -1223,12 +1279,30 @@ def fetchall_arrow(self) -> pyarrow.Table:

return results

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows

while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
self._next_row_index += partial_results.num_rows

return results

def fetchone(self) -> Optional[Row]:
"""
Fetch the next row of a query result set, returning a single sequence,
or None when no more data is available.
"""
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if isinstance(self.results, ColumnQueue):
res = self._convert_columnar_table(self.fetchmany_columnar(1))
else:
res = self._convert_arrow_table(self.fetchmany_arrow(1))

if len(res) > 0:
return res[0]
else:
Expand All @@ -1238,15 +1312,21 @@ def fetchall(self) -> List[Row]:
"""
Fetch all (remaining) rows of a query result, returning them as a list of rows.
"""
return self._convert_arrow_table(self.fetchall_arrow())
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchall_columnar())
else:
return self._convert_arrow_table(self.fetchall_arrow())

def fetchmany(self, size: int) -> List[Row]:
"""
Fetch the next set of rows of a query result, returning a list of rows.
An empty sequence is returned when no more rows are available.
"""
return self._convert_arrow_table(self.fetchmany_arrow(size))
if isinstance(self.results, ColumnQueue):
return self._convert_columnar_table(self.fetchmany_columnar(size))
else:
return self._convert_arrow_table(self.fetchmany_arrow(size))

def close(self) -> None:
"""
Expand Down
25 changes: 17 additions & 8 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import threading
from typing import List, Union

import pyarrow
try:
import pyarrow
except ImportError:
pyarrow = None
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -621,6 +624,7 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -726,12 +730,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)
else:
schema_bytes = None

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
Expand Down Expand Up @@ -827,7 +836,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Loading

0 comments on commit a151df2

Please sign in to comment.