Skip to content

Commit

Permalink
Special handling for flow cytometer images (#41)
Browse files Browse the repository at this point in the history
* add a stub function to normalise flow cytometer images

* normalise the flow cytometer image. it's still greyscale

* heavy-handed conversion to 3 band, will it work for display?

* more test coverage, small bug related to model input data type
  • Loading branch information
metazool authored Oct 8, 2024
1 parent c8422e9 commit 4b63f9a
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 46 deletions.
2 changes: 1 addition & 1 deletion scripts/params.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cluster:
n_clusters: 5

collection: untagged-images-lana
collection: test-upload-alba
46 changes: 40 additions & 6 deletions src/cyto_ml/data/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from io import BytesIO

import numpy as np
import requests
import torch
from PIL import Image
Expand All @@ -27,16 +28,49 @@ def prepare_image(image: Image) -> torch.Tensor:
a) Converts the image data to a PyTorch tensor
b) Accepts a single image or batch (no need for torch.stack)
"""
# Flow Cytometer images are 16-bit greyscale
# https://stackoverflow.com/questions/18522295/python-pil-change-greyscale-tif-to-rgb
# TODO revisit

if image.mode == "I;16":
image.point(lambda p: p * 0.0039063096, mode="RGB")
image = image.convert("RGB")
if hasattr(image, "mode") and image.mode == "I;16":
# Flow Cytometer images are 16-bit greyscale, in a low range
# Note - tried this and variants, does not have expected result
# https://stackoverflow.com/questions/18522295/python-pil-change-greyscale-tif-to-rgb
#
# Convert to 3 bands because our model has 3 channel input
image = convert_3_band(normalise_flowlr(image))

tensor_image = transforms.ToTensor()(image)

# Single image, add a batch dimension
tensor_image = tensor_image.unsqueeze(0)
return tensor_image


def normalise_flowlr(image: Image) -> np.array:
"""Utility function to normalise flow cytometer images.
As output from the flow cytometer, they are 16 bit greyscale,
but all the values are in a low range (max value 1018 across the set)
As recommended by @Kzra, normalise all values by the maximum
Both for display, and before handing to a model.
Image.point(lambda...) should do this, but the values stay integers
So roundtrip this through numpy
"""
pix = np.array(image)
max_val = max(pix.flatten())
pix = pix / max_val
return pix


def convert_3_band(image: np.array) -> np.array:
"""
Given a 1-band image normalised between 0 and 1, convert to 3 band
https://stackoverflow.com/a/57723482
This seems very brute-force, but PIL is not converting our odd-format
greyscale images from the Flow Cytometer well. Improvements appreciated
"""
img2 = np.zeros((image.shape[0], image.shape[1], 3))
img2[:, :, 0] = image # same value in each channel
img2[:, :, 1] = image
img2[:, :, 2] = image
# Cast to float32 as this is what the model layers expect
return img2.astype(np.float32)
12 changes: 6 additions & 6 deletions src/cyto_ml/visualisation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dotenv import load_dotenv
from PIL import Image

from cyto_ml.data.image import normalise_flowlr
from cyto_ml.data.vectorstore import client, embeddings, vector_store

load_dotenv()
Expand Down Expand Up @@ -73,13 +74,12 @@ def cached_image(url: str) -> Image:
"""
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Special handling for Flow Cytometer images,
# All 16 bit greyscale in low range of values
if image.mode == "I;16":
# 16 bit greyscale - divide by 255, convert RGB for display
(_, max_val) = image.getextrema()
image.point(lambda p: p * 1 / max_val)
# image.point(lambda p: p * (1/255))#.convert('RGB')
# image.mode = 'I'#, mode="RGB")
image = image.convert("RGB")
image = normalise_flowlr(image)

return image


Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def image_batch(image_dir):


@pytest.fixture
def scivision_model():
def resnet_model():
return load_model(strip_final_layer=True)


Expand Down
31 changes: 25 additions & 6 deletions tests/test_image_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@

from torch import Tensor
import torch
from PIL import Image
from cyto_ml.models.utils import flat_embeddings
from cyto_ml.data.image import load_image
from cyto_ml.data.image import load_image, normalise_flowlr, prepare_image


def test_embeddings(scivision_model, single_image):
features = scivision_model(load_image(single_image))
def test_embeddings(resnet_model, single_image, greyscale_image):
features = resnet_model(load_image(single_image))

assert isinstance(features, Tensor)
assert isinstance(features, torch.Tensor)

embeddings = flat_embeddings(features)

assert len(embeddings) == features.size()[1]

features = resnet_model(load_image(greyscale_image))

assert isinstance(features, torch.Tensor)

embeddings = flat_embeddings(features)
assert len(embeddings) == features.size()[1]


def test_normalise_flowlr(greyscale_image):
# Normalise first, hand the tensorize function an array
image = normalise_flowlr(Image.open(greyscale_image))
prepared_image = prepare_image(image)

assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0))

# Do it all at once
prepared_image = prepare_image(Image.open(greyscale_image))

assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0))
54 changes: 33 additions & 21 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import numpy as np
from skimage.io import imsave
import luigi
from cyto_ml.pipeline.pipeline_decollage import ReadMetadata, DecollageImages, UploadDecollagedImagesToS3
from cyto_ml.pipeline.pipeline_decollage import (
ReadMetadata,
DecollageImages,
UploadDecollagedImagesToS3,
)


@pytest.fixture
Expand All @@ -20,14 +24,13 @@ def test_read_metadata(temp_dir):
for i in range(1, 54):
lst_file_content += f"field{i}|val{i}\n"
# Also assumes at least one row of data! Add two of them
for i in [0,1]:
lst_file_content += "|".join([str(i) for i in range(1,54)]) + "\n"
for i in [0, 1]:
lst_file_content += "|".join([str(i) for i in range(1, 54)]) + "\n"

lst_file_path = os.path.join(temp_dir, 'test.lst')
with open(lst_file_path, 'w') as f:
lst_file_path = os.path.join(temp_dir, "test.lst")
with open(lst_file_path, "w") as f:
f.write(lst_file_content)


# Run the ReadMetadata task
task = ReadMetadata(directory=str(temp_dir))
luigi.build([task], local_scheduler=True)
Expand All @@ -42,13 +45,15 @@ def test_read_metadata(temp_dir):

def test_decollage_images(temp_dir):
# Create mock metadata
metadata = pd.DataFrame({
"collage_file": ["test_collage.tif"],
"image_x": [0],
"image_y": [0],
"image_h": [100],
"image_w": [100]
})
metadata = pd.DataFrame(
{
"collage_file": ["test_collage.tif"],
"image_x": [0],
"image_y": [0],
"image_h": [100],
"image_w": [100],
}
)
metadata.to_csv(os.path.join(temp_dir, "metadata.csv"), index=False)

# Create a mock TIFF image
Expand All @@ -57,31 +62,38 @@ def test_decollage_images(temp_dir):
imsave(img_path, img)

# Run the DecollageImages task
task = DecollageImages(directory=str(temp_dir), output_directory=str(temp_dir), experiment_name="test_experiment")
task = DecollageImages(
directory=str(temp_dir),
output_directory=str(temp_dir),
experiment_name="test_experiment",
)
luigi.build([task], local_scheduler=True)

# Check if the output image was created
output_image = os.path.join(temp_dir, "test_experiment_0.tif")
assert os.path.exists(output_image), "Decollaged image should be created."


class MockTask(luigi.Task):
directory = luigi.Parameter()
check_unfulfilled_deps = False

def output(self) -> luigi.Target:
# "The output() method returns one or more Target objects.""
return luigi.LocalTarget(f'{self.directory}/out.txt')

return luigi.LocalTarget(f"{self.directory}/out.txt")


def test_upload_to_api(temp_dir, mocker):
# Write a tmp file to serve as our upstream task's output
with open(os.path.join(temp_dir, 'out.txt'), 'w') as out:
with open(os.path.join(temp_dir, "out.txt"), "w") as out:
out.write("blah")
# The task `requires` DecollageImages, but that requires other tasks, which run first
# Rather than mock its output, or the whole chain, require a mock task that replaces it
mock_output = mocker.patch(f'cyto_ml.pipeline.pipeline_decollage.UploadDecollagedImagesToS3.requires')
# Rather than mock its output, or the whole chain, require a mock task that replaces it
mock_output = mocker.patch(f"cyto_ml.pipeline.pipeline_decollage.UploadDecollagedImagesToS3.requires")
mock_output.return_value = MockTask(directory=temp_dir)

# Mock the requests.post to simulate the API response
mock_post = mocker.patch('cyto_ml.pipeline.pipeline_decollage.requests.post')
mock_post = mocker.patch("cyto_ml.pipeline.pipeline_decollage.requests.post")
mock_post.return_value.status_code = 200

task = UploadDecollagedImagesToS3(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_prepare_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


def test_single_image(single_image):

# Tensorise the image (potentially normalise if we have useful values)
prepared_image = load_image(single_image)

Expand All @@ -17,8 +16,8 @@ def test_single_image(single_image):

assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0))

def test_greyscale_image(greyscale_image):

def test_greyscale_image(greyscale_image):
# Tensorise the image (potentially normalise if we have useful values)
prepared_image = load_image(greyscale_image)
assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0))
Expand Down
1 change: 1 addition & 0 deletions tests/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_store():
assert record
assert len(record["embeddings"][0]) == 2048


def test_embeddings():
store = vector_store()
assert len(embeddings(store))
4 changes: 1 addition & 3 deletions tests/test_visualisation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def test_app_starts(self):
"cyto_ml.visualisation.app.image_ids",
return_value=self.data,
):
AppTest.from_file("src/cyto_ml/visualisation/app.py").run(
timeout=30
)
AppTest.from_file("src/cyto_ml/visualisation/app.py").run(timeout=30)

def test_create_figure(self):
"""
Expand Down

0 comments on commit 4b63f9a

Please sign in to comment.