Skip to content

Commit

Permalink
SQLAlchemy 2: add type compilation for uppercase types (databricks#240)
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
Jesse authored Oct 2, 2023
1 parent 74f4126 commit efc0337
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 77 deletions.
128 changes: 78 additions & 50 deletions src/databricks/sqlalchemy/test_local/test_types.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
import enum

import pytest
from sqlalchemy.types import (
BigInteger,
Boolean,
Date,
DateTime,
Double,
Enum,
Float,
Integer,
Interval,
LargeBinary,
MatchType,
Numeric,
PickleType,
SchemaType,
SmallInteger,
String,
Text,
Time,
TypeEngine,
Unicode,
UnicodeText,
Uuid,
)
import sqlalchemy

from databricks.sqlalchemy import DatabricksDialect

Expand Down Expand Up @@ -55,43 +32,49 @@ class DatabricksDataType(enum.Enum):
# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
# Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that.
camel_case_type_map = {
BigInteger: DatabricksDataType.BIGINT,
LargeBinary: DatabricksDataType.BINARY,
Boolean: DatabricksDataType.BOOLEAN,
Date: DatabricksDataType.DATE,
DateTime: DatabricksDataType.TIMESTAMP,
Double: DatabricksDataType.DOUBLE,
Enum: DatabricksDataType.STRING,
Float: DatabricksDataType.FLOAT,
Integer: DatabricksDataType.INT,
Interval: DatabricksDataType.TIMESTAMP,
Numeric: DatabricksDataType.DECIMAL,
PickleType: DatabricksDataType.BINARY,
SmallInteger: DatabricksDataType.SMALLINT,
String: DatabricksDataType.STRING,
Text: DatabricksDataType.STRING,
Time: DatabricksDataType.STRING,
Unicode: DatabricksDataType.STRING,
UnicodeText: DatabricksDataType.STRING,
Uuid: DatabricksDataType.STRING,
sqlalchemy.types.BigInteger: DatabricksDataType.BIGINT,
sqlalchemy.types.LargeBinary: DatabricksDataType.BINARY,
sqlalchemy.types.Boolean: DatabricksDataType.BOOLEAN,
sqlalchemy.types.Date: DatabricksDataType.DATE,
sqlalchemy.types.DateTime: DatabricksDataType.TIMESTAMP,
sqlalchemy.types.Double: DatabricksDataType.DOUBLE,
sqlalchemy.types.Enum: DatabricksDataType.STRING,
sqlalchemy.types.Float: DatabricksDataType.FLOAT,
sqlalchemy.types.Integer: DatabricksDataType.INT,
sqlalchemy.types.Interval: DatabricksDataType.TIMESTAMP,
sqlalchemy.types.Numeric: DatabricksDataType.DECIMAL,
sqlalchemy.types.PickleType: DatabricksDataType.BINARY,
sqlalchemy.types.SmallInteger: DatabricksDataType.SMALLINT,
sqlalchemy.types.String: DatabricksDataType.STRING,
sqlalchemy.types.Text: DatabricksDataType.STRING,
sqlalchemy.types.Time: DatabricksDataType.STRING,
sqlalchemy.types.Unicode: DatabricksDataType.STRING,
sqlalchemy.types.UnicodeText: DatabricksDataType.STRING,
sqlalchemy.types.Uuid: DatabricksDataType.STRING,
}

# Convert the dictionary into a list of tuples for use in pytest.mark.parametrize
_as_tuple_list = [(key, value) for key, value in camel_case_type_map.items()]

def dict_as_tuple_list(d: dict):
"""Return a list of [(key, value), ...] from a dictionary."""
return [(key, value) for key, value in d.items()]


class CompilationTestBase:
dialect = DatabricksDialect()

def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType):
def _assert_compiled_value(
self, type_: sqlalchemy.types.TypeEngine, expected: DatabricksDataType
):
"""Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name.
This method initialises the type_ with no arguments.
"""
compiled_result = type_().compile(dialect=self.dialect) # type: ignore
assert compiled_result == expected.name

def _assert_compiled_value_explicit(self, type_: TypeEngine, expected: str):
def _assert_compiled_value_explicit(
self, type_: sqlalchemy.types.TypeEngine, expected: str
):
"""Assert that when type_ is compiled for the databricks dialect, it renders the expected string.
This method expects an initialised type_ so that we can test how a TypeEngine created with arguments
Expand All @@ -117,12 +100,57 @@ class TestCamelCaseTypesCompilation(CompilationTestBase):
[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types
"""

@pytest.mark.parametrize("type_, expected", _as_tuple_list)
@pytest.mark.parametrize("type_, expected", dict_as_tuple_list(camel_case_type_map))
def test_bare_camel_case_types_compile(self, type_, expected):
self._assert_compiled_value(type_, expected)

def test_numeric_renders_as_decimal_with_precision(self):
self._assert_compiled_value_explicit(Numeric(10), "DECIMAL(10)")
self._assert_compiled_value_explicit(
sqlalchemy.types.Numeric(10), "DECIMAL(10)"
)

