Skip to content

Commit

Permalink
Backport PR #2673: chore(train): deprecate savebeststate callback and…
Browse files Browse the repository at this point in the history
… save_best arg
  • Loading branch information
martinkim0 authored and meeseeksmachine committed Apr 3, 2024
1 parent d2e402e commit dbe68e1
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
[cellxgene-census](https://chanzuckerberg.github.io/cellxgene-census/) instead {pr}`2542`.
- Deprecate {func}`scvi.nn.one_hot`, to be removed in v1.3. Please directly use the
`one_hot` function in PyTorch instead {pr}`2608`.
- Deprecate {class}`scvi.train.SaveBestState`, to be removed in v1.3. Please use
{class}`scvi.train.SaveCheckpoint` instead {pr}`2673`.
- Deprecate `save_best` argument in {meth}`scvi.model.PEAKVI.train` and
{meth}`scvi.model.MULTIVI.train`, to be removed in v1.3. Please pass in `enable_checkpointing`
or specify a custom checkpointing procedure with {class}`scvi.train.SaveCheckpoint` instead
{pr}`2673`.

#### Removed

Expand Down
16 changes: 14 additions & 2 deletions scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def train(
early_stopping
Whether to perform early stopping with respect to the validation set.
save_best
Save the best model state with respect to the validation loss, or use the final
state in the training procedure
``DEPRECATED`` Save the best model state with respect to the validation loss, or use
the final state in the training procedure.
check_val_every_n_epoch
Check val every n train epochs. By default, val is not checked, unless `early_stopping`
is `True`. If so, val is checked every epoch.
Expand All @@ -306,6 +306,11 @@ def train(
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
Notes
-----
``save_best`` is deprecated in v1.2 and will be removed in v1.3. Please use
``enable_checkpointing`` instead.
"""
update_dict = {
"lr": lr,
Expand All @@ -325,6 +330,13 @@ def train(
datasplitter_kwargs = datasplitter_kwargs or {}

if save_best:
warnings.warn(
"`save_best` is deprecated in v1.2 and will be removed in v1.3. Please use "
"`enable_checkpointing` instead. See "
"https://github.com/scverse/scvi-tools/issues/2568 for more details.",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
if "callbacks" not in kwargs.keys():
kwargs["callbacks"] = []
kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation"))
Expand Down
19 changes: 17 additions & 2 deletions scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterable, Sequence
from functools import partial
from typing import Literal
Expand All @@ -11,6 +12,7 @@
from anndata import AnnData
from scipy.sparse import csr_matrix, vstack

from scvi import settings
from scvi._constants import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
Expand Down Expand Up @@ -198,8 +200,8 @@ def train(
early_stopping_patience
How many epochs to wait for improvement before early stopping
save_best
Save the best model state with respect to the validation loss (default), or use the
final state in the training procedure
``DEPRECATED`` Save the best model state with respect to the validation loss (default),
or use the final state in the training procedure
check_val_every_n_epoch
Check val every n train epochs. By default, val is not checked, unless `early_stopping`
is `True`. If so, val is checked every epoch.
Expand All @@ -217,6 +219,11 @@ def train(
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
Notes
-----
``save_best`` is deprecated in v1.2 and will be removed in v1.3. Please use
``enable_checkpointing`` instead.
"""
update_dict = {
"lr": lr,
Expand All @@ -231,6 +238,14 @@ def train(
else:
plan_kwargs = update_dict
if save_best:
warnings.warn(
"`save_best` is deprecated in v1.2 and will be removed in v1.3. Please use "
"`enable_checkpointing` instead. See "
"https://github.com/scverse/scvi-tools/issues/2568 for more details.",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)

if "callbacks" not in kwargs.keys():
kwargs["callbacks"] = []
kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation"))
Expand Down
15 changes: 14 additions & 1 deletion scvi/train/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def on_train_epoch_start(self, trainer, pl_module):


class SaveBestState(Callback):
r"""Save the best module state and restore into model.
r"""``DEPRECATED`` Save the best module state and restore into model.
Parameters
----------
Expand All @@ -170,6 +170,11 @@ class SaveBestState(Callback):
--------
from scvi.train import Trainer
from scvi.train import SaveBestState
Notes
-----
Lifecycle: deprecated in v1.2 and to be removed in v1.3. Please use
:class:`~scvi.train.callbacks.SaveCheckpoint` instead.
"""

def __init__(
Expand All @@ -181,6 +186,14 @@ def __init__(
):
super().__init__()

warnings.warn(
"`SaveBestState` is deprecated in v1.2 and will be removed in v1.3. Please use "
"`SaveCheckpoint` instead. See https://github.com/scverse/scvi-tools/issues/2568 "
"for more details.",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)

self.monitor = monitor
self.verbose = verbose
self.period = period
Expand Down

0 comments on commit dbe68e1

Please sign in to comment.