Skip to content

Commit

Permalink
implements group dro FT-Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Aug 23, 2023
1 parent 5673aed commit 1700df9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
16 changes: 15 additions & 1 deletion tableshift/models/default_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@
# This is feature embedding size in Table 13 above.
"d_token": 64,
},
"group_dro_ft_transformer":
{"cat_cardinalities": None,
"n_blocks": 1,
"residual_dropout": 0.,
"attention_dropout": 0.,
"ffn_dropout": 0.,
"ffn_factor": 1.,
# This is feature embedding size in Table 13 above.
"d_token": 64,
"num_layers": 2,
"d_hidden": 256,
"group_weights_step_size": 0.05,
"dropouts": 0.
},
"group_dro":
{"num_layers": 2,
"d_hidden": 256,
Expand Down Expand Up @@ -154,7 +168,7 @@ def get_default_config(model: str, dset: TabularDataset) -> dict:
model_is_pt = is_pytorch_model_name(model)

d_in = dset.X_shape[1]
if model_is_pt and model != "ft_transformer":
if model_is_pt and "ft_transformer" not in model:
config.update({"d_in": d_in,
"activation": "ReLU"})
elif model_is_pt:
Expand Down
39 changes: 38 additions & 1 deletion tableshift/models/dro.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,44 @@ def train_epoch(self,
train_loader = list(train_loaders.values())[0]

for iteration, batch in tqdm(enumerate(train_loader),
desc="groupdroresnet:train"):
desc="groupdro-resnet:train"):
x_batch, y_batch, _, d_batch = unpack_batch(batch)
x_batch = x_batch.float().to(device)
y_batch = y_batch.float().to(device)
d_batch = d_batch.float().to(device)
self.train()
self.optimizer.zero_grad()
outputs = apply_model(self, x_batch)
loss = loss_fn(outputs.squeeze(1), y_batch, d_batch,
self.group_weights,
self.group_weights_step_size)

loss.backward()
self.optimizer.step()

class DomainGroupDROFTTransformerModel(FTTransformerModel, SklearnStylePytorchModel):
"""Group DRO with domain labels as groups. (For domain robustness.)"""
def to(self, device):
super().to(device)
for attr in ("group_weights_step_size", "group_weights"):
setattr(self, attr, getattr(self, attr).to(device))
return self

def train_epoch(self,
train_loaders: Dict[Any, DataLoader],
loss_fn: Callable,
device: str,
uda_loader: Optional[DataLoader] = None,
eval_loaders: Optional[Mapping[str, DataLoader]] = None,
# Terminate after this many steps if reached before end
# of epoch.
max_examples_per_epoch: Optional[int] = None
) -> float:
assert len(train_loaders.values()) == 1
train_loader = list(train_loaders.values())[0]

for iteration, batch in tqdm(enumerate(train_loader),
desc="groupdro-fttransformer:train"):
x_batch, y_batch, _, d_batch = unpack_batch(batch)
x_batch = x_batch.float().to(device)
y_batch = y_batch.float().to(device)
Expand Down
38 changes: 38 additions & 0 deletions tableshift/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import xgboost as xgb
from lightgbm import LGBMClassifier
from sklearn.ensemble import HistGradientBoostingClassifier
import torch

from tableshift.models.compat import OPTIMIZER_ARGS
from tableshift.models.coral import DeepCoralModel, MMDModel
from tableshift.models.dann import DANNModel
from tableshift.models.dro import (DomainGroupDROModel,
DomainGroupDROResNetModel,
DomainGroupDROFTTransformerModel,
AdversarialLabelDROModel,
LabelGroupDROModel)
from tableshift.models.expgrad import ExponentiatedGradient
Expand Down Expand Up @@ -106,6 +108,42 @@ def get_estimator(model:str, d_out=1, **kwargs):

return model

elif model == "group_dro_ft_transformer":
tconfig = FTTransformerModel.get_default_transformer_config()

tconfig["last_layer_query_idx"] = [-1]
tconfig["d_out"] = 1
params_to_override = ("n_blocks", "residual_dropout", "d_token",
"attention_dropout", "ffn_dropout")
for k in params_to_override:
tconfig[k] = kwargs[k]

tconfig["ffn_d_hidden"] = int(kwargs["d_token"] * kwargs["ffn_factor"])

# Fixed as in https://arxiv.org/pdf/2106.11959.pdf
tconfig['attention_n_heads'] = 8

# Hacky way to construct a FTTransformer model
model = DomainGroupDROFTTransformerModel._make(
n_num_features=kwargs["n_num_features"],
cat_cardinalities=kwargs["cat_cardinalities"],
transformer_config=tconfig)
tconfig.update({k: kwargs[k] for k in OPTIMIZER_ARGS})

n_groups = kwargs["n_groups"]
group_weights_step_size=kwargs["group_weights_step_size"]

tconfig.update({"n_groups": n_groups, "group_weights_step_size": group_weights_step_size})
assert n_groups > 0, "require nonzero n_groups."
model.group_weights_step_size = torch.Tensor([group_weights_step_size])
# initialize adversarial weights
model.group_weights = torch.nn.Parameter(
torch.full([n_groups], 1. / n_groups))
model.config = copy.deepcopy(tconfig)
model._init_optimizer()

return model

elif model == "group_dro":
return DomainGroupDROModel(
d_in=kwargs["d_in"],
Expand Down

0 comments on commit 1700df9

Please sign in to comment.