Skip to content

Commit

Permalink
updates to fix various small model-specific bugs so that all models c…
Browse files Browse the repository at this point in the history
…an be trained with run_expt.py
  • Loading branch information
jpgard committed Aug 20, 2023
1 parent de58de0 commit 35ab0db
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
5 changes: 4 additions & 1 deletion tableshift/models/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def is_domain_adaptation_model_name(model_name: str) -> bool:
def is_pytorch_model_name(model: str) -> bool:
"""Helper function to determine whether a model name is a pytorch model.
ISee description of is_pytorch_model() above."""
See description of is_pytorch_model() above."""
if model=="catboost":
logging.warning("Catboost models are not suported in Ray hyperparameter training."
" Instead, use the provided catboost-specific script.")
is_sklearn = model in SKLEARN_MODEL_NAMES
is_pt = model in PYTORCH_MODEL_NAMES
assert is_sklearn or is_pt, f"unknown model name {model}"
Expand Down
17 changes: 9 additions & 8 deletions tableshift/models/default_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,23 @@


def get_default_config(model: str, dset: TabularDataset) -> dict:
"""Get a default config for a model by name."""
"""Get a default config for a model, by name."""
config = _DEFAULT_CONFIGS.get(model, {})
model_is_pt = is_pytorch_model_name(model)

d_in = dset.X_shape[1]
if is_pytorch_model_name(model) and model != "ft_transformer":
if model_is_pt and model != "ft_transformer":
config.update({"d_in": d_in,
"activation": "ReLU"})
elif is_pytorch_model_name(model):
elif model_is_pt:
config.update({"n_num_features": d_in})

if model in ("tabtransformer", "saint"):
cat_idxs = dset.cat_idxs
config["cat_idxs"] = cat_idxs
config["categories"] = [2] * len(cat_idxs)

# Models that use non-cross-entropy training objectives.
# Set the training objective and any associated hypperparameters.
if model == "dro":
config["criterion"] = DROLoss(size=config["size"],
reg=config["reg"],
Expand All @@ -170,10 +171,10 @@ def get_default_config(model: str, dset: TabularDataset) -> dict:
config["criterion"] = GroupDROLoss(n_groups=2)


else:
elif model_is_pt:
config["criterion"] = F.binary_cross_entropy_with_logits

if is_pytorch_model_name(model) and model != "dann":
if model_is_pt and model != "dann":
# Note: for DANN model, lr and weight decay are set separately for D
# and G.
config.update({"lr": 0.01,
Expand All @@ -182,9 +183,9 @@ def get_default_config(model: str, dset: TabularDataset) -> dict:

# Do not overwrite batch size or epochs if they are set in the default
# config for the model.
if "batch_size" not in config:
if "batch_size" not in config and model_is_pt:
config["batch_size"] = DEFAULT_BATCH_SIZE
if "n_epochs" not in config:
if "n_epochs" not in config and model_is_pt:
config["n_epochs"] = 1

if model == "saint" and d_in > 100:
Expand Down
6 changes: 5 additions & 1 deletion tableshift/models/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def apply_model(model: torch.nn.Module, x):


@torch.no_grad()
def get_predictions_and_labels(model, loader, device, as_logits=False) -> Tuple[
def get_predictions_and_labels(model, loader, device=None, as_logits=False) -> Tuple[
np.ndarray, np.ndarray]:
"""Get the predictions (as logits, or probabilities) and labels."""
prediction = []
label = []

if not device:
device = f"cuda:{torch.cuda.current_device()}" \
if torch.cuda.is_available() else "cpu"

modelname = model.__class__.__name__
for batch in tqdm(loader, desc=f"{modelname}:getpreds"):
batch_x, batch_y, _, _ = unpack_batch(batch)
Expand Down
5 changes: 4 additions & 1 deletion tableshift/models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,13 @@ def get_eval_loaders(


def _train_pytorch(estimator: SklearnStylePytorchModel, dset: TabularDataset,
device: str,
config=PYTORCH_DEFAULTS,
device: str=None,
tune_report_split: str = None):
"""Helper function to train a pytorch estimator."""
if not device:
device = f"cuda:{torch.cuda.current_device()}" \
if torch.cuda.is_available() else "cpu"
logging.debug(f"config is {config}")
logging.debug(f"estimator is of type {type(estimator)}")
logging.debug(f"dset name is {dset.name}")
Expand Down
16 changes: 15 additions & 1 deletion tableshift/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@
from tableshift.models.wcs import WeightedCovariateShiftClassifier


def get_estimator(model, d_out=1, **kwargs):
def get_estimator(model:str, d_out=1, **kwargs):
"""
Fetch an estimator for training.
Args:
model: the string name of the model to use.
d_out: output dimension of the model (set to 1 for binary classification).
kwargs: named arguments to pass to the model's class constructor. These
vary by model; for more details see below. Note that only a specific
subset of the kwargs will be used; passing arbitrary kwargs not accepted by
the model's class constructor will result in those kwargs being ignored.
Returns:
An instance of the class specified by the `model` string, with
any hyperparameters set according to kwargs.
"""
if model == "aldro":
assert d_out == 1, "assume binary classification."
return AdversarialLabelDROModel(
Expand Down

0 comments on commit 35ab0db

Please sign in to comment.