From a151df2b925aa4b7800919b3eefd57452af3f608 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 3 Oct 2024 18:21:15 +0530 Subject: [PATCH] [PECO-1926] Create a non pyarrow flow to handle small results for the column set (#440) * Implemented the columnar flow for non arrow users * Minor fixes * Introduced the Column Table structure * Added test for the new column table * Minor fix * Removed unnecessory fikes --- src/databricks/sql/client.py | 96 +++++++++++++++-- src/databricks/sql/thrift_backend.py | 25 +++-- src/databricks/sql/utils.py | 152 ++++++++++++++++++++++++--- tests/unit/test_column_queue.py | 22 ++++ 4 files changed, 263 insertions(+), 32 deletions(-) create mode 100644 tests/unit/test_column_queue.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index addc340e..4df67a08 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,7 +1,10 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import requests import json import os @@ -22,6 +25,8 @@ ParamEscaper, inject_parameters, transform_paramstyle, + ColumnTable, + ColumnQueue ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]: else: raise Error("There is no active result set") - def fetchall_arrow(self) -> pyarrow.Table: + def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: raise Error("There is no active result set") - def fetchmany_arrow(self, size) -> pyarrow.Table: + def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + def _convert_arrow_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index - def fetchmany_arrow(self, size: int) -> pyarrow.Table: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1210,7 +1227,46 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table: return results - def fetchall_arrow(self) -> pyarrow.Table: + def merge_columnar(self, result1, result2): + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -1223,12 +1279,30 @@ def fetchall_arrow(self) -> pyarrow.Table: return results + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + if len(res) > 0: return res[0] else: @@ -1238,7 +1312,10 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - return self._convert_arrow_table(self.fetchall_arrow()) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) def fetchmany(self, size: int) -> List[Row]: """ @@ -1246,7 +1323,10 @@ def fetchmany(self, size: int) -> List[Row]: An empty sequence is returned when no more rows are available. """ - return self._convert_arrow_table(self.fetchmany_arrow(size)) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) def close(self) -> None: """ diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e89bff26..7f6ada9d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -7,7 +7,10 @@ import threading from typing import List, Union -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket @@ -621,6 +624,7 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): + def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -726,12 +730,17 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) - schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() - ) + + if pyarrow: + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) + else: + schema_bytes = None + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: @@ -827,7 +836,7 @@ def execute_command( getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), - canReadArrowResult=True, + canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, confOverlay={ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2807bd2b..97df6d4d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytz import datetime import decimal from abc import ABC, abstractmethod @@ -11,7 +12,10 @@ import re import lz4.frame -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None from databricks.sql import OperationalError, exc from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -33,11 +37,11 @@ class ResultSetQueue(ABC): @abstractmethod - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int): pass @abstractmethod - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self): pass @@ -76,13 +80,15 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: - arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( + column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description + + converted_column_table = convert_to_assigned_datatypes_in_column_table( + column_table, description ) - return ArrowQueue(converted_arrow_table, n_valid_rows) + + return ColumnQueue(ColumnTable(converted_column_table, column_names)) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( schema_bytes=arrow_schema_bytes, @@ -96,11 +102,55 @@ def build_queue( else: raise AssertionError("Row set type is not valid") +class ColumnTable: + def __init__(self, column_table, column_names): + self.column_table = column_table + self.column_names = column_names + + @property + def num_rows(self): + if len(self.column_table) == 0: + return 0 + else: + return len(self.column_table[0]) + + @property + def num_columns(self): + return len(self.column_names) + + def get_item(self, col_index, row_index): + return self.column_table[col_index][row_index] + + def slice(self, curr_index, length): + sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] + return ColumnTable(sliced_column_table, self.column_names) + + def __eq__(self, other): + return self.column_table == other.column_table and self.column_names == other.column_names + +class ColumnQueue(ResultSetQueue): + def __init__(self, column_table: ColumnTable): + self.column_table = column_table + self.cur_row_index = 0 + self.n_valid_rows = column_table.num_rows + + def next_n_rows(self, num_rows): + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + + slice = self.column_table.slice(self.cur_row_index, length) + self.cur_row_index += slice.num_rows + return slice + + def remaining_rows(self): + slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) + self.cur_row_index += slice.num_rows + return slice + class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: pyarrow.Table, + arrow_table: "pyarrow.Table", n_valid_rows: int, start_row_index: int = 0, ): @@ -115,7 +165,7 @@ def __init__( self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice @@ -124,7 +174,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: self.cur_row_index += slice.num_rows return slice - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": slice = self.arrow_table.slice( self.cur_row_index, self.n_valid_rows - self.cur_row_index ) @@ -184,7 +234,7 @@ def __init__( self.table = self._create_next_table() self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -216,7 +266,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": """ Get all remaining rows of the cloud fetch Arrow dataframes. @@ -237,7 +287,7 @@ def remaining_rows(self) -> pyarrow.Table: self.table_row_index = 0 return results - def _create_next_table(self) -> Union[pyarrow.Table, None]: + def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -276,7 +326,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]: return arrow_table - def _create_empty_table(self) -> pyarrow.Table: + def _create_empty_table(self) -> "pyarrow.Table": # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -515,7 +565,7 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -542,7 +592,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema return arrow_table, n_rows -def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: +def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": for i, col in enumerate(table.itercolumns()): if description[i][1] == "decimal": decimal_col = col.to_pandas().apply( @@ -560,6 +610,33 @@ def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: return table +def convert_to_assigned_datatypes_in_column_table(column_table, description): + + converted_column_table = [] + for i, col in enumerate(column_table): + if description[i][1] == "decimal": + converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) + elif description[i][1] == "date": + converted_column_table.append(tuple( + v if v is None else datetime.date.fromisoformat(v) for v in col + )) + elif description[i][1] == "timestamp": + converted_column_table.append(tuple( + ( + v + if v is None + else datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S.%f").replace( + tzinfo=pytz.UTC + ) + ) + for v in col + )) + else: + converted_column_table.append(col) + + return converted_column_table + + def convert_column_based_set_to_arrow_table(columns, description): arrow_table = pyarrow.Table.from_arrays( [_convert_column_to_arrow_array(c) for c in columns], @@ -571,6 +648,13 @@ def convert_column_based_set_to_arrow_table(columns, description): return arrow_table, arrow_table.num_rows +def convert_column_based_set_to_column_table(columns, description): + column_names = [c[0] for c in description] + column_table = [_convert_column_to_list(c) for c in columns] + + return column_table, column_names + + def _convert_column_to_arrow_array(t_col): """ Return a pyarrow array from the values in a TColumn instance. @@ -595,6 +679,26 @@ def _convert_column_to_arrow_array(t_col): raise OperationalError("Empty TColumn instance {}".format(t_col)) +def _convert_column_to_list(t_col): + SUPPORTED_FIELD_TYPES = ( + "boolVal", + "byteVal", + "i16Val", + "i32Val", + "i64Val", + "doubleVal", + "stringVal", + "binaryVal", + ) + + for field in SUPPORTED_FIELD_TYPES: + wrapper = getattr(t_col, field) + if wrapper: + return _create_python_tuple(wrapper) + + raise OperationalError("Empty TColumn instance {}".format(t_col)) + + def _create_arrow_array(t_col_value_wrapper, arrow_type): result = t_col_value_wrapper.values nulls = t_col_value_wrapper.nulls # bitfield describing which values are null @@ -609,3 +713,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type): result[i] = None return pyarrow.array(result, type=arrow_type) + + +def _create_python_tuple(t_col_value_wrapper): + result = t_col_value_wrapper.values + nulls = t_col_value_wrapper.nulls # bitfield describing which values are null + assert isinstance(nulls, bytes) + + # The number of bits in nulls can be both larger or smaller than the number of + # elements in result, so take the minimum of both to iterate over. + length = min(len(result), len(nulls) * 8) + + for i in range(length): + if nulls[i >> 3] & BIT_MASKS[i & 0x7]: + result[i] = None + + return tuple(result) \ No newline at end of file diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py new file mode 100644 index 00000000..130b589b --- /dev/null +++ b/tests/unit/test_column_queue.py @@ -0,0 +1,22 @@ +from databricks.sql.utils import ColumnQueue, ColumnTable + + +class TestColumnQueueSuite: + @staticmethod + def make_column_table(table): + n_cols = len(table) if table else 0 + return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) + + def test_fetchmany_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) + + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.next_n_rows(2) == column_table.slice(2, 2) + + def test_fetch_remaining_rows_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) + + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.remaining_rows() == column_table.slice(2, 2)