Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bigframes.ml.compose.SQLScalarColumnTransformer to create custom SQL-based transformations #955

Merged
merged 52 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
1b244ed
Add support for custom transformers (not ML.) in ColumnTransformer.
Sep 4, 2024
72510d4
allow numbers in Custom-Transformer-IDs.
Sep 4, 2024
1242685
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 4, 2024
43ba050
comment was moved to the end of the sql.
Sep 4, 2024
d6e7bb2
Do not offer the feedback link for missing custom transformers.
Sep 4, 2024
24e39ec
cleanup typing hints.
Sep 5, 2024
1fd25de
Add unit tests for CustomTransformer.
ferenc-hechler Sep 5, 2024
6dd9038
added unit tests for _extract_output_names() and _compile_to_sql().
ferenc-hechler Sep 5, 2024
ab9ab35
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 5, 2024
9d8d8c4
run black and flake8 linter.
ferenc-hechler Sep 5, 2024
2665bc7
fixed wrong @classmethod annotation.
ferenc-hechler Sep 5, 2024
f3a6317
Merge branch 'main' into main
ferenc-hechler Sep 9, 2024
257f819
Merge branch 'main' into main
ferenc-hechler Sep 16, 2024
10672f2
Merge branch 'main' into main
ferenc-hechler Sep 19, 2024
ffc6824
on the way to SQLScalarColumnTransformer
ferenc-hechler Sep 20, 2024
a53b51c
remove pytest.main call.
ferenc-hechler Sep 20, 2024
41ca2ff
remove CustomTransformer class and implementations.
Sep 20, 2024
9af6c6b
Merge branch 'main' into main
ferenc-hechler Sep 20, 2024
fa094f8
Merge branch 'main' into main
ferenc-hechler Sep 20, 2024
2f4f459
Merge branch 'main' into main
tswast Sep 22, 2024
5cd9672
fix typing.
ferenc-hechler Sep 22, 2024
980cd1a
fix typing.
ferenc-hechler Sep 22, 2024
e9b3ab0
fixed mock typing.
ferenc-hechler Sep 22, 2024
d858035
replace _NameClass.
Sep 23, 2024
b21b2f4
black formating.
Sep 23, 2024
acb74ab
add traget_column as input_column with a "?" prefix
Sep 23, 2024
78f964a
reformatted with black version 22.3.0.
Sep 23, 2024
7dc544a
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 23, 2024
774fef6
Merge branch 'main' into main
tswast Sep 23, 2024
e9a9410
remove eclipse project files
ferenc-hechler Sep 23, 2024
744f272
SQLScalarColumnTransformer needs not to be inherited from
ferenc-hechler Sep 23, 2024
cc9840c
remove filter for "ML." sqls in _extract_output_names() of
ferenc-hechler Sep 23, 2024
bb142f5
introduced type hint SingleColTransformer
ferenc-hechler Sep 23, 2024
4ecdec6
make sql and target_column private in SQLScalarColumnTransformer
ferenc-hechler Sep 23, 2024
a36845b
Add documentation for SQLScalarColumnTransformer.
ferenc-hechler Sep 23, 2024
1890ea2
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 24, 2024
517edf6
add first system test for SQLScalarColumnTransformer.
Sep 24, 2024
29fed20
SQLScalarColumnTransformer system tests for fit-transform and save-load
Sep 24, 2024
1b67957
make SQLScalarColumnTransformer comparable (equals) for comparing sets
Sep 24, 2024
9b12eac
implement hash and eq (copied from BaseTransformer)
Sep 24, 2024
48ecd38
undo accidentally checked in files
ferenc-hechler Sep 24, 2024
89298f9
remove eclipse settings accidentally checked in.
Sep 24, 2024
7078b6a
Merge branch 'main' into main
ferenc-hechler Sep 24, 2024
0b27eac
fix docs.
ferenc-hechler Sep 24, 2024
4ec3b0a
Update bigframes/ml/compose.py
tswast Sep 24, 2024
b81627a
Merge branch 'main' into main
tswast Sep 24, 2024
f48e27b
Update bigframes/ml/compose.py
tswast Sep 24, 2024
5ca8a81
add support for flexible column names.
Sep 25, 2024
a4ed2fd
remove main.
Sep 25, 2024
c35086d
add system test for output column with flexible column name
Sep 25, 2024
cc1b0b1
system tests: add new flexible output column to check-df-schema.
Sep 25, 2024
d82793c
Apply suggestions from code review
tswast Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 155 additions & 6 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import re
import types
import typing
from typing import cast, Iterable, List, Optional, Set, Tuple, Union
from typing import cast, Iterable, List, Optional, Set, Tuple, Union, Dict, Type
import abc
import json

