Skip to content

Commit

Permalink
Merge pull request #84 from MannLabs/fix_linting
Browse files Browse the repository at this point in the history
implement pre-commit rules
  • Loading branch information
sophiamaedler authored Sep 26, 2024
2 parents 0a10ecd + c120cc2 commit b3adf54
Show file tree
Hide file tree
Showing 32 changed files with 300 additions and 269 deletions.
22 changes: 11 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ docstring-code-format = true

[tool.ruff.lint]
select = [
#"F", # Errors detected by Pyflakes
#"E", # Error detected by Pycodestyle
#"W", # Warning detected by Pycodestyle
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
"W", # Warning detected by Pycodestyle
"I", # isort
#"D", # pydocstyle
#"B", # flake8-bugbear
#"TID", # flake8-tidy-imports
#"C4", # flake8-comprehensions
#"BLE", # flake8-blind-except
#"UP", # pyupgrade
#"RUF100", # Report unused noqa directives
#"TCH", # Typing imports
#"NPY", # Numpy specific rules
"B", # flake8-bugbear
"TID", # flake8-tidy-imports
"C4", # flake8-comprehensions
"BLE", # flake8-blind-except
"UP", # pyupgrade
"RUF100", # Report unused noqa directives
"TCH", # Typing imports
"NPY", # Numpy specific rules
#"PTH" # Use pathlib
]
ignore = [
Expand Down
20 changes: 14 additions & 6 deletions src/scportrait/io/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _load_input_image(self):

self.log("Finished loading input image to memory mapped temp array.")

def write_input_image_to_spatialdata(self, scale_factors=[2, 4, 8]):
def write_input_image_to_spatialdata(self, scale_factors=None):
"""
Write the input image found under the label "channels" in the segmentation.h5 file to a spatialdata object.
Expand All @@ -138,6 +138,8 @@ def write_input_image_to_spatialdata(self, scale_factors=[2, 4, 8]):
"""

# reconnect to temporary image as a dask array
if scale_factors is None:
scale_factors = [2, 4, 8]
temp_image = daskmmap.dask_array_from_path(self.temp_image_path)

if self.channel_names is None:
Expand Down Expand Up @@ -186,7 +188,7 @@ def _load_segmentation(self):
self.log("No segmentation found in project.")
self.segmentation_status = False

def write_segmentation_to_spatialdata(self, scale_factors=[]):
def write_segmentation_to_spatialdata(self, scale_factors=None):
"""
Write the segmentation masks found under the label "labels" in the segmentation.h5 file to a spatialdata object.
Expand All @@ -197,6 +199,8 @@ def write_segmentation_to_spatialdata(self, scale_factors=[]):
In the future this behaviour may be changed but at the moment scPortrait is not designed to handle multiple resolutions for segmentation masks.
"""

if scale_factors is None:
scale_factors = []
if self.segmentation_status is None:
# reconnect to temporary image as a dask array
temp_segmentation = daskmmap.dask_array_from_path(self.temp_segmentation_path)
Expand Down Expand Up @@ -249,10 +253,14 @@ def _lookup_region_annotations(self):

return region_lookup

def add_multiscale_segmentation(self, region_keys=["seg_all_nucleus", "seg_all_cytosol"], scale_factors=[2, 4, 8]):
def add_multiscale_segmentation(self, region_keys=None, scale_factors=None):
"""
Add multiscale segmentation to the spatialdata object.
"""
if scale_factors is None:
scale_factors = [2, 4, 8]
if region_keys is None:
region_keys = ["seg_all_nucleus", "seg_all_cytosol"]
region_lookup = self._lookup_region_annotations()
sdata = SpatialData.read(self._get_sdata_path())

Expand Down Expand Up @@ -360,9 +368,9 @@ def _read_classification_results(self, classification_result):
table = anndata.AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs)
return table

def write_classification_result_to_spatialdata(
self, classification_result, segmentation_regions=["seg_all_nucleus", "seg_all_cytosol"]
):
def write_classification_result_to_spatialdata(self, classification_result, segmentation_regions=None):
if segmentation_regions is None:
segmentation_regions = ["seg_all_nucleus", "seg_all_cytosol"]
class_result = self._read_classification_results(classification_result)
sdata = SpatialData.read(self._get_sdata_path())

Expand Down
13 changes: 7 additions & 6 deletions src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch


