Skip to content

Commit

Permalink
Backport PR #2870 on branch 1.1.x (feat(external): add return_logits …
Browse files Browse the repository at this point in the history
…to solo predict for reproducibility) (#2871)
  • Loading branch information
martinkim0 authored Jun 30, 2024
1 parent 714e19a commit 05dc15b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
6 changes: 6 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/

### 1.1.4 (unreleased)

#### Added

- Add argument `return_logits` to {meth}`scvi.external.SOLO.predict` that allows returning logits
instead of probabilities when passing in `soft=True` to replicate the buggy behavior previous
to v1.1.3 {pr}`2870`.

### 1.1.3 (2024-06-26)

#### Fixed
Expand Down
16 changes: 12 additions & 4 deletions scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,12 @@ def train(
return runner()

@torch.inference_mode()
def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -> pd.DataFrame:
def predict(
self,
soft: bool = True,
include_simulated_doublets: bool = False,
return_logits: bool = False,
) -> pd.DataFrame:
"""Return doublet predictions.
Parameters
Expand All @@ -395,6 +400,8 @@ def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -
Return probabilities instead of class label.
include_simulated_doublets
Return probabilities for simulated doublets as well.
return_logits
Whether to return logits instead of probabilities when ``soft`` is ``True``.
Returns
-------
Expand All @@ -403,7 +410,8 @@ def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -
warnings.warn(
"Prior to scvi-tools 1.1.3, `SOLO.predict` with `soft=True` (the default option) "
"returned logits instead of probabilities. This behavior has since been corrected to "
"return probabiltiies.",
"return probabiltiies. The previous behavior can be replicated by passing in "
"`return_logits=True`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
Expand All @@ -417,8 +425,8 @@ def auto_forward(module, x):

y_pred = []
for _, tensors in enumerate(scdl):
x = tensors[REGISTRY_KEYS.X_KEY]
pred = torch.nn.functional.softmax(auto_forward(self.module, x), dim=-1)
pred = auto_forward(self.module, tensors[REGISTRY_KEYS.X_KEY])
pred = torch.nn.functional.softmax(pred, dim=-1) if not return_logits else pred
y_pred.append(pred.cpu())

y_pred = torch.cat(y_pred).numpy()
Expand Down
10 changes: 7 additions & 3 deletions tests/external/test_solo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


@pytest.mark.parametrize("soft", [True, False])
def test_solo(soft: bool):
@pytest.mark.parametrize("return_logits", [True, False])
def test_solo(soft: bool, return_logits: bool):
n_latent = 5
adata = synthetic_iid()
SCVI.setup_anndata(adata)
Expand All @@ -17,11 +18,14 @@ def test_solo(soft: bool):
solo = SOLO.from_scvi_model(model)
solo.train(1, check_val_every_n_epoch=1, train_size=0.9)
assert "validation_loss" in solo.history.keys()
predictions = solo.predict(soft=soft)
predictions = solo.predict(soft=soft, return_logits=return_logits)
if soft:
preds = predictions.to_numpy()
assert preds.shape == (adata.n_obs, 2)
assert np.allclose(preds.sum(axis=-1), 1)
if not return_logits:
assert np.allclose(preds.sum(axis=-1), 1)
else:
assert not np.allclose(preds.sum(axis=-1), 1)

bdata = synthetic_iid()
solo = SOLO.from_scvi_model(model, bdata)
Expand Down

0 comments on commit 05dc15b

Please sign in to comment.