Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1803] Splitting the PySql connector into the core and the non core part #443

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions databricks_sql_connector/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.3.0"
description = "Databricks SQL Connector for Python"
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
license = "Apache-2.0"


[tool.poetry.dependencies]
databricks_sql_connector_core = { version = ">=1.0.0", extras=["all"]}
databricks_sqlalchemy = { version = ">=1.0.0", optional = true }

[tool.poetry.extras]
databricks_sqlalchemy = ["databricks_sqlalchemy"]

[tool.poetry.urls]
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

26 changes: 7 additions & 19 deletions pyproject.toml → databricks_sql_connector_core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.3.0"
description = "Databricks SQL Connector for Python"
name = "databricks-sql-connector-core"
version = "1.0.0"
description = "Databricks SQL Connector core for Python"
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
license = "Apache-2.0"
readme = "README.md"
packages = [{ include = "databricks", from = "src" }]
include = ["CHANGELOG.md"]

[tool.poetry.dependencies]
python = "^3.8.0"
thrift = ">=0.16.0,<0.21.0"
pandas = [
{ version = ">=1.2.5,<2.3.0", python = ">=3.8" }
{ version = ">=1.2.5,<2.2.0", python = ">=3.8" }
]
pyarrow = ">=14.0.1,<17"

lz4 = "^4.0.2"
requests = "^2.18.1"
oauthlib = "^3.1.0"
numpy = [
{ version = "^1.16.6", python = ">=3.8,<3.11" },
{ version = "^1.23.4", python = ">=3.11" },
]
sqlalchemy = { version = ">=2.0.21", optional = true }
openpyxl = "^3.0.10"
alembic = { version = "^1.0.11", optional = true }
urllib3 = ">=1.26"
pyarrow = {version = ">=14.0.1,<17", optional = true}

[tool.poetry.extras]
sqlalchemy = ["sqlalchemy"]
alembic = ["sqlalchemy", "alembic"]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand All @@ -43,8 +33,6 @@ pytest-dotenv = "^0.5.2"
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[tool.poetry.plugins."sqlalchemy.dialects"]
"databricks" = "databricks.sqlalchemy:DatabricksDialect"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -62,5 +50,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"}
minversion = "6.0"
log_cli = "false"
log_cli_level = "INFO"
testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
testpaths = ["tests", "databricks_sql_connector_core/tests"]
env_files = ["test.env"]
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
import requests
import json
import os
Expand Down Expand Up @@ -43,6 +42,10 @@
TSparkParameter,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -977,14 +980,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1171,7 +1174,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.

Expand All @@ -1196,7 +1199,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -37,6 +36,11 @@
convert_column_based_set_to_arrow_table,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

unsafe_logger = logging.getLogger("databricks.sql.unsafe")
Expand Down Expand Up @@ -652,6 +656,12 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

if pyarrow is None:
raise ImportError(
"pyarrow is required to convert Hive schema to Arrow schema"
)

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -858,7 +868,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ssl import SSLContext

import lz4.frame
import pyarrow

from databricks.sql import OperationalError, exc
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
Expand All @@ -28,16 +27,21 @@

import logging

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int):
pass

@abstractmethod
def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self):
pass


Expand Down Expand Up @@ -100,7 +104,7 @@ def build_queue(
class ArrowQueue(ResultSetQueue):
def __init__(
self,
arrow_table: pyarrow.Table,
arrow_table: "pyarrow.Table",
n_valid_rows: int,
start_row_index: int = 0,
):
Expand All @@ -115,7 +119,7 @@ def __init__(
self.arrow_table = arrow_table
self.n_valid_rows = n_valid_rows

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""Get upto the next n rows of the Arrow dataframe"""
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
# Note that the table.slice API is not the same as Python's slice
Expand All @@ -124,7 +128,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
self.cur_row_index += slice.num_rows
return slice

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
slice = self.arrow_table.slice(
self.cur_row_index, self.n_valid_rows - self.cur_row_index
)
Expand Down Expand Up @@ -184,7 +188,7 @@ def __init__(
self.table = self._create_next_table()
self.table_row_index = 0

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""
Get up to the next n rows of the cloud fetch Arrow dataframes.

Expand Down Expand Up @@ -216,7 +220,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
return results

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
"""
Get all remaining rows of the cloud fetch Arrow dataframes.

Expand All @@ -237,7 +241,7 @@ def remaining_rows(self) -> pyarrow.Table:
self.table_row_index = 0
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
def _create_next_table(self) -> Union["pyarrow.Table", None]:
logger.debug(
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
self.start_row_index
Expand Down Expand Up @@ -276,7 +280,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:

return arrow_table

def _create_empty_table(self) -> pyarrow.Table:
def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

Expand Down Expand Up @@ -515,7 +519,7 @@ def transform_paramstyle(
return output


def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
return convert_decimals_in_arrow_table(arrow_table, description)

Expand All @@ -542,7 +546,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
return arrow_table, n_rows


def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
for i, col in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from databricks_sqlalchemy import *
except:
import warnings

warnings.warn("Install databricks-sqlalchemy plugin before using this")
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from decimal import Decimal

import pyarrow
import pytest

try:
import pyarrow
except ImportError:
pyarrow = None

class DecimalTestsMixin:
decimal_and_expected_results = [
from tests.e2e.common.predicates import pysql_supports_arrow

def decimal_and_expected_results():

if pyarrow is None:
return []

return [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
Expand All @@ -17,7 +26,12 @@ class DecimalTestsMixin:
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
def multi_decimals_and_expected_results():

if pyarrow is None:
return []

return [
(
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None],
Expand All @@ -30,7 +44,9 @@ class DecimalTestsMixin:
),
]

@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class DecimalTestsMixin:
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
def test_decimals(self, decimal, expected_value, expected_type):
with self.cursor({}) as cursor:
query = "SELECT CAST ({})".format(decimal)
Expand All @@ -39,9 +55,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
assert table.field(0).type == expected_type
assert table.to_pydict().popitem()[1][0] == expected_value

@pytest.mark.parametrize(
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
)
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
def test_multi_decimals(self, decimals, expected_values, expected_type):
with self.cursor({}) as cursor:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
import math
import time
from unittest import skipUnless

import pytest
from tests.e2e.common.predicates import pysql_supports_arrow

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
+ "assuming 10K fetch size."
)

@skipUnless(pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@


def pysql_supports_arrow():
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
from databricks.sql.client import Cursor
return hasattr(Cursor, 'fetchall_arrow')
"""Checks if the pyarrow library is installed or not"""
try:
import pyarrow

return True
except ImportError:
return False


def pysql_has_version(compare, version):
Expand Down
Loading
Loading