From b35c6eb8fd0eaa50fffc40c19b9a44b730e4e9e6 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 10 Oct 2024 16:24:01 +0300 Subject: [PATCH] adding the lamindb as well --- .../test_linux_custom_dataloader.yml | 2 +- pyproject.toml | 7 +- ...ataloader.py => test_custom_dataloader.py} | 189 +++++++++++++++++- 3 files changed, 191 insertions(+), 7 deletions(-) rename tests/dataloaders/{test_czi_custom_dataloader.py => test_custom_dataloader.py} (55%) diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml index 30e2e1d222..53d6de3353 100644 --- a/.github/workflows/test_linux_custom_dataloader.yml +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -66,7 +66,7 @@ jobs: GH_TOKEN: ${{ secrets.GH_TOKEN }} run: | git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git - name: Run specific custom dataloader pytest diff --git a/pyproject.toml b/pyproject.toml index bd8e42ca22..031467a0d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,8 +53,6 @@ dependencies = [ "torch", "torchmetrics>=0.11.0", "tqdm", - "cellxgene-census", - "torchdata", "xarray>=2023.2.0", ] @@ -83,6 +81,8 @@ docsbuild = ["scvi-tools[docs,optional]"] autotune = ["hyperopt>=0.2", "ray[tune]>=2.5.0"] # scvi.hub.HubModel.pull_from_s3 aws = ["boto3"] +# scvi.data.cellxgene +census = ["cellxgene-census"] # scvi.hub dependencies hub = ["huggingface_hub"] # scvi.model.utils.mde dependencies @@ -112,6 +112,9 @@ tutorials = [ "scvi-tools[optional]", "squidpy", ] +dataloaders = [ + "scdataloader" +] all = ["scvi-tools[dev,docs,tutorials]"] diff --git a/tests/dataloaders/test_czi_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py similarity index 55% rename from tests/dataloaders/test_czi_custom_dataloader.py rename to tests/dataloaders/test_custom_dataloader.py index 662c6d6166..b8034bb6af 100644 --- a/tests/dataloaders/test_czi_custom_dataloader.py +++ b/tests/dataloaders/test_custom_dataloader.py @@ -1,18 +1,18 @@ -from __future__ import annotations - +import os from pprint import pprint import numpy as np import pandas as pd import pytest +import scanpy as sc import scvi from scvi.data import synthetic_iid @pytest.mark.custom_dataloader -def test_czi_custom_dataloader(save_path): - # local bracnh with fix only for this test +def test_czi_custom_dataloader(save_path="."): + # local branch with fix only for this test import sys # should be ready for importing the cloned branch on a remote machine that runs github action @@ -26,6 +26,7 @@ def test_czi_custom_dataloader(save_path): import tiledbsoma as soma from cellxgene_census.experimental.ml import experiment_dataloader from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule + # from cellxgene_census.experimental.pp import highly_variable_genes # this test checks the local custom dataloder made by CZI and run several tests with it census = cellxgene_census.open_soma(census_version="stable") @@ -202,11 +203,15 @@ def test_czi_custom_dataloader(save_path): dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, ) + # Create a dataloder of a CZI module datapipe = datamodule_inference.datapipe dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) mapped_dataloader = ( datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader ) + _ = model_census.get_elbo(dataloader=mapped_dataloader) + _ = model_census.get_marginal_ll(dataloader=mapped_dataloader) + _ = model_census.get_reconstruction_error(dataloader=mapped_dataloader) latent = model_census.get_latent_representation(dataloader=mapped_dataloader) emb_idx = datapipe._obs_joinids @@ -218,3 +223,179 @@ def test_czi_custom_dataloader(save_path): # Reindexing is necessary to ensure that the cells in the embedding match the ones in # the anndata object. adata.obsm["scvi"] = latent[idx] + + # #We can now generate the neighbors and the UMAP (tutorials) + # sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi") + # sc.tl.umap(adata, neighbors_key="scvi") + # sc.pl.umap(adata, color="dataset_id", title="SCVI") + # + # sc.pl.umap(adata, color="tissue_general", title="SCVI") + # + # sc.pl.umap(adata, color="cell_type", title="SCVI") + + +@pytest.mark.custom_dataloader +def test_lamindb_custom_dataloader(save_path="."): + # initialize a local lamin database + os.system("lamin init --storage ~/scdataloader2 --schema bionty") + # os.system("lamin close") + # os.system("lamin load scdataloader") + + # local branch with fix only for this test + import sys + + # should be ready for importing the cloned branch on a remote machine that runs github action + sys.path.insert( + 0, + "/home/runner/work/scvi-tools/scvi-tools/" "scDataLoader/", + ) + sys.path.insert(0, "src") + import lamindb as ln + import tqdm + from scdataloader import Collator, DataModule, SimpleAnnDataset + + # import bionty as bt + # from scdataloader import utils + # from scdataloader.preprocess import ( + # LaminPreprocessor, + # additional_postprocess, + # additional_preprocess, + # ) + # import numpy as np + # import tiledbsoma as soma + from scdataloader.utils import populate_my_ontology + from torch.utils.data import DataLoader + # from scdataloader.base import NAME + # from cellxgene_census.experimental.ml import experiment_dataloader + + # populate_my_ontology() #to populate everything (recommended) (can take 2-10mns) + + populate_my_ontology( + organisms=["NCBITaxon:10090", "NCBITaxon:9606"], + sex=["PATO:0000384", "PATO:0000383"], + ) + + # preprocess datasets - do we need this part? + # DESCRIPTION = "preprocessed by scDataLoader" + + cx_dataset = ( + ln.Collection.using(instance="laminlabs/cellxgene") + .filter(name="cellxgene-census", version="2023-12-15") + .one() + ) + cx_dataset, len(cx_dataset.artifacts.all()) + + # do_preprocess = LaminPreprocessor( + # additional_postprocess=additional_postprocess, + # additional_preprocess=additional_preprocess, + # skip_validate=True, + # subset_hvg=0, + # ) + + # preprocessed_dataset = do_preprocess( + # cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=1, version="2" + # ) + + # create dataloaders + + datamodule = DataModule( + collection_name="preprocessed dataset", + organisms=["NCBITaxon:9606"], # organism that we will work on + how="most expr", # for the collator (most expr genes only will be selected) / "some" + max_len=1000, # only the 1000 most expressed + batch_size=64, + num_workers=1, + validation_split=0.1, + test_split=0, + ) + + # we setup the datamodule (as exemplified in lightning's good practices, b + # ut there might be some things to improve here) + # testfiles = datamodule.setup() + + for i in tqdm.tqdm(datamodule.train_dataloader()): + # pass #or do pass + print(i) + break + + # with lightning: + # Trainer(model, datamodule) + + # Read adata and create lamindb dataloader + adata_orig = sc.read_h5ad( + "/Users/orikr/PycharmProjects/scvi-tools/scDataLoader/tests/test.h5ad" + ) + # preprocessor = Preprocessor(do_postp=False) + # adata = preprocessor(adata_orig) + adataset = SimpleAnnDataset(adata_orig, obs_to_output=["organism_ontology_term_id"]) + col = Collator( + organisms=["NCBITaxon:9606"], + max_len=1000, + how="random expr", + ) + dataloader = DataLoader( + adataset, + collate_fn=col, + batch_size=4, + num_workers=1, + shuffle=False, + ) + + # We will now create the SCVI model object: + # Its parameters: + # n_layers = 1 + # n_latent = 10 + # batch_size = 1024 + # train_size = 0.9 + # max_epochs = 1 + + # def on_before_batch_transfer( + # batch: tuple[torch.Tensor, torch.Tensor], + # ) -> dict[str, torch.Tensor | None]: + # """Format the datapipe output with registry keys for scvi-tools.""" + # X, obs = batch + # X_KEY: str = "X" + # BATCH_KEY: str = "batch" + # LABELS_KEY: str = "labels" + # return { + # X_KEY: X, + # BATCH_KEY: obs, + # LABELS_KEY: None, + # } + + # Try the lamindb dataloder on a trained scvi-model with adata + # adata = adata.copy() + scvi.model.SCVI.setup_anndata(adata_orig, batch_key="cell_type_ontology_term_id") + model = scvi.model.SCVI(adata_orig, n_latent=10) + model.train(max_epochs=1) + # dataloader2 = experiment_dataloader(dataloader, num_workers=0, persistent_workers=False) + # mapped_dataloader = ( + # on_before_batch_transfer(tensor, None) for tensor in dataloader2 + # ) + # dataloader = model._make_data_loader(mapped_dataloader) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + + # scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time + # model_lamindb = scvi.model.SCVI( + # registry=datamodule.registry, + # n_layers=n_layers, + # n_latent=n_latent, + # gene_likelihood="nb", + # encode_covariates=False, + # ) + # + # pprint(datamodule.registry) + # + # model_lamindb.train( + # datamodule=datamodule, + # max_epochs=max_epochs, + # batch_size=batch_size, + # train_size=train_size, + # early_stopping=False, + # ) + # We have to create a registry without setup_anndata that contains the same elements + # The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK + # need to pass here new object of registry taht contains everything we will need