Skip to content

Commit

Permalink
adding the lamindb as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Oct 10, 2024
1 parent c6acb5a commit b35c6eb
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_linux_custom_dataloader.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
GH_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/"
git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git
git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/chanzuckerberg/cellxgene-census.git
git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git
- name: Run specific custom dataloader pytest
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ dependencies = [
"torch",
"torchmetrics>=0.11.0",
"tqdm",
"cellxgene-census",
"torchdata",
"xarray>=2023.2.0",
]

Expand Down Expand Up @@ -83,6 +81,8 @@ docsbuild = ["scvi-tools[docs,optional]"]
autotune = ["hyperopt>=0.2", "ray[tune]>=2.5.0"]
# scvi.hub.HubModel.pull_from_s3
aws = ["boto3"]
# scvi.data.cellxgene
census = ["cellxgene-census"]
# scvi.hub dependencies
hub = ["huggingface_hub"]
# scvi.model.utils.mde dependencies
Expand Down Expand Up @@ -112,6 +112,9 @@ tutorials = [
"scvi-tools[optional]",
"squidpy",
]
dataloaders = [
"scdataloader"
]

all = ["scvi-tools[dev,docs,tutorials]"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import os
from pprint import pprint

import numpy as np
import pandas as pd
import pytest
import scanpy as sc

import scvi
from scvi.data import synthetic_iid


@pytest.mark.custom_dataloader
def test_czi_custom_dataloader(save_path):
# local bracnh with fix only for this test
def test_czi_custom_dataloader(save_path="."):
# local branch with fix only for this test
import sys

# should be ready for importing the cloned branch on a remote machine that runs github action
Expand All @@ -26,6 +26,7 @@ def test_czi_custom_dataloader(save_path):
import tiledbsoma as soma
from cellxgene_census.experimental.ml import experiment_dataloader
from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule
# from cellxgene_census.experimental.pp import highly_variable_genes

# this test checks the local custom dataloder made by CZI and run several tests with it
census = cellxgene_census.open_soma(census_version="stable")
Expand Down Expand Up @@ -202,11 +203,15 @@ def test_czi_custom_dataloader(save_path):
dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
)

# Create a dataloder of a CZI module
datapipe = datamodule_inference.datapipe
dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False)
mapped_dataloader = (
datamodule_inference.on_before_batch_transfer(tensor, None) for tensor in dataloader
)
_ = model_census.get_elbo(dataloader=mapped_dataloader)
_ = model_census.get_marginal_ll(dataloader=mapped_dataloader)
_ = model_census.get_reconstruction_error(dataloader=mapped_dataloader)
latent = model_census.get_latent_representation(dataloader=mapped_dataloader)

emb_idx = datapipe._obs_joinids
Expand All @@ -218,3 +223,179 @@ def test_czi_custom_dataloader(save_path):
# Reindexing is necessary to ensure that the cells in the embedding match the ones in
# the anndata object.
adata.obsm["scvi"] = latent[idx]

# #We can now generate the neighbors and the UMAP (tutorials)
# sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
# sc.tl.umap(adata, neighbors_key="scvi")
# sc.pl.umap(adata, color="dataset_id", title="SCVI")
#
# sc.pl.umap(adata, color="tissue_general", title="SCVI")
#
# sc.pl.umap(adata, color="cell_type", title="SCVI")


@pytest.mark.custom_dataloader
def test_lamindb_custom_dataloader(save_path="."):
# initialize a local lamin database
os.system("lamin init --storage ~/scdataloader2 --schema bionty")
# os.system("lamin close")
# os.system("lamin load scdataloader")

# local branch with fix only for this test
import sys

# should be ready for importing the cloned branch on a remote machine that runs github action
sys.path.insert(
0,
"/home/runner/work/scvi-tools/scvi-tools/" "scDataLoader/",
)
sys.path.insert(0, "src")
import lamindb as ln
import tqdm
from scdataloader import Collator, DataModule, SimpleAnnDataset

# import bionty as bt
# from scdataloader import utils
# from scdataloader.preprocess import (
# LaminPreprocessor,
# additional_postprocess,
# additional_preprocess,
# )
# import numpy as np
# import tiledbsoma as soma
from scdataloader.utils import populate_my_ontology
from torch.utils.data import DataLoader
# from scdataloader.base import NAME
# from cellxgene_census.experimental.ml import experiment_dataloader

