Skip to content

Commit

Permalink
updated for CZI custom dataloader test and backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Oct 9, 2024
1 parent bf4d3bf commit 2cc8ff9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 37 deletions.
27 changes: 24 additions & 3 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _get_summary_stats_from_registry(registry: dict) -> attrdict:
def setup_datamodule(
cls,
datamodule, # TODO: what to put here?
source_registry=None,
layer: str | None = None,
batch_key: list[str] | None = None,
labels_key: str | None = None,
Expand All @@ -262,6 +263,26 @@ def setup_datamodule(
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
"""
# TODO: from adata (czi)?
if datamodule.__class__.__name__ == "CensusSCVIDataModule":
# CZI
categorical_mapping = datamodule.datapipe.obs_encoders["batch"].classes_
column_names = list(
datamodule.datapipe.var_query.coords[0]
if datamodule.datapipe.var_query is not None
else range(datamodule.n_vars)
)
n_batch = datamodule.n_batch
else:
# Anndata -> CZI
# if we are here and datamodule is actually an AnnData object
# it means we init the custom dataloder model with anndata
categorical_mapping = source_registry["field_registries"]["batch"]["state_registry"][
"categorical_mapping"
]
column_names = datamodule.var.soma_joinid.values
n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"]

datamodule.registry = {
"scvi_version": scvi.__version__,
"model_name": "SCVI",
Expand All @@ -279,17 +300,17 @@ def setup_datamodule(
"state_registry": {
"n_obs": datamodule.n_obs,
"n_vars": datamodule.n_vars,
"column_names": [str(i) for i in datamodule.vars],
"column_names": [str(i) for i in column_names], # TODO: from adata (czi)?
},
"summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs},
},
"batch": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"},
"state_registry": {
"categorical_mapping": datamodule.datapipe.obs_encoders["batch"].classes_,
"categorical_mapping": categorical_mapping,
"original_key": "batch",
},
"summary_stats": {"n_batch": datamodule.n_batch},
"summary_stats": {"n_batch": n_batch},
},
"labels": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
from __future__ import annotations

from pprint import pprint

import cellxgene_census
import numpy as np
import pandas as pd
import pytest
import tiledbsoma as soma
from cellxgene_census.experimental.ml import experiment_dataloader
from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule

import scvi
from scvi.data import synthetic_iid


@pytest.mark.custom_dataloader
def test_custom_dataloader(save_path):
# local bracnh with fix only for this test
import sys

# should be ready for importing the cloned branch on a remote machine that runs github action
sys.path.insert(
0,
"/home/runner/work/scvi-tools/scvi-tools/"
"cellxgene-census/api/python/cellxgene_census/src",
)
sys.path.insert(0, "src")
import cellxgene_census
import tiledbsoma as soma
from cellxgene_census.experimental.ml import experiment_dataloader
from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule

def test_czi_custom_dataloader(save_path="."):
# this test checks the local custom dataloder made by CZI and run several tests with it
census = cellxgene_census.open_soma(census_version="stable")

experiment_name = "mus_musculus"
obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000'

# This is under comments just to save time (selecting highly varkable genes):
# top_n_hvg = 8000
# hvg_batch = ["assay", "suspension_type"]
#
Expand All @@ -44,8 +37,10 @@ def test_custom_dataloader(save_path):
# hv = hvgs_df.highly_variable
# hv_idx = hv[hv].index

hv_idx = np.arange(100)
hv_idx = np.arange(100) # just ot make it smaller and faster for debug

# this is CZI part to be taken once all is ready
batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]
datamodule = CensusSCVIDataModule(
census["census_data"][experiment_name],
measurement_name="RNA",
Expand All @@ -54,32 +49,41 @@ def test_custom_dataloader(save_path):
var_query=soma.AxisQuery(coords=(list(hv_idx),)),
batch_size=1024,
shuffle=True,
batch_keys=["dataset_id", "assay", "suspension_type", "donor_id"],
batch_keys=batch_keys,
dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
)

datamodule.vars = hv_idx
# table of genes should be filtered by soma_joinid - but we should keep the encoded indexes
# This is nice to have and might be uses in the downstream analysis
# var_df = census["census_data"][experiment_name].ms["RNA"].var.read().concat().to_pandas()
# var_df = var_df.loc[var_df.soma_joinid.isin(
# list(datamodule.datapipe.var_query.coords[0] if datamodule.datapipe.var_query is not None
# else range(datamodule.n_vars)))]

scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time

adata = synthetic_iid()
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
# basicaly we should mimin everything below to any model census in scvi
adata_orig = synthetic_iid()
scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch")

model = scvi.model.SCVI(adata, n_latent=10)
model = scvi.model.SCVI(adata_orig, n_latent=10)
model.train(max_epochs=1)

dataloader = model._make_data_loader(adata)
# TODO: do we need to apply those functions to any census model as is?
dataloader = model._make_data_loader(adata_orig)
_ = model.get_elbo(dataloader=dataloader)
_ = model.get_marginal_ll(dataloader=dataloader)
_ = model.get_reconstruction_error(dataloader=dataloader)
_ = model.get_latent_representation(dataloader=dataloader)

scvi.model.SCVI.prepare_query_anndata(adata, reference_model=model)
scvi.model.SCVI.load_query_data(adata, reference_model=model)
scvi.model.SCVI.prepare_query_anndata(adata_orig, reference_model=model)
scvi.model.SCVI.load_query_data(adata_orig, reference_model=model)

user_attributes = model._get_user_attributes()
pprint(user_attributes)

n_layers = 1
n_latent = 50

scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time
model_census = scvi.model.SCVI(
registry=datamodule.registry,
n_layers=n_layers,
Expand All @@ -88,6 +92,8 @@ def test_custom_dataloader(save_path):
encode_covariates=False,
)

pprint(datamodule.registry)

batch_size = 1024
train_size = 0.9
max_epochs = 1
Expand All @@ -100,8 +106,17 @@ def test_custom_dataloader(save_path):
early_stopping=False,
)

user_attributes = model_census._get_user_attributes()
user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"}
user_attributes_model_census = model_census._get_user_attributes()
# # TODO: do we need to put inside
# user_attributes_model_census = \
# {a[0]: a[1] for a in user_attributes_model_census if a[0][-1] == "_"}
pprint(user_attributes_model_census)
# dataloader_census = model_census._make_data_loader(datamodule.datapipe)
# # this casus errors
# _ = model_census.get_elbo(dataloader=dataloader_census)
# _ = model_census.get_marginal_ll(dataloader=dataloader_census)
# _ = model_census.get_reconstruction_error(dataloader=dataloader_census)
# _ = model_census.get_latent_representation(dataloader=dataloader_census)

model_census.save(save_path, overwrite=True)
model_census2 = scvi.model.SCVI.load(save_path, adata=False)
Expand All @@ -114,6 +129,15 @@ def test_custom_dataloader(save_path):
early_stopping=False,
)

user_attributes_model_census2 = model_census2._get_user_attributes()
pprint(user_attributes_model_census2)
# dataloader_census2 = model_census2._make_data_loader()
# this casus errors
# _ = model_census2.get_elbo()
# _ = model_census2.get_marginal_ll()
# _ = model_census2.get_reconstruction_error()
# _ = model_census2.get_latent_representation()

# takes time
adata = cellxgene_census.get_anndata(
census,
Expand All @@ -122,18 +146,32 @@ def test_custom_dataloader(save_path):
var_coords=hv_idx,
)

adata.obs["batch"] = (
"batch_" + adata.obs[datamodule.batch_keys[0]].cat.codes.astype(str)
).astype("category")
# adata.var_names = 'gene_'+adata.var_names #not sure we need it
# TODO: do we need to put inside (or is it alrady pre-made) - perhaps need to tell CZI
adata.obs["batch"] = adata.obs[batch_keys].agg("".join, axis=1).astype("category")

scvi.model.SCVI.prepare_query_anndata(adata, save_path)
scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path)

scvi.model.SCVI.prepare_query_anndata(adata, model_census2)

scvi.model.SCVI.setup_anndata(adata, batch_key="batch") # needed?
model_census3 = scvi.model.SCVI.load(save_path, adata=adata)

model_census3.train(
datamodule=datamodule,
max_epochs=max_epochs,
batch_size=batch_size,
train_size=train_size,
early_stopping=False,
)

user_attributes_model_census3 = model_census3._get_user_attributes()
pprint(user_attributes_model_census3)
_ = model_census3.get_elbo()
_ = model_census3.get_marginal_ll()
_ = model_census3.get_reconstruction_error()
_ = model_census3.get_latent_representation()

scvi.model.SCVI.prepare_query_anndata(adata, save_path, return_reference_var_names=True)
scvi.model.SCVI.load_query_data(adata, save_path)

Expand All @@ -158,7 +196,7 @@ def test_custom_dataloader(save_path):
mapped_dataloader = (
datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader
)
latent = model.get_latent_representation(dataloader=mapped_dataloader)
latent = model_census.get_latent_representation(dataloader=mapped_dataloader)

emb_idx = datapipe._obs_joinids

Expand Down

0 comments on commit 2cc8ff9

Please sign in to comment.