Skip to content

Commit

Permalink
Merge pull request #31 from georgian-io/akash/update
Browse files Browse the repository at this point in the history
0.2-alpha release
  • Loading branch information
akashsaravanan-georgian authored Mar 10, 2023
2 parents 8f5b64d + d839824 commit c341715
Show file tree
Hide file tree
Showing 18 changed files with 353 additions and 83 deletions.
4 changes: 2 additions & 2 deletions datasets/Melbourne_Airbnb_Open_Data/train_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"num_train_epochs": 5,
"overwrite_output_dir": true,
"learning_rate": 3e-3,
"per_device_train_batch_size": 12,
"per_device_eval_batch_size": 12,
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"logging_steps": 50,
"eval_steps": 500,
"save_steps": 3000,
Expand Down
21 changes: 11 additions & 10 deletions datasets/PetFindermy_Adoption_Prediction/train_config.json
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
{
"output_dir": "./logs_petfinder/",
"output_dir": "./logs_petfinder/bertmultilingual_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3",
"debug_dataset": false,
"task": "classification",
"num_labels": 5,
"combine_feat_method": "text_only",
"experiment_name": "bert-base-multilingual-uncased",
"model_name_or_path": "bert-base-multilingual-uncased",
"do_train": true,
"categorical_encode_type": "ohe",
"numerical_transformer_method": "quantile_normal",
"tokenizer_name": "bert-base-multilingual-uncased",
"per_device_train_batch_size": 12,
"gpu_num": 0,
"use_simple_classifier": false,
"logging_dir": "./logs_petfinder/bertmultilingual_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/",
"num_train_epochs": 5,
"categorical_encode_type": "ohe",
"use_class_weights": false,
"overwrite_output_dir": true,
"learning_rate": 1e-4,
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"logging_steps": 50,
"eval_steps": 750,
"save_steps": 3000,
"learning_rate": 1e-4,
"data_path": "./datasets/PetFindermy_Adoption_Prediction/",
"column_info_path": "./datasets/PetFindermy_Adoption_Prediction/column_info_all_text.json",
"overwrite_output_dir": true
"data_path": "./datasets/PetFindermy_Adoption_Prediction",
"column_info_path": "./datasets/PetFindermy_Adoption_Prediction/column_info_all_text.json"
}

16 changes: 11 additions & 5 deletions datasets/Womens_Clothing_E-Commerce_Reviews/train_config.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
{
"output_dir": "./logs_clothing_review/gating_on_cat_and_num_feats_then_sum/",
"output_dir": "./logs_clothing_review/bertbase_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/",
"debug_dataset": false,
"task": "classification",
"combine_feat_method": "text_only",
"experiment_name": "Unimodal Bert Base Uncased",
"model_name_or_path": "bert-base-uncased",
"gpu_num": 0,
"do_train": true,
"categorical_encode_type": "binary",
"numerical_transformer_method": "quantile_normal",
"tokenizer_name": "bert-base-uncased",
"per_device_train_batch_size": 12,
"use_simple_classifier": false,
"logging_dir": "./logs_clothing_review/bertbase_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/",
"num_train_epochs": 5,
"overwrite_output_dir": true,
"learning_rate": 3e-3,
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"logging_steps": 50,
"eval_steps": 750,
"save_steps": 3000,
"data_path": "./datasets/Womens_Clothing_E-Commerce_Reviews",
"column_info_path": "./datasets/Womens_Clothing_E-Commerce_Reviews/column_info_all_text.json",
"overwrite_output_dir": true
"column_info_path": "./datasets/Womens_Clothing_E-Commerce_Reviews/column_info_all_text.json"
}