# populate_my_ontology() #to populate everything (recommended) (can take 2-10mns)

populate_my_ontology(
organisms=["NCBITaxon:10090", "NCBITaxon:9606"],
sex=["PATO:0000384", "PATO:0000383"],
)

# preprocess datasets - do we need this part?
# DESCRIPTION = "preprocessed by scDataLoader"

cx_dataset = (
ln.Collection.using(instance="laminlabs/cellxgene")
.filter(name="cellxgene-census", version="2023-12-15")
.one()
)
cx_dataset, len(cx_dataset.artifacts.all())

# do_preprocess = LaminPreprocessor(
# additional_postprocess=additional_postprocess,
# additional_preprocess=additional_preprocess,
# skip_validate=True,
# subset_hvg=0,
# )

# preprocessed_dataset = do_preprocess(
# cx_dataset, name=DESCRIPTION, description=DESCRIPTION, start_at=1, version="2"
# )

# create dataloaders

datamodule = DataModule(
collection_name="preprocessed dataset",
organisms=["NCBITaxon:9606"], # organism that we will work on
how="most expr", # for the collator (most expr genes only will be selected) / "some"
max_len=1000, # only the 1000 most expressed
batch_size=64,
num_workers=1,
validation_split=0.1,
test_split=0,
)

# we setup the datamodule (as exemplified in lightning's good practices, b
# ut there might be some things to improve here)
# testfiles = datamodule.setup()

for i in tqdm.tqdm(datamodule.train_dataloader()):
# pass #or do pass
print(i)
break

# with lightning:
# Trainer(model, datamodule)

# Read adata and create lamindb dataloader
adata_orig = sc.read_h5ad(
"/Users/orikr/PycharmProjects/scvi-tools/scDataLoader/tests/test.h5ad"
)
# preprocessor = Preprocessor(do_postp=False)
# adata = preprocessor(adata_orig)
adataset = SimpleAnnDataset(adata_orig, obs_to_output=["organism_ontology_term_id"])
col = Collator(
organisms=["NCBITaxon:9606"],
max_len=1000,
how="random expr",
)
dataloader = DataLoader(
adataset,
collate_fn=col,
batch_size=4,
num_workers=1,
shuffle=False,
)

# We will now create the SCVI model object:
# Its parameters:
# n_layers = 1
# n_latent = 10
# batch_size = 1024
# train_size = 0.9
# max_epochs = 1

# def on_before_batch_transfer(
# batch: tuple[torch.Tensor, torch.Tensor],
# ) -> dict[str, torch.Tensor | None]:
# """Format the datapipe output with registry keys for scvi-tools."""
# X, obs = batch
# X_KEY: str = "X"
# BATCH_KEY: str = "batch"
# LABELS_KEY: str = "labels"
# return {
# X_KEY: X,
# BATCH_KEY: obs,
# LABELS_KEY: None,
# }

# Try the lamindb dataloder on a trained scvi-model with adata
# adata = adata.copy()
scvi.model.SCVI.setup_anndata(adata_orig, batch_key="cell_type_ontology_term_id")
model = scvi.model.SCVI(adata_orig, n_latent=10)
model.train(max_epochs=1)
# dataloader2 = experiment_dataloader(dataloader, num_workers=0, persistent_workers=False)
# mapped_dataloader = (
# on_before_batch_transfer(tensor, None) for tensor in dataloader2
# )
# dataloader = model._make_data_loader(mapped_dataloader)
_ = model.get_elbo(dataloader=dataloader)
_ = model.get_marginal_ll(dataloader=dataloader)
_ = model.get_reconstruction_error(dataloader=dataloader)
_ = model.get_latent_representation(dataloader=dataloader)

# scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time
# model_lamindb = scvi.model.SCVI(
# registry=datamodule.registry,
# n_layers=n_layers,
# n_latent=n_latent,
# gene_likelihood="nb",
# encode_covariates=False,
# )
#
# pprint(datamodule.registry)
#
# model_lamindb.train(
# datamodule=datamodule,
# max_epochs=max_epochs,
# batch_size=batch_size,
# train_size=train_size,
# early_stopping=False,
# )
# We have to create a registry without setup_anndata that contains the same elements
# The other way will be to fill the model ,LIKE IN CELLXGENE NOTEBOOK
# need to pass here new object of registry taht contains everything we will need

0 comments on commit b35c6eb

Please sign in to comment.