Skip to content

Commit

Permalink
fix: fixed all tests after filepattern PR
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Aug 2, 2023
1 parent 856a4de commit 41e0d99
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 0 deletions.
38 changes: 38 additions & 0 deletions transforms/images/apply-flatfield-plugin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[tool.poetry]
name = "polus-plugins-transforms-images-apply-flatfield"
version = "2.0.0-dev7"
description = ""
authors = [
"Nick Schaub <nicholas.schaub@nih.gov>",
"Najib Ishaq <najib.ishaq@nih.gov>"
]
readme = "README.md"
packages = [{include = "polus", from = "src"}]

[tool.poetry.dependencies]
python = "^3.9"
bfio = { version = "2.1.9", extras = ["all"] }
filepattern = [
{ version = "^2.0.0", platform = "linux" },
{ version = "^2.0.0", platform = "win32" },
# { git = "https://github.com/PolusAI/filepattern", rev = "c07bf543c435cbc4cf264effd5a178868e9eaf19", platform = "darwin" },
{ git = "https://github.com/JesseMckinzie/filepattern-1", rev = "c27cf04ba3a1946b87c0c43d5720ba394c340894", platform = "darwin" },
]
typer = { version = "^0.7.0", extras = ["all"] }
numpy = "^1.24.3"
tqdm = "^4.65.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
pytest-cov = "^4.0.0"
pytest-sugar = "^0.9.6"
pytest-xdist = "^3.2.0"
pytest-benchmark = "^4.0.0"
bump2version = "^1.0.1"
pre-commit = "^3.0.4"
black = "^23.1.0"
ruff = "^0.0.265"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Provides the apply_flatfield module."""

import concurrent.futures
import logging
import operator
import pathlib
import sys
import typing

import bfio
import numpy
import tqdm
from filepattern import FilePattern

from . import utils

logger = logging.getLogger(__name__)
logger.setLevel(utils.POLUS_LOG)


def apply( # noqa: PLR0913
img_dir: pathlib.Path,
img_pattern: str,
ff_dir: pathlib.Path,
ff_pattern: str,
df_pattern: typing.Optional[str],
out_dir: pathlib.Path,
) -> None:
"""Run batch-wise flatfield correction on the image collection."""
img_fp = FilePattern(str(img_dir), img_pattern)
img_variables = img_fp.get_variables()

ff_fp = FilePattern(str(ff_dir), ff_pattern)
ff_variables = ff_fp.get_variables()

# check that ff_variables are a subset of img_variables
if set(ff_variables) - set(img_variables):
msg = (
f"Flatfield variables are not a subset of image variables: "
f"{ff_variables} - {img_variables}"
)
raise ValueError(msg)

if (df_pattern is None) or (not df_pattern):
df_fp = None
else:
df_fp = FilePattern(str(ff_dir), df_pattern)
df_variables = df_fp.get_variables()
if set(df_variables) != set(ff_variables):
msg = (
f"Flatfield and darkfield variables do not match: "
f"{ff_variables} != {df_variables}"
)
raise ValueError(msg)

for group, files in img_fp(group_by=ff_variables):
img_paths = [p for _, [p] in files]
variables = dict(group)

ff_path: pathlib.Path = ff_fp.get_matching(**variables)[0][1][0]

df_path = None if df_fp is None else df_fp.get_matching(**variables)[0][1][0]

_unshade_images(img_paths, out_dir, ff_path, df_path)


def _unshade_images(
img_paths: list[pathlib.Path],
out_dir: pathlib.Path,
ff_path: pathlib.Path,
df_path: typing.Optional[pathlib.Path],
) -> None:
"""Remove the given flatfield components from all images and save outputs.
Args:
img_paths: list of paths to images to be processed
out_dir: directory to save the corrected images
ff_path: path to the flatfield image
df_path: path to the darkfield image
"""
with bfio.BioReader(ff_path, max_workers=2) as bf:
ff_image = bf[:, :, :, 0, 0].squeeze()

if df_path is not None:
with bfio.BioReader(df_path, max_workers=2) as df:
df_image = df[:, :, :, 0, 0].squeeze()
else:
df_image = None

batch_indices = list(range(0, len(img_paths), 16))
if batch_indices[-1] != len(img_paths):
batch_indices.append(len(img_paths))

for i_start, i_end in tqdm.tqdm(
zip(batch_indices[:-1], batch_indices[1:]),
total=len(batch_indices) - 1,
):
_unshade_batch(
img_paths[i_start:i_end],
out_dir,
ff_image,
df_image,
)


def _unshade_batch(
batch_paths: list[pathlib.Path],
out_dir: pathlib.Path,
ff_image: numpy.ndarray,
df_image: typing.Optional[numpy.ndarray] = None,
) -> None:
"""Apply flatfield correction to a batch of images.
Args:
batch_paths: list of paths to images to be processed
out_dir: directory to save the corrected images
ff_image: component to be used for flatfield correction
df_image: component to be used for flatfield correction
"""
# Load images
images = []
with concurrent.futures.ProcessPoolExecutor(
max_workers=utils.MAX_WORKERS,
) as load_executor:
load_futures = []
for i, inp_path in enumerate(batch_paths):
load_futures.append(load_executor.submit(utils.load_img, inp_path, i))

for lf in tqdm.tqdm(
concurrent.futures.as_completed(load_futures),
total=len(load_futures),
desc="Loading batch",
):
images.append(lf.result())

images = [img for _, img in sorted(images, key=operator.itemgetter(0))]
img_stack = numpy.stack(images, axis=0)

# Apply flatfield correction
if df_image is not None:
img_stack -= df_image

img_stack /= ff_image

# Save outputs
with concurrent.futures.ProcessPoolExecutor(
max_workers=utils.MAX_WORKERS,
) as save_executor:
save_futures = []
for inp_path, img in zip(batch_paths, img_stack):
save_futures.append(
save_executor.submit(utils.save_img, inp_path, img, out_dir),
)

for sf in tqdm.tqdm(
concurrent.futures.as_completed(save_futures),
total=len(save_futures),
desc="Saving batch",
):
sf.result()
151 changes: 151 additions & 0 deletions transforms/images/apply-flatfield-plugin/tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Tests for the plugin."""

import itertools
import logging
import pathlib
import shutil
import tempfile

import bfio
import numpy
import pytest
import typer.testing
from polus.plugins.transforms.images.apply_flatfield import apply
from polus.plugins.transforms.images.apply_flatfield.__main__ import app

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def _make_random_image(
path: pathlib.Path,
rng: numpy.random.Generator,
size: int,
) -> None:
with bfio.BioWriter(path) as writer:
writer.X = size
writer.Y = size
writer.dtype = numpy.float32

writer[:] = rng.random(size=(size, size), dtype=writer.dtype)


FixtureReturnType = tuple[pathlib.Path, str, pathlib.Path, str]


def gen_once(num_groups: int, img_size: int) -> FixtureReturnType:
"""Generate a set of random images for testing."""

img_pattern = "img_x{x}_c{c}.ome.tif"
ff_pattern = "img_x(1-10)_c{c}"

img_dir = pathlib.Path(tempfile.mkdtemp(suffix="img_dir"))
ff_dir = pathlib.Path(tempfile.mkdtemp(suffix="ff_dir"))

rng = numpy.random.default_rng(42)

for i in range(num_groups):
ff_path = ff_dir.joinpath(f"{ff_pattern.format(c=i + 1)}_flatfield.ome.tif")
_make_random_image(ff_path, rng, img_size)

df_path = ff_dir.joinpath(f"{ff_pattern.format(c=i + 1)}_darkfield.ome.tif")
_make_random_image(df_path, rng, img_size)

for j in range(10): # 10 images in each group
img_path = img_dir.joinpath(img_pattern.format(x=j + 1, c=i + 1))
_make_random_image(img_path, rng, img_size)

image_names = list(sorted(p.name for p in img_dir.iterdir()))
logger.debug(f"Generated {image_names} images in {img_dir}")

ff_names = list(sorted(p.name for p in ff_dir.iterdir()))
logger.debug(f"Generated {ff_names} flatfield images in {ff_dir}")

img_pattern = "img_x{x:d+}_c{c:d}.ome.tif"
ff_pattern = "img_x\\(1-10\\)_c{c:d}"
return img_dir, img_pattern, ff_dir, ff_pattern


NUM_GROUPS = [2**i for i in range(3)]
IMG_SIZES = [1024 * 2**i for i in range(3)]
PARAMS = list(itertools.product(NUM_GROUPS, IMG_SIZES))
IDS = [f"{num_groups}_{img_size}" for num_groups, img_size in PARAMS]


@pytest.fixture(params=PARAMS, ids=IDS)
def gen_images(request: pytest.FixtureRequest) -> FixtureReturnType:
"""Generate a set of random images for testing."""
num_groups: int
img_size: int
num_groups, img_size = request.param
img_dir, img_pattern, ff_dir, ff_pattern = gen_once(num_groups, img_size)

yield img_dir, img_pattern, ff_dir, ff_pattern

# Cleanup
shutil.rmtree(img_dir)
shutil.rmtree(ff_dir)


def test_estimate(gen_images: FixtureReturnType) -> None:
"""Test the `estimate` function."""

img_dir, img_pattern, ff_dir, ff_pattern = gen_images
out_dir = pathlib.Path(tempfile.mkdtemp(suffix="out_dir"))

apply(
img_dir,
img_pattern,
ff_dir,
f"{ff_pattern}_flatfield.ome.tif",
f"{ff_pattern}_darkfield.ome.tif",
out_dir,
)

img_names = [p.name for p in img_dir.iterdir()]
out_names = [p.name for p in out_dir.iterdir()]

for name in img_names:
assert name in out_names, f"{name} not in {out_names}"

shutil.rmtree(out_dir)


def test_cli() -> None:
"""Test the CLI."""

img_dir, img_pattern, ff_dir, ff_pattern = gen_once(2, 2_048)
out_dir = pathlib.Path(tempfile.mkdtemp(suffix="out_dir"))

runner = typer.testing.CliRunner()

result = runner.invoke(
app,
[
"--imgDir",
str(img_dir),
"--imgPattern",
img_pattern,
"--ffDir",
str(ff_dir),
"--brightPattern",
f"{ff_pattern}_flatfield.ome.tif",
"--darkPattern",
f"{ff_pattern}_darkfield.ome.tif",
"--outDir",
str(out_dir),
],
)

assert result.exit_code == 0, result.stdout

img_paths = set(p.name for p in img_dir.iterdir() if p.name.endswith(".ome.tif"))

out_names = set(p.name for p in out_dir.iterdir() if p.name.endswith(".ome.tif"))

assert img_paths == out_names, f"{(img_paths)} != {out_names}"

# Cleanup
shutil.rmtree(img_dir)
shutil.rmtree(ff_dir)
shutil.rmtree(out_dir)

0 comments on commit 41e0d99

Please sign in to comment.