18 changes: 10 additions & 8 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
joblib==0.15.1
numpy==1.18.5
pandas==1.0.4
networkx~=2.6.3
numpy~=1.21.6
pandas~=1.3.5
pytest~=7.2.2
sacremoses~=0.0.53
scikit-image==0.17.2
scikit-learn==0.23.1
scipy==1.4.1
sklearn==0.0
scikit-learn~=1.0.2
scipy~=1.7.3
Sphinx==3.2.1
sphinx-markdown-tables==0.0.15
sphinx-rtd-theme==0.5.0
Expand All @@ -16,6 +18,6 @@ sphinxcontrib-napoleon==0.7
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4
threadpoolctl==2.1.0
torch==1.5.0
torchvision==0.6.0
transformers==3.0.0
torch>=1.13.1
tqdm~=4.64.1
transformers>=4.26.1
2 changes: 1 addition & 1 deletion docs/source/notes/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The following example shows a forward pass on two data examples
labels = torch.tensor([1, 0])
model_inputs['cat_feats'] = categorical_feat
model_inputs['num_feats'] = numerical_feat
model_inputs['numerical_feats'] = numerical_feat
model_inputs['labels'] = labels
loss, logits, layer_outs = model(**model_inputs)
Expand Down
21 changes: 13 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def main():
sep_text_token_str=tokenizer.sep_token if not data_args.column_info['text_col_sep_token'] else data_args.column_info['text_col_sep_token'],
max_token_length=training_args.max_token_length,
debug=training_args.debug_dataset,
debug_dataset_size=training_args.debug_dataset_size
)
train_datasets = [train_dataset]
val_datasets = [val_dataset]
Expand All @@ -104,6 +105,7 @@ def main():
data_args.column_info['text_col_sep_token'],
max_token_length=training_args.max_token_length,
debug=training_args.debug_dataset,
debug_dataset_size=training_args.debug_dataset_size
)
train_dataset = train_datasets[0]

Expand All @@ -116,16 +118,19 @@ def main():

def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
# p.predictions is now a list of objects
# The first entry is the actual predictions
predictions = p.predictions[0]
if task_name == "classification":
preds_labels = np.argmax(p.predictions, axis=1)
if p.predictions.shape[-1] == 2:
pred_scores = softmax(p.predictions, axis=1)[:, 1]
preds_labels = np.argmax(predictions, axis=1)
if predictions.shape[-1] == 2:
pred_scores = softmax(predictions, axis=1)[:, 1]
else:
pred_scores = softmax(p.predictions, axis=1)
pred_scores = softmax(predictions, axis=1)
return calc_classification_metrics(pred_scores, preds_labels,
p.label_ids)
elif task_name == "regression":
preds = np.squeeze(p.predictions)
preds = np.squeeze(predictions)
return calc_regression_metrics(preds, p.label_ids)
else:
return {}
Expand Down Expand Up @@ -178,7 +183,7 @@ def compute_metrics_fn(p: EvalPrediction):
output_eval_file = os.path.join(
training_args.output_dir, f"eval_metric_results_{task}_fold_{i+1}.txt"
)
if trainer.is_world_master():
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(task))
for key, value in eval_result.items():
Expand All @@ -190,13 +195,13 @@ def compute_metrics_fn(p: EvalPrediction):
if training_args.do_predict:
logging.info("*** Test ***")

predictions = trainer.predict(test_dataset=test_dataset).predictions
predictions = trainer.predict(test_dataset=test_dataset).predictions[0]
output_test_file = os.path.join(
training_args.output_dir, f"test_results_{task}_fold_{i+1}.txt"
)
eval_result = trainer.evaluate(eval_dataset=test_dataset)
logger.info(pformat(eval_result, indent=4))
if trainer.is_world_master():
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
logger.info("***** Test results {} *****".format(task))
writer.write("index\tprediction\n")
Expand Down
26 changes: 18 additions & 8 deletions multimodal_exp_args.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
import json
import logging
from typing import Optional, Tuple
from typing import Optional, Tuple, List

import torch
from transformers.training_args import TrainingArguments, torch_required, cached_property
from transformers.training_args import TrainingArguments, requires_backends, cached_property


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,6 +159,11 @@ class OurTrainingArguments(TrainingArguments):
metadata={'help': 'Whether we are training in debug mode (smaller model)'}
)

debug_dataset_size: int = field(
default=100,
metadata={'help': 'Size of the dataset in debug mode. Only used when debug_dataset = True.'}
)

do_eval: bool = field(default=True, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=True, metadata={"help": "Whether to run predictions on the test set."})

Expand All @@ -178,6 +183,10 @@ class OurTrainingArguments(TrainingArguments):

learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})

report_to: Optional[List[str]] = field(
default_factory=list, metadata={"help": "The list of integrations to report the results and logs to."}
)

def __post_init__(self):
if self.debug_dataset:
self.max_token_length = 16
Expand All @@ -186,12 +195,12 @@ def __post_init__(self):


@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
self._n_gpu = 0
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
Expand All @@ -200,15 +209,16 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
self._n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
n_gpu = 1
self._n_gpu = 1

if device.type == "cuda":
torch.cuda.set_device(device)

return device, n_gpu
return device
2 changes: 1 addition & 1 deletion multimodal_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multimodal_transformers.data
import multimodal_transformers.model

__version__ = '0.1.2-alpha'
__version__ = '0.2-alpha'

__all__ = ['multimodal_transformers', '__version__']
2 changes: 1 addition & 1 deletion multimodal_transformers/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def change_name_func(x):
def _one_hot(self):
ohe = preprocessing.OneHotEncoder(sparse=False)
ohe.fit(self.df[self.cat_feats].values)
self.feat_names = list(ohe.get_feature_names(self.cat_feats))
self.feat_names = list(ohe.get_feature_names_out(self.cat_feats))
return ohe.transform(self.df[self.cat_feats].values)

def fit_transform(self):
Expand Down
25 changes: 17 additions & 8 deletions multimodal_transformers/data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def load_data_into_folds(data_csv_path,
empty_text_values=None,
replace_empty_text=None,
max_token_length=None,
debug=False
debug=False,
debug_dataset_size=100
):
"""
Function to load tabular and text data from a specified folder into folds
Expand Down Expand Up @@ -114,7 +115,8 @@ def load_data_into_folds(data_csv_path,
empty_text_values,
replace_empty_text,
max_token_length,
debug)
debug,
debug_dataset_size)
train_splits.append(train)
val_splits.append(val)
test_splits.append(test)
Expand All @@ -136,6 +138,7 @@ def load_data_from_folder(folder_path,
replace_empty_text=None,
max_token_length=None,
debug=False,
debug_dataset_size=100
):
"""
Function to load tabular and text data from a specified folder
Expand Down Expand Up @@ -205,7 +208,8 @@ def load_data_from_folder(folder_path,
empty_text_values,
replace_empty_text,
max_token_length,
debug)
debug,
debug_dataset_size)


def load_train_val_test_helper(train_df,
Expand All @@ -223,7 +227,8 @@ def load_train_val_test_helper(train_df,
empty_text_values=None,
replace_empty_text=None,
max_token_length=None,
debug=False):
debug=False,
debug_dataset_size=100):
if categorical_encode_type == 'ohe' or categorical_encode_type == 'binary':
dfs = [df for df in [train_df, val_df, test_df] if df is not None]
data_df = pd.concat(dfs, axis=0)
Expand Down Expand Up @@ -272,7 +277,8 @@ def load_train_val_test_helper(train_df,
empty_text_values,
replace_empty_text,
max_token_length,
debug
debug,
debug_dataset_size
)
test_dataset = load_data(test_df,
text_cols,
Expand All @@ -287,7 +293,8 @@ def load_train_val_test_helper(train_df,
empty_text_values,
replace_empty_text,
max_token_length,
debug
debug,
debug_dataset_size
)

if val_df is not None:
Expand All @@ -304,7 +311,8 @@ def load_train_val_test_helper(train_df,
empty_text_values,
replace_empty_text,
max_token_length,
debug
debug,
debug_dataset_size
)
else:
val_dataset = None
Expand All @@ -326,6 +334,7 @@ def load_data(data_df,
replace_empty_text=None,
max_token_length=None,
debug=False,
debug_dataset_size=100
):
"""Function to load a single dataset given a pandas DataFrame
Expand Down Expand Up @@ -370,7 +379,7 @@ def load_data(data_df,
:obj:`tabular_torch_dataset.TorchTextDataset`: The converted dataset
"""
if debug:
data_df = data_df[:500]
data_df = data_df[:debug_dataset_size]
if empty_text_values is None:
empty_text_values = ['nan', 'None']

Expand Down
4 changes: 2 additions & 2 deletions multimodal_transformers/model/tabular_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __init__(self, tabular_config):
self.numerical_feat_dim,
division=self.mlp_division,
output_dim=output_dim_num)
self.cat_mlp = MLP(
self.num_mlp = MLP(
self.numerical_feat_dim,
output_dim_num,
num_hidden_lyr=len(dims),
Expand Down Expand Up @@ -406,7 +406,7 @@ def forward(self, text_feats, cat_feats=None, numerical_feats=None):
if self.numerical_feat_dim > self.text_out_dim:
numerical_feats = self.num_mlp(numerical_feats)
w_num = torch.mm(numerical_feats, self.weight_num)
g_num = (torch.cat([w_text, w_cat], dim=-1) * self.weight_a).sum(dim=1).unsqueeze(0).T
g_num = (torch.cat([w_text, w_num], dim=-1) * self.weight_a).sum(dim=1).unsqueeze(0).T
else:
w_num = None
g_num = torch.zeros(0, device=g_text.device)
Expand Down
2 changes: 1 addition & 1 deletion multimodal_transformers/model/tabular_modeling_auto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict

from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_auto import (
from transformers import (
AutoConfig,
AlbertConfig,
BertConfig,
Expand Down
Loading

0 comments on commit c341715

Please sign in to comment.