Skip to content

Commit

Permalink
add more transformers to toolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
codeKgu committed Dec 18, 2020
1 parent 64d26e6 commit a37c571
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 8 deletions.
99 changes: 99 additions & 0 deletions datasets/Melbourne_Airbnb_Open_Data/column_info_all_text.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"text_cols": [
"name",
"summary",
"host_about",
"host_location",
"host_is_superhost",
"host_neighborhood",
"host_identity_verified",
"neighborhood",
"city",
"suburb",
"state",
"property_type",
"room_type",
"bed_type",
"cancellation_policy",
"instant_bookable",
"internet",
"wifi",
"pets_live_on_this_property",
"dog(s)",
"tv",
"air_conditioning",
"wifi",
"kitchen",
"heating",
"family/kid_friendly",
"washer",
"smoke_detector",
"first_aid_kit",
"fire_extinguisher",
"essentials",
"shampoo",
"lock_on_bedroom_door",
"hair_dryer",
"iron",
"laptop_friendly_workspace",
"private_entrance",
"hot_water",
"microwave",
"coffee_maker",
"refrigerator",
"dishes_and_silverware",
"cooking_basics",
"oven",
"stove",
"garden_or_backyard",
"luggage_dropoff_allowed",
"long_term_stays_allowed",
"free_parking_on_premises",
"elevator",
"buzzer/wireless_intercom",
"dryer",
"hangers",
"self_check_in",
"lockbox",
"bed_linens",
"extra_pillows_and_blankets",
"dishwasher",
"patio_or_balcony",
"breakfast",
"pool",
"bathtub",
"host_greets_you",
"smoking_allowed",
"gym",
"room_darkening_shades",
"single_level_home",
"bbq_grill",
"step_free_access",
"flat_path_to_front_door",
"well_lit_path_to_entrance",
"paid_parking_on_premises",
"free_street_parking",
"carbon_monoxide_detector",
"private_living_room",
"cable_tv",
"step_free_access",
"accommodates",
"bathrooms",
"beds",
"security_deposit",
"cleaning_fee",
"guests_included",
"extra_people",
"minimum_nights",
"availability_30",
"availability_60",
"availability_90",
"availability_365",
"number_of_reviews",
"review_scores_rating",
"reviews_per_month"
],
"label_col": "price",
"text_col_sep_token": " ",
"label_list": []
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"Review Text"
],
"cat_cols": [
"Clothing ID",
"Division Name",
"Department Name",
"Class Name"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"text_cols": [
"Title",
"Review Text",
"Division Name",
"Department Name",
"Class Name",
"Rating",
"Age",
"Positive Feedback Count"
],
"text_col_sep_token": " ",
"label_col": "Recommended IND",
"label_list": ["Not Recommended", "Recommended"]
}
Binary file modified docs/build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/doctrees/modules/model.doctree
Binary file not shown.
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
numerical_cols=data_args.column_info['num_cols'],
categorical_encode_type=data_args.categorical_encode_type,
numerical_transformer_method=data_args.numerical_transformer_method,
sep_text_token_str=tokenizer.sep_token,
sep_text_token_str=tokenizer.sep_token if not data_args.column_info['text_col_sep_token'] else data_args.column_info['text_col_sep_token'],
max_token_length=training_args.max_token_length,
debug=training_args.debug_dataset,
)
Expand All @@ -93,8 +93,8 @@ def main():
cache_dir=model_args.cache_dir,
)
tabular_config = TabularConfig(num_labels=num_labels,
cat_feat_dim=train_dataset.cat_feats.shape[1],
numerical_feat_dim=train_dataset.numerical_feats.shape[1],
cat_feat_dim=train_dataset.cat_feats.shape[1] if train_dataset.cat_feats is not None else 0,
numerical_feat_dim=train_dataset.numerical_feats.shape[1] if train_dataset.numerical_feats is not None else 0,
**vars(data_args))
config.tabular_config = tabular_config
logger.info(tabular_config)
Expand Down
10 changes: 9 additions & 1 deletion multimodal_exp_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ def __post_init__(self):
if self.column_info is None and self.column_info_path:
with open(self.column_info_path, 'r') as f:
self.column_info = json.load(f)

assert 'text_cols' in self.column_info and 'label_col' in self.column_info
if 'cat_cols' not in self.column_info:
self.column_info['cat_cols'] = None
self.categorical_encode_type = 'none'
if 'num_cols' not in self.column_info:
self.column_info['num_cols'] = None
self.numerical_transformer_method = 'none'
if 'text_col_sep_token' not in self.column_info:
self.column_info['text_col_sep_token'] = None

@dataclass
class OurTrainingArguments(TrainingArguments):
Expand Down
21 changes: 18 additions & 3 deletions multimodal_transformers/model/tabular_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,35 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_auto import (
AutoConfig,
AlbertConfig,
BertConfig,
DistilBertConfig,
RobertaConfig,

XLNetConfig,
XLMConfig,
XLMRobertaConfig
)

from .tabular_transformers import RobertaWithTabular, BertWithTabular, DistilBertWithTabular
from .tabular_transformers import (
RobertaWithTabular,
BertWithTabular,
DistilBertWithTabular,
AlbertWithTabular,
XLNetWithTabular,
XLMWithTabular,
XLMRobertaWithTabular
)


MODEL_FOR_SEQUENCE_W_TABULAR_CLASSIFICATION_MAPPING = OrderedDict(
[
(RobertaConfig, RobertaWithTabular),
(BertConfig, BertWithTabular),
(DistilBertConfig, DistilBertWithTabular)
(DistilBertConfig, DistilBertWithTabular),
(AlbertConfig, AlbertWithTabular),
(XLNetConfig, XLNetWithTabular),
(XLMConfig, XLMWithTabular),
(XLMRobertaConfig, XLMRobertaWithTabular)
]
)

Expand Down
Loading

0 comments on commit a37c571

Please sign in to comment.