Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
jmarshrossney committed Aug 29, 2024
2 parents e1f6d58 + a21719c commit 59531eb
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 99 deletions.
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[flake8]
max-line-length=120
exclude =
venv
__pycache__
tests
vectors
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
vectors/
*.ipynb
*.egg-info/
venv/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
38 changes: 29 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,46 @@ It's a companion project to an R-shiny based image annotation app that is not ye

## Installation

### Python environment setup
### Environment and package installation

Use anaconda or miniconda to create a python environment using the included `environment.yml`
#### Using pip

Create a fresh virtual environment in the repository root using Python >=3.12 and (e.g.) `venv`:

```
conda env create -f environment.yml
python -m venv venv
```

Please note that this is specifically pinned to python 3.9 due to dependency versions; we make experimental use of the [CEFAS plankton model available through SciVision](https://sci.vision/#/model/resnet50-plankton), which in turn uses an older version of pytorch that isn't packaged above python 3.9.
Next, install the package using `pip`:

### Object store connection
```
python -m pip install .
```

`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below.
Most likely you are interested in developing and/or experimenting, so you will probably want to install the package in 'editable' mode (`-e`), along with dev tools and jupyter notebook functionality

```
python -m pip install -e .[all]
```

### Package installation
#### Using conda

Get started by cloning this repository and running
Use anaconda or miniconda to create a python environment using the included `environment.yml`

`python -m pip install -e .`
```
conda env create -f environment.yml
conda activate cyto_ml
```

Next install this package _without dependencies_:

```
python -m pip install --no-deps -e .
```

### Object store connection

`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below.

### Running tests

Expand Down
28 changes: 16 additions & 12 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
name: cyto_39
name: cyto_ml
channels:
- pytorch
- conda-forge
- defaults
channel_priority: flexible
dependencies:
- python=3.9
- pytorch=1.10.0
- mkl=2024.0
- chromadb=0.5.3
- python=3.12
- pytorch
- black
- chromadb
- flake8
- intake-xarray
- scikit-image
- scikit-learn
- intake=0.7
- isort
- jupyterlab
- jupytext
- matplotlib
- pandas
- pytest
- python-dotenv
- s3fs
- jupyterlab
- jupytext
- scikit-image
- scikit-learn
- xarray
- pip
- streamlit
- plotly
- pip:
- scivision
- git+https://github.com/alan-turing-institute/plankton-cefas-scivision@main
- git+https://github.com/jmarshrossney/resnet50-cefas
50 changes: 15 additions & 35 deletions notebooks/ImageEmbeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ jupyter:
name: python3
---

Use this with the `cyto_39` environment (the scivision model needs a specific version of `pytorch` that isn't packaged for >3.9, i have raised a Github issue asking if they plan to update it)
Use this with the `cyto_ml` environment.

`conda env create -f environment.yml`
`conda activate cyto_39`
`conda activate cyto_ml`

```python
import os
from scivision import load_pretrained_model, load_dataset
from dotenv import load_dotenv
import torch
import torchvision
Expand All @@ -29,58 +28,37 @@ sys.path.append('../')
from cyto_ml.models.scivision import prepare_image
from intake_xarray import ImageSource
load_dotenv() # sets our object store endpoint and credentials from the .env file

from intake import open_catalog

from resnet50_cefas import load_model
```

```python
dataset = load_dataset(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml")
model = load_pretrained_model("https://github.com/alan-turing-institute/plankton-cefas-scivision")
dataset = open_catalog(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml")
model = load_model()
dataset.test_image().to_dask()
```

The scivision wrapper depends on this being an xarray Dataset with settable attributes, rather than a DataArray

Setting exif_tags: True (Dataset) or False (DataArray) is what controls this
https://docs.xarray.dev/en/stable/generated/xarray.DataArray.to_dataset.html

https://github.com/alan-turing-institute/scivision/blob/07fb74e5231bc1d56cf39df38c19ef40e3265e4c/src/scivision/io/reader.py#L183
https://github.com/intake/intake/blob/29c8878aa7bf6e93185e2c9639f8739445dff22b/intake/__init__.py#L101

But now we're dependent on image height and width metadata being set in the EXIF tags to use the `predict` interface, this is set in the model description through `scivision`, this is brittle

https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/model.py#L71



A quick look at the example dataset that comes with the model, for reference


In this case we don't want to use the `predict` interface anyway (one of N class labels) - we want the features that go into the last fully-connected layer (as described here https://stackoverflow.com/a/52548419)

```python
network = torch.nn.Sequential(*(list(model._plumbing.model.pretrained_model.children())[:-1]))
network = load_model(strip_final_layer=True)
```

```python
imgs = dataset.test_image().to_dask()
i= imgs.to_numpy()
i.shape

imgs.to_numpy().shape
```

https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/data.py



Pass the image through our truncated network and get some embeddings out

```python
o = torch.stack([torchvision.transforms.ToTensor()(i)])
o = prepare_image(imgs)
feats = network(o)
feats.shape
```

```python
embeddings = list(feats[0].squeeze(1).squeeze(1).detach().numpy().astype(float))
embeddings = feats[0].tolist()
```

```python
Expand Down Expand Up @@ -129,7 +107,7 @@ index

```python
def flat_embeddings(features: torch.Tensor):
return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float))
return features[0].tolist()
```

```python
Expand Down Expand Up @@ -158,6 +136,8 @@ This scales ok at 8000 or so images
collection.count()
```

This is _really_ slow - joe

```python
res = index.apply(file_embeddings, axis=1)
```
Expand Down
23 changes: 20 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,27 @@ build-backend = "setuptools.build_meta"

[project]
name = "cyto_ml"
version = "0.1.0"
requires-python = "==3.9.*"
version = "0.2.0"
requires-python = ">=3.12"
description = "This package supports the processing and analysis of plankton sample data"
readme = "README.md"
dependencies = [
"chromadb",
"intake==0.7.0",
"intake-xarray",
"pandas",
"python-dotenv",
"s3fs",
"scikit-image", # secretly required by intake-xarray as default reader
"torch",
"xarray",
"resnet50-cefas@git+https://github.com/jmarshrossney/resnet50-cefas",
]

[project.optional-dependencies]
jupyter = ["jupyterlab", "jupytext", "matplotlib"]
dev = ["pytest", "black", "flake8", "isort"]
all = ["cyto_ml[jupyter,dev]"]

[tool.jupytext]
formats = "ipynb,md"
Expand All @@ -18,5 +35,5 @@ filterwarnings = [
]

[tool.black]
target-version = ["py39"]
target-version = ["py312"]
line-length = 88
2 changes: 1 addition & 1 deletion src/cyto_ml/data/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

logging.basicConfig(level=logging.INFO)
# TODO make this sensibly configurable, not confusingly hardcoded
STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../vectors")
STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../vectors")

client = chromadb.PersistentClient(
path=STORE,
Expand Down
43 changes: 12 additions & 31 deletions src/cyto_ml/models/scivision.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,7 @@
from scivision import load_pretrained_model
from scivision.io import PretrainedModel
import torch
import torchvision
from torchvision.transforms.v2.functional import to_image, to_dtype
from xarray import DataArray

SCIVISION_URL = (
"https://github.com/alan-turing-institute/plankton-cefas-scivision" # noqa: E501
)


def load_model(url: str):
"""Load a scivision model from a URL, for example
https://github.com/alan-turing-institute/plankton-cefas-scivision
"""
model = load_pretrained_model(url)
return model


def truncate_model(model: PretrainedModel):
"""
Accepts a scivision model wrapper and returns the pytorch model,
with its last fully-connected layer removed so that it returns
2048 features rather than a handle of label predictions
"""
network = torch.nn.Sequential(
*(list(model._plumbing.model.pretrained_model.children())[:-1])
)
return network


def prepare_image(image: DataArray):
"""
Expand All @@ -39,17 +13,24 @@ def prepare_image(image: DataArray):
image_numpy = image.to_numpy()

# Convert the image data to a PyTorch tensor
tensor_image = torchvision.transforms.ToTensor()(image_numpy)
tensor_image = to_dtype(
to_image(image_numpy), # permutes HWC -> CHW
torch.float32,
scale=True, # rescales [0, 255] -> [0, 1]
)
assert torch.all((tensor_image >= 0.0) & (tensor_image <= 1.0))

# Check if the input is a single image or a batch
if len(tensor_image.shape) == 3:
if tensor_image.dim() == 3:
# Single image, add a batch dimension
tensor_image = tensor_image.unsqueeze(0)

assert tensor_image.dim() == 4

return tensor_image


def flat_embeddings(features: torch.Tensor):
"""Utility function that takes the features returned by the model in truncate_model
And flattens them into a list suitable for storing in a vector database"""
return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float))
# TODO: this only returns the 0th tensor in the batch...why?
return features[0].detach().tolist()
9 changes: 3 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import pytest
from cyto_ml.models.scivision import (
load_model,
truncate_model,
SCIVISION_URL,
)

from resnet50_cefas import load_model


@pytest.fixture
Expand All @@ -30,7 +27,7 @@ def image_batch(image_dir):

@pytest.fixture
def scivision_model():
return truncate_model(load_model(SCIVISION_URL))
return load_model(strip_final_layer=True)


@pytest.fixture
Expand Down
3 changes: 1 addition & 2 deletions tests/test_prepare_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


def test_single_image(single_image):

image_data = ImageSource(single_image).to_dask()

# Tensorise the image (potentially normalise if we have useful values)
prepared_image = prepare_image(image_data)

Expand All @@ -25,7 +25,6 @@ def test_image_batch(image_batch):
We either pad them (and process a lot of blank space) or stick to single image input
"""
# Load a batch of plankton images

image_data = ImageSource(image_batch).to_dask()

with pytest.raises(ValueError) as err:
Expand Down

0 comments on commit 59531eb

Please sign in to comment.