Skip to content

Commit

Permalink
Add serialization test and fix NumUniqueSeparators (#122)
Browse files Browse the repository at this point in the history
* add serialization test and fix NumUniqueSeparators

* update release notes

* skip CFM with Elmo

* reorder test calls

* lint fix
  • Loading branch information
thehomebrewnerd authored Apr 7, 2022
1 parent e6cb726 commit ee6dc7b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Future Release
==============
* Enhancements
* Fixes
* Fix ``NumUniqueSeparators`` to allow for serialization and deserialization (:pr:`122`)
* Changes
* Speed up LSA primitive initialization (:pr:`118`)
* Documentation Changes
Expand Down
16 changes: 8 additions & 8 deletions nlp_primitives/num_unique_separators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import IntegerNullable, NaturalLanguage

NATURAL_LANGUAGE_SEPARATORS = " .,!?;\n"
NATURAL_LANGUAGE_SEPARATORS = [" ", ".", ",", "!", "?", ";", "\n"]


class NumUniqueSeparators(TransformPrimitive):
"""Calculates the number of unique separators.
Description:
Given a string and an iterable of separators, determine
Given a string and a list of separators, determine
the number of unique separators in each string. If a string
is null determined by pd.isnull return pd.NA.
Args:
separators (str, optional): an iterable of characters to count.
" .,!?;\n" is used by default.
separators (list, optional): a list of separator characters to count.
`[`" ", ".", ",", "!", "?", ";", "\n"]` is used by default.
Examples:
>>> x = ['First. Line.', 'This. is the second, line!', 'notinlist@#$%^%&']
>>> num_unique_separators = NumUniqueSeparators(".,!")
>>> x = ["First. Line.", "This. is the second, line!", "notinlist@#$%^%&"]
>>> num_unique_separators = NumUniqueSeparators([".", ",", "!"])
>>> num_unique_separators(x).tolist()
[1, 3, 0]
"""
Expand All @@ -31,13 +31,13 @@ class NumUniqueSeparators(TransformPrimitive):

def __init__(self, separators=NATURAL_LANGUAGE_SEPARATORS):
assert separators is not None, "separators needs to be defined"
self.separators = set(separators)
self.separators = separators

def get_function(self):
def count_unique_separator(s):
if pd.isnull(s):
return pd.NA
return len(self.separators.intersection(set(s)))
return len(set(self.separators).intersection(set(s)))

def get_separator_count(column):
return column.apply(count_unique_separator)
Expand Down
33 changes: 0 additions & 33 deletions nlp_primitives/tests/test_lsa.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import numpy as np
import pandas as pd
from featuretools import (
calculate_feature_matrix,
dfs,
load_features,
save_features
)

from ..lsa import LSA
from .test_utils import PrimitiveT, find_applicable_primitives, valid_dfs
Expand Down Expand Up @@ -56,30 +50,3 @@ def test_with_featuretools(self, es):
primitive_instance = self.primitive()
transform.append(primitive_instance)
valid_dfs(es, aggregation, transform, self.primitive.name.upper(), multi_output=True)

def test_serialize(self, es):
features = dfs(entityset=es,
target_dataframe_name="log",
trans_primitives=[self.primitive],
max_features=-1,
max_depth=3,
features_only=True)

feat_to_serialize = None
for feature in features:
if feature.primitive.__class__ == self.primitive:
feat_to_serialize = feature
break
for base_feature in feature.get_dependencies(deep=True):
if base_feature.primitive.__class__ == self.primitive:
feat_to_serialize = base_feature
break
assert feat_to_serialize is not None

df1 = calculate_feature_matrix([feat_to_serialize], entityset=es)

new_feat = load_features(save_features([feat_to_serialize]))[0]

df2 = calculate_feature_matrix([new_feat], entityset=es)

assert df1.equals(df2)
40 changes: 39 additions & 1 deletion nlp_primitives/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import featuretools as ft
import pytest
from featuretools import dfs, list_primitives
from featuretools import (
calculate_feature_matrix,
dfs,
list_primitives,
load_features,
save_features
)
from featuretools.tests.testing_utils import make_ecommerce_entityset

ft.primitives._load_primitives()
Expand Down Expand Up @@ -50,6 +56,38 @@ def test_arg_init(self):
if parameter.default is not parameter.empty:
assert hasattr(primitive_, name)

def test_serialize(self, es):
features = dfs(entityset=es,
target_dataframe_name="log",
trans_primitives=[self.primitive],
max_features=-1,
max_depth=3,
features_only=True)

feat_to_serialize = None
for feature in features:
if feature.primitive.__class__ == self.primitive:
feat_to_serialize = feature
break
for base_feature in feature.get_dependencies(deep=True):
if base_feature.primitive.__class__ == self.primitive:
feat_to_serialize = base_feature
break
assert feat_to_serialize is not None

# Skip calculating feature matrix for long running primitives
skip_primitives = ["elmo"]

if self.primitive.name not in skip_primitives:
df1 = calculate_feature_matrix([feat_to_serialize], entityset=es)

new_feat = load_features(save_features([feat_to_serialize]))[0]
assert isinstance(new_feat, ft.FeatureBase)

if self.primitive.name not in skip_primitives:
df2 = calculate_feature_matrix([new_feat], entityset=es)
assert df1.equals(df2)


def find_applicable_primitives(primitive):
from featuretools.primitives.utils import (
Expand Down

0 comments on commit ee6dc7b

Please sign in to comment.