-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Drop scivision #1
Changes from all commits
724c096
65cc62a
9600b33
5807fed
49b1d21
70b82c2
bc7c43b
6c7aaf2
c03d9f2
d84c768
daa6733
32bdc1f
2bf3586
a21719c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
[flake8] | ||
max-line-length=120 | ||
exclude = | ||
venv | ||
__pycache__ | ||
tests | ||
vectors |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ | |
**/__pycache__/ | ||
vectors/ | ||
*.ipynb | ||
*.egg-info/ | ||
venv/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.12 |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,29 @@ | ||
name: cyto_39 | ||
name: cyto_ml | ||
channels: | ||
- pytorch | ||
- conda-forge | ||
- defaults | ||
channel_priority: flexible | ||
dependencies: | ||
- python=3.9 | ||
- pytorch=1.10.0 | ||
- mkl=2024.0 | ||
- chromadb=0.5.3 | ||
- python=3.12 | ||
- pytorch | ||
- black | ||
- chromadb | ||
- flake8 | ||
- intake-xarray | ||
- scikit-image | ||
- scikit-learn | ||
- intake=0.7 | ||
- isort | ||
- jupyterlab | ||
- jupytext | ||
- matplotlib | ||
- pandas | ||
- pytest | ||
- python-dotenv | ||
- s3fs | ||
- jupyterlab | ||
- jupytext | ||
- scikit-image | ||
- scikit-learn | ||
- xarray | ||
- pip | ||
- streamlit | ||
- plotly | ||
- pip: | ||
- scivision | ||
- git+https://github.com/alan-turing-institute/plankton-cefas-scivision@main | ||
- git+https://github.com/jmarshrossney/resnet50-cefas |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,13 @@ jupyter: | |
name: python3 | ||
--- | ||
|
||
Use this with the `cyto_39` environment (the scivision model needs a specific version of `pytorch` that isn't packaged for >3.9, i have raised a Github issue asking if they plan to update it) | ||
Use this with the `cyto_ml` environment. | ||
|
||
`conda env create -f environment.yml` | ||
`conda activate cyto_39` | ||
`conda activate cyto_ml` | ||
|
||
```python | ||
import os | ||
from scivision import load_pretrained_model, load_dataset | ||
from dotenv import load_dotenv | ||
import torch | ||
import torchvision | ||
|
@@ -29,58 +28,37 @@ sys.path.append('../') | |
from cyto_ml.models.scivision import prepare_image | ||
from intake_xarray import ImageSource | ||
load_dotenv() # sets our object store endpoint and credentials from the .env file | ||
|
||
from intake import open_catalog | ||
|
||
from resnet50_cefas import load_model | ||
``` | ||
|
||
```python | ||
dataset = load_dataset(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml") | ||
model = load_pretrained_model("https://github.com/alan-turing-institute/plankton-cefas-scivision") | ||
dataset = open_catalog(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. huh.. it looks like |
||
model = load_model() | ||
dataset.test_image().to_dask() | ||
``` | ||
|
||
The scivision wrapper depends on this being an xarray Dataset with settable attributes, rather than a DataArray | ||
|
||
Setting exif_tags: True (Dataset) or False (DataArray) is what controls this | ||
https://docs.xarray.dev/en/stable/generated/xarray.DataArray.to_dataset.html | ||
|
||
https://github.com/alan-turing-institute/scivision/blob/07fb74e5231bc1d56cf39df38c19ef40e3265e4c/src/scivision/io/reader.py#L183 | ||
https://github.com/intake/intake/blob/29c8878aa7bf6e93185e2c9639f8739445dff22b/intake/__init__.py#L101 | ||
|
||
But now we're dependent on image height and width metadata being set in the EXIF tags to use the `predict` interface, this is set in the model description through `scivision`, this is brittle | ||
|
||
https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/model.py#L71 | ||
|
||
|
||
|
||
A quick look at the example dataset that comes with the model, for reference | ||
|
||
|
||
In this case we don't want to use the `predict` interface anyway (one of N class labels) - we want the features that go into the last fully-connected layer (as described here https://stackoverflow.com/a/52548419) | ||
|
||
```python | ||
network = torch.nn.Sequential(*(list(model._plumbing.model.pretrained_model.children())[:-1])) | ||
network = load_model(strip_final_layer=True) | ||
``` | ||
|
||
```python | ||
imgs = dataset.test_image().to_dask() | ||
i= imgs.to_numpy() | ||
i.shape | ||
|
||
imgs.to_numpy().shape | ||
``` | ||
|
||
https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/data.py | ||
|
||
|
||
|
||
Pass the image through our truncated network and get some embeddings out | ||
|
||
```python | ||
o = torch.stack([torchvision.transforms.ToTensor()(i)]) | ||
o = prepare_image(imgs) | ||
feats = network(o) | ||
feats.shape | ||
``` | ||
|
||
```python | ||
embeddings = list(feats[0].squeeze(1).squeeze(1).detach().numpy().astype(float)) | ||
embeddings = feats[0].tolist() | ||
``` | ||
|
||
```python | ||
|
@@ -129,7 +107,7 @@ index | |
|
||
```python | ||
def flat_embeddings(features: torch.Tensor): | ||
return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float)) | ||
return features[0].tolist() | ||
``` | ||
|
||
```python | ||
|
@@ -158,6 +136,8 @@ This scales ok at 8000 or so images | |
collection.count() | ||
``` | ||
|
||
This is _really_ slow - joe | ||
|
||
```python | ||
res = index.apply(file_embeddings, axis=1) | ||
``` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,39 @@ | ||
[build-system] | ||
requires = ["setuptools >= 61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "cyto_ml" | ||
version = "0.1" | ||
version = "0.2.0" | ||
requires-python = ">=3.12" | ||
description = "This package supports the processing and analysis of plankton sample data" | ||
readme = "README.md" | ||
requires-python = "==3.9.*" | ||
dependencies = [ | ||
"chromadb", | ||
"intake==0.7.0", | ||
"intake-xarray", | ||
"pandas", | ||
"python-dotenv", | ||
"s3fs", | ||
"scikit-image", # secretly required by intake-xarray as default reader | ||
"torch", | ||
"xarray", | ||
"resnet50-cefas@git+https://github.com/jmarshrossney/resnet50-cefas", | ||
] | ||
|
||
[tool.setuptools] | ||
py-modules = [] | ||
[project.optional-dependencies] | ||
jupyter = ["jupyterlab", "jupytext", "matplotlib"] | ||
dev = ["pytest", "black", "flake8", "isort"] | ||
all = ["cyto_ml[jupyter,dev]"] | ||
|
||
[tool.jupytext] | ||
formats = "ipynb,md" | ||
|
||
[tool.pytest.ini_options] | ||
filterwarnings = [ | ||
"ignore::DeprecationWarning", | ||
] | ||
|
||
[tool.black] | ||
target-version = ["py312"] | ||
line-length = 88 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from torchvision.transforms.v2.functional import to_image, to_dtype | ||
from xarray import DataArray | ||
|
||
|
||
def prepare_image(image: DataArray): | ||
""" | ||
Take an xarray of image data and prepare it to pass through the model | ||
a) Converts the image data to a PyTorch tensor | ||
b) Accepts a single image or batch (no need for torch.stack) | ||
""" | ||
# Computes the DataArray and returns a numpy array | ||
image_numpy = image.to_numpy() | ||
|
||
# Convert the image data to a PyTorch tensor | ||
tensor_image = to_dtype( | ||
to_image(image_numpy), # permutes HWC -> CHW | ||
torch.float32, | ||
scale=True, # rescales [0, 255] -> [0, 1] | ||
) | ||
Comment on lines
+16
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
assert torch.all((tensor_image >= 0.0) & (tensor_image <= 1.0)) | ||
|
||
if tensor_image.dim() == 3: | ||
# Single image, add a batch dimension | ||
tensor_image = tensor_image.unsqueeze(0) | ||
|
||
assert tensor_image.dim() == 4 | ||
|
||
return tensor_image | ||
|
||
|
||
def flat_embeddings(features: torch.Tensor): | ||
"""Utility function that takes the features returned by the model in truncate_model | ||
And flattens them into a list suitable for storing in a vector database""" | ||
# TODO: this only returns the 0th tensor in the batch...why? | ||
return features[0].detach().tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it fails for me with
strict
priority