Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bigframes.ml.compose.SQLScalarColumnTransformer to create custom SQL-based transformations #955

Merged
merged 52 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
1b244ed
Add support for custom transformers (not ML.) in ColumnTransformer.
Sep 4, 2024
72510d4
allow numbers in Custom-Transformer-IDs.
Sep 4, 2024
1242685
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 4, 2024
43ba050
comment was moved to the end of the sql.
Sep 4, 2024
d6e7bb2
Do not offer the feedback link for missing custom transformers.
Sep 4, 2024
24e39ec
cleanup typing hints.
Sep 5, 2024
1fd25de
Add unit tests for CustomTransformer.
ferenc-hechler Sep 5, 2024
6dd9038
added unit tests for _extract_output_names() and _compile_to_sql().
ferenc-hechler Sep 5, 2024
ab9ab35
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 5, 2024
9d8d8c4
run black and flake8 linter.
ferenc-hechler Sep 5, 2024
2665bc7
fixed wrong @classmethod annotation.
ferenc-hechler Sep 5, 2024
f3a6317
Merge branch 'main' into main
ferenc-hechler Sep 9, 2024
257f819
Merge branch 'main' into main
ferenc-hechler Sep 16, 2024
10672f2
Merge branch 'main' into main
ferenc-hechler Sep 19, 2024
ffc6824
on the way to SQLScalarColumnTransformer
ferenc-hechler Sep 20, 2024
a53b51c
remove pytest.main call.
ferenc-hechler Sep 20, 2024
41ca2ff
remove CustomTransformer class and implementations.
Sep 20, 2024
9af6c6b
Merge branch 'main' into main
ferenc-hechler Sep 20, 2024
fa094f8
Merge branch 'main' into main
ferenc-hechler Sep 20, 2024
2f4f459
Merge branch 'main' into main
tswast Sep 22, 2024
5cd9672
fix typing.
ferenc-hechler Sep 22, 2024
980cd1a
fix typing.
ferenc-hechler Sep 22, 2024
e9b3ab0
fixed mock typing.
ferenc-hechler Sep 22, 2024
d858035
replace _NameClass.
Sep 23, 2024
b21b2f4
black formating.
Sep 23, 2024
acb74ab
add traget_column as input_column with a "?" prefix
Sep 23, 2024
78f964a
reformatted with black version 22.3.0.
Sep 23, 2024
7dc544a
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 23, 2024
774fef6
Merge branch 'main' into main
tswast Sep 23, 2024
e9a9410
remove eclipse project files
ferenc-hechler Sep 23, 2024
744f272
SQLScalarColumnTransformer needs not to be inherited from
ferenc-hechler Sep 23, 2024
cc9840c
remove filter for "ML." sqls in _extract_output_names() of
ferenc-hechler Sep 23, 2024
bb142f5
introduced type hint SingleColTransformer
ferenc-hechler Sep 23, 2024
4ecdec6
make sql and target_column private in SQLScalarColumnTransformer
ferenc-hechler Sep 23, 2024
a36845b
Add documentation for SQLScalarColumnTransformer.
ferenc-hechler Sep 23, 2024
1890ea2
Merge branch 'googleapis:main' into main
ferenc-hechler Sep 24, 2024
517edf6
add first system test for SQLScalarColumnTransformer.
Sep 24, 2024
29fed20
SQLScalarColumnTransformer system tests for fit-transform and save-load
Sep 24, 2024
1b67957
make SQLScalarColumnTransformer comparable (equals) for comparing sets
Sep 24, 2024
9b12eac
implement hash and eq (copied from BaseTransformer)
Sep 24, 2024
48ecd38
undo accidentally checked in files
ferenc-hechler Sep 24, 2024
89298f9
remove eclipse settings accidentally checked in.
Sep 24, 2024
7078b6a
Merge branch 'main' into main
ferenc-hechler Sep 24, 2024
0b27eac
fix docs.
ferenc-hechler Sep 24, 2024
4ec3b0a
Update bigframes/ml/compose.py
tswast Sep 24, 2024
b81627a
Merge branch 'main' into main
tswast Sep 24, 2024
f48e27b
Update bigframes/ml/compose.py
tswast Sep 24, 2024
5ca8a81
add support for flexible column names.
Sep 25, 2024
a4ed2fd
remove main.
Sep 25, 2024
c35086d
add system test for output column with flexible column name
Sep 25, 2024
cc1b0b1
system tests: add new flexible output column to check-df-schema.
Sep 25, 2024
d82793c
Apply suggestions from code review
tswast Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,6 @@ def _extract_output_names(self):
# pass the columns that are not transformed
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
continue

