From 4b63f9ac4657e61bc2e51848ebfa11a7367d06f4 Mon Sep 17 00:00:00 2001 From: Jo Walsh Date: Tue, 8 Oct 2024 08:41:34 +0100 Subject: [PATCH] Special handling for flow cytometer images (#41) * add a stub function to normalise flow cytometer images * normalise the flow cytometer image. it's still greyscale * heavy-handed conversion to 3 band, will it work for display? * more test coverage, small bug related to model input data type --- scripts/params.yaml | 2 +- src/cyto_ml/data/image.py | 46 +++++++++++++++++++++++---- src/cyto_ml/visualisation/app.py | 12 +++---- tests/conftest.py | 2 +- tests/test_image_embeddings.py | 31 ++++++++++++++---- tests/test_pipeline.py | 54 +++++++++++++++++++------------- tests/test_prepare_image.py | 3 +- tests/test_vector_store.py | 1 + tests/test_visualisation_app.py | 4 +-- 9 files changed, 109 insertions(+), 46 deletions(-) diff --git a/scripts/params.yaml b/scripts/params.yaml index 56dbd55..e4023b3 100644 --- a/scripts/params.yaml +++ b/scripts/params.yaml @@ -1,4 +1,4 @@ cluster: n_clusters: 5 -collection: untagged-images-lana +collection: test-upload-alba diff --git a/src/cyto_ml/data/image.py b/src/cyto_ml/data/image.py index 5274032..8e9c9c9 100644 --- a/src/cyto_ml/data/image.py +++ b/src/cyto_ml/data/image.py @@ -1,6 +1,7 @@ import logging from io import BytesIO +import numpy as np import requests import torch from PIL import Image @@ -27,16 +28,49 @@ def prepare_image(image: Image) -> torch.Tensor: a) Converts the image data to a PyTorch tensor b) Accepts a single image or batch (no need for torch.stack) """ - # Flow Cytometer images are 16-bit greyscale - # https://stackoverflow.com/questions/18522295/python-pil-change-greyscale-tif-to-rgb - # TODO revisit - if image.mode == "I;16": - image.point(lambda p: p * 0.0039063096, mode="RGB") - image = image.convert("RGB") + if hasattr(image, "mode") and image.mode == "I;16": + # Flow Cytometer images are 16-bit greyscale, in a low range + # Note - tried this and variants, does not have expected result + # https://stackoverflow.com/questions/18522295/python-pil-change-greyscale-tif-to-rgb + # + # Convert to 3 bands because our model has 3 channel input + image = convert_3_band(normalise_flowlr(image)) tensor_image = transforms.ToTensor()(image) # Single image, add a batch dimension tensor_image = tensor_image.unsqueeze(0) return tensor_image + + +def normalise_flowlr(image: Image) -> np.array: + """Utility function to normalise flow cytometer images. + As output from the flow cytometer, they are 16 bit greyscale, + but all the values are in a low range (max value 1018 across the set) + + As recommended by @Kzra, normalise all values by the maximum + Both for display, and before handing to a model. + + Image.point(lambda...) should do this, but the values stay integers + So roundtrip this through numpy + """ + pix = np.array(image) + max_val = max(pix.flatten()) + pix = pix / max_val + return pix + + +def convert_3_band(image: np.array) -> np.array: + """ + Given a 1-band image normalised between 0 and 1, convert to 3 band + https://stackoverflow.com/a/57723482 + This seems very brute-force, but PIL is not converting our odd-format + greyscale images from the Flow Cytometer well. Improvements appreciated + """ + img2 = np.zeros((image.shape[0], image.shape[1], 3)) + img2[:, :, 0] = image # same value in each channel + img2[:, :, 1] = image + img2[:, :, 2] = image + # Cast to float32 as this is what the model layers expect + return img2.astype(np.float32) diff --git a/src/cyto_ml/visualisation/app.py b/src/cyto_ml/visualisation/app.py index 41a4df1..d8e8204 100644 --- a/src/cyto_ml/visualisation/app.py +++ b/src/cyto_ml/visualisation/app.py @@ -20,6 +20,7 @@ from dotenv import load_dotenv from PIL import Image +from cyto_ml.data.image import normalise_flowlr from cyto_ml.data.vectorstore import client, embeddings, vector_store load_dotenv() @@ -73,13 +74,12 @@ def cached_image(url: str) -> Image: """ response = requests.get(url) image = Image.open(BytesIO(response.content)) + + # Special handling for Flow Cytometer images, + # All 16 bit greyscale in low range of values if image.mode == "I;16": - # 16 bit greyscale - divide by 255, convert RGB for display - (_, max_val) = image.getextrema() - image.point(lambda p: p * 1 / max_val) - # image.point(lambda p: p * (1/255))#.convert('RGB') - # image.mode = 'I'#, mode="RGB") - image = image.convert("RGB") + image = normalise_flowlr(image) + return image diff --git a/tests/conftest.py b/tests/conftest.py index e1f2590..b9dd997 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,7 +36,7 @@ def image_batch(image_dir): @pytest.fixture -def scivision_model(): +def resnet_model(): return load_model(strip_final_layer=True) diff --git a/tests/test_image_embeddings.py b/tests/test_image_embeddings.py index 33d0b5f..f9bcaf0 100644 --- a/tests/test_image_embeddings.py +++ b/tests/test_image_embeddings.py @@ -1,15 +1,34 @@ - -from torch import Tensor +import torch +from PIL import Image from cyto_ml.models.utils import flat_embeddings -from cyto_ml.data.image import load_image +from cyto_ml.data.image import load_image, normalise_flowlr, prepare_image -def test_embeddings(scivision_model, single_image): - features = scivision_model(load_image(single_image)) +def test_embeddings(resnet_model, single_image, greyscale_image): + features = resnet_model(load_image(single_image)) - assert isinstance(features, Tensor) + assert isinstance(features, torch.Tensor) embeddings = flat_embeddings(features) assert len(embeddings) == features.size()[1] + features = resnet_model(load_image(greyscale_image)) + + assert isinstance(features, torch.Tensor) + + embeddings = flat_embeddings(features) + assert len(embeddings) == features.size()[1] + + +def test_normalise_flowlr(greyscale_image): + # Normalise first, hand the tensorize function an array + image = normalise_flowlr(Image.open(greyscale_image)) + prepared_image = prepare_image(image) + + assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0)) + + # Do it all at once + prepared_image = prepare_image(Image.open(greyscale_image)) + + assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 19ba83f..9555882 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,7 +4,11 @@ import numpy as np from skimage.io import imsave import luigi -from cyto_ml.pipeline.pipeline_decollage import ReadMetadata, DecollageImages, UploadDecollagedImagesToS3 +from cyto_ml.pipeline.pipeline_decollage import ( + ReadMetadata, + DecollageImages, + UploadDecollagedImagesToS3, +) @pytest.fixture @@ -20,14 +24,13 @@ def test_read_metadata(temp_dir): for i in range(1, 54): lst_file_content += f"field{i}|val{i}\n" # Also assumes at least one row of data! Add two of them - for i in [0,1]: - lst_file_content += "|".join([str(i) for i in range(1,54)]) + "\n" + for i in [0, 1]: + lst_file_content += "|".join([str(i) for i in range(1, 54)]) + "\n" - lst_file_path = os.path.join(temp_dir, 'test.lst') - with open(lst_file_path, 'w') as f: + lst_file_path = os.path.join(temp_dir, "test.lst") + with open(lst_file_path, "w") as f: f.write(lst_file_content) - # Run the ReadMetadata task task = ReadMetadata(directory=str(temp_dir)) luigi.build([task], local_scheduler=True) @@ -42,13 +45,15 @@ def test_read_metadata(temp_dir): def test_decollage_images(temp_dir): # Create mock metadata - metadata = pd.DataFrame({ - "collage_file": ["test_collage.tif"], - "image_x": [0], - "image_y": [0], - "image_h": [100], - "image_w": [100] - }) + metadata = pd.DataFrame( + { + "collage_file": ["test_collage.tif"], + "image_x": [0], + "image_y": [0], + "image_h": [100], + "image_w": [100], + } + ) metadata.to_csv(os.path.join(temp_dir, "metadata.csv"), index=False) # Create a mock TIFF image @@ -57,31 +62,38 @@ def test_decollage_images(temp_dir): imsave(img_path, img) # Run the DecollageImages task - task = DecollageImages(directory=str(temp_dir), output_directory=str(temp_dir), experiment_name="test_experiment") + task = DecollageImages( + directory=str(temp_dir), + output_directory=str(temp_dir), + experiment_name="test_experiment", + ) luigi.build([task], local_scheduler=True) # Check if the output image was created output_image = os.path.join(temp_dir, "test_experiment_0.tif") assert os.path.exists(output_image), "Decollaged image should be created." + class MockTask(luigi.Task): directory = luigi.Parameter() check_unfulfilled_deps = False + def output(self) -> luigi.Target: # "The output() method returns one or more Target objects."" - return luigi.LocalTarget(f'{self.directory}/out.txt') - + return luigi.LocalTarget(f"{self.directory}/out.txt") + + def test_upload_to_api(temp_dir, mocker): # Write a tmp file to serve as our upstream task's output - with open(os.path.join(temp_dir, 'out.txt'), 'w') as out: + with open(os.path.join(temp_dir, "out.txt"), "w") as out: out.write("blah") # The task `requires` DecollageImages, but that requires other tasks, which run first - # Rather than mock its output, or the whole chain, require a mock task that replaces it - mock_output = mocker.patch(f'cyto_ml.pipeline.pipeline_decollage.UploadDecollagedImagesToS3.requires') + # Rather than mock its output, or the whole chain, require a mock task that replaces it + mock_output = mocker.patch(f"cyto_ml.pipeline.pipeline_decollage.UploadDecollagedImagesToS3.requires") mock_output.return_value = MockTask(directory=temp_dir) - + # Mock the requests.post to simulate the API response - mock_post = mocker.patch('cyto_ml.pipeline.pipeline_decollage.requests.post') + mock_post = mocker.patch("cyto_ml.pipeline.pipeline_decollage.requests.post") mock_post.return_value.status_code = 200 task = UploadDecollagedImagesToS3( diff --git a/tests/test_prepare_image.py b/tests/test_prepare_image.py index 91c2b7f..b628b4e 100644 --- a/tests/test_prepare_image.py +++ b/tests/test_prepare_image.py @@ -8,7 +8,6 @@ def test_single_image(single_image): - # Tensorise the image (potentially normalise if we have useful values) prepared_image = load_image(single_image) @@ -17,8 +16,8 @@ def test_single_image(single_image): assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0)) -def test_greyscale_image(greyscale_image): +def test_greyscale_image(greyscale_image): # Tensorise the image (potentially normalise if we have useful values) prepared_image = load_image(greyscale_image) assert torch.all((prepared_image >= 0.0) & (prepared_image <= 1.0)) diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 6f37537..490d836 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -21,6 +21,7 @@ def test_store(): assert record assert len(record["embeddings"][0]) == 2048 + def test_embeddings(): store = vector_store() assert len(embeddings(store)) diff --git a/tests/test_visualisation_app.py b/tests/test_visualisation_app.py index a285142..088e24f 100644 --- a/tests/test_visualisation_app.py +++ b/tests/test_visualisation_app.py @@ -40,9 +40,7 @@ def test_app_starts(self): "cyto_ml.visualisation.app.image_ids", return_value=self.data, ): - AppTest.from_file("src/cyto_ml/visualisation/app.py").run( - timeout=30 - ) + AppTest.from_file("src/cyto_ml/visualisation/app.py").run(timeout=30) def test_create_figure(self): """