From f8811ad999d470e9d589520496905ae0328b1402 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Mon, 9 Sep 2024 19:06:21 +0300 Subject: [PATCH] feat(train) Add support for categorial covariates in scArches (#2936) closes https://github.com/scverse/scvi-tools/issues/2583 --------- Co-authored-by: Can Ergen --- CHANGELOG.md | 1 + src/scvi/model/base/_archesmixin.py | 7 +- tests/model/test_scanvi.py | 5 +- tests/model/test_scvi.py | 184 ++++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3fee51385..277817beca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`. - Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`. - Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`. - {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 117188f3cf..363be6c0ed 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -11,7 +11,7 @@ from mudata import MuData from scipy.sparse import csr_matrix -from scvi import REGISTRY_KEYS, settings +from scvi import settings from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME @@ -132,11 +132,6 @@ def load_query_data( model = _initialize_model(cls, adata, attr_dict) adata_manager = model.get_anndata_manager(adata, required=True) - if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: - raise NotImplementedError( - "scArches currently does not support models with extra categorical covariates." - ) - version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( diff --git a/tests/model/test_scanvi.py b/tests/model/test_scanvi.py index 8a07962716..2bf791e3e2 100644 --- a/tests/model/test_scanvi.py +++ b/tests/model/test_scanvi.py @@ -441,7 +441,7 @@ def test_scanvi_online_update(save_path): model.get_latent_representation() model.predict() - # Test error on extra categoricals + # Test on extra categoricals as well adata1 = synthetic_iid() new_labels = adata1.obs.labels.to_numpy() new_labels[0] = "Unknown" @@ -474,8 +474,7 @@ def test_scanvi_online_update(save_path): adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],)) adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],)) - with pytest.raises(NotImplementedError): - SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) + SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) # ref has fully-observed labels n_latent = 5 diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 62ddad42f3..6276693e19 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -737,6 +737,190 @@ def test_scarches_data_prep(save_path): SCVI.load_query_data(adata5, dir_path) +def test_scarches_data_prep_with_categorial_covariates(save_path): + n_latent = 5 + num_categ_orig = 5 + adata1 = synthetic_iid() + adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],)) + adata1.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata1.shape[0],)) + SCVI.setup_anndata( + adata1, + batch_key="batch", + labels_key="labels", + continuous_covariate_keys=["cont1"], + categorical_covariate_keys=["cat1"], + ) + model = SCVI(adata1, n_latent=n_latent) + model.train(1, check_val_every_n_epoch=1) + dir_path = os.path.join(save_path, "saved_model/") + model.save(dir_path, overwrite=True) + + # adata2 has more genes and a perfect subset of adata1, buא missing the categ cov + adata2 = synthetic_iid(n_genes=110) + adata2.layers["counts"] = adata2.X.copy() + new_var_names_init = [f"Random {i}" for i in range(10)] + new_var_names = new_var_names_init + adata2.var_names[10:].to_list() + adata2.var_names = new_var_names + adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + # adata2 has more genes and missing 10 genes from adata1 + SCVI.prepare_query_anndata(adata2, dir_path) # see here how those extra genes were removed + with pytest.raises(KeyError): + SCVI.load_query_data(adata2, dir_path) + # model2 = SCVI(adata2, n_latent=n_latent) + # model2.train(1, check_val_every_n_epoch=1) + + adata3 = SCVI.prepare_query_anndata(adata2, dir_path, inplace=False) + with pytest.raises(KeyError): + SCVI.load_query_data(adata3, dir_path) + # model3 = SCVI(adata3, n_latent=n_latent) + # model3.train(1, check_val_every_n_epoch=1) + + # try the opposite - with a the categ covariate - raise the error + # adata4 has more genes and a perfect subset of adata1 + adata4 = synthetic_iid(n_genes=110) + adata4.obs["batch"] = adata4.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata4.obs["cont1"] = np.random.normal(size=(adata4.shape[0],)) + adata4.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata4.shape[0],)) + SCVI.prepare_query_anndata(adata4, dir_path) + SCVI.load_query_data(adata4, dir_path) + model4 = SCVI(adata4, n_latent=n_latent) + model4.train(1, check_val_every_n_epoch=1) + model4.get_latent_representation() + model4.get_elbo() + + adata5 = SCVI.prepare_query_anndata(adata4, dir_path, inplace=False) + SCVI.load_query_data(adata5, dir_path) + model5 = SCVI(adata5, n_latent=n_latent) + model5.train(1, check_val_every_n_epoch=1) + model5.get_latent_representation() + model5.get_elbo() + + # try also different categ - it expects cat1 + adata6 = synthetic_iid(n_genes=110) + adata6.obs["batch"] = adata6.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata6.obs["cont2"] = np.random.normal(size=(adata6.shape[0],)) + adata6.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata6.shape[0],)) + SCVI.prepare_query_anndata(adata6, dir_path) + with pytest.raises(KeyError): + SCVI.load_query_data(adata6, dir_path) + # model6 = SCVI(adata6, n_latent=n_latent) + # model6.train(1, check_val_every_n_epoch=1) + + # try only cont - missing the categ cov + adata7 = synthetic_iid(n_genes=110) + adata7.obs["batch"] = adata7.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata7.obs["cont2"] = np.random.normal(size=(adata7.shape[0],)) + SCVI.prepare_query_anndata(adata7, dir_path) + with pytest.raises(KeyError): + SCVI.load_query_data(adata7, dir_path) + # model7 = SCVI(adata7, n_latent=n_latent) + # model7.train(1, check_val_every_n_epoch=1) + + # try also additional categ cov - it expects cont1 + adata8 = synthetic_iid(n_genes=110) + adata8.obs["batch"] = adata8.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata8.obs["cont2"] = np.random.normal(size=(adata8.shape[0],)) + adata8.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata8.shape[0],)) + adata8.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata8.shape[0],)) + SCVI.prepare_query_anndata(adata8, dir_path) + with pytest.raises(KeyError): + SCVI.load_query_data(adata8, dir_path) + # model8 = SCVI(adata8, n_latent=n_latent) + # model8.train(1, check_val_every_n_epoch=1) + + # try also additional categ cov - it works + adata9 = synthetic_iid(n_genes=110) + adata9.obs["batch"] = adata9.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata9.obs["cont1"] = np.random.normal(size=(adata9.shape[0],)) + adata9.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata9.shape[0],)) + adata9.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata9.shape[0],)) + SCVI.prepare_query_anndata(adata9, dir_path) + SCVI.load_query_data(adata9, dir_path) + model9 = SCVI(adata9, n_latent=n_latent) + model9.train(1, check_val_every_n_epoch=1) + model9.get_latent_representation() + model9.get_elbo() + + # try also additional cont/categ cov - it works + adata10 = synthetic_iid(n_genes=110) + adata10.obs["batch"] = adata10.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata10.obs["cont1"] = np.random.normal(size=(adata10.shape[0],)) + adata10.obs["cont2"] = np.random.normal(size=(adata10.shape[0],)) + adata10.obs["cat1"] = np.random.randint(0, num_categ_orig, size=(adata10.shape[0],)) + adata10.obs["cat2"] = np.random.randint(0, num_categ_orig, size=(adata10.shape[0],)) + SCVI.prepare_query_anndata(adata10, dir_path) + SCVI.load_query_data(adata10, dir_path) + model10 = SCVI(adata10, n_latent=n_latent) + model10.train(1, check_val_every_n_epoch=1) + attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model10) + registry = attr_dict.pop("registry_") + # we validate only relevant covariates were passed - cat2 and cont2 are not used + assert ( + len(registry["field_registries"]["extra_categorical_covs"]["state_registry"]["field_keys"]) + == 1 + ) + assert ( + len(registry["field_registries"]["extra_continuous_covs"]["state_registry"]["columns"]) + == 1 + ) + assert ( + registry["field_registries"]["extra_categorical_covs"]["state_registry"]["field_keys"][0] + == "cat1" + ) + assert ( + registry["field_registries"]["extra_continuous_covs"]["state_registry"]["columns"][0] + == "cont1" + ) + model10.get_latent_representation() + model10.get_elbo() + + # try also runing with less categories than needed + num_categ = 4 + adata11 = synthetic_iid(n_genes=110) + adata11.obs["batch"] = adata11.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata11.obs["cont1"] = np.random.normal(size=(adata11.shape[0],)) + adata11.obs["cat1"] = np.random.randint(0, num_categ, size=(adata11.shape[0],)) + SCVI.prepare_query_anndata(adata11, dir_path) + SCVI.load_query_data(adata11, dir_path) + model11 = SCVI(adata11, n_latent=n_latent) + model11.train(1, check_val_every_n_epoch=1) + attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model11) + registry = attr_dict.pop("registry_") + assert ( + registry["field_registries"]["extra_categorical_covs"]["state_registry"]["n_cats_per_key"][ + 0 + ] + == num_categ + if num_categ > num_categ_orig + else num_categ_orig + ) + model11.get_latent_representation() + model11.get_elbo() + + # try also runing with more categories than needed + num_categ = 6 + adata12 = synthetic_iid(n_genes=110) + adata12.obs["batch"] = adata12.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata12.obs["cont1"] = np.random.normal(size=(adata12.shape[0],)) + adata12.obs["cat1"] = np.random.randint(0, num_categ, size=(adata12.shape[0],)) + SCVI.prepare_query_anndata(adata12, dir_path) + SCVI.load_query_data(adata12, dir_path) + model12 = SCVI(adata12, n_latent=n_latent) + model12.train(1, check_val_every_n_epoch=1) + attr_dict, var_names, load_state_dict = scvi.model.base._archesmixin._get_loaded_data(model12) + registry = attr_dict.pop("registry_") + assert ( + registry["field_registries"]["extra_categorical_covs"]["state_registry"]["n_cats_per_key"][ + 0 + ] + == num_categ + if num_categ > num_categ_orig + else num_categ_orig + ) + model12.get_latent_representation() + model12.get_elbo() + + def test_scarches_data_prep_layer(save_path): n_latent = 5 adata1 = synthetic_iid()