diff --git a/.flake8 b/.flake8 index aa079ec..ffe76ee 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,7 @@ [flake8] max-line-length=120 +exclude = + venv + __pycache__ + tests + vectors diff --git a/.gitignore b/.gitignore index 3cd8b01..c5a8c5f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ vectors/ *.ipynb *.egg-info/ +venv/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/README.md b/README.md index d20fec0..43e3bbc 100644 --- a/README.md +++ b/README.md @@ -6,26 +6,46 @@ It's a companion project to an R-shiny based image annotation app that is not ye ## Installation -### Python environment setup +### Environment and package installation -Use anaconda or miniconda to create a python environment using the included `environment.yml` +#### Using pip + +Create a fresh virtual environment in the repository root using Python >=3.12 and (e.g.) `venv`: ``` -conda env create -f environment.yml +python -m venv venv ``` -Please note that this is specifically pinned to python 3.9 due to dependency versions; we make experimental use of the [CEFAS plankton model available through SciVision](https://sci.vision/#/model/resnet50-plankton), which in turn uses an older version of pytorch that isn't packaged above python 3.9. +Next, install the package using `pip`: -### Object store connection +``` +python -m pip install . +``` -`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below. +Most likely you are interested in developing and/or experimenting, so you will probably want to install the package in 'editable' mode (`-e`), along with dev tools and jupyter notebook functionality +``` +python -m pip install -e .[all] +``` -### Package installation +#### Using conda -Get started by cloning this repository and running +Use anaconda or miniconda to create a python environment using the included `environment.yml` -`python -m pip install -e .` +``` +conda env create -f environment.yml +conda activate cyto_ml +``` + +Next install this package _without dependencies_: + +``` +python -m pip install --no-deps -e . +``` + +### Object store connection + +`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below. ### Running tests diff --git a/environment.yml b/environment.yml index 75d90c1..59ee514 100644 --- a/environment.yml +++ b/environment.yml @@ -1,25 +1,29 @@ -name: cyto_39 +name: cyto_ml channels: - pytorch - conda-forge - - defaults +channel_priority: flexible dependencies: - - python=3.9 - - pytorch=1.10.0 - - mkl=2024.0 - - chromadb=0.5.3 + - python=3.12 + - pytorch + - black + - chromadb + - flake8 - intake-xarray - - scikit-image - - scikit-learn + - intake=0.7 + - isort + - jupyterlab + - jupytext + - matplotlib - pandas - pytest - python-dotenv - s3fs - - jupyterlab - - jupytext + - scikit-image + - scikit-learn + - xarray - pip - streamlit - plotly - pip: - - scivision - - git+https://github.com/alan-turing-institute/plankton-cefas-scivision@main + - git+https://github.com/jmarshrossney/resnet50-cefas diff --git a/notebooks/ImageEmbeddings.md b/notebooks/ImageEmbeddings.md index 889a9f2..19362ab 100644 --- a/notebooks/ImageEmbeddings.md +++ b/notebooks/ImageEmbeddings.md @@ -12,14 +12,13 @@ jupyter: name: python3 --- -Use this with the `cyto_39` environment (the scivision model needs a specific version of `pytorch` that isn't packaged for >3.9, i have raised a Github issue asking if they plan to update it) +Use this with the `cyto_ml` environment. `conda env create -f environment.yml` -`conda activate cyto_39` +`conda activate cyto_ml` ```python import os -from scivision import load_pretrained_model, load_dataset from dotenv import load_dotenv import torch import torchvision @@ -29,58 +28,37 @@ sys.path.append('../') from cyto_ml.models.scivision import prepare_image from intake_xarray import ImageSource load_dotenv() # sets our object store endpoint and credentials from the .env file + +from intake import open_catalog + +from resnet50_cefas import load_model ``` ```python -dataset = load_dataset(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml") -model = load_pretrained_model("https://github.com/alan-turing-institute/plankton-cefas-scivision") +dataset = open_catalog(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml") +model = load_model() dataset.test_image().to_dask() ``` -The scivision wrapper depends on this being an xarray Dataset with settable attributes, rather than a DataArray - -Setting exif_tags: True (Dataset) or False (DataArray) is what controls this -https://docs.xarray.dev/en/stable/generated/xarray.DataArray.to_dataset.html - -https://github.com/alan-turing-institute/scivision/blob/07fb74e5231bc1d56cf39df38c19ef40e3265e4c/src/scivision/io/reader.py#L183 -https://github.com/intake/intake/blob/29c8878aa7bf6e93185e2c9639f8739445dff22b/intake/__init__.py#L101 - -But now we're dependent on image height and width metadata being set in the EXIF tags to use the `predict` interface, this is set in the model description through `scivision`, this is brittle - -https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/model.py#L71 - - - -A quick look at the example dataset that comes with the model, for reference - - -In this case we don't want to use the `predict` interface anyway (one of N class labels) - we want the features that go into the last fully-connected layer (as described here https://stackoverflow.com/a/52548419) - ```python -network = torch.nn.Sequential(*(list(model._plumbing.model.pretrained_model.children())[:-1])) +network = load_model(strip_final_layer=True) ``` ```python imgs = dataset.test_image().to_dask() -i= imgs.to_numpy() -i.shape - +imgs.to_numpy().shape ``` -https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/data.py - - - Pass the image through our truncated network and get some embeddings out ```python -o = torch.stack([torchvision.transforms.ToTensor()(i)]) +o = prepare_image(imgs) feats = network(o) feats.shape ``` ```python -embeddings = list(feats[0].squeeze(1).squeeze(1).detach().numpy().astype(float)) +embeddings = feats[0].tolist() ``` ```python @@ -129,7 +107,7 @@ index ```python def flat_embeddings(features: torch.Tensor): - return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float)) + return features[0].tolist() ``` ```python @@ -158,6 +136,8 @@ This scales ok at 8000 or so images collection.count() ``` +This is _really_ slow - joe + ```python res = index.apply(file_embeddings, axis=1) ``` diff --git a/pyproject.toml b/pyproject.toml index ffd6870..bb89f2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,10 +4,27 @@ build-backend = "setuptools.build_meta" [project] name = "cyto_ml" -version = "0.1.0" -requires-python = "==3.9.*" +version = "0.2.0" +requires-python = ">=3.12" description = "This package supports the processing and analysis of plankton sample data" readme = "README.md" +dependencies = [ + "chromadb", + "intake==0.7.0", + "intake-xarray", + "pandas", + "python-dotenv", + "s3fs", + "scikit-image", # secretly required by intake-xarray as default reader + "torch", + "xarray", + "resnet50-cefas@git+https://github.com/jmarshrossney/resnet50-cefas", +] + +[project.optional-dependencies] +jupyter = ["jupyterlab", "jupytext", "matplotlib"] +dev = ["pytest", "black", "flake8", "isort"] +all = ["cyto_ml[jupyter,dev]"] [tool.jupytext] formats = "ipynb,md" @@ -18,5 +35,5 @@ filterwarnings = [ ] [tool.black] -target-version = ["py39"] +target-version = ["py312"] line-length = 88 diff --git a/src/cyto_ml/data/vectorstore.py b/src/cyto_ml/data/vectorstore.py index f6e5bcc..08cc531 100644 --- a/src/cyto_ml/data/vectorstore.py +++ b/src/cyto_ml/data/vectorstore.py @@ -9,7 +9,7 @@ logging.basicConfig(level=logging.INFO) # TODO make this sensibly configurable, not confusingly hardcoded -STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../vectors") +STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../vectors") client = chromadb.PersistentClient( path=STORE, diff --git a/src/cyto_ml/models/scivision.py b/src/cyto_ml/models/scivision.py index efba243..ec022e1 100644 --- a/src/cyto_ml/models/scivision.py +++ b/src/cyto_ml/models/scivision.py @@ -1,33 +1,7 @@ -from scivision import load_pretrained_model -from scivision.io import PretrainedModel import torch -import torchvision +from torchvision.transforms.v2.functional import to_image, to_dtype from xarray import DataArray -SCIVISION_URL = ( - "https://github.com/alan-turing-institute/plankton-cefas-scivision" # noqa: E501 -) - - -def load_model(url: str): - """Load a scivision model from a URL, for example - https://github.com/alan-turing-institute/plankton-cefas-scivision - """ - model = load_pretrained_model(url) - return model - - -def truncate_model(model: PretrainedModel): - """ - Accepts a scivision model wrapper and returns the pytorch model, - with its last fully-connected layer removed so that it returns - 2048 features rather than a handle of label predictions - """ - network = torch.nn.Sequential( - *(list(model._plumbing.model.pretrained_model.children())[:-1]) - ) - return network - def prepare_image(image: DataArray): """ @@ -39,17 +13,24 @@ def prepare_image(image: DataArray): image_numpy = image.to_numpy() # Convert the image data to a PyTorch tensor - tensor_image = torchvision.transforms.ToTensor()(image_numpy) + tensor_image = to_dtype( + to_image(image_numpy), # permutes HWC -> CHW + torch.float32, + scale=True, # rescales [0, 255] -> [0, 1] + ) + assert torch.all((tensor_image >= 0.0) & (tensor_image <= 1.0)) - # Check if the input is a single image or a batch - if len(tensor_image.shape) == 3: + if tensor_image.dim() == 3: # Single image, add a batch dimension tensor_image = tensor_image.unsqueeze(0) + assert tensor_image.dim() == 4 + return tensor_image def flat_embeddings(features: torch.Tensor): """Utility function that takes the features returned by the model in truncate_model And flattens them into a list suitable for storing in a vector database""" - return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float)) + # TODO: this only returns the 0th tensor in the batch...why? + return features[0].detach().tolist() diff --git a/tests/conftest.py b/tests/conftest.py index 1b092ae..9cda323 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,7 @@ import os import pytest -from cyto_ml.models.scivision import ( - load_model, - truncate_model, - SCIVISION_URL, -) + +from resnet50_cefas import load_model @pytest.fixture @@ -30,7 +27,7 @@ def image_batch(image_dir): @pytest.fixture def scivision_model(): - return truncate_model(load_model(SCIVISION_URL)) + return load_model(strip_final_layer=True) @pytest.fixture diff --git a/tests/test_prepare_image.py b/tests/test_prepare_image.py index 459496b..2c602da 100644 --- a/tests/test_prepare_image.py +++ b/tests/test_prepare_image.py @@ -9,8 +9,8 @@ def test_single_image(single_image): - image_data = ImageSource(single_image).to_dask() + # Tensorise the image (potentially normalise if we have useful values) prepared_image = prepare_image(image_data) @@ -25,7 +25,6 @@ def test_image_batch(image_batch): We either pad them (and process a lot of blank space) or stick to single image input """ # Load a batch of plankton images - image_data = ImageSource(image_batch).to_dask() with pytest.raises(ValueError) as err: