Skip to content

Commit

Permalink
fix again for from __future__ import annotations
Browse files Browse the repository at this point in the history
and fix the test for custom dataloaders
  • Loading branch information
ori-kron-wis committed Aug 11, 2024
1 parent c0889d8 commit 8fe043c
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test_linux_custom_dataloader.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ jobs:
python -m uv pip install --system "scvi-tools[tests] @ ."
- name: Install Specific Branch of Repository
env:
GH_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/"
pip install git+https://github.com/ebezzi/cellxgene-census.git@census-scvi-datamodule
- name: Run specific custom dataloader pytest
Expand Down
29 changes: 15 additions & 14 deletions src/scvi/model/_amortizedlda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import collections.abc
import logging
from collections.abc import Sequence
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,8 +61,8 @@ def __init__(
registry: dict | None = None,
n_topics: int = 20,
n_hidden: int = 128,
cell_topic_prior: Optional[Union[float, Sequence[float]]] = None,
topic_feature_prior: Optional[Union[float, Sequence[float]]] = None,
cell_topic_prior: float | Sequence[float] = None,
topic_feature_prior: float | Sequence[float] = None,
):
# in case any other model was created before that shares the same parameter names.
pyro.clear_param_store()
Expand Down Expand Up @@ -110,9 +111,9 @@ def __init__(
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
layer: None | str = None,
**kwargs,
) -> Optional[AnnData]:
) -> None | AnnData:
"""%(summary)s.
Parameters
Expand Down Expand Up @@ -155,9 +156,9 @@ def get_feature_by_topic(self, n_samples=5000) -> pd.DataFrame:

def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
adata: None | AnnData = None,
indices: None | Sequence[int] = None,
batch_size: None | int = None,
n_samples: int = 5000,
) -> pd.DataFrame:
"""Converts a count matrix to an inferred topic distribution.
Expand Down Expand Up @@ -198,9 +199,9 @@ def get_latent_representation(

def get_elbo(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
adata: None | AnnData = None,
indices: None | Sequence[int] = None,
batch_size: None | int = None,
) -> float:
"""Return the ELBO for the data.
Expand Down Expand Up @@ -235,9 +236,9 @@ def get_elbo(

def get_perplexity(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
adata: None | AnnData = None,
indices: None | Sequence[int] = None,
batch_size: None | int = None,
) -> float:
"""Computes approximate perplexity for `adata`.
Expand Down
24 changes: 12 additions & 12 deletions src/scvi/model/_autozi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Literal, Optional, Union
from typing import Literal

import numpy as np
import torch
Expand Down Expand Up @@ -106,8 +108,8 @@ def __init__(
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
latent_distribution: Literal["normal", "ln"] = "normal",
alpha_prior: Optional[float] = 0.5,
beta_prior: Optional[float] = 0.5,
alpha_prior: None | float = 0.5,
beta_prior: None | float = 0.5,
minimal_dropout: float = 0.01,
zero_inflation: str = "gene",
use_observed_lib_size: bool = True,
Expand Down Expand Up @@ -147,19 +149,17 @@ def __init__(
)
self.init_params_ = self._get_init_params(locals())

def get_alphas_betas(
self, as_numpy: bool = True
) -> dict[str, Union[torch.Tensor, np.ndarray]]:
def get_alphas_betas(self, as_numpy: bool = True) -> dict[str, torch.Tensor | np.ndarray]:
"""Return parameters of Bernoulli Beta distributions in a dictionary."""
return self.module.get_alphas_betas(as_numpy=as_numpy)

@torch.inference_mode()
def get_marginal_ll(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
adata: None | AnnData = None,
indices: None | Sequence[int] = None,
n_mc_samples: int = 1000,
batch_size: Optional[int] = None,
batch_size: None | int = None,
) -> float:
"""Return the marginal LL for the data.
Expand Down Expand Up @@ -262,9 +262,9 @@ def get_marginal_ll(
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
batch_key: None | str = None,
labels_key: None | str = None,
layer: None | str = None,
**kwargs,
):
"""%(summary)s.
Expand Down
14 changes: 8 additions & 6 deletions src/scvi/model/_jaxscvi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Literal, Optional
from typing import Literal

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -83,8 +85,8 @@ def __init__(
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
layer: None | str = None,
batch_key: None | str = None,
**kwargs,
):
"""%(summary)s.
Expand All @@ -106,11 +108,11 @@ def setup_anndata(

def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
adata: None | AnnData = None,
indices: None | Sequence[int] = None,
give_mean: bool = True,
n_samples: int = 1,
batch_size: Optional[int] = None,
batch_size: None | int = None,
) -> np.ndarray:
r"""Return the latent representation for each cell.
Expand Down
10 changes: 6 additions & 4 deletions src/scvi/model/_linear_scvi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import Literal, Optional
from typing import Literal

import pandas as pd
from anndata import AnnData
Expand Down Expand Up @@ -128,9 +130,9 @@ def get_loadings(self) -> pd.DataFrame:
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
batch_key: None | str = None,
labels_key: None | str = None,
layer: None | str = None,
**kwargs,
):
"""%(summary)s.
Expand Down
15 changes: 8 additions & 7 deletions src/scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
import warnings
from copy import deepcopy
from typing import Optional, Union

import anndata
import numpy as np
Expand Down Expand Up @@ -39,11 +40,11 @@ class ArchesMixin:
def load_query_data(
cls,
adata: AnnOrMuData = None,
reference_model: Union[str, BaseModelClass] = None,
reference_model: str | BaseModelClass = None,
registry: dict = None,
inplace_subset_query_vars: bool = False,
accelerator: str = "auto",
device: Union[int, str] = "auto",
device: int | str = "auto",
unfrozen: bool = False,
freeze_dropout: bool = False,
freeze_expression: bool = True,
Expand Down Expand Up @@ -187,10 +188,10 @@ def load_query_data(
@staticmethod
def prepare_query_anndata(
adata: AnnData,
reference_model: Union[str, BaseModelClass],
reference_model: str | BaseModelClass,
return_reference_var_names: bool = False,
inplace: bool = True,
) -> Optional[Union[AnnData, pd.Index]]:
) -> AnnData | pd.Index:
"""Prepare data for query integration.
This function will return a new AnnData object with padded zeros
Expand Down Expand Up @@ -226,10 +227,10 @@ def prepare_query_anndata(
@staticmethod
def prepare_query_mudata(
mdata: MuData,
reference_model: Union[str, BaseModelClass],
reference_model: str | BaseModelClass,
return_reference_var_names: bool = False,
inplace: bool = True,
) -> Optional[Union[MuData, dict[str, pd.Index]]]:
) -> None | MuData | dict[str, pd.Index]:
"""Prepare multimodal dataset for query integration.
This function will return a new MuData object such that the
Expand Down
2 changes: 1 addition & 1 deletion tests/dataloaders/test_custom_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.custom.dataloader
def test_custom_dataloader(save_path):
def custom_dataloader_test(save_path):
# this test checks the local custom dataloder made by CZI and run several tests with it
census = cellxgene_census.open_soma(census_version="stable")

Expand Down

0 comments on commit 8fe043c

Please sign in to comment.