diff --git a/transforms/images/apply-flatfield-plugin/pyproject.toml b/transforms/images/apply-flatfield-plugin/pyproject.toml new file mode 100644 index 000000000..d4e72a227 --- /dev/null +++ b/transforms/images/apply-flatfield-plugin/pyproject.toml @@ -0,0 +1,38 @@ +[tool.poetry] +name = "polus-plugins-transforms-images-apply-flatfield" +version = "2.0.0-dev7" +description = "" +authors = [ + "Nick Schaub ", + "Najib Ishaq " +] +readme = "README.md" +packages = [{include = "polus", from = "src"}] + +[tool.poetry.dependencies] +python = "^3.9" +bfio = { version = "2.1.9", extras = ["all"] } +filepattern = [ + { version = "^2.0.0", platform = "linux" }, + { version = "^2.0.0", platform = "win32" }, + # { git = "https://github.com/PolusAI/filepattern", rev = "c07bf543c435cbc4cf264effd5a178868e9eaf19", platform = "darwin" }, + { git = "https://github.com/JesseMckinzie/filepattern-1", rev = "c27cf04ba3a1946b87c0c43d5720ba394c340894", platform = "darwin" }, +] +typer = { version = "^0.7.0", extras = ["all"] } +numpy = "^1.24.3" +tqdm = "^4.65.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^7.2.1" +pytest-cov = "^4.0.0" +pytest-sugar = "^0.9.6" +pytest-xdist = "^3.2.0" +pytest-benchmark = "^4.0.0" +bump2version = "^1.0.1" +pre-commit = "^3.0.4" +black = "^23.1.0" +ruff = "^0.0.265" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/transforms/images/apply-flatfield-plugin/src/polus/plugins/transforms/images/apply_flatfield/__init__.py b/transforms/images/apply-flatfield-plugin/src/polus/plugins/transforms/images/apply_flatfield/__init__.py new file mode 100644 index 000000000..dd697c5a6 --- /dev/null +++ b/transforms/images/apply-flatfield-plugin/src/polus/plugins/transforms/images/apply_flatfield/__init__.py @@ -0,0 +1,160 @@ +"""Provides the apply_flatfield module.""" + +import concurrent.futures +import logging +import operator +import pathlib +import sys +import typing + +import bfio +import numpy +import tqdm +from filepattern import FilePattern + +from . import utils + +logger = logging.getLogger(__name__) +logger.setLevel(utils.POLUS_LOG) + + +def apply( # noqa: PLR0913 + img_dir: pathlib.Path, + img_pattern: str, + ff_dir: pathlib.Path, + ff_pattern: str, + df_pattern: typing.Optional[str], + out_dir: pathlib.Path, +) -> None: + """Run batch-wise flatfield correction on the image collection.""" + img_fp = FilePattern(str(img_dir), img_pattern) + img_variables = img_fp.get_variables() + + ff_fp = FilePattern(str(ff_dir), ff_pattern) + ff_variables = ff_fp.get_variables() + + # check that ff_variables are a subset of img_variables + if set(ff_variables) - set(img_variables): + msg = ( + f"Flatfield variables are not a subset of image variables: " + f"{ff_variables} - {img_variables}" + ) + raise ValueError(msg) + + if (df_pattern is None) or (not df_pattern): + df_fp = None + else: + df_fp = FilePattern(str(ff_dir), df_pattern) + df_variables = df_fp.get_variables() + if set(df_variables) != set(ff_variables): + msg = ( + f"Flatfield and darkfield variables do not match: " + f"{ff_variables} != {df_variables}" + ) + raise ValueError(msg) + + for group, files in img_fp(group_by=ff_variables): + img_paths = [p for _, [p] in files] + variables = dict(group) + + ff_path: pathlib.Path = ff_fp.get_matching(**variables)[0][1][0] + + df_path = None if df_fp is None else df_fp.get_matching(**variables)[0][1][0] + + _unshade_images(img_paths, out_dir, ff_path, df_path) + + +def _unshade_images( + img_paths: list[pathlib.Path], + out_dir: pathlib.Path, + ff_path: pathlib.Path, + df_path: typing.Optional[pathlib.Path], +) -> None: + """Remove the given flatfield components from all images and save outputs. + + Args: + img_paths: list of paths to images to be processed + out_dir: directory to save the corrected images + ff_path: path to the flatfield image + df_path: path to the darkfield image + """ + with bfio.BioReader(ff_path, max_workers=2) as bf: + ff_image = bf[:, :, :, 0, 0].squeeze() + + if df_path is not None: + with bfio.BioReader(df_path, max_workers=2) as df: + df_image = df[:, :, :, 0, 0].squeeze() + else: + df_image = None + + batch_indices = list(range(0, len(img_paths), 16)) + if batch_indices[-1] != len(img_paths): + batch_indices.append(len(img_paths)) + + for i_start, i_end in tqdm.tqdm( + zip(batch_indices[:-1], batch_indices[1:]), + total=len(batch_indices) - 1, + ): + _unshade_batch( + img_paths[i_start:i_end], + out_dir, + ff_image, + df_image, + ) + + +def _unshade_batch( + batch_paths: list[pathlib.Path], + out_dir: pathlib.Path, + ff_image: numpy.ndarray, + df_image: typing.Optional[numpy.ndarray] = None, +) -> None: + """Apply flatfield correction to a batch of images. + + Args: + batch_paths: list of paths to images to be processed + out_dir: directory to save the corrected images + ff_image: component to be used for flatfield correction + df_image: component to be used for flatfield correction + """ + # Load images + images = [] + with concurrent.futures.ProcessPoolExecutor( + max_workers=utils.MAX_WORKERS, + ) as load_executor: + load_futures = [] + for i, inp_path in enumerate(batch_paths): + load_futures.append(load_executor.submit(utils.load_img, inp_path, i)) + + for lf in tqdm.tqdm( + concurrent.futures.as_completed(load_futures), + total=len(load_futures), + desc="Loading batch", + ): + images.append(lf.result()) + + images = [img for _, img in sorted(images, key=operator.itemgetter(0))] + img_stack = numpy.stack(images, axis=0) + + # Apply flatfield correction + if df_image is not None: + img_stack -= df_image + + img_stack /= ff_image + + # Save outputs + with concurrent.futures.ProcessPoolExecutor( + max_workers=utils.MAX_WORKERS, + ) as save_executor: + save_futures = [] + for inp_path, img in zip(batch_paths, img_stack): + save_futures.append( + save_executor.submit(utils.save_img, inp_path, img, out_dir), + ) + + for sf in tqdm.tqdm( + concurrent.futures.as_completed(save_futures), + total=len(save_futures), + desc="Saving batch", + ): + sf.result() diff --git a/transforms/images/apply-flatfield-plugin/tests/test_plugin.py b/transforms/images/apply-flatfield-plugin/tests/test_plugin.py new file mode 100644 index 000000000..886773807 --- /dev/null +++ b/transforms/images/apply-flatfield-plugin/tests/test_plugin.py @@ -0,0 +1,151 @@ +"""Tests for the plugin.""" + +import itertools +import logging +import pathlib +import shutil +import tempfile + +import bfio +import numpy +import pytest +import typer.testing +from polus.plugins.transforms.images.apply_flatfield import apply +from polus.plugins.transforms.images.apply_flatfield.__main__ import app + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def _make_random_image( + path: pathlib.Path, + rng: numpy.random.Generator, + size: int, +) -> None: + with bfio.BioWriter(path) as writer: + writer.X = size + writer.Y = size + writer.dtype = numpy.float32 + + writer[:] = rng.random(size=(size, size), dtype=writer.dtype) + + +FixtureReturnType = tuple[pathlib.Path, str, pathlib.Path, str] + + +def gen_once(num_groups: int, img_size: int) -> FixtureReturnType: + """Generate a set of random images for testing.""" + + img_pattern = "img_x{x}_c{c}.ome.tif" + ff_pattern = "img_x(1-10)_c{c}" + + img_dir = pathlib.Path(tempfile.mkdtemp(suffix="img_dir")) + ff_dir = pathlib.Path(tempfile.mkdtemp(suffix="ff_dir")) + + rng = numpy.random.default_rng(42) + + for i in range(num_groups): + ff_path = ff_dir.joinpath(f"{ff_pattern.format(c=i + 1)}_flatfield.ome.tif") + _make_random_image(ff_path, rng, img_size) + + df_path = ff_dir.joinpath(f"{ff_pattern.format(c=i + 1)}_darkfield.ome.tif") + _make_random_image(df_path, rng, img_size) + + for j in range(10): # 10 images in each group + img_path = img_dir.joinpath(img_pattern.format(x=j + 1, c=i + 1)) + _make_random_image(img_path, rng, img_size) + + image_names = list(sorted(p.name for p in img_dir.iterdir())) + logger.debug(f"Generated {image_names} images in {img_dir}") + + ff_names = list(sorted(p.name for p in ff_dir.iterdir())) + logger.debug(f"Generated {ff_names} flatfield images in {ff_dir}") + + img_pattern = "img_x{x:d+}_c{c:d}.ome.tif" + ff_pattern = "img_x\\(1-10\\)_c{c:d}" + return img_dir, img_pattern, ff_dir, ff_pattern + + +NUM_GROUPS = [2**i for i in range(3)] +IMG_SIZES = [1024 * 2**i for i in range(3)] +PARAMS = list(itertools.product(NUM_GROUPS, IMG_SIZES)) +IDS = [f"{num_groups}_{img_size}" for num_groups, img_size in PARAMS] + + +@pytest.fixture(params=PARAMS, ids=IDS) +def gen_images(request: pytest.FixtureRequest) -> FixtureReturnType: + """Generate a set of random images for testing.""" + num_groups: int + img_size: int + num_groups, img_size = request.param + img_dir, img_pattern, ff_dir, ff_pattern = gen_once(num_groups, img_size) + + yield img_dir, img_pattern, ff_dir, ff_pattern + + # Cleanup + shutil.rmtree(img_dir) + shutil.rmtree(ff_dir) + + +def test_estimate(gen_images: FixtureReturnType) -> None: + """Test the `estimate` function.""" + + img_dir, img_pattern, ff_dir, ff_pattern = gen_images + out_dir = pathlib.Path(tempfile.mkdtemp(suffix="out_dir")) + + apply( + img_dir, + img_pattern, + ff_dir, + f"{ff_pattern}_flatfield.ome.tif", + f"{ff_pattern}_darkfield.ome.tif", + out_dir, + ) + + img_names = [p.name for p in img_dir.iterdir()] + out_names = [p.name for p in out_dir.iterdir()] + + for name in img_names: + assert name in out_names, f"{name} not in {out_names}" + + shutil.rmtree(out_dir) + + +def test_cli() -> None: + """Test the CLI.""" + + img_dir, img_pattern, ff_dir, ff_pattern = gen_once(2, 2_048) + out_dir = pathlib.Path(tempfile.mkdtemp(suffix="out_dir")) + + runner = typer.testing.CliRunner() + + result = runner.invoke( + app, + [ + "--imgDir", + str(img_dir), + "--imgPattern", + img_pattern, + "--ffDir", + str(ff_dir), + "--brightPattern", + f"{ff_pattern}_flatfield.ome.tif", + "--darkPattern", + f"{ff_pattern}_darkfield.ome.tif", + "--outDir", + str(out_dir), + ], + ) + + assert result.exit_code == 0, result.stdout + + img_paths = set(p.name for p in img_dir.iterdir() if p.name.endswith(".ome.tif")) + + out_names = set(p.name for p in out_dir.iterdir() if p.name.endswith(".ome.tif")) + + assert img_paths == out_names, f"{(img_paths)} != {out_names}" + + # Cleanup + shutil.rmtree(img_dir) + shutil.rmtree(ff_dir) + shutil.rmtree(out_dir)