diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 7741697758..3b9b7708a6 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -242,6 +242,7 @@ def _get_summary_stats_from_registry(registry: dict) -> attrdict: def setup_datamodule( cls, datamodule, # TODO: what to put here? + source_registry=None, layer: str | None = None, batch_key: list[str] | None = None, labels_key: str | None = None, @@ -262,6 +263,26 @@ def setup_datamodule( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ + # TODO: from adata (czi)? + if datamodule.__class__.__name__ == "CensusSCVIDataModule": + # CZI + categorical_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ + column_names = list( + datamodule.datapipe.var_query.coords[0] + if datamodule.datapipe.var_query is not None + else range(datamodule.n_vars) + ) + n_batch = datamodule.n_batch + else: + # Anndata -> CZI + # if we are here and datamodule is actually an AnnData object + # it means we init the custom dataloder model with anndata + categorical_mapping = source_registry["field_registries"]["batch"]["state_registry"][ + "categorical_mapping" + ] + column_names = datamodule.var.soma_joinid.values + n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"] + datamodule.registry = { "scvi_version": scvi.__version__, "model_name": "SCVI", @@ -279,17 +300,17 @@ def setup_datamodule( "state_registry": { "n_obs": datamodule.n_obs, "n_vars": datamodule.n_vars, - "column_names": [str(i) for i in datamodule.vars], + "column_names": [str(i) for i in column_names], # TODO: from adata (czi)? }, "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, }, "batch": { "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, "state_registry": { - "categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_, + "categorical_mapping": categorical_mapping, "original_key": "batch", }, - "summary_stats": {"n_batch": datamodule.n_batch}, + "summary_stats": {"n_batch": n_batch}, }, "labels": { "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_czi_custom_dataloader.py similarity index 55% rename from tests/dataloaders/test_custom_dataloader.py rename to tests/dataloaders/test_czi_custom_dataloader.py index ebc13af852..060b9f6fed 100644 --- a/tests/dataloaders/test_custom_dataloader.py +++ b/tests/dataloaders/test_czi_custom_dataloader.py @@ -1,35 +1,28 @@ from __future__ import annotations +from pprint import pprint + +import cellxgene_census import numpy as np import pandas as pd import pytest +import tiledbsoma as soma +from cellxgene_census.experimental.ml import experiment_dataloader +from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule import scvi from scvi.data import synthetic_iid @pytest.mark.custom_dataloader -def test_custom_dataloader(save_path): - # local bracnh with fix only for this test - import sys - - # should be ready for importing the cloned branch on a remote machine that runs github action - sys.path.insert( - 0, - "/home/runner/work/scvi-tools/scvi-tools/" - "cellxgene-census/api/python/cellxgene_census/src", - ) - sys.path.insert(0, "src") - import cellxgene_census - import tiledbsoma as soma - from cellxgene_census.experimental.ml import experiment_dataloader - from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule - +def test_czi_custom_dataloader(save_path="."): # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") experiment_name = "mus_musculus" obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000' + + # This is under comments just to save time (selecting highly varkable genes): # top_n_hvg = 8000 # hvg_batch = ["assay", "suspension_type"] # @@ -44,8 +37,10 @@ def test_custom_dataloader(save_path): # hv = hvgs_df.highly_variable # hv_idx = hv[hv].index - hv_idx = np.arange(100) + hv_idx = np.arange(100) # just ot make it smaller and faster for debug + # this is CZI part to be taken once all is ready + batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"] datamodule = CensusSCVIDataModule( census["census_data"][experiment_name], measurement_name="RNA", @@ -54,32 +49,41 @@ def test_custom_dataloader(save_path): var_query=soma.AxisQuery(coords=(list(hv_idx),)), batch_size=1024, shuffle=True, - batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"], + batch_keys=batch_keys, dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) - datamodule.vars = hv_idx + # table of genes should be filtered by soma_joinid - but we should keep the encoded indexes + # This is nice to have and might be uses in the downstream analysis + # var_df = census["census_data"][experiment_name].ms["RNA"].var.read().concat().to_pandas() + # var_df = var_df.loc[var_df.soma_joinid.isin( + # list(datamodule.datapipe.var_query.coords[0] if datamodule.datapipe.var_query is not None + # else range(datamodule.n_vars)))] - scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time - - adata = synthetic_iid() - scvi.model.SCVI.setup_anndata(adata, batch_key="batch") + # basicaly we should mimin everything below to any model census in scvi + adata_orig = synthetic_iid() + scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch") - model = scvi.model.SCVI(adata, n_latent=10) + model = scvi.model.SCVI(adata_orig, n_latent=10) model.train(max_epochs=1) - dataloader = model._make_data_loader(adata) + # TODO: do we need to apply those functions to any census model as is? + dataloader = model._make_data_loader(adata_orig) _ = model.get_elbo(dataloader=dataloader) _ = model.get_marginal_ll(dataloader=dataloader) _ = model.get_reconstruction_error(dataloader=dataloader) _ = model.get_latent_representation(dataloader=dataloader) - scvi.model.SCVI.prepare_query_anndata(adata, reference_model=model) - scvi.model.SCVI.load_query_data(adata, reference_model=model) + scvi.model.SCVI.prepare_query_anndata(adata_orig, reference_model=model) + scvi.model.SCVI.load_query_data(adata_orig, reference_model=model) + + user_attributes = model._get_user_attributes() + pprint(user_attributes) n_layers = 1 n_latent = 50 + scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time model_census = scvi.model.SCVI( registry=datamodule.registry, n_layers=n_layers, @@ -88,6 +92,8 @@ def test_custom_dataloader(save_path): encode_covariates=False, ) + pprint(datamodule.registry) + batch_size = 1024 train_size = 0.9 max_epochs = 1 @@ -100,8 +106,17 @@ def test_custom_dataloader(save_path): early_stopping=False, ) - user_attributes = model_census._get_user_attributes() - user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} + user_attributes_model_census = model_census._get_user_attributes() + # # TODO: do we need to put inside + # user_attributes_model_census = \ + # {a[0]: a[1] for a in user_attributes_model_census if a[0][-1] == "_"} + pprint(user_attributes_model_census) + # dataloader_census = model_census._make_data_loader(datamodule.datapipe) + # # this casus errors + # _ = model_census.get_elbo(dataloader=dataloader_census) + # _ = model_census.get_marginal_ll(dataloader=dataloader_census) + # _ = model_census.get_reconstruction_error(dataloader=dataloader_census) + # _ = model_census.get_latent_representation(dataloader=dataloader_census) model_census.save(save_path, overwrite=True) model_census2 = scvi.model.SCVI.load(save_path, adata=False) @@ -114,6 +129,15 @@ def test_custom_dataloader(save_path): early_stopping=False, ) + user_attributes_model_census2 = model_census2._get_user_attributes() + pprint(user_attributes_model_census2) + # dataloader_census2 = model_census2._make_data_loader() + # this casus errors + # _ = model_census2.get_elbo() + # _ = model_census2.get_marginal_ll() + # _ = model_census2.get_reconstruction_error() + # _ = model_census2.get_latent_representation() + # takes time adata = cellxgene_census.get_anndata( census, @@ -122,18 +146,32 @@ def test_custom_dataloader(save_path): var_coords=hv_idx, ) - adata.obs["batch"] = ( - "batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str) - ).astype("category") - # adata.var_names = 'gene_'+adata.var_names #not sure we need it + # TODO: do we need to put inside (or is it alrady pre-made) - perhaps need to tell CZI + adata.obs["batch"] = adata.obs[batch_keys].agg("".join, axis=1).astype("category") scvi.model.SCVI.prepare_query_anndata(adata, save_path) scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path) scvi.model.SCVI.prepare_query_anndata(adata, model_census2) + scvi.model.SCVI.setup_anndata(adata, batch_key="batch") # needed? model_census3 = scvi.model.SCVI.load(save_path, adata=adata) + model_census3.train( + datamodule=datamodule, + max_epochs=max_epochs, + batch_size=batch_size, + train_size=train_size, + early_stopping=False, + ) + + user_attributes_model_census3 = model_census3._get_user_attributes() + pprint(user_attributes_model_census3) + _ = model_census3.get_elbo() + _ = model_census3.get_marginal_ll() + _ = model_census3.get_reconstruction_error() + _ = model_census3.get_latent_representation() + scvi.model.SCVI.prepare_query_anndata(adata, save_path, return_reference_var_names=True) scvi.model.SCVI.load_query_data(adata, save_path) @@ -158,7 +196,7 @@ def test_custom_dataloader(save_path): mapped_dataloader = ( datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader ) - latent = model.get_latent_representation(dataloader=mapped_dataloader) + latent = model_census.get_latent_representation(dataloader=mapped_dataloader) emb_idx = datapipe._obs_joinids