Skip to content

Commit

Permalink
Merge branch 'main' into ori-2907-custom-dataloader-registry
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis authored Sep 15, 2024
2 parents 083c76e + f8811ad commit 70bba69
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`.
- Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`.
- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`.
- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mudata import MuData
from scipy.sparse import csr_matrix

from scvi import REGISTRY_KEYS, settings
from scvi import settings
from scvi._types import AnnOrMuData
from scvi.data import _constants
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME
Expand Down Expand Up @@ -146,6 +146,7 @@ def load_query_data(
)

version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".")

if int(version_split[1]) < 8 and int(version_split[0]) == 0:
warnings.warn(
"Query integration should be performed using models trained with "
Expand Down
5 changes: 2 additions & 3 deletions tests/model/test_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def test_scanvi_online_update(save_path):
model.get_latent_representation()
model.predict()

# Test error on extra categoricals
# Test on extra categoricals as well
adata1 = synthetic_iid()
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
Expand Down Expand Up @@ -474,8 +474,7 @@ def test_scanvi_online_update(save_path):
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],))
adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],))
with pytest.raises(NotImplementedError):
SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)

# ref has fully-observed labels
n_latent = 5
Expand Down
184 changes: 184 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,190 @@ def test_scarches_data_prep(save_path):
SCVI.load_query_data(adata5, dir_path)


def test_scarches_data_prep_with_categorial_covariates(save_path):
n_latent = 5
num_categ_orig = 5
adata1 = synthetic_iid()
adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata1.shape[0],))
SCVI.setup_anndata(
adata1,
batch_key="batch",
labels_key="labels",
continuous_covariate_keys=["cont1"],
categorical_covariate_keys=["cat1"],
)
model = SCVI(adata1, n_latent=n_latent)
model.train(1, check_val_every_n_epoch=1)
dir_path = os.path.join(save_path, "saved_model/")
model.save(dir_path, overwrite=True)

# adata2 has more genes and a perfect subset of adata1, buא missing the categ cov
adata2 = synthetic_iid(n_genes=110)
adata2.layers["counts"] = adata2.X.copy()
new_var_names_init = [f"Random {i}" for i in range(10)]
new_var_names = new_var_names_init + adata2.var_names[10:].to_list()
adata2.var_names = new_var_names
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
# adata2 has more genes and missing 10 genes from adata1
SCVI.prepare_query_anndata(adata2, dir_path) # see here how those extra genes were removed
with pytest.raises(KeyError):
SCVI.load_query_data(adata2, dir_path)
# model2 = SCVI(adata2, n_latent=n_latent)
# model2.train(1, check_val_every_n_epoch=1)

adata3 = SCVI.prepare_query_anndata(adata2, dir_path, inplace=False)
with pytest.raises(KeyError):
SCVI.load_query_data(adata3, dir_path)
# model3 = SCVI(adata3, n_latent=n_latent)
# model3.train(1, check_val_every_n_epoch=1)

# try the opposite - with a the categ covariate - raise the error
# adata4 has more genes and a perfect subset of adata1
adata4 = synthetic_iid(n_genes=110)
adata4.obs["batch"] = adata4.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata4.obs["cont1"] = np.random.normal(size=(adata4.shape[0],))
adata4.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata4.shape[0],))
SCVI.prepare_query_anndata(adata4, dir_path)
SCVI.load_query_data(adata4, dir_path)
model4 = SCVI(adata4, n_latent=n_latent)
model4.train(1, check_val_every_n_epoch=1)
model4.get_latent_representation()
model4.get_elbo()

adata5 = SCVI.prepare_query_anndata(adata4, dir_path, inplace=False)
SCVI.load_query_data(adata5, dir_path)
model5 = SCVI(adata5, n_latent=n_latent)
model5.train(1, check_val_every_n_epoch=1)
model5.get_latent_representation()
model5.get_elbo()

# try also different categ - it expects cat1
adata6 = synthetic_iid(n_genes=110)
adata6.obs["batch"] = adata6.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata6.obs["cont2"] = np.random.normal(size=(adata6.shape[0],))
adata6.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata6.shape[0],))
SCVI.prepare_query_anndata(adata6, dir_path)
with pytest.raises(KeyError):
SCVI.load_query_data(adata6, dir_path)
# model6 = SCVI(adata6, n_latent=n_latent)
# model6.train(1, check_val_every_n_epoch=1)

# try only cont - missing the categ cov
adata7 = synthetic_iid(n_genes=110)
adata7.obs["batch"] = adata7.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata7.obs["cont2"] = np.random.normal(size=(adata7.shape[0],))
SCVI.prepare_query_anndata(adata7, dir_path)
with pytest.raises(KeyError):
SCVI.load_query_data(adata7, dir_path)
# model7 = SCVI(adata7, n_latent=n_latent)
# model7.train(1, check_val_every_n_epoch=1)

# try also additional categ cov - it expects cont1
adata8 = synthetic_iid(n_genes=110)
adata8.obs["batch"] = adata8.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata8.obs["cont2"] = np.random.normal(size=(adata8.shape[0],))
adata8.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata8.shape[0],))
adata8.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata8.shape[0],))
SCVI.prepare_query_anndata(adata8, dir_path)
with pytest.raises(KeyError):
SCVI.load_query_data(adata8, dir_path)
# model8 = SCVI(adata8, n_latent=n_latent)
# model8.train(1, check_val_every_n_epoch=1)

# try also additional categ cov - it works
adata9 = synthetic_iid(n_genes=110)
adata9.obs["batch"] = adata9.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata9.obs["cont1"] = np.random.normal(size=(adata9.shape[0],))
adata9.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata9.shape[0],))
adata9.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata9.shape[0],))
SCVI.prepare_query_anndata(adata9, dir_path)
SCVI.load_query_data(adata9, dir_path)
model9 = SCVI(adata9, n_latent=n_latent)
model9.train(1, check_val_every_n_epoch=1)
model9.get_latent_representation()
model9.get_elbo()

# try also additional cont/categ cov - it works
adata10 = synthetic_iid(n_genes=110)
adata10.obs["batch"] = adata10.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata10.obs["cont1"] = np.random.normal(size=(adata10.shape[0],))
adata10.obs["cont2"] = np.random.normal(size=(adata10.shape[0],))
adata10.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata10.shape[0],))
adata10.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata10.shape[0],))
SCVI.prepare_query_anndata(adata10, dir_path)
SCVI.load_query_data(adata10, dir_path)
model10 = SCVI(adata10, n_latent=n_latent)
model10.train(1, check_val_every_n_epoch=1)
attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model10)
registry = attr_dict.pop("registry_")
# we validate only relevant covariates were passed - cat2 and cont2 are not used
assert (
len(registry["field_registries"]["extra_categorical_covs"]["state_registry"]["field_keys"])
== 1
)
assert (
len(registry["field_registries"]["extra_continuous_covs"]["state_registry"]["columns"])
== 1
)
assert (
registry["field_registries"]["extra_categorical_covs"]["state_registry"]["field_keys"][0]
== "cat1"
)
assert (
registry["field_registries"]["extra_continuous_covs"]["state_registry"]["columns"][0]
== "cont1"
)
model10.get_latent_representation()
model10.get_elbo()

# try also runing with less categories than needed
num_categ = 4
adata11 = synthetic_iid(n_genes=110)
adata11.obs["batch"] = adata11.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata11.obs["cont1"] = np.random.normal(size=(adata11.shape[0],))
adata11.obs["cat1"] = np.random.randint(0, num_categ, size=(adata11.shape[0],))
SCVI.prepare_query_anndata(adata11, dir_path)
SCVI.load_query_data(adata11, dir_path)
model11 = SCVI(adata11, n_latent=n_latent)
model11.train(1, check_val_every_n_epoch=1)
attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model11)
registry = attr_dict.pop("registry_")
assert (
registry["field_registries"]["extra_categorical_covs"]["state_registry"]["n_cats_per_key"][
0
]
== num_categ
if num_categ > num_categ_orig
else num_categ_orig
)
model11.get_latent_representation()
model11.get_elbo()

# try also runing with more categories than needed
num_categ = 6
adata12 = synthetic_iid(n_genes=110)
adata12.obs["batch"] = adata12.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata12.obs["cont1"] = np.random.normal(size=(adata12.shape[0],))
adata12.obs["cat1"] = np.random.randint(0, num_categ, size=(adata12.shape[0],))
SCVI.prepare_query_anndata(adata12, dir_path)
SCVI.load_query_data(adata12, dir_path)
model12 = SCVI(adata12, n_latent=n_latent)
model12.train(1, check_val_every_n_epoch=1)
attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model12)
registry = attr_dict.pop("registry_")
assert (
registry["field_registries"]["extra_categorical_covs"]["state_registry"]["n_cats_per_key"][
0
]
== num_categ
if num_categ > num_categ_orig
else num_categ_orig
)
model12.get_latent_representation()
model12.get_elbo()


def test_scarches_data_prep_layer(save_path):
n_latent = 5
adata1 = synthetic_iid()
Expand Down

0 comments on commit 70bba69

Please sign in to comment.