def test_numeric_renders_as_decimal_with_precision_and_scale(self):
self._assert_compiled_value_explicit(Numeric(10, 2), "DECIMAL(10, 2)")
return self._assert_compiled_value_explicit(
sqlalchemy.types.Numeric(10, 2), "DECIMAL(10, 2)"
)


uppercase_type_map = {
sqlalchemy.types.ARRAY: DatabricksDataType.ARRAY,
sqlalchemy.types.BIGINT: DatabricksDataType.BIGINT,
sqlalchemy.types.BINARY: DatabricksDataType.BINARY,
sqlalchemy.types.BOOLEAN: DatabricksDataType.BOOLEAN,
sqlalchemy.types.DATE: DatabricksDataType.DATE,
sqlalchemy.types.DECIMAL: DatabricksDataType.DECIMAL,
sqlalchemy.types.DOUBLE: DatabricksDataType.DOUBLE,
sqlalchemy.types.FLOAT: DatabricksDataType.FLOAT,
sqlalchemy.types.INT: DatabricksDataType.INT,
sqlalchemy.types.SMALLINT: DatabricksDataType.SMALLINT,
sqlalchemy.types.TIMESTAMP: DatabricksDataType.TIMESTAMP,
}


class TestUppercaseTypesCompilation(CompilationTestBase):
"""Per the sqlalchemy documentation[^1], uppercase types are considered to be specific to some
database backends. These tests verify that the types compile into valid Databricks SQL type strings.
[1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#backend-specific-uppercase-datatypes
"""

@pytest.mark.parametrize("type_, expected", dict_as_tuple_list(uppercase_type_map))
def test_bare_uppercase_types_compile(self, type_, expected):
if isinstance(type_, type(sqlalchemy.types.ARRAY)):
# ARRAY cannot be initialised without passing an item definition so we test separately
# I preserve it in the uppercase_type_map for clarity
return True
return self._assert_compiled_value(type_, expected)

def test_array_string_renders_as_array_of_string(self):
"""SQLAlchemy's ARRAY type requires an item definition. And their docs indicate that they've only tested
it with Postgres since that's the only first-class dialect with support for ARRAY.
https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY
"""
return self._assert_compiled_value_explicit(
sqlalchemy.types.ARRAY(sqlalchemy.types.String), "ARRAY<STRING>"
)
59 changes: 32 additions & 27 deletions src/databricks/sqlalchemy/types.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
import sqlalchemy
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import GenericTypeCompiler
from sqlalchemy.types import (
DateTime,
Enum,
Integer,
LargeBinary,
Numeric,
String,
Text,
Time,
Unicode,
UnicodeText,
Uuid,
)


@compiles(Enum, "databricks")
@compiles(String, "databricks")
@compiles(Text, "databricks")
@compiles(Time, "databricks")
@compiles(Unicode, "databricks")
@compiles(UnicodeText, "databricks")
@compiles(Uuid, "databricks")


@compiles(sqlalchemy.types.Enum, "databricks")
@compiles(sqlalchemy.types.String, "databricks")
@compiles(sqlalchemy.types.Text, "databricks")
@compiles(sqlalchemy.types.Time, "databricks")
@compiles(sqlalchemy.types.Unicode, "databricks")
@compiles(sqlalchemy.types.UnicodeText, "databricks")
@compiles(sqlalchemy.types.Uuid, "databricks")
def compile_string_databricks(type_, compiler, **kw):
"""
We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy
Expand All @@ -40,23 +27,23 @@ def compile_string_databricks(type_, compiler, **kw):
return "STRING"


@compiles(Integer, "databricks")
@compiles(sqlalchemy.types.Integer, "databricks")
def compile_integer_databricks(type_, compiler, **kw):
"""
We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER"
"""
return "INT"


@compiles(LargeBinary, "databricks")
@compiles(sqlalchemy.types.LargeBinary, "databricks")
def compile_binary_databricks(type_, compiler, **kw):
"""
We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB"
"""
return "BINARY"


@compiles(Numeric, "databricks")
@compiles(sqlalchemy.types.Numeric, "databricks")
def compile_numeric_databricks(type_, compiler, **kw):
"""
We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC"
Expand All @@ -67,9 +54,27 @@ def compile_numeric_databricks(type_, compiler, **kw):
return compiler.visit_DECIMAL(type_, **kw)


@compiles(DateTime, "databricks")
@compiles(sqlalchemy.types.DateTime, "databricks")
def compile_datetime_databricks(type_, compiler, **kw):
"""
We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME"
"""
return "TIMESTAMP"


@compiles(sqlalchemy.types.ARRAY, "databricks")
def compile_array_databricks(type_, compiler, **kw):
"""
SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql.
The Postgres implementation works for Databricks SQL, so we duplicate that here.
:type_:
This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute
which is itself an instance of TypeEngine
https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY
"""

inner = compiler.process(type_.item_type, **kw)

return f"ARRAY<{inner}>"

0 comments on commit efc0337

Please sign in to comment.