from bigframes_vendored import constants
import bigframes_vendored.sklearn.compose._column_transformer
Expand All @@ -46,6 +48,114 @@
)


CUSTOM_TRANSFORMER_SQL_RX = re.compile(
"^(?P<sql>.*)/[*]CT.(?P<id>[A-Z]+[A-Z0-9]*)[(](?P<config>[^*]*)[)][*]/$",
re.IGNORECASE,
)


class CustomTransformer(base.BaseTransformer):
ferenc-hechler marked this conversation as resolved.
Show resolved Hide resolved
_CTID = None
_custom_transformer_classes = {}

@classmethod
def register(cls, transformer_cls: Type[base.BaseTransformer]):
assert transformer_cls._CTID
cls._custom_transformer_classes[transformer_cls._CTID] = transformer_cls

@classmethod
def find_matching_transformer(
cls, transform_sql: str
) -> Optional[Type[base.BaseTransformer]]:
for transform_cls in cls._custom_transformer_classes.values():
if transform_cls.understands(transform_sql):
return transform_cls
return None

@classmethod
def understands(cls, transform_sql: str) -> bool:
"""
may be overwritten to have a more advanced matching, possibly without comments in SQL
"""
m = CUSTOM_TRANSFORMER_SQL_RX.match(transform_sql)
if m and m.group("id").strip() == cls._CTID:
return True
return False

def __init__(self):
super().__init__()

def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
if columns is None:
columns = X.columns
return [
f"{self.custom_compile_to_sql(X, column)} {self._get_sql_comment(column)} AS {self.get_target_column_name(column)}"
for column in columns
]

def get_target_column_name(self, column: str) -> str:
return f"{self._CTID.lower()}_{column}"

@classmethod
@abc.abstractclassmethod
def custom_compile_to_sql(cls, X: bpd.DataFrame, column: str) -> str:
ferenc-hechler marked this conversation as resolved.
Show resolved Hide resolved
pass

def get_persistent_config(self, column: str) -> Optional[Union[Dict, List]]:
"""
return structure to be persisted in the comment of the sql
"""
return None

def _get_pc_as_args(self, column: str) -> str:
config = self.get_persistent_config(column)
if not config:
return ""
return json.dumps(config)

def _get_sql_comment(self, column: str) -> str:
args = self._get_pc_as_args(column)
return f"/*CT.{self._CTID}({args})*/"

@classmethod
def _parse_from_sql(cls, transform_sql: str) -> Tuple[base.BaseTransformer, str]:
m = CUSTOM_TRANSFORMER_SQL_RX.match(transform_sql)
if m and m.group("id").strip() != cls._CTID:
raise ValueError("understand() does not match _parse_from_sql!")
args = m.group("config").strip()
if args != "":
config = json.loads(args)
else:
config = None
sql = m.group("sql").strip()
return cls.custom_parse_from_sql(config, sql)

@classmethod
@abc.abstractclassmethod
def custom_parse_from_sql(
ferenc-hechler marked this conversation as resolved.
Show resolved Hide resolved
cls, config: Optional[Union[Dict, List]], sql: str
) -> Tuple[base.BaseTransformer, str]:
"""
return transformer instance and column name
"""
pass

def _keys(self):
return ()

# CustomTransformers are thought to be used inside a column transformer.
# So there is no need to implement fit() and transform() directly.
# ColumnTransformer.merge() takes care, that a single custom transformer
# is not returned as a standalone transformer.
def fit(self, y: Union[bpd.DataFrame, bpd.Series]) -> base.BaseTransformer:
raise NotImplementedError("Unsupported")