output_names.append(transform_col_dict["name"])

self._output_names = output_names
Expand Down
133 changes: 123 additions & 10 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,101 @@
)


class SQLScalarColumnTransformer:
r"""
Wrapper for plain SQL code contained in a ColumnTransformer.

Create a single column transformer in plain sql.
This transformer can only be used inside ColumnTransformer.

When creating an instance '{0}' can be used as placeholder
for the column to transform:

SQLScalarColumnTransformer("{0}+1")

The default target column gets the prefix 'transformed\_'
tswast marked this conversation as resolved.
Show resolved Hide resolved
but can also be changed when creating an instance:

SQLScalarColumnTransformer("{0}+1", "inc_{0}")

**Examples:**

>>> from bigframes.ml.compose import ColumnTransformer, SQLScalarColumnTransformer
>>> import bigframes.pandas as bpd
<BLANKLINE>
>>> df = bpd.DataFrame({'name': ["James", None, "Mary"], 'city': ["New York", "Boston", None]})
>>> col_trans = ColumnTransformer([
... ("strlen",
... SQLScalarColumnTransformer("CASE WHEN {0} IS NULL THEN 15 ELSE LENGTH({0}) END"),
... ['name', 'city']),
... ])
>>> col_trans = col_trans.fit(df)
>>> df_transformed = col_trans.transform(df)
>>> df_transformed
transformed_name transformed_city
0 5 8
1 15 6
2 4 15
<BLANKLINE>
[3 rows x 2 columns]

SQLScalarColumnTransformer can be combined with other transformers, like StandardScaler:

>>> col_trans = ColumnTransformer([
... ("identity", SQLScalarColumnTransformer("{0}", target_column="{0}"), ["col1", col5"]),
tswast marked this conversation as resolved.
Show resolved Hide resolved
... ("increment", SQLScalarColumnTransformer("{0}+1", target_column="inc_{0}"), "col2"),
... ("stdscale", preprocessing.StandardScaler(), "col3"),
... ...
tswast marked this conversation as resolved.
Show resolved Hide resolved
... ])

"""

def __init__(self, sql: str, target_column: str = "transformed_{0}"):
super().__init__()
self._sql = sql
self._target_column = target_column.replace("`", "")

PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE)

def escape(self, colname: str):
colname = colname.replace("`", "")
if self.PLAIN_COLNAME_RX.match(colname):
return colname
return f"`{colname}`"

def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
if columns is None:
columns = X.columns
result = []
for column in columns:
current_sql = self._sql.format(self.escape(column))
current_target_column = self.escape(self._target_column.format(column))
result.append(f"{current_sql} AS {current_target_column}")
return result

def __repr__(self):
return f"SQLScalarColumnTransformer(sql='{self._sql}', target_column='{self._target_column}')"

def __eq__(self, other) -> bool:
return type(self) is type(other) and self._keys() == other._keys()

def __hash__(self) -> int:
return hash(self._keys())

def _keys(self):
return (self._sql, self._target_column)


# Type hints for transformers contained in ColumnTransformer
SingleColTransformer = Union[
preprocessing.PreprocessingType,
impute.SimpleImputer,
SQLScalarColumnTransformer,
]


