diff --git a/datasets/Melbourne_Airbnb_Open_Data/column_info_all_text.json b/datasets/Melbourne_Airbnb_Open_Data/column_info_all_text.json new file mode 100644 index 0000000..98f1fab --- /dev/null +++ b/datasets/Melbourne_Airbnb_Open_Data/column_info_all_text.json @@ -0,0 +1,99 @@ +{ + "text_cols": [ + "name", + "summary", + "host_about", + "host_location", + "host_is_superhost", + "host_neighborhood", + "host_identity_verified", + "neighborhood", + "city", + "suburb", + "state", + "property_type", + "room_type", + "bed_type", + "cancellation_policy", + "instant_bookable", + "internet", + "wifi", + "pets_live_on_this_property", + "dog(s)", + "tv", + "air_conditioning", + "wifi", + "kitchen", + "heating", + "family/kid_friendly", + "washer", + "smoke_detector", + "first_aid_kit", + "fire_extinguisher", + "essentials", + "shampoo", + "lock_on_bedroom_door", + "hair_dryer", + "iron", + "laptop_friendly_workspace", + "private_entrance", + "hot_water", + "microwave", + "coffee_maker", + "refrigerator", + "dishes_and_silverware", + "cooking_basics", + "oven", + "stove", + "garden_or_backyard", + "luggage_dropoff_allowed", + "long_term_stays_allowed", + "free_parking_on_premises", + "elevator", + "buzzer/wireless_intercom", + "dryer", + "hangers", + "self_check_in", + "lockbox", + "bed_linens", + "extra_pillows_and_blankets", + "dishwasher", + "patio_or_balcony", + "breakfast", + "pool", + "bathtub", + "host_greets_you", + "smoking_allowed", + "gym", + "room_darkening_shades", + "single_level_home", + "bbq_grill", + "step_free_access", + "flat_path_to_front_door", + "well_lit_path_to_entrance", + "paid_parking_on_premises", + "free_street_parking", + "carbon_monoxide_detector", + "private_living_room", + "cable_tv", + "step_free_access", + "accommodates", + "bathrooms", + "beds", + "security_deposit", + "cleaning_fee", + "guests_included", + "extra_people", + "minimum_nights", + "availability_30", + "availability_60", + "availability_90", + "availability_365", + "number_of_reviews", + "review_scores_rating", + "reviews_per_month" + ], + "label_col": "price", + "text_col_sep_token": " ", + "label_list": [] +} \ No newline at end of file diff --git a/datasets/Womens_Clothing_E-Commerce_Reviews/column_info.json b/datasets/Womens_Clothing_E-Commerce_Reviews/column_info.json index 8dbb3dc..f4b76de 100644 --- a/datasets/Womens_Clothing_E-Commerce_Reviews/column_info.json +++ b/datasets/Womens_Clothing_E-Commerce_Reviews/column_info.json @@ -4,7 +4,6 @@ "Review Text" ], "cat_cols": [ - "Clothing ID", "Division Name", "Department Name", "Class Name" diff --git a/datasets/Womens_Clothing_E-Commerce_Reviews/column_info_all_text.json b/datasets/Womens_Clothing_E-Commerce_Reviews/column_info_all_text.json new file mode 100644 index 0000000..b1ffc90 --- /dev/null +++ b/datasets/Womens_Clothing_E-Commerce_Reviews/column_info_all_text.json @@ -0,0 +1,15 @@ +{ + "text_cols": [ + "Title", + "Review Text", + "Division Name", + "Department Name", + "Class Name", + "Rating", + "Age", + "Positive Feedback Count" + ], + "text_col_sep_token": " ", + "label_col": "Recommended IND", + "label_list": ["Not Recommended", "Recommended"] +} \ No newline at end of file diff --git a/docs/build/doctrees/environment.pickle b/docs/build/doctrees/environment.pickle index 441603c..995f622 100644 Binary files a/docs/build/doctrees/environment.pickle and b/docs/build/doctrees/environment.pickle differ diff --git a/docs/build/doctrees/modules/model.doctree b/docs/build/doctrees/modules/model.doctree index cf9496f..751ea8a 100644 Binary files a/docs/build/doctrees/modules/model.doctree and b/docs/build/doctrees/modules/model.doctree differ diff --git a/main.py b/main.py index 6ce6eda..f30ad89 100644 --- a/main.py +++ b/main.py @@ -76,7 +76,7 @@ def main(): numerical_cols=data_args.column_info['num_cols'], categorical_encode_type=data_args.categorical_encode_type, numerical_transformer_method=data_args.numerical_transformer_method, - sep_text_token_str=tokenizer.sep_token, + sep_text_token_str=tokenizer.sep_token if not data_args.column_info['text_col_sep_token'] else data_args.column_info['text_col_sep_token'], max_token_length=training_args.max_token_length, debug=training_args.debug_dataset, ) @@ -93,8 +93,8 @@ def main(): cache_dir=model_args.cache_dir, ) tabular_config = TabularConfig(num_labels=num_labels, - cat_feat_dim=train_dataset.cat_feats.shape[1], - numerical_feat_dim=train_dataset.numerical_feats.shape[1], + cat_feat_dim=train_dataset.cat_feats.shape[1] if train_dataset.cat_feats is not None else 0, + numerical_feat_dim=train_dataset.numerical_feats.shape[1] if train_dataset.numerical_feats is not None else 0, **vars(data_args)) config.tabular_config = tabular_config logger.info(tabular_config) diff --git a/multimodal_exp_args.py b/multimodal_exp_args.py index a65176c..0c3fd72 100644 --- a/multimodal_exp_args.py +++ b/multimodal_exp_args.py @@ -110,7 +110,15 @@ def __post_init__(self): if self.column_info is None and self.column_info_path: with open(self.column_info_path, 'r') as f: self.column_info = json.load(f) - + assert 'text_cols' in self.column_info and 'label_col' in self.column_info + if 'cat_cols' not in self.column_info: + self.column_info['cat_cols'] = None + self.categorical_encode_type = 'none' + if 'num_cols' not in self.column_info: + self.column_info['num_cols'] = None + self.numerical_transformer_method = 'none' + if 'text_col_sep_token' not in self.column_info: + self.column_info['text_col_sep_token'] = None @dataclass class OurTrainingArguments(TrainingArguments): diff --git a/multimodal_transformers/model/tabular_modeling_auto.py b/multimodal_transformers/model/tabular_modeling_auto.py index fe63a0a..c95420c 100644 --- a/multimodal_transformers/model/tabular_modeling_auto.py +++ b/multimodal_transformers/model/tabular_modeling_auto.py @@ -3,20 +3,35 @@ from transformers.configuration_utils import PretrainedConfig from transformers.configuration_auto import ( AutoConfig, + AlbertConfig, BertConfig, DistilBertConfig, RobertaConfig, - + XLNetConfig, + XLMConfig, + XLMRobertaConfig ) -from .tabular_transformers import RobertaWithTabular, BertWithTabular, DistilBertWithTabular +from .tabular_transformers import ( + RobertaWithTabular, + BertWithTabular, + DistilBertWithTabular, + AlbertWithTabular, + XLNetWithTabular, + XLMWithTabular, + XLMRobertaWithTabular +) MODEL_FOR_SEQUENCE_W_TABULAR_CLASSIFICATION_MAPPING = OrderedDict( [ (RobertaConfig, RobertaWithTabular), (BertConfig, BertWithTabular), - (DistilBertConfig, DistilBertWithTabular) + (DistilBertConfig, DistilBertWithTabular), + (AlbertConfig, AlbertWithTabular), + (XLNetConfig, XLNetWithTabular), + (XLMConfig, XLMWithTabular), + (XLMRobertaConfig, XLMRobertaWithTabular) ] ) diff --git a/multimodal_transformers/model/tabular_transformers.py b/multimodal_transformers/model/tabular_transformers.py index 4e30111..1d2a934 100644 --- a/multimodal_transformers/model/tabular_transformers.py +++ b/multimodal_transformers/model/tabular_transformers.py @@ -3,10 +3,17 @@ BertForSequenceClassification, RobertaForSequenceClassification, DistilBertForSequenceClassification, + AlbertForSequenceClassification, + XLNetForSequenceClassification, + XLMForSequenceClassification ) from transformers.modeling_bert import BERT_INPUTS_DOCSTRING from transformers.modeling_roberta import ROBERTA_INPUTS_DOCSTRING from transformers.modeling_distilbert import DISTILBERT_INPUTS_DOCSTRING +from transformers.modeling_albert import ALBERT_INPUTS_DOCSTRING +from transformers.modeling_xlnet import XLNET_INPUTS_DOCSTRING +from transformers.modeling_xlm import XLM_INPUTS_DOCSTRING +from transformers.configuration_xlm_roberta import XLMRobertaConfig from transformers.file_utils import add_start_docstrings_to_callable from .tabular_combiner import TabularFeatCombiner @@ -220,6 +227,14 @@ def forward( return loss, logits, classifier_layer_outputs +class XLMRobertaWithTabular(RobertaWithTabular): + """ + This class overrides :class:`~RobertaWithTabular`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + config_class = XLMRobertaConfig + + class DistilBertWithTabular(DistilBertForSequenceClassification): """ DistilBert Model transformer with a sequence classification/regression head as well as @@ -317,4 +332,289 @@ def forward( labels, self.num_labels, class_weights) + return loss, logits, classifier_layer_outputs + + +class AlbertWithTabular(AlbertForSequenceClassification): + """ + ALBERT Model transformer with a sequence classification/regression head as well as + a TabularFeatCombiner module to combine categorical and numerical features + with the Roberta pooled output + + Parameters: + hf_model_config (:class:`~transformers.AlbertConfig`): + Model configuration class with all the parameters of the model. + This object must also have a tabular_config member variable that is a + :obj:`TabularConfig` instance specifying the configs for :obj:`TabularFeatCombiner` + """ + + def __init__(self, hf_model_config): + super().__init__(hf_model_config) + tabular_config = hf_model_config.tabular_config + if type(tabular_config) is dict: # when loading from saved model + tabular_config = TabularConfig(**tabular_config) + else: + self.config.tabular_config = tabular_config.__dict__ + + tabular_config.text_feat_dim = hf_model_config.hidden_size + tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob + self.tabular_combiner = TabularFeatCombiner(tabular_config) + self.num_labels = tabular_config.num_labels + combined_feat_dim = self.tabular_combiner.final_out_dim + if tabular_config.use_simple_classifier: + self.tabular_classifier = nn.Linear(combined_feat_dim, + tabular_config.num_labels) + else: + dims = calc_mlp_dims(combined_feat_dim, + division=tabular_config.mlp_division, + output_dim=tabular_config.num_labels) + self.tabular_classifier = MLP(combined_feat_dim, + tabular_config.num_labels, + num_hidden_lyr=len(dims), + dropout_prob=tabular_config.mlp_dropout, + hidden_channels=dims, + bn=True) + + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + class_weights=None, + cat_feats=None, + numerical_feats=None + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + combined_feats = self.tabular_combiner(pooled_output, + cat_feats, + numerical_feats) + loss, logits, classifier_layer_outputs = hf_loss_func(combined_feats, + self.tabular_classifier, + labels, + self.num_labels, + class_weights) + return loss, logits, classifier_layer_outputs + + +class XLNetWithTabular(XLNetForSequenceClassification): + """ + XLNet Model transformer with a sequence classification/regression head as well as + a TabularFeatCombiner module to combine categorical and numerical features + with the Roberta pooled output + + Parameters: + hf_model_config (:class:`~transformers.XLNetConfig`): + Model configuration class with all the parameters of the model. + This object must also have a tabular_config member variable that is a + :obj:`TabularConfig` instance specifying the configs for :obj:`TabularFeatCombiner` + """ + def __init__(self, hf_model_config): + super().__init__(hf_model_config) + tabular_config = hf_model_config.tabular_config + if type(tabular_config) is dict: # when loading from saved model + tabular_config = TabularConfig(**tabular_config) + else: + self.config.tabular_config = tabular_config.__dict__ + + tabular_config.text_feat_dim = hf_model_config.hidden_size + self.tabular_combiner = TabularFeatCombiner(tabular_config) + self.num_labels = tabular_config.num_labels + combined_feat_dim = self.tabular_combiner.final_out_dim + if tabular_config.use_simple_classifier: + self.tabular_classifier = nn.Linear(combined_feat_dim, + tabular_config.num_labels) + else: + dims = calc_mlp_dims(combined_feat_dim, + division=tabular_config.mlp_division, + output_dim=tabular_config.num_labels) + self.tabular_classifier = MLP(combined_feat_dim, + tabular_config.num_labels, + num_hidden_lyr=len(dims), + dropout_prob=tabular_config.mlp_dropout, + hidden_channels=dims, + bn=True) + + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + mems=None, + perm_mask=None, + target_mapping=None, + token_type_ids=None, + input_mask=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + class_weights=None, + cat_feats=None, + numerical_feats=None + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`) + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + mems=mems, + perm_mask=perm_mask, + target_mapping=target_mapping, + token_type_ids=token_type_ids, + input_mask=input_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + output = transformer_outputs[0] + + output = self.sequence_summary(output) + combined_feats = self.tabular_combiner(output, + cat_feats, + numerical_feats) + loss, logits, classifier_layer_outputs = hf_loss_func(combined_feats, + self.tabular_classifier, + labels, + self.num_labels, + class_weights) + return loss, logits, classifier_layer_outputs + + +class XLMWithTabular(XLMForSequenceClassification): + """ + XLM Model transformer with a sequence classification/regression head as well as + a TabularFeatCombiner module to combine categorical and numerical features + with the Roberta pooled output + + Parameters: + hf_model_config (:class:`~transformers.XLMConfig`): + Model configuration class with all the parameters of the model. + This object must also have a tabular_config member variable that is a + :obj:`TabularConfig` instance specifying the configs for :obj:`TabularFeatCombiner` + """ + def __init__(self, hf_model_config): + super().__init__(hf_model_config) + tabular_config = hf_model_config.tabular_config + if type(tabular_config) is dict: # when loading from saved model + tabular_config = TabularConfig(**tabular_config) + else: + self.config.tabular_config = tabular_config.__dict__ + + tabular_config.text_feat_dim = hf_model_config.hidden_size + self.tabular_combiner = TabularFeatCombiner(tabular_config) + self.num_labels = tabular_config.num_labels + combined_feat_dim = self.tabular_combiner.final_out_dim + if tabular_config.use_simple_classifier: + self.tabular_classifier = nn.Linear(combined_feat_dim, + tabular_config.num_labels) + else: + dims = calc_mlp_dims(combined_feat_dim, + division=tabular_config.mlp_division, + output_dim=tabular_config.num_labels) + self.tabular_classifier = MLP(combined_feat_dim, + tabular_config.num_labels, + num_hidden_lyr=len(dims), + dropout_prob=tabular_config.mlp_dropout, + hidden_channels=dims, + bn=True) + + @ add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + langs=None, + token_type_ids=None, + position_ids=None, + lengths=None, + cache=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + class_weights=None, + cat_feats=None, + numerical_feats=None + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output = transformer_outputs[0] + output = self.sequence_summary(output) + combined_feats = self.tabular_combiner(output, + cat_feats, + numerical_feats) + loss, logits, classifier_layer_outputs = hf_loss_func(combined_feats, + self.tabular_classifier, + labels, + self.num_labels, + class_weights) return loss, logits, classifier_layer_outputs \ No newline at end of file