def transform(self, y: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
raise NotImplementedError("Unsupported")


@log_adapter.class_logger
class ColumnTransformer(
base.Transformer,
Expand Down Expand Up @@ -130,10 +240,7 @@ def camel_to_snake(name):
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
continue

output_names.append(transform_col_dict["name"])
found_transformer = False
for prefix in _BQML_TRANSFROM_TYPE_MAPPING:
if transform_sql.startswith(prefix):
Expand All @@ -147,10 +254,30 @@ def camel_to_snake(name):

found_transformer = True
break

if not found_transformer:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
transformer_cls = CustomTransformer.find_matching_transformer(
transform_sql
)
if transformer_cls:
transformers_set.add(
(
camel_to_snake(transformer_cls.__name__),
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
)
)
found_transformer = True

if not found_transformer:
if not transform_sql.startswith("ML.") and "/*CT." not in transform_sql:
continue # ignore other patterns, only report unhandled known patterns
if transform_sql.startswith("ML."):
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
)
raise ValueError("Missing custom transformer")
ferenc-hechler marked this conversation as resolved.
Show resolved Hide resolved

output_names.append(transform_col_dict["name"])

transformer = cls(transformers=list(transformers_set))
transformer._output_names = output_names
Expand All @@ -167,6 +294,8 @@ def _merge(

assert len(transformers) > 0
_, transformer_0, column_0 = transformers[0]
if isinstance(transformer_0, CustomTransformer):
return self # CustomTransformers only work inside ColumnTransformer
feature_columns_sorted = sorted(
[
cast(str, feature_column.name)
Expand Down Expand Up @@ -234,6 +363,26 @@ def fit(
self._extract_output_names()
return self

# Overwrite the implementation in BaseTransformer, as it only supports the "ML." transformers.
def _extract_output_names(self):
"""Extract transform output column names. Save the results to self._output_names."""
assert self._bqml_model is not None

output_names = []
for transform_col in self._bqml_model._model._properties["transformColumns"]:
transform_col_dict = cast(dict, transform_col)
# pass the columns that are not transformed
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
if not CustomTransformer.find_matching_transformer(transform_sql):
continue

output_names.append(transform_col_dict["name"])

self._output_names = output_names

def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("Must be fitted before transform")
Expand Down
91 changes: 91 additions & 0 deletions tests/unit/ml/compose_custom_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import bigframes.pandas as bpd
from bigframes.ml.compose import CustomTransformer
from typing import List, Optional, Union, Dict
import re


class IdentityTransformer(CustomTransformer):
_CTID = "IDENT"
IDENT_BQSQL_RX = re.compile("^(?P<colname>[a-z][a-z0-9_]+)$", flags=re.IGNORECASE)

def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
return f"{column}"

@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
col_label = cls.IDENT_BQSQL_RX.match(sql).group("colname")
return cls(), col_label


CustomTransformer.register(IdentityTransformer)


class Length1Transformer(CustomTransformer):
_CTID = "LEN1"
_DEFAULT_VALUE_DEFAULT = -1
LEN1_BQSQL_RX = re.compile(
"^CASE WHEN (?P<colname>[a-z][a-z0-9_]*) IS NULL THEN (?P<defaultvalue>[-]?[0-9]+) ELSE LENGTH[(](?P=colname)[)] END$",
flags=re.IGNORECASE,
)

def __init__(self, default_value: Optional[int] = None):
self.default_value = default_value

def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
default_value = (
self.default_value
if self.default_value is not None
else Length1Transformer._DEFAULT_VALUE_DEFAULT
)
return (
f"CASE WHEN {column} IS NULL THEN {default_value} ELSE LENGTH({column}) END"
)

@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
m = cls.LEN1_BQSQL_RX.match(sql)
col_label = m.group("colname")
default_value = int(m.group("defaultvalue"))
return cls(default_value), col_label


CustomTransformer.register(Length1Transformer)


class Length2Transformer(CustomTransformer):
_CTID = "LEN2"
_DEFAULT_VALUE_DEFAULT = -1
LEN2_BQSQL_RX = re.compile(
"^CASE WHEN (?P<colname>[a-z][a-z0-9_]*) .*$", flags=re.IGNORECASE
)

def __init__(self, default_value: Optional[int] = None):
self.default_value = default_value

def get_persistent_config(self, column: str) -> Optional[Union[Dict, List]]:
return [self.default_value]

def custom_compile_to_sql(self, X: bpd.DataFrame, column: str) -> str:
default_value = (
self.default_value
if self.default_value is not None
else Length2Transformer._DEFAULT_VALUE_DEFAULT
)
return (
f"CASE WHEN {column} IS NULL THEN {default_value} ELSE LENGTH({column}) END"
)

@classmethod
def custom_parse_from_sql(
cls, config: Optional[Union[Dict, List]], sql: str
) -> tuple[CustomTransformer, str]:
col_label = cls.LEN2_BQSQL_RX.match(sql).group("colname")
default_value = config[0] # get default value from persistent_config
return cls(default_value), col_label


CustomTransformer.register(Length2Transformer)
Loading