Skip to content

Commit

Permalink
updates to support catboost benchmark sources
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Jan 26, 2024
1 parent aab18c3 commit 4553f8c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
1 change: 1 addition & 0 deletions tableshift/core/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def _load_data(self) -> pd.DataFrame:
class KddCup2009DataSource(DataSource):
def __init__(self, task_name: str, **kwargs):
self.task_name = task_name
self.name = task_name
_resources = [
"https://kdd.org/cupfiles/KDDCupData/2009/orange_small_train.data.zip",
f"http://www.kdd.org/cupfiles/KDDCupData/2009/orange_small_train_{task_name}.labels",
Expand Down
12 changes: 7 additions & 5 deletions tableshift/datasets/catboost_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
Feature('ROLE_CODE', int,
'Company role code; this code is unique to each role (e.g. Manager)',
name_extended='company role code'),
], documentation="https://www.kaggle.com/c/amazon-employee-access-challenge")
], documentation="https://www.kaggle.com/c/amazon-employee-access-challenge/overview , "
"https://www.kaggle.com/c/amazon-employee-access-challenge/data")

APPETENCY_FEATURES = FeatureList(features=[
Feature('label', float, name_extended='class label', is_target=True),
Feature('label', float, name_extended='user will buy new products or services (appetency)', is_target=True),
Feature('Var202', cat_dtype), # importance: 0.0881
Feature('Var220', cat_dtype), # importance: 0.0622
Feature('Var218', cat_dtype), # importance: 0.0532
Expand Down Expand Up @@ -175,7 +176,7 @@
'https://medium.com/@kushaldps1996/customer-relationship-prediction-kdd-cup-2009-6b57d08ffb0')

CHURN_FEATURES = FeatureList(features=[
Feature('label', float, name_extended='class label', is_target=True),
Feature('label', float, name_extended='customer will switch provider (churn)', is_target=True),
Feature('Var202', cat_dtype), # importance: 0.1061
Feature('Var222', cat_dtype), # importance: 0.0707
Feature('Var220', cat_dtype), # importance: 0.0699
Expand Down Expand Up @@ -303,7 +304,7 @@
'https://medium.com/@kushaldps1996/customer-relationship-prediction-kdd-cup-2009-6b57d08ffb0')

UPSELLING_FEATURES = FeatureList(features=[
Feature('label', float, name_extended='class label', is_target=True),
Feature('label', float, name_extended='customer will buy upgrades or add-ons proposed to them to make the sale more profitable (up-selling)', is_target=True),
Feature('Var126', float), # importance: 0.1205
Feature('Var202', cat_dtype), # importance: 0.0812
Feature('Var198', cat_dtype), # importance: 0.0687
Expand Down Expand Up @@ -507,7 +508,8 @@
name_extended='vehicle was originally purchased online'),
Feature('WarrantyCost', int,
name_extended='Warranty price (with term=36 month and mileage=36K)'),
], documentation="https://www.kaggle.com/competitions/DontGetKicked/data")
], documentation="https://www.kaggle.com/competitions/DontGetKicked/ , "
"https://www.kaggle.com/competitions/DontGetKicked/data")


def preprocess_kick(df: DataFrame) -> DataFrame:
Expand Down

0 comments on commit 4553f8c

Please sign in to comment.