Skip to content

Commit

Permalink
🥋 Use same CategoricalFeatures interface for binary and one hot featu…
Browse files Browse the repository at this point in the history
…rization
  • Loading branch information
codeKgu committed Jul 27, 2021
1 parent e30514c commit 8f5b64d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion multimodal_transformers/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def change_name_func(x):
def _one_hot(self):
ohe = preprocessing.OneHotEncoder(sparse=False)
ohe.fit(self.df[self.cat_feats].values)
self.feat_names = ohe.get_feature_names(self.cat_feats)
self.feat_names = list(ohe.get_feature_names(self.cat_feats))
return ohe.transform(self.df[self.cat_feats].values)

def fit_transform(self):
Expand Down
16 changes: 5 additions & 11 deletions multimodal_transformers/data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,11 @@ def load_train_val_test_helper(train_df,
if categorical_encode_type == 'ohe' or categorical_encode_type == 'binary':
dfs = [df for df in [train_df, val_df, test_df] if df is not None]
data_df = pd.concat(dfs, axis=0)
if categorical_encode_type == 'ohe':
data_df = pd.get_dummies(data_df, columns=categorical_cols,
dummy_na=True)
categorical_cols = [col for col in data_df.columns for old_col in categorical_cols
if col.startswith(old_col) and len(col) > len(old_col)]
elif categorical_encode_type == 'binary':
cat_feat_processor = CategoricalFeatures(data_df, categorical_cols, 'binary')
vals = cat_feat_processor.fit_transform()
cat_df = pd.DataFrame(vals, columns=cat_feat_processor.feat_names)
data_df = pd.concat([data_df, cat_df], axis=1)
categorical_cols = cat_feat_processor.feat_names
cat_feat_processor = CategoricalFeatures(data_df, categorical_cols, categorical_encode_type)
vals = cat_feat_processor.fit_transform()
cat_df = pd.DataFrame(vals, columns=cat_feat_processor.feat_names)
data_df = pd.concat([data_df, cat_df], axis=1)
categorical_cols = cat_feat_processor.feat_names

len_train = len(train_df)
len_val = len(val_df) if val_df is not None else 0
Expand Down

0 comments on commit 8f5b64d

Please sign in to comment.