class Logable(object):
class Logable:
"""
Object which can create log entries.
Expand Down Expand Up @@ -55,9 +55,10 @@ def log(self, message):
else:
try:
lines = [str(message)]
except Exception:
self.log("unknown type during logging")
return
except (TypeError, ValueError):
raise TypeError(
"Message must be a string, list of strings or a dictionary, but received type: ", type(message)
) from None

for line in lines:
log_path = os.path.join(self.directory, self.DEFAULT_LOG_NAME)
Expand Down Expand Up @@ -215,7 +216,7 @@ def __call__(self, *args, debug=None, overwrite=None, **kwargs):
return x
else:
self.clear_temp_dir() # also ensure clearing if not callable just to make sure everything is cleaned up
warnings.warn("no process method defined")
Warning("no process method defined.")

def __call_empty__(self, *args, debug=None, overwrite=None, **kwargs):
"""Call the empty processing step.
Expand All @@ -242,7 +243,7 @@ def __call_empty__(self, *args, debug=None, overwrite=None, **kwargs):
x = self.return_empty_mask(*args, **kwargs)
return x
else:
warnings.warn("no return_empty_mask method defined")
Warning("no return_empty_mask method defined")

# also clear empty temp directory here
self.clear_temp_dir()
Expand Down
9 changes: 4 additions & 5 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import shutil
from typing import List, Tuple

import datatree
import xarray
Expand Down Expand Up @@ -134,7 +133,7 @@ def _write_segmentation_sdata(
segmentation,
segmentation_label: str,
classes: set = None,
chunks: Tuple[int, int] = (1000, 1000),
chunks: tuple[int, int] = (1000, 1000),
overwrite: bool = False,
):
transform_original = Identity()
Expand Down Expand Up @@ -248,7 +247,7 @@ def _load_input_image_to_memmap(self, tmp_dir_abs_path: str, image=None):

def _load_seg_to_memmap(
self,
seg_name: List[str],
seg_name: list[str],
tmp_dir_abs_path: str,
):
"""
Expand All @@ -259,7 +258,7 @@ def _load_seg_to_memmap(
Parameters
----------
seg_name : List[str]
seg_name : list[str]
List of segmentation element names that should be loaded found in the sdata object.
The segmentation elments need to have the same size.
tmp_dir_abs_path : str
Expand All @@ -276,7 +275,7 @@ def _load_seg_to_memmap(
_sdata = self._check_sdata_status(return_sdata=True)

assert all(
[seg in _sdata.labels for seg in seg_name]
seg in _sdata.labels for seg in seg_name
), "Not all passed segmentation elements found in sdata object."

seg_objects = [_sdata.labels[seg] for seg in seg_name]
Expand Down
6 changes: 4 additions & 2 deletions src/scportrait/pipeline/_utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _return_edge_labels_2d(input_map):
.union(set(last_column.flatten()))
)

full_union = set([np.uint64(i) for i in full_union])
full_union = set([np.uint64(i) for i in full_union]) # noqa: C403 set comprehensions are not supported by numba
full_union.discard(0)

return list(full_union)
Expand Down Expand Up @@ -780,7 +780,7 @@ def _class_size(mask, debug=False, background=0):
return mean_arr, length.flatten()


def size_filter(label, limits=[0, 100000], background=0, reindex=False):
def size_filter(label, limits=None, background=0, reindex=False):
"""
Filter classes in a labeled array based on their size (number of pixels).
Expand Down Expand Up @@ -817,6 +817,8 @@ def size_filter(label, limits=[0, 100000], background=0, reindex=False):
"""

# Calculate the number of pixels for each class in the labeled array
if limits is None:
limits = [0, 100000]
_, points_class = _class_size(label)

# Find the classes with size below the lower limit and above the upper limit
Expand Down
6 changes: 3 additions & 3 deletions src/scportrait/pipeline/_utils/spatialdata_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import singledispatchmethod
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any

from dask.array import Array as DaskArray
from dask.array import unique as DaskUnique
Expand All @@ -19,7 +19,7 @@

class spLabels2DModel(Labels2DModel):
# add an additional attribute that always contains the unique classes in a labels image
attrs = AttrsSchema({"transform": Transform_s}, {"cell_ids": Set})
attrs = AttrsSchema({"transform": Transform_s}, {"cell_ids": set})

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -44,7 +44,7 @@ def _get_cell_ids(data, remove_background=True):
return data

@singledispatchmethod
def convert(self, data: Union[DataTree, DataArray], classes: set = None) -> Union[DataTree, DataArray]:
def convert(self, data: DataTree | DataArray, classes: set = None) -> DataTree | DataArray:
""" """
raise ValueError(f"Unsupported data type: {type(data)}. Please use .convert() from Labels2DModel instead.")

Expand Down
12 changes: 5 additions & 7 deletions src/scportrait/pipeline/_utils/spatialdata_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Dict, List, Set, Tuple, Union

import datatree
import numpy as np
import pandas as pd
Expand All @@ -24,7 +22,7 @@ def check_memory(item):
return array_size < available_memory


def generate_region_annotation_lookuptable(sdata: SpatialData) -> Dict:
def generate_region_annotation_lookuptable(sdata: SpatialData) -> dict:
"""Generate a lookup table for the region annotation tables contained in a SpatialData object ordered according to the region they annotate.
Parameters
Expand Down Expand Up @@ -80,7 +78,7 @@ def remap_region_annotation_table(table: TableModel, region_name: str) -> TableM
return table


