diff --git a/tableshift/models/compat.py b/tableshift/models/compat.py index 5e8712ebc7..7f4f2e3d28 100644 --- a/tableshift/models/compat.py +++ b/tableshift/models/compat.py @@ -140,7 +140,10 @@ def is_domain_adaptation_model_name(model_name: str) -> bool: def is_pytorch_model_name(model: str) -> bool: """Helper function to determine whether a model name is a pytorch model. - ISee description of is_pytorch_model() above.""" + See description of is_pytorch_model() above.""" + if model=="catboost": + logging.warning("Catboost models are not suported in Ray hyperparameter training." + " Instead, use the provided catboost-specific script.") is_sklearn = model in SKLEARN_MODEL_NAMES is_pt = model in PYTORCH_MODEL_NAMES assert is_sklearn or is_pt, f"unknown model name {model}" diff --git a/tableshift/models/default_hparams.py b/tableshift/models/default_hparams.py index d31782ccf1..27e8d79a47 100644 --- a/tableshift/models/default_hparams.py +++ b/tableshift/models/default_hparams.py @@ -140,14 +140,15 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: - """Get a default config for a model by name.""" + """Get a default config for a model, by name.""" config = _DEFAULT_CONFIGS.get(model, {}) + model_is_pt = is_pytorch_model_name(model) d_in = dset.X_shape[1] - if is_pytorch_model_name(model) and model != "ft_transformer": + if model_is_pt and model != "ft_transformer": config.update({"d_in": d_in, "activation": "ReLU"}) - elif is_pytorch_model_name(model): + elif model_is_pt: config.update({"n_num_features": d_in}) if model in ("tabtransformer", "saint"): @@ -155,7 +156,7 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: config["cat_idxs"] = cat_idxs config["categories"] = [2] * len(cat_idxs) - # Models that use non-cross-entropy training objectives. + # Set the training objective and any associated hypperparameters. if model == "dro": config["criterion"] = DROLoss(size=config["size"], reg=config["reg"], @@ -170,10 +171,10 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: config["criterion"] = GroupDROLoss(n_groups=2) - else: + elif model_is_pt: config["criterion"] = F.binary_cross_entropy_with_logits - if is_pytorch_model_name(model) and model != "dann": + if model_is_pt and model != "dann": # Note: for DANN model, lr and weight decay are set separately for D # and G. config.update({"lr": 0.01, @@ -182,9 +183,9 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: # Do not overwrite batch size or epochs if they are set in the default # config for the model. - if "batch_size" not in config: + if "batch_size" not in config and model_is_pt: config["batch_size"] = DEFAULT_BATCH_SIZE - if "n_epochs" not in config: + if "n_epochs" not in config and model_is_pt: config["n_epochs"] = 1 if model == "saint" and d_in > 100: diff --git a/tableshift/models/torchutils.py b/tableshift/models/torchutils.py index c960d5d152..a2d30e7c88 100644 --- a/tableshift/models/torchutils.py +++ b/tableshift/models/torchutils.py @@ -69,12 +69,16 @@ def apply_model(model: torch.nn.Module, x): @torch.no_grad() -def get_predictions_and_labels(model, loader, device, as_logits=False) -> Tuple[ +def get_predictions_and_labels(model, loader, device=None, as_logits=False) -> Tuple[ np.ndarray, np.ndarray]: """Get the predictions (as logits, or probabilities) and labels.""" prediction = [] label = [] + if not device: + device = f"cuda:{torch.cuda.current_device()}" \ + if torch.cuda.is_available() else "cpu" + modelname = model.__class__.__name__ for batch in tqdm(loader, desc=f"{modelname}:getpreds"): batch_x, batch_y, _, _ = unpack_batch(batch) diff --git a/tableshift/models/training.py b/tableshift/models/training.py index 7f20f5f9aa..f129481e0b 100644 --- a/tableshift/models/training.py +++ b/tableshift/models/training.py @@ -120,10 +120,13 @@ def get_eval_loaders( def _train_pytorch(estimator: SklearnStylePytorchModel, dset: TabularDataset, - device: str, config=PYTORCH_DEFAULTS, + device: str=None, tune_report_split: str = None): """Helper function to train a pytorch estimator.""" + if not device: + device = f"cuda:{torch.cuda.current_device()}" \ + if torch.cuda.is_available() else "cpu" logging.debug(f"config is {config}") logging.debug(f"estimator is of type {type(estimator)}") logging.debug(f"dset name is {dset.name}") diff --git a/tableshift/models/utils.py b/tableshift/models/utils.py index 95dcb87357..cda392692b 100644 --- a/tableshift/models/utils.py +++ b/tableshift/models/utils.py @@ -22,7 +22,21 @@ from tableshift.models.wcs import WeightedCovariateShiftClassifier -def get_estimator(model, d_out=1, **kwargs): +def get_estimator(model:str, d_out=1, **kwargs): + """ + Fetch an estimator for training. + + Args: + model: the string name of the model to use. + d_out: output dimension of the model (set to 1 for binary classification). + kwargs: named arguments to pass to the model's class constructor. These + vary by model; for more details see below. Note that only a specific + subset of the kwargs will be used; passing arbitrary kwargs not accepted by + the model's class constructor will result in those kwargs being ignored. + Returns: + An instance of the class specified by the `model` string, with + any hyperparameters set according to kwargs. + """ if model == "aldro": assert d_out == 1, "assume binary classification." return AdversarialLabelDROModel(