@log_adapter.class_logger
class ColumnTransformer(
base.Transformer,
Expand All @@ -60,7 +155,7 @@ def __init__(
transformers: Iterable[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
Union[str, Iterable[str]],
]
],
Expand All @@ -78,14 +173,12 @@ def _keys(self):
@property
def transformers_(
self,
) -> List[
Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str]
]:
) -> List[Tuple[str, SingleColTransformer, str,]]:
"""The collection of transformers as tuples of (name, transformer, column)."""
result: List[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
str,
]
] = []
Expand All @@ -103,6 +196,8 @@ def transformers_(

return result

AS_FLEXNAME_SUFFIX_RX = re.compile("^(.*)\\bAS\\s*`[^`]+`\\s*$", re.IGNORECASE)

@classmethod
def _extract_from_bq_model(
cls,
Expand All @@ -114,7 +209,7 @@ def _extract_from_bq_model(
transformers_set: Set[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
Union[str, List[str]],
]
] = set()
Expand All @@ -130,8 +225,10 @@ def camel_to_snake(name):
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
continue

# workaround for bug in bq_model returning " AS `...`" suffix for flexible names
if cls.AS_FLEXNAME_SUFFIX_RX.match(transform_sql):
transform_sql = cls.AS_FLEXNAME_SUFFIX_RX.match(transform_sql).group(1)
tswast marked this conversation as resolved.
Show resolved Hide resolved

output_names.append(transform_col_dict["name"])
found_transformer = False
Expand All @@ -148,8 +245,22 @@ def camel_to_snake(name):
found_transformer = True
break
if not found_transformer:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
if transform_sql.startswith("ML."):
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
)

target_column = transform_col_dict["name"]
sql_transformer = SQLScalarColumnTransformer(
transform_sql, target_column=target_column
)
input_column_name = f"?{target_column}"
transformers_set.add(
(
camel_to_snake(sql_transformer.__class__.__name__),
sql_transformer,
input_column_name,
)
)

transformer = cls(transformers=list(transformers_set))
Expand All @@ -167,6 +278,8 @@ def _merge(

assert len(transformers) > 0
_, transformer_0, column_0 = transformers[0]
if isinstance(transformer_0, SQLScalarColumnTransformer):
return self # SQLScalarColumnTransformer only work inside ColumnTransformer
feature_columns_sorted = sorted(
[
cast(str, feature_column.name)
Expand Down
103 changes: 103 additions & 0 deletions tests/system/large/ml/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ def test_columntransformer_standalone_fit_and_transform(
preprocessing.MinMaxScaler(),
["culmen_length_mm"],
),
(
"increment",
compose.SQLScalarColumnTransformer("{0}+1"),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"ohe",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0}='Adelie Penguin (Pygoscelis adeliae)' THEN 1 ELSE 0 END",
target_column="ohe_adelie",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
]
)

Expand All @@ -51,6 +77,12 @@ def test_columntransformer_standalone_fit_and_transform(
"standard_scaled_culmen_length_mm",
"min_max_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"transformed_culmen_length_mm",
"transformed_flipper_length_mm",
"len_species",
"ohe_adelie",
"culmen_length_mm",
"flipper_length_mm",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand All @@ -70,6 +102,19 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
]
)

Expand All @@ -83,6 +128,9 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
"onehotencoded_species",
"standard_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"len_species",
"culmen_length_mm",
"flipper_length_mm",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand All @@ -102,6 +150,27 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
(
"flexname",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="Flex {0} Name",
),
"species",
),
]
)
transformer.fit(
Expand All @@ -122,6 +191,36 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
),
("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"),
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END",
target_column="len_species",
),
"?len_species",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"flipper_length_mm", target_column="flipper_length_mm"
),
"?flipper_length_mm",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"culmen_length_mm", target_column="culmen_length_mm"
),
"?culmen_length_mm",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END ",
target_column="Flex species Name",
),
"?Flex species Name",
),
]
assert set(reloaded_transformer.transformers) == set(expected)
assert reloaded_transformer._bqml_model is not None
Expand All @@ -136,6 +235,10 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
"onehotencoded_species",
"standard_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"len_species",
"culmen_length_mm",
"flipper_length_mm",
"Flex species Name",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand Down
Loading