def get_chunk_size(element: Union[datatree.DataTree, xarray.DataArray]) -> Union[Tuple, List[Tuple]]:
def get_chunk_size(element: datatree.DataTree | xarray.DataArray) -> tuple | list[tuple]:
"""Get the chunk size of the image data.
Parameters
Expand Down Expand Up @@ -148,8 +146,8 @@ def get_chunk_size(element: Union[datatree.DataTree, xarray.DataArray]) -> Union


def rechunk_image(
element: Union[datatree.DataTree, xarray.DataArray], chunk_size: Tuple
) -> Union[datatree.DataTree, xarray.DataArray]:
element: datatree.DataTree | xarray.DataArray, chunk_size: tuple
) -> datatree.DataTree | xarray.DataArray:
"""
Rechunk the image data to the desired chunksize. This is useful for ensuring that the data is chunked in a regular manner.
Expand Down Expand Up @@ -183,7 +181,7 @@ def rechunk_image(

def make_centers_object(
centers: np.ndarray,
ids: List,
ids: list,
transformation: str,
coordinate_system="global",
):
Expand Down
33 changes: 14 additions & 19 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import shutil
from contextlib import redirect_stdout
from functools import partial as func_partial
from typing import List, Union

import numpy as np
import pandas as pd
Expand All @@ -20,11 +19,9 @@


class _ClassificationBase(ProcessingStep):
PRETRAINED_MODEL_NAMES = list(
[
"autophagy_classifier",
]
)
PRETRAINED_MODEL_NAMES = [
"autophagy_classifier",
]
MASK_NAMES = ["nucleus", "cytosol"]

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -219,7 +216,7 @@ def _get_gpu_memory_usage(self):
memory_usage.append(gpu_memory)
results = {f"GPU_{i}": f"{memory_usage[i]} MiB" for i in range(len(memory_usage))}
return results
except Exception as e:
except (RuntimeError, ValueError) as e:
print("Error:", e)
return None

Expand All @@ -228,7 +225,7 @@ def _get_gpu_memory_usage(self):
used_memory = torch.mps.driver_allocated_memory() + torch.mps.driver_allocated_memory()
used_memory = used_memory / 1024**2 # Convert bytes to MiB
return {"MPS": f"{memory_usage} MiB"}
except Exception as e:
except (RuntimeError, ValueError) as e:
print("Error:", e)
return None

Expand Down Expand Up @@ -303,8 +300,8 @@ def _load_pretrained_model(self, model_name: str):
def _load_model(
self,
ckpt_path,
hparams_path: Union[str, None] = None,
model_type: Union[str, None] = None,
hparams_path: str | None = None,
model_type: str | None = None,
) -> pl.LightningModule:
"""Load a model from a checkpoint file and transfer it to the inference device.
Expand Down Expand Up @@ -373,14 +370,14 @@ def _load_model(
def load_model(
self,
ckpt_path,
hparams_path: Union[str, None] = None,
model_type: Union[str, None] = None,
hparams_path: str | None = None,
model_type: str | None = None,
):
model = self._load_model(ckpt_path, hparams_path, model_type)
self._assign_model(model)

### Functions regarding dataloading and transforms ####
def configure_transforms(self, selected_transforms: List):
def configure_transforms(self, selected_transforms: list):
self.transforms = transforms.Compose(selected_transforms)
self.log(f"The following transforms were applied: {self.transforms}")

Expand All @@ -389,7 +386,7 @@ def generate_dataloader(
extraction_dir: str,
selected_transforms: transforms.Compose = transforms.Compose([]),
size: int = 0,
seed: Union[int, None] = 42,
seed: int | None = 42,
dataset_class=HDF5SingleCellDataset,
) -> torch.utils.data.DataLoader:
"""Create a pytorch dataloader from the provided single-cell image dataset.
Expand Down Expand Up @@ -1022,7 +1019,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee
)

# perform inference
for model_name, model in zip(self.model_names, self.model):
for model_name, model in zip(self.model_names, self.model, strict=False):
self.log(f"Starting inference for model {model_name}")
results = self.inference(self.dataloader, model)

Expand Down Expand Up @@ -1080,7 +1077,7 @@ def _generate_column_names(
self,
n_masks: int = 2,
n_channels: int = 3,
channel_names: Union[List, None] = None,
channel_names: list | None = None,
) -> None:
column_names = []

Expand Down Expand Up @@ -1217,9 +1214,7 @@ def _write_results_sdata(self, results, mask_type="seg_all"):
)

# define name to save table under
label = self.label.replace(
"CellFeaturizer_", ""
) # remove class name from label to ensure we dont have duplicates
self.label.replace("CellFeaturizer_", "") # remove class name from label to ensure we dont have duplicates

if self.channel_classification is not None:
table_name = f"{self.__class__.__name__ }_{self.config['channel_classification']}_{self.MASK_NAMES[0]}"
Expand Down
Loading

0 comments on commit b3adf54

Please sign in to comment.