From bcf0762c334bcf08ae069738e30123572333b5c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:54:32 +0200 Subject: [PATCH 01/12] implement flake8-blind-except and fix all issues --- pyproject.toml | 2 +- src/scportrait/pipeline/_base.py | 6 ++++-- src/scportrait/pipeline/classification.py | 4 ++-- .../pipeline/mask_filtering/filter_segmentation.py | 9 ++++++--- src/scportrait/pipeline/segmentation/segmentation.py | 6 ++++-- src/scportrait/pipeline/segmentation/workflows.py | 2 +- src/scportrait/pipeline/selection.py | 3 ++- src/scportrait/tools/ml/datasets.py | 3 ++- src/scportrait/tools/ml/plmodels.py | 2 +- tests/processing_test.py | 4 ++-- 10 files changed, 25 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58e5db3a..759c96c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ select = [ #"B", # flake8-bugbear #"TID", # flake8-tidy-imports #"C4", # flake8-comprehensions - #"BLE", # flake8-blind-except + "BLE", # flake8-blind-except #"UP", # pyupgrade #"RUF100", # Report unused noqa directives #"TCH", # Typing imports diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index cf12d657..afd01e19 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -55,8 +55,10 @@ def log(self, message): else: try: lines = [str(message)] - except Exception: - self.log("unknown type during logging") + except (TypeError, ValueError): + raise TypeError( + "Message must be a string, list of strings or a dictionary, but recieved type: ", type(message) + ) return for line in lines: diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index 8674cf37..f8dc48e9 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -219,7 +219,7 @@ def _get_gpu_memory_usage(self): memory_usage.append(gpu_memory) results = {f"GPU_{i}": f"{memory_usage[i]} MiB" for i in range(len(memory_usage))} return results - except Exception as e: + except (RuntimeError, ValueError) as e: print("Error:", e) return None @@ -228,7 +228,7 @@ def _get_gpu_memory_usage(self): used_memory = torch.mps.driver_allocated_memory() + torch.mps.driver_allocated_memory() used_memory = used_memory / 1024**2 # Convert bytes to MiB return {"MPS": f"{memory_usage} MiB"} - except Exception as e: + except (RuntimeError, ValueError) as e: print("Error:", e) return None diff --git a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py index 76fa67f7..ce697625 100644 --- a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py +++ b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py @@ -113,13 +113,15 @@ def call_as_tile(self): try: self.log(f"Beginning filtering on tile in position [{self.window[0]}, {self.window[1]}]") super().__call__(input_image) - except Exception: + except (IOError, ValueError, RuntimeError) as e: + self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) else: print(f"Tile in position [{self.window[0]}, {self.window[1]}] only contained zeroes.") try: super().__call_empty__(input_image) - except Exception: + except (IOError, ValueError, RuntimeError) as e: + self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) del input_image @@ -303,7 +305,8 @@ def execute_tile_list(self, tile_list, n_cpu=None): def f(x): try: x.call_as_tile() - except Exception: + except (IOError, ValueError, RuntimeError) as e: + self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) return x.get_output() diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 33fd27cb..9645b96e 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -388,7 +388,8 @@ def _call_as_shard(self): try: self._execute_segmentation(input_image) self.clear_temp_dir() - except Exception: + except (RuntimeError, ValueError, TypeError) as e: + self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) self.clear_temp_dir() else: @@ -396,7 +397,8 @@ def _call_as_shard(self): try: super().__call_empty__(input_image) self.clear_temp_dir() - except Exception: + except (RuntimeError, ValueError, TypeError) as e: + self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) self.clear_temp_dir() diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 23ab5b02..2cc0bb93 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -1205,7 +1205,7 @@ def _check_gpu_status(self): self.gpu_id = gpu_id_list[cpu_id] self.status = "multi_GPU" - except Exception: + except (AttributeError, ValueError): # default to single GPU self.gpu_id = 0 self.status = "potentially_single_GPU" diff --git a/src/scportrait/pipeline/selection.py b/src/scportrait/pipeline/selection.py index a2eab406..1fcaa09c 100644 --- a/src/scportrait/pipeline/selection.py +++ b/src/scportrait/pipeline/selection.py @@ -32,7 +32,8 @@ def _setup_selection(self): if self.name is None: try: name = "_".join([cell_set["name"] for cell_set in self.cell_sets]) - except Exception: + except KeyError: + Warning("No name provided for the selection. Will use default name.") name = "selected_cells" # create savepath diff --git a/src/scportrait/tools/ml/datasets.py b/src/scportrait/tools/ml/datasets.py index e248efaa..9e99723d 100644 --- a/src/scportrait/tools/ml/datasets.py +++ b/src/scportrait/tools/ml/datasets.py @@ -160,7 +160,8 @@ def _add_hdf_to_index( for row in index_handle: self.data_locator.append([label, handle_id] + list(row)) - except Exception: + except (FileNotFoundError, KeyError, OSError) as e: + print(f"Error: {e}") return def _add_dataset( diff --git a/src/scportrait/tools/ml/plmodels.py b/src/scportrait/tools/ml/plmodels.py index cdb99f59..5b7c807f 100644 --- a/src/scportrait/tools/ml/plmodels.py +++ b/src/scportrait/tools/ml/plmodels.py @@ -474,7 +474,7 @@ def training_step(self, batch, batch_idx): try: if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): print(type(obj), obj.size()) - except Exception: + except (AttributeError, TypeError): pass def validation_step(self, batch, batch_idx): diff --git a/tests/processing_test.py b/tests/processing_test.py index 24feba6f..125a6769 100644 --- a/tests/processing_test.py +++ b/tests/processing_test.py @@ -288,7 +288,7 @@ def test_visualize_class(): # Since this function does not return anything, we just check if it produces any exceptions try: visualize_class(class_ids, seg_map, background) - except Exception as e: + except (ValueError, TypeError) as e: pytest.fail(f"visualize_class raised exception: {str(e)}") @@ -299,7 +299,7 @@ def test_plot_image(tmpdir): # Since this function does not return anything, we just check if it produces any exceptions try: plot_image(array, size=(5, 5), save_name=save_name) - except Exception as e: + except (ValueError, TypeError, IOError) as e: pytest.fail(f"plot_image raised exception: {str(e)}") assert os.path.isfile(str(save_name) + ".png") From 139fac901f8945b38c8cfc622fee8e600c67de6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:55:24 +0200 Subject: [PATCH 02/12] implement RUF100 linting rule --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 759c96c7..bf9626b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ select = [ #"C4", # flake8-comprehensions "BLE", # flake8-blind-except #"UP", # pyupgrade - #"RUF100", # Report unused noqa directives + "RUF100", # Report unused noqa directives #"TCH", # Typing imports #"NPY", # Numpy specific rules #"PTH" # Use pathlib From f6916a1032c99929fe831d1b0d3566ed49dab54f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:04:12 +0200 Subject: [PATCH 03/12] activate TCH --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf9626b8..81adbe35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ select = [ "BLE", # flake8-blind-except #"UP", # pyupgrade "RUF100", # Report unused noqa directives - #"TCH", # Typing imports + "TCH", # Typing imports #"NPY", # Numpy specific rules #"PTH" # Use pathlib ] From d1fa35c0a14bfe5667b5ab860ca1d350bf08be24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:06:18 +0200 Subject: [PATCH 04/12] activate flake8-comprehensions rule and fix all linting issues --- pyproject.toml | 2 +- src/scportrait/pipeline/_utils/sdata_io.py | 2 +- .../pipeline/_utils/segmentation.py | 2 +- src/scportrait/pipeline/classification.py | 8 ++--- src/scportrait/pipeline/project.py | 4 +-- .../pipeline/segmentation/segmentation.py | 4 +-- .../pipeline/segmentation/workflows.py | 24 +++++++------- src/scportrait/plotting/vis.py | 2 +- .../processing/masks/mask_filtering.py | 4 +-- src/scportrait/tools/stitch/_stitch.py | 20 ++++++------ .../tools/stitch/_utils/ashlar_plotting.py | 2 +- .../tools/stitch/_utils/filewriters.py | 2 +- .../stitch/_utils/parallelized_ashlar.py | 32 +++++++++---------- tests/e2e_tests/segmentation_workflow.py | 4 +-- 14 files changed, 55 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 81adbe35..68ffdfc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ select = [ #"D", # pydocstyle #"B", # flake8-bugbear #"TID", # flake8-tidy-imports - #"C4", # flake8-comprehensions + "C4", # flake8-comprehensions "BLE", # flake8-blind-except #"UP", # pyupgrade "RUF100", # Report unused noqa directives diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index e508fa73..52e7713d 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -276,7 +276,7 @@ def _load_seg_to_memmap( _sdata = self._check_sdata_status(return_sdata=True) assert all( - [seg in _sdata.labels for seg in seg_name] + seg in _sdata.labels for seg in seg_name ), "Not all passed segmentation elements found in sdata object." seg_objects = [_sdata.labels[seg] for seg in seg_name] diff --git a/src/scportrait/pipeline/_utils/segmentation.py b/src/scportrait/pipeline/_utils/segmentation.py index 8db5af19..76eda2a7 100644 --- a/src/scportrait/pipeline/_utils/segmentation.py +++ b/src/scportrait/pipeline/_utils/segmentation.py @@ -373,7 +373,7 @@ def _return_edge_labels_2d(input_map): .union(set(last_column.flatten())) ) - full_union = set([np.uint64(i) for i in full_union]) + full_union = {np.uint64(i) for i in full_union} full_union.discard(0) return list(full_union) diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index f8dc48e9..4e8efb12 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -20,11 +20,9 @@ class _ClassificationBase(ProcessingStep): - PRETRAINED_MODEL_NAMES = list( - [ - "autophagy_classifier", - ] - ) + PRETRAINED_MODEL_NAMES = [ + "autophagy_classifier", + ] MASK_NAMES = ["nucleus", "cytosol"] def __init__(self, *args, **kwargs): diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 11e569b2..da1b1c30 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -291,7 +291,7 @@ def _check_chunk_size(self, elem): if isinstance(chunk_size, list): # check if all chunk sizes are the same otherwise rechunking needs to occur anyways - if not all([x == chunk_size[0] for x in chunk_size]): + if not all(x == chunk_size[0] for x in chunk_size): elem = rechunk_image(elem, chunks=self.DEFAULT_CHUNK_SIZE) else: # ensure that the chunk size is the default chunk size @@ -513,7 +513,7 @@ def _load_seg_to_memmap(self, seg_name: List[str], tmp_dir_abs_path: str): # get the segmentation object assert all( - [seg in self.sdata.labels for seg in seg_name] + seg in self.sdata.labels for seg in seg_name ), "Not all passed segmentation elements found in sdata object." seg_objects = [self.sdata.labels[seg] for seg in seg_name] diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 9645b96e..8d174503 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -790,7 +790,7 @@ def _resolve_sharding(self, sharding_plan): # if this is not the case it is worth investigating and there it can be helpful to see which classes are contained in the mask but not in the classes file and vice versa if self.deep_debug: masks = hdf_labels[:, :, :] - unique_ids = set(np.unique(masks[0])) - set([0]) + unique_ids = set(np.unique(masks[0])) - {0} self.log(f"Total number of classes in final segmentation after processing: {len(unique_ids)}") difference_classes = filtered_classes_combined.difference(unique_ids) @@ -817,7 +817,7 @@ def _resolve_sharding(self, sharding_plan): self.log(f"Finished stitching tile {i} in {time.time() - timer} seconds.") # remove background class - filtered_classes_combined = filtered_classes_combined - set([0]) + filtered_classes_combined = filtered_classes_combined - {0} self.log(f"Number of filtered classes in Dataset: {len(filtered_classes_combined)}") diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 2cc0bb93..fc75c393 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -842,7 +842,7 @@ def _nucleus_segmentation(self, input_image, debug: bool = False): if self.contact_filter_nuclei: if self.debug: - n_classes = len(set(np.unique(self.maps["nucleus_segmentation"])) - set([0])) + n_classes = len(set(np.unique(self.maps["nucleus_segmentation"])) - {0}) self.maps["nucleus_segmentation"] = contact_filter( self.maps["nucleus_segmentation"], @@ -851,7 +851,7 @@ def _nucleus_segmentation(self, input_image, debug: bool = False): ) if self.debug: - n_classes_post = len(set(np.unique(self.maps["nucleus_segmentation"])) - set([0])) + n_classes_post = len(set(np.unique(self.maps["nucleus_segmentation"])) - {0}) self.log(f"Filtered out {n_classes - n_classes_post} nuclei due to contact filtering.") def _cytosol_segmentation(self, input_image, debug: bool = False): @@ -975,7 +975,7 @@ def _cytosol_segmentation(self, input_image, debug: bool = False): if self.contact_filter_cytosol: if self.debug: - n_classes = len(set(np.unique(self.maps["cytosol_segmentation"])) - set([0])) + n_classes = len(set(np.unique(self.maps["cytosol_segmentation"])) - {0}) self.maps["cytosol_segmentation"] = contact_filter( self.maps["cytosol_segmentation"], @@ -984,10 +984,10 @@ def _cytosol_segmentation(self, input_image, debug: bool = False): ) if self.debug: - n_classes_post = len(set(np.unique(self.maps["cytosol_segmentation"])) - set([0])) + n_classes_post = len(set(np.unique(self.maps["cytosol_segmentation"])) - {0}) self.log(f"Filtered out {n_classes - n_classes_post} cytosols due to contact filtering.") - unique_cytosol_ids = set(np.unique(self.maps["cytosol_segmentation"])) - set([0]) + unique_cytosol_ids = set(np.unique(self.maps["cytosol_segmentation"])) - {0} # remove any ids from nucleus mask that dont have a cytosol mask self.maps["nucleus_segmentation"][~np.isin(self.maps["nucleus_segmentation"], list(unique_cytosol_ids))] = 0 @@ -1050,7 +1050,7 @@ def process(self, input_image): if self.debug: self._visualize_final_masks() - all_classes = list(set(np.unique(self.maps["nucleus_segmentation"])) - set([0])) + all_classes = list(set(np.unique(self.maps["nucleus_segmentation"])) - {0}) segmentation = self._finalize_segmentation_results() # type: ignore print("Channels shape: ", segmentation.shape) @@ -1110,7 +1110,7 @@ def process(self, input_image): if self.segment_nuclei: self._nucleus_segmentation(input_image[0], debug=self.debug) - all_classes = list(set(np.unique(self.maps["nucleus_segmentation"])) - set([0])) + all_classes = list(set(np.unique(self.maps["nucleus_segmentation"])) - {0}) segmentation = self._finalize_segmentation_results() results = self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) @@ -1323,7 +1323,7 @@ def process(self, input_image): self.cellpose_segmentation(input_image) # finalize classes list - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - set([0]) + all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} segmentation = self._finalize_segmentation_results() self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) @@ -1481,7 +1481,7 @@ def process(self, input_image): self.cellpose_segmentation(input_image) # finalize segmentation classes ensuring that background is removed - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - set([0]) + all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} segmentation = self._finalize_segmentation_results() self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) @@ -1562,7 +1562,7 @@ def process(self, input_image): self.cellpose_segmentation(input_image) # finalize classes list - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - set([0]) + all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} segmentation = self._finalize_segmentation_results() @@ -1674,7 +1674,7 @@ def _execute_segmentation(self, input_image) -> None: self.segmentation_time = stop_segmentation - start_segmentation # get final classes list - all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - set([0]) + all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - {0} segmentation = self._finalize_segmentation_results() self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) @@ -1774,7 +1774,7 @@ def process(self, input_image) -> None: self.cellpose_segmentation(input_image) # currently no implemented filtering steps to remove nuclei outside of specific thresholds - all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - set([0]) + all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - {0} segmentation = self._finalize_segmentation_results() # type: ignore diff --git a/src/scportrait/plotting/vis.py b/src/scportrait/plotting/vis.py index 04929438..67a52256 100644 --- a/src/scportrait/plotting/vis.py +++ b/src/scportrait/plotting/vis.py @@ -76,7 +76,7 @@ def visualize_class(class_ids, seg_map, image, all_ids=None, return_fig=False, * class_ids = list(class_ids) if all_ids is None: - all_ids = set(np.unique(seg_map)) - set([0]) + all_ids = set(np.unique(seg_map)) - {0} # get the ids to keep keep_ids = list(all_ids - set(class_ids)) diff --git a/src/scportrait/processing/masks/mask_filtering.py b/src/scportrait/processing/masks/mask_filtering.py index 271d401f..998d1c4d 100644 --- a/src/scportrait/processing/masks/mask_filtering.py +++ b/src/scportrait/processing/masks/mask_filtering.py @@ -884,8 +884,8 @@ def visualize_filtering_results(self, return_fig=True, return_maps=False, plot_f nuc_mask = self.nucleus_mask.copy() cyto_mask = self.cytosol_mask.copy() - class_ids_nuc = set(np.unique(nuc_mask)) - set([0]) - class_ids_cyto = set(np.unique(cyto_mask)) - set([0]) + class_ids_nuc = set(np.unique(nuc_mask)) - {0} + class_ids_cyto = set(np.unique(cyto_mask)) - {0} # get the ids to visualize as red for discarded ids_discard_nuc = self.nuclei_discard_list diff --git a/src/scportrait/tools/stitch/_stitch.py b/src/scportrait/tools/stitch/_stitch.py index a8688eee..dd7da9e1 100644 --- a/src/scportrait/tools/stitch/_stitch.py +++ b/src/scportrait/tools/stitch/_stitch.py @@ -631,11 +631,11 @@ def assemble_mosaic(self): for i, channel in enumerate(self.channels): args.append((channel, i, hdf5_path)) - tqdm_args = dict( - file=sys.stdout, - desc="assembling mosaic", - total=len(self.channels), - ) + tqdm_args = { + "file": sys.stdout, + "desc": "assembling mosaic", + "total": len(self.channels), + } # threading over channels is safe as the channels are written to different postions in the hdf5 file and do not interact with one another # threading over the writing of a single channel is not safe and leads to inconsistent results @@ -664,11 +664,11 @@ def write_tif_parallel(self, export_xml=True): filenames.append(filename) args.append((filename, i)) - tqdm_args = dict( - file=sys.stdout, - desc="writing tif files", - total=len(self.channels), - ) + tqdm_args = { + "file": sys.stdout, + "desc": "writing tif files", + "total": len(self.channels), + } # define helper function to execute in threadpooler def _write_tif(args): diff --git a/src/scportrait/tools/stitch/_utils/ashlar_plotting.py b/src/scportrait/tools/stitch/_utils/ashlar_plotting.py index b5d710f5..38f7a64f 100644 --- a/src/scportrait/tools/stitch/_utils/ashlar_plotting.py +++ b/src/scportrait/tools/stitch/_utils/ashlar_plotting.py @@ -33,7 +33,7 @@ def plot_edge_quality(aligner, outdir, img=None, show_tree=True, pos="metadata", im_kwargs = {} if nx_kwargs is None: nx_kwargs = {} - final_nx_kwargs = dict(width=2, node_size=100, font_size=6) + final_nx_kwargs = {"width": 2, "node_size": 100, "font_size": 6} final_nx_kwargs.update(nx_kwargs) if show_tree: nrows, ncols = 1, 2 diff --git a/src/scportrait/tools/stitch/_utils/filewriters.py b/src/scportrait/tools/stitch/_utils/filewriters.py index 4d07467a..dd0417d5 100644 --- a/src/scportrait/tools/stitch/_utils/filewriters.py +++ b/src/scportrait/tools/stitch/_utils/filewriters.py @@ -117,7 +117,7 @@ def write_ome_zarr( max_layer=n_downscaling_layers, method="nearest", ) # increase downscale so that large slides can also be opened in napari - write_image(image, group=group, axes=axes, storage_options=dict(chunks=chunk_size), scaler=scaler) + write_image(image, group=group, axes=axes, storage_options={"chunks": chunk_size}, scaler=scaler) def write_xml(image_paths: List[str], channels: List[str], slidename: str, outdir: str = None): diff --git a/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py b/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py index 5e5c684f..ea50028b 100644 --- a/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py +++ b/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py @@ -151,11 +151,11 @@ def register(t1, t2, offset1, offset2): errors = execute_indexed_parallel( register, args=args, - tqdm_args=dict( - file=sys.stdout, - disable=not self.verbose, - desc=" quantifying alignment error", - ), + tqdm_args={ + "file": sys.stdout, + "disable": not self.verbose, + "desc": " quantifying alignment error", + }, n_threads=self.n_threads, ) @@ -172,11 +172,11 @@ def register_all(self): execute_parallel( self.register_pair, args=args, - tqdm_args=dict( - file=sys.stdout, - disable=not self.verbose, - desc=" aligning edge", - ), + tqdm_args={ + "file": sys.stdout, + "disable": not self.verbose, + "desc": " aligning edge", + }, n_threads=self.n_threads, ) @@ -386,12 +386,12 @@ def assemble_channel_parallel(self, channel, ch_index, out=None, hdf5_path=None, "if specifying an out array, you also need to pass the HDF5 path of the memory mapped temparray" ) - tqdm_args = dict( - file=sys.stdout, - disable=not self.verbose, - desc=f"assembling channel {ch_index}", - total=len(self.aligner.positions), - ) + tqdm_args = { + "file": sys.stdout, + "disable": not self.verbose, + "desc": f"assembling channel {ch_index}", + "total": len(self.aligner.positions), + } # this can not be multi-threaded as it leads to inconsistent results in the overlap array # threading over the channels was the easiest and most robust way to implement diff --git a/tests/e2e_tests/segmentation_workflow.py b/tests/e2e_tests/segmentation_workflow.py index a81a8ee9..9092b0e2 100644 --- a/tests/e2e_tests/segmentation_workflow.py +++ b/tests/e2e_tests/segmentation_workflow.py @@ -40,8 +40,8 @@ with h5py.File(f"{project.seg_directory}/segmentation.h5", "r") as hf: masks = hf["labels"][:] - nucleus_ids = set(np.unique(masks[0])) - set([0]) - cytosol_ids = set(np.unique(masks[1])) - set([0]) + nucleus_ids = set(np.unique(masks[0])) - {0} + cytosol_ids = set(np.unique(masks[1])) - {0} classes = set( pd.read_csv(f"{project.seg_directory}/classes.csv", header=None)[0] .astype(project.DEFAULT_SEGMENTATION_DTYPE) From f02319069036054e7e1211cd30c2af74bd965aa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:08:46 +0200 Subject: [PATCH 05/12] activate rule errors detected by Pyflakes and fix linting issues --- pyproject.toml | 2 +- src/scportrait/pipeline/classification.py | 4 +--- src/scportrait/pipeline/segmentation/segmentation.py | 2 +- src/scportrait/tools/ml/datasets.py | 2 +- src/scportrait/tools/parse/_parse_phenix.py | 8 ++++---- src/scportrait/tools/stitch/_stitch.py | 4 ++-- tests/e2e_tests/segmentation_workflow.py | 2 +- 7 files changed, 11 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 68ffdfc9..481942d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ docstring-code-format = true [tool.ruff.lint] select = [ - #"F", # Errors detected by Pyflakes + "F", # Errors detected by Pyflakes #"E", # Error detected by Pycodestyle #"W", # Warning detected by Pycodestyle "I", # isort diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index 4e8efb12..72d5e44c 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -1215,9 +1215,7 @@ def _write_results_sdata(self, results, mask_type="seg_all"): ) # define name to save table under - label = self.label.replace( - "CellFeaturizer_", "" - ) # remove class name from label to ensure we dont have duplicates + self.label.replace("CellFeaturizer_", "") # remove class name from label to ensure we dont have duplicates if self.channel_classification is not None: table_name = f"{self.__class__.__name__ }_{self.config['channel_classification']}_{self.MASK_NAMES[0]}" diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 8d174503..3b2aac63 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -861,7 +861,7 @@ def _perform_segmentation(self, shard_list): initializer=self._initializer_function, initargs=[self.gpu_id_list], ) as pool: - results = list( + list( tqdm( pool.imap(self.method._call_as_shard, shard_list), total=len(shard_list), diff --git a/src/scportrait/tools/ml/datasets.py b/src/scportrait/tools/ml/datasets.py index 9e99723d..441367c4 100644 --- a/src/scportrait/tools/ml/datasets.py +++ b/src/scportrait/tools/ml/datasets.py @@ -275,7 +275,7 @@ def _add_all_datasets(self, read_label_from_dataset: bool = False): current_index_list = self.index_list[i] # get current label - current_label = self._get_dataset_label(i) + self._get_dataset_label(i) # check if "directory" is a path to specific hdf5 filetype = directory.split(".")[-1] diff --git a/src/scportrait/tools/parse/_parse_phenix.py b/src/scportrait/tools/parse/_parse_phenix.py index 4ba09090..e3a27237 100644 --- a/src/scportrait/tools/parse/_parse_phenix.py +++ b/src/scportrait/tools/parse/_parse_phenix.py @@ -509,12 +509,12 @@ def sort_wells(self, sort_tiles=False): print("\t Tiles: ", tiles) # only print if these folders should be created # update metadata to include destination for each tile metadata["dest"] = [ - os.path.join(getattr(self, f"outdir_sorted_wells"), f"row{row}_well{well}", tile) + os.path.join(getattr(self, "outdir_sorted_wells"), f"row{row}_well{well}", tile) for row, well, tile in zip(metadata.Row, metadata.Well, metadata.tiles) ] else: metadata["dest"] = [ - os.path.join(getattr(self, f"outdir_sorted_wells"), f"row{row}_well{well}") + os.path.join(getattr(self, "outdir_sorted_wells"), f"row{row}_well{well}") for row, well in zip(metadata.Row, metadata.Well) ] @@ -566,12 +566,12 @@ def sort_timepoints(self, sort_wells=False): if sort_wells: # update metadata to include destination for each tile metadata["dest"] = [ - os.path.join(getattr(self, f"outdir_sorted_timepoints"), timepoint, f"{row}_{well}") + os.path.join(getattr(self, "outdir_sorted_timepoints"), timepoint, f"{row}_{well}") for row, well, timepoint in zip(metadata.Row, metadata.Well, metadata.Timepoint) ] else: metadata["dest"] = [ - os.path.join(getattr(self, f"outdir_sorted_timepoints"), timepoint) for timepoint in metadata.Timepoint + os.path.join(getattr(self, "outdir_sorted_timepoints"), timepoint) for timepoint in metadata.Timepoint ] # unique directories for each tile diff --git a/src/scportrait/tools/stitch/_stitch.py b/src/scportrait/tools/stitch/_stitch.py index dd7da9e1..e49c1843 100644 --- a/src/scportrait/tools/stitch/_stitch.py +++ b/src/scportrait/tools/stitch/_stitch.py @@ -306,8 +306,8 @@ def plot_qc(self): """ Plot quality control (QC) figures for the alignment. """ - fig = plot_edge_scatter(self.aligner, self.outdir) - fig = plot_edge_quality(self.aligner, self.outdir) + plot_edge_scatter(self.aligner, self.outdir) + plot_edge_quality(self.aligner, self.outdir) def perform_alignment(self): """ diff --git a/tests/e2e_tests/segmentation_workflow.py b/tests/e2e_tests/segmentation_workflow.py index 9092b0e2..9e123d56 100644 --- a/tests/e2e_tests/segmentation_workflow.py +++ b/tests/e2e_tests/segmentation_workflow.py @@ -13,7 +13,7 @@ if __name__ == "__main__": print(os.getcwd()) - project_location = f"example_data/example_4/benchmark" + project_location = "example_data/example_4/benchmark" project = Project( os.path.abspath(project_location), From 6713a3eaa4e43f6a98eff644238f9819f5a010f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:14:59 +0200 Subject: [PATCH 06/12] activate error detect rule by pycodestyle --- pyproject.toml | 2 +- src/scportrait/pipeline/extraction.py | 2 +- src/scportrait/pipeline/segmentation/workflows.py | 4 ++-- src/scportrait/tools/parse/_parse_phenix.py | 3 ++- src/scportrait/tools/stitch/_utils/filereaders.py | 9 ++++++--- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 481942d3..a3efe2c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ docstring-code-format = true [tool.ruff.lint] select = [ "F", # Errors detected by Pyflakes - #"E", # Error detected by Pycodestyle + "E", # Error detected by Pycodestyle #"W", # Warning detected by Pycodestyle "I", # isort #"D", # pydocstyle diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index e5c12edb..4817ca8b 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): Warning("Windows detected. Multithreading not supported on windows so setting threads to 1.") self.config["threads"] = 1 - if not "overwrite_run_path" in self.__dict__.keys(): + if "overwrite_run_path" not in self.__dict__.keys(): self.overwrite_run_path = self.overwrite def _get_compression_type(self): diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index fc75c393..5f2a8fa0 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -707,9 +707,9 @@ def _get_processing_parameters(self): ) # check that the normalization ranges are of the same type otherwise this will result in issues - assert type(self.lower_quantile_normalization_input_image) == type( + assert type(self.lower_quantile_normalization_input_image) == type( # noqa: E721 self.upper_quantile_normalization_input_image - ) + ) # these need to be the same types! So we need to circumvent the ruff linting rules here # check if median filtering is required if "median_filter_size" in self.config.keys(): diff --git a/src/scportrait/tools/parse/_parse_phenix.py b/src/scportrait/tools/parse/_parse_phenix.py index e3a27237..8da4d59d 100644 --- a/src/scportrait/tools/parse/_parse_phenix.py +++ b/src/scportrait/tools/parse/_parse_phenix.py @@ -399,7 +399,8 @@ def define_copy_functions(self): def copyfunction(input, output): try: os.symlink(input, output) - except: + except OSError as e: + print("Error: ", e) return () else: diff --git a/src/scportrait/tools/stitch/_utils/filereaders.py b/src/scportrait/tools/stitch/_utils/filereaders.py index 853ce961..bb95fa1c 100644 --- a/src/scportrait/tools/stitch/_utils/filereaders.py +++ b/src/scportrait/tools/stitch/_utils/filereaders.py @@ -27,11 +27,14 @@ def __init__( ): try: super().__init__(path, pattern, overlap, pixel_size=pixel_size) - except: + except (FileNotFoundError, IOError): print( f"Error: Could not read images with the given pattern {pattern}. Please check the path {path} and pattern." ) - print(f"At the provided location the following files could be found:{os.listdir(path)} ") + found_files = os.listdir(path) + print( + f"At the provided location the the files follow the naming convention:{found_files[0:max(5, len(found_files))]} " + ) self.do_rescale = do_rescale self.WGAchannel = WGAchannel @@ -145,7 +148,7 @@ def read(self, series, c): else: if c not in self.no_rescale_channel: # get rescale_range for channel c - if type(self.rescale_range) is dict: + if isinstance(self.rescale_range, dict): rescale_range = self.rescale_range[c] else: rescale_range = self.rescale_range From 197df05eba27a37b00c0fefa6744008333364105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:15:34 +0200 Subject: [PATCH 07/12] activate litning rule warnings detected by Pycodecstyle --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a3efe2c7..03c1164c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ docstring-code-format = true select = [ "F", # Errors detected by Pyflakes "E", # Error detected by Pycodestyle - #"W", # Warning detected by Pycodestyle + "W", # Warning detected by Pycodestyle "I", # isort #"D", # pydocstyle #"B", # flake8-bugbear From e7d41da9def7a0297a3693a743c938ebcd900e99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:25:02 +0200 Subject: [PATCH 08/12] activate flake8-bugbear rule --- pyproject.toml | 2 +- src/scportrait/io/convert.py | 20 +++++++++----- src/scportrait/pipeline/_base.py | 9 +++---- .../pipeline/_utils/segmentation.py | 4 ++- src/scportrait/pipeline/classification.py | 2 +- src/scportrait/pipeline/extraction.py | 3 ++- .../mask_filtering/filter_segmentation.py | 2 +- src/scportrait/pipeline/project.py | 11 +++++--- .../pipeline/segmentation/segmentation.py | 6 +++-- .../pipeline/segmentation/workflows.py | 6 ++++- src/scportrait/plotting/vis.py | 4 +-- .../processing/images/_image_processing.py | 2 +- src/scportrait/tools/ml/datasets.py | 2 +- src/scportrait/tools/ml/plmodels.py | 2 +- src/scportrait/tools/ml/transforms.py | 14 +++++++--- src/scportrait/tools/ml/utils.py | 10 +++---- src/scportrait/tools/parse/_parse_phenix.py | 27 ++++++++++--------- src/scportrait/tools/stitch/_stitch.py | 12 ++++++--- .../tools/stitch/_utils/ashlar_plotting.py | 4 +-- .../tools/stitch/_utils/filereaders.py | 2 +- .../tools/stitch/_utils/filewriters.py | 4 ++- src/scportrait/tools/stitch/_utils/graphs.py | 2 +- .../stitch/_utils/parallelized_ashlar.py | 4 +-- 23 files changed, 94 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03c1164c..ed8f6e20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ select = [ "W", # Warning detected by Pycodestyle "I", # isort #"D", # pydocstyle - #"B", # flake8-bugbear + "B", # flake8-bugbear #"TID", # flake8-tidy-imports "C4", # flake8-comprehensions "BLE", # flake8-blind-except diff --git a/src/scportrait/io/convert.py b/src/scportrait/io/convert.py index 12b7e565..44e93c37 100644 --- a/src/scportrait/io/convert.py +++ b/src/scportrait/io/convert.py @@ -127,7 +127,7 @@ def _load_input_image(self): self.log("Finished loading input image to memory mapped temp array.") - def write_input_image_to_spatialdata(self, scale_factors=[2, 4, 8]): + def write_input_image_to_spatialdata(self, scale_factors=None): """ Write the input image found under the label "channels" in the segmentation.h5 file to a spatialdata object. @@ -138,6 +138,8 @@ def write_input_image_to_spatialdata(self, scale_factors=[2, 4, 8]): """ # reconnect to temporary image as a dask array + if scale_factors is None: + scale_factors = [2, 4, 8] temp_image = daskmmap.dask_array_from_path(self.temp_image_path) if self.channel_names is None: @@ -186,7 +188,7 @@ def _load_segmentation(self): self.log("No segmentation found in project.") self.segmentation_status = False - def write_segmentation_to_spatialdata(self, scale_factors=[]): + def write_segmentation_to_spatialdata(self, scale_factors=None): """ Write the segmentation masks found under the label "labels" in the segmentation.h5 file to a spatialdata object. @@ -197,6 +199,8 @@ def write_segmentation_to_spatialdata(self, scale_factors=[]): In the future this behaviour may be changed but at the moment scPortrait is not designed to handle multiple resolutions for segmentation masks. """ + if scale_factors is None: + scale_factors = [] if self.segmentation_status is None: # reconnect to temporary image as a dask array temp_segmentation = daskmmap.dask_array_from_path(self.temp_segmentation_path) @@ -249,10 +253,14 @@ def _lookup_region_annotations(self): return region_lookup - def add_multiscale_segmentation(self, region_keys=["seg_all_nucleus", "seg_all_cytosol"], scale_factors=[2, 4, 8]): + def add_multiscale_segmentation(self, region_keys=None, scale_factors=None): """ Add multiscale segmentation to the spatialdata object. """ + if scale_factors is None: + scale_factors = [2, 4, 8] + if region_keys is None: + region_keys = ["seg_all_nucleus", "seg_all_cytosol"] region_lookup = self._lookup_region_annotations() sdata = SpatialData.read(self._get_sdata_path()) @@ -360,9 +368,9 @@ def _read_classification_results(self, classification_result): table = anndata.AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) return table - def write_classification_result_to_spatialdata( - self, classification_result, segmentation_regions=["seg_all_nucleus", "seg_all_cytosol"] - ): + def write_classification_result_to_spatialdata(self, classification_result, segmentation_regions=None): + if segmentation_regions is None: + segmentation_regions = ["seg_all_nucleus", "seg_all_cytosol"] class_result = self._read_classification_results(classification_result) sdata = SpatialData.read(self._get_sdata_path()) diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index afd01e19..bc29e98e 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -57,9 +57,8 @@ def log(self, message): lines = [str(message)] except (TypeError, ValueError): raise TypeError( - "Message must be a string, list of strings or a dictionary, but recieved type: ", type(message) - ) - return + "Message must be a string, list of strings or a dictionary, but received type: ", type(message) + ) from None for line in lines: log_path = os.path.join(self.directory, self.DEFAULT_LOG_NAME) @@ -217,7 +216,7 @@ def __call__(self, *args, debug=None, overwrite=None, **kwargs): return x else: self.clear_temp_dir() # also ensure clearing if not callable just to make sure everything is cleaned up - warnings.warn("no process method defined") + Warning("no process method defined.") def __call_empty__(self, *args, debug=None, overwrite=None, **kwargs): """Call the empty processing step. @@ -244,7 +243,7 @@ def __call_empty__(self, *args, debug=None, overwrite=None, **kwargs): x = self.return_empty_mask(*args, **kwargs) return x else: - warnings.warn("no return_empty_mask method defined") + Warning("no return_empty_mask method defined") # also clear empty temp directory here self.clear_temp_dir() diff --git a/src/scportrait/pipeline/_utils/segmentation.py b/src/scportrait/pipeline/_utils/segmentation.py index 76eda2a7..8fd8b16e 100644 --- a/src/scportrait/pipeline/_utils/segmentation.py +++ b/src/scportrait/pipeline/_utils/segmentation.py @@ -780,7 +780,7 @@ def _class_size(mask, debug=False, background=0): return mean_arr, length.flatten() -def size_filter(label, limits=[0, 100000], background=0, reindex=False): +def size_filter(label, limits=None, background=0, reindex=False): """ Filter classes in a labeled array based on their size (number of pixels). @@ -817,6 +817,8 @@ def size_filter(label, limits=[0, 100000], background=0, reindex=False): """ # Calculate the number of pixels for each class in the labeled array + if limits is None: + limits = [0, 100000] _, points_class = _class_size(label) # Find the classes with size below the lower limit and above the upper limit diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index 72d5e44c..28ebddf7 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -1020,7 +1020,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee ) # perform inference - for model_name, model in zip(self.model_names, self.model): + for model_name, model in zip(self.model_names, self.model, strict=False): self.log(f"Starting inference for model {model_name}") results = self.inference(self.dataloader, model) diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index 4817ca8b..7d030781 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -337,6 +337,7 @@ def _get_arg(self, cell_ids): range(len(cell_ids)), [self.save_index_lookup.index.get_loc(x) for x in cell_ids], cell_ids, + strict=False, ) ) return args @@ -620,7 +621,7 @@ def _transfer_tempmmap_to_hdf5(self): with h5py.File(self.output_path, "w") as hf: hf.create_dataset( "single_cell_index", - data=list(zip(list(range(len(cell_ids))), cell_ids)), + data=list(zip(list(range(len(cell_ids))), cell_ids, strict=False)), dtype=self.DEFAULT_SEGMENTATION_DTYPE, ) # increase to 64 bit otherwise information may become truncated diff --git a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py index ce697625..0ec74ee3 100644 --- a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py +++ b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py @@ -207,7 +207,7 @@ def initialize_tile_list_incomplete(self, tileing_plan, incomplete_indexes, inpu _tile_list = [] self.input_path = input_path - for i, window in zip(incomplete_indexes, tileing_plan): + for i, window in zip(incomplete_indexes, tileing_plan, strict=False): local_tile_directory = os.path.join(self.tile_directory, str(i)) current_tile = self.method( self.config, diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index da1b1c30..49c18276 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -102,7 +102,7 @@ def __init__( if not os.path.isdir(self.project_location): os.makedirs(self.project_location) else: - warnings.warn("There is already a directory in the location path") + Warning("There is already a directory in the location path") # === setup sdata reader/writer === self.filehandler = sdata_filehandler( @@ -437,7 +437,7 @@ def _write_image_sdata( image, image_name, channel_names=None, - scale_factors=[2, 4, 8], + scale_factors=None, chunks=(1, 1000, 1000), overwrite=False, ): @@ -452,6 +452,8 @@ def _write_image_sdata( List of scale factors for the image. Default is [2, 4, 8]. This will load the image at 4 different resolutions to allow for fluid visualization. """ + if scale_factors is None: + scale_factors = [2, 4, 8] if self.sdata is None: self._read_sdata() @@ -668,7 +670,7 @@ def load_input_from_tif_files( self, file_paths, channel_names=None, - crop=[(0, -1), (0, -1)], + crop=None, overwrite=None, remap=None, cache=None, @@ -692,6 +694,9 @@ def load_input_from_tif_files( """ + if crop is None: + crop = [(0, -1), (0, -1)] + def extract_unique_parts(paths: List[str]): """helper function to get unique channel names from filepaths diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 3b2aac63..f51a2840 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -261,7 +261,9 @@ def _save_segmentation(self, labels: np.array, classes: List) -> None: self.log("=== Finished segmentation of shard ===") - def _save_segmentation_sdata(self, labels, classes, masks=["nuclei", "cytosol"]): + def _save_segmentation_sdata(self, labels, classes, masks=None): + if masks is None: + masks = ["nuclei", "cytosol"] if self.is_shard: self._save_segmentation(labels, classes) else: @@ -604,7 +606,7 @@ def _cleanup_shards(self, sharding_plan, keep_plots=False): if keep_plots: self.log("Moving generated plots from shard directory to main directory.") - for i, window in enumerate(sharding_plan): + for i, _window in enumerate(sharding_plan): local_shard_directory = os.path.join(self.shard_directory, str(i)) for file in os.listdir(local_shard_directory): if file.endswith(tuple(file_identifiers_plots)): diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 5f2a8fa0..5fd4a920 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -340,12 +340,14 @@ def _median_correct_image(self, input_image, median_filter_size: int, debug: boo ##### Filtering Functions ##### # 1. Size Filtering - def _check_for_size_filtering(self, mask_types=["nucleus", "cytosol"]) -> None: + def _check_for_size_filtering(self, mask_types=None) -> None: """ Check if size filtering should be performed on the masks. If size filtering is turned on, the thresholds for filtering are loaded from the config file. """ + if mask_types is None: + mask_types = ["nucleus", "cytosol"] if "filter_masks_size" in self.config.keys(): self.filter_size = self.config["filter_masks_size"] else: @@ -677,6 +679,7 @@ def _get_processing_parameters(self): zip( range(self.input_image.shape[0]), self.config["lower_quantile_normalization"], + strict=False, ) ) else: @@ -698,6 +701,7 @@ def _get_processing_parameters(self): zip( range(self.input_image.shape[0]), self.config["upper_quantile_normalization"], + strict=False, ) ) else: diff --git a/src/scportrait/plotting/vis.py b/src/scportrait/plotting/vis.py index 67a52256..3d7a8485 100644 --- a/src/scportrait/plotting/vis.py +++ b/src/scportrait/plotting/vis.py @@ -47,7 +47,7 @@ def plot_image(array, size=(10, 10), save_name="", cmap="magma", return_fig=Fals plt.close() -def visualize_class(class_ids, seg_map, image, all_ids=None, return_fig=False, *args, **kwargs): +def visualize_class(class_ids, seg_map, image, all_ids=None, return_fig=False, **kwargs): """ Visualize specific classes in a segmentation map by highlighting them on top of a background image. @@ -92,7 +92,7 @@ def visualize_class(class_ids, seg_map, image, all_ids=None, return_fig=False, * vis_map = label2rgb(outmap, image=image, colors=["red", "blue"], alpha=0.4, bg_label=0) - fig = plot_image(vis_map, return_fig=True, *args, **kwargs) + fig = plot_image(vis_map, return_fig=True, **kwargs) if return_fig: return fig diff --git a/src/scportrait/processing/images/_image_processing.py b/src/scportrait/processing/images/_image_processing.py index 4de1c36f..7aefbac5 100644 --- a/src/scportrait/processing/images/_image_processing.py +++ b/src/scportrait/processing/images/_image_processing.py @@ -128,7 +128,7 @@ def percentile_normalization(im, lower_percentile=0.001, upper_percentile=0.999, im = _percentile_norm(im, lower_percentile, upper_percentile) elif len(im.shape) == 3: - for i, channel in enumerate(im): + for i, _channel in enumerate(im): im[i] = _percentile_norm(im[i], lower_percentile, upper_percentile) else: diff --git a/src/scportrait/tools/ml/datasets.py b/src/scportrait/tools/ml/datasets.py index 441367c4..24e20745 100644 --- a/src/scportrait/tools/ml/datasets.py +++ b/src/scportrait/tools/ml/datasets.py @@ -149,7 +149,7 @@ def _add_hdf_to_index( # generate identifiers for all single-cells # iterate over rows in index handle, i.e. over all cells - for current_target, row in zip(label_col, index_handle): + for current_target, row in zip(label_col, index_handle, strict=False): # append target, handle id, and row to data locator self.data_locator.append([current_target, handle_id] + list(row)) diff --git a/src/scportrait/tools/ml/plmodels.py b/src/scportrait/tools/ml/plmodels.py index 5b7c807f..8cedf4a8 100644 --- a/src/scportrait/tools/ml/plmodels.py +++ b/src/scportrait/tools/ml/plmodels.py @@ -535,7 +535,7 @@ def sample_plot(self, input_sample, output_sample, label_sample): rows = ["Cell Mask", "Nucleus Mask", "Nucleus", "TGOLN2-mCherry", "WGA", "CAE"] - for ax, row in zip(axs[:, 0], rows): + for ax, row in zip(axs[:, 0], rows, strict=False): ax.set_ylabel(row, rotation=0, size="large", ha="right") plt.subplots_adjust(wspace=0.1, hspace=0.1) diff --git a/src/scportrait/tools/ml/transforms.py b/src/scportrait/tools/ml/transforms.py index 074272fc..c66a0dd7 100644 --- a/src/scportrait/tools/ml/transforms.py +++ b/src/scportrait/tools/ml/transforms.py @@ -31,7 +31,9 @@ class GaussianNoise(object): Add gaussian noise to the input image. """ - def __init__(self, sigma=0.1, channels_to_exclude=[]): + def __init__(self, sigma=0.1, channels_to_exclude=None): + if channels_to_exclude is None: + channels_to_exclude = [] self.sigma = sigma self.channels = channels_to_exclude @@ -55,7 +57,11 @@ class GaussianBlur(object): Apply a gaussian blur to the input image. """ - def __init__(self, kernel_size=[1, 1, 1, 1, 5, 5, 7, 9], sigma=(0.1, 2), channels=[]): + def __init__(self, kernel_size=None, sigma=(0.1, 2), channels=None): + if channels is None: + channels = [] + if kernel_size is None: + kernel_size = [1, 1, 1, 1, 5, 5, 7, 9] self.kernel_size = kernel_size self.sigma = sigma self.channels = channels @@ -97,7 +103,9 @@ class ChannelSelector(object): select the channel used for prediction. """ - def __init__(self, channels=[0, 1, 2, 3, 4], num_channels=5): + def __init__(self, channels=None, num_channels=5): + if channels is None: + channels = [0, 1, 2, 3, 4] if not np.max(channels) < num_channels: raise ValueError("highest channel index exceeds channel numb") self.channels = channels diff --git a/src/scportrait/tools/ml/utils.py b/src/scportrait/tools/ml/utils.py index 7bdac769..cfa0fd06 100644 --- a/src/scportrait/tools/ml/utils.py +++ b/src/scportrait/tools/ml/utils.py @@ -51,8 +51,9 @@ def combine_datasets_balanced( # check to make sure we have more than one occurance of a dataset (otherwise it will throw an error) if np.sum(pd.Series(class_labels).value_counts() > 1) == 0: - for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction): - print(dataset, label, 1) + for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction, strict=False): + print(dataset, label, fraction) + train_size = floor(train_per_class) test_size = floor(test_per_class) val_size = floor(val_per_class) @@ -79,10 +80,7 @@ def combine_datasets_balanced( test_dataset.append(test) val_dataset.append(val) else: - for dataset, label, fraction in zip(list_of_datasets, class_labels, dataset_fraction): - # train_size = floor(train_per_class * fraction) - # test_size = floor(test_per_class * fraction) - # val_size = floor(val_per_class * fraction) + for dataset, fraction in zip(list_of_datasets, dataset_fraction, strict=False): train_size = int(np.round(train_per_class * fraction)) test_size = int(np.round(test_per_class * fraction)) val_size = int(np.round(val_per_class * fraction)) diff --git a/src/scportrait/tools/parse/_parse_phenix.py b/src/scportrait/tools/parse/_parse_phenix.py index 8da4d59d..877ccc5b 100644 --- a/src/scportrait/tools/parse/_parse_phenix.py +++ b/src/scportrait/tools/parse/_parse_phenix.py @@ -133,8 +133,8 @@ def read_phenix_xml(self, xml_path): if _get_child_name(child.tag) == "Images": images = root[i] - for i, image in enumerate(images): - for ix, child in enumerate(image): + for image in images: + for _ix, child in enumerate(image): tag = _get_child_name(child.tag) if tag == "Row": rows.append(child.text) @@ -167,7 +167,7 @@ def read_phenix_xml(self, xml_path): image_names = [] for row, col, field, plane, channel_id, timepoint, flim_id in zip( - rows, cols, fields, planes, channel_ids, timepoints, flim_ids + rows, cols, fields, planes, channel_ids, timepoints, flim_ids, strict=False ): image_names.append(f"r{row}c{col}f{field}p{plane}-ch{channel_id}sk{timepoint}fk1fl{flim_id}.tiff") @@ -178,7 +178,7 @@ def read_phenix_xml(self, xml_path): dates = [x.split("T")[0] for x in times] _times = [x.split("T")[1] for x in times] _times = [(x.split("+")[0].split(".")[0] + "+" + x.split("+")[1].replace(":", "")) for x in _times] - time_final = [x + " " + y for x, y in zip(dates, _times)] + time_final = [x + " " + y for x, y in zip(dates, _times, strict=False)] datetime_format = "%Y-%m-%d %H:%M:%S%z" time_unix = [datetime.strptime(x, datetime_format) for x in time_final] @@ -434,6 +434,7 @@ def copy_files(self, metadata): metadata.new_file_name.tolist(), metadata.source.tolist(), metadata.dest.tolist(), + strict=False, ), total=len(metadata.new_file_name.tolist()), desc="Copying files", @@ -463,7 +464,7 @@ def parse(self): metadata = self.generate_metadata() # set destination for copying - metadata["dest"] = getattr(self, "outdir_parsed_images") + metadata["dest"] = self.outdir_parsed_images # copy/link the images to their new names self.copy_files(metadata=metadata) @@ -510,13 +511,13 @@ def sort_wells(self, sort_tiles=False): print("\t Tiles: ", tiles) # only print if these folders should be created # update metadata to include destination for each tile metadata["dest"] = [ - os.path.join(getattr(self, "outdir_sorted_wells"), f"row{row}_well{well}", tile) - for row, well, tile in zip(metadata.Row, metadata.Well, metadata.tiles) + os.path.join(self.outdir_sorted_wells, f"row{row}_well{well}", tile) + for row, well, tile in zip(metadata.Row, metadata.Well, metadata.tiles, strict=False) ] else: metadata["dest"] = [ - os.path.join(getattr(self, "outdir_sorted_wells"), f"row{row}_well{well}") - for row, well in zip(metadata.Row, metadata.Well) + os.path.join(self.outdir_sorted_wells, f"row{row}_well{well}") + for row, well in zip(metadata.Row, metadata.Well, strict=False) ] # unique directories for each tile @@ -567,12 +568,12 @@ def sort_timepoints(self, sort_wells=False): if sort_wells: # update metadata to include destination for each tile metadata["dest"] = [ - os.path.join(getattr(self, "outdir_sorted_timepoints"), timepoint, f"{row}_{well}") - for row, well, timepoint in zip(metadata.Row, metadata.Well, metadata.Timepoint) + os.path.join(self.outdir_sorted_timepoints, timepoint, f"{row}_{well}") + for row, well, timepoint in zip(metadata.Row, metadata.Well, metadata.Timepoint, strict=False) ] else: metadata["dest"] = [ - os.path.join(getattr(self, "outdir_sorted_timepoints"), timepoint) for timepoint in metadata.Timepoint + os.path.join(self.outdir_sorted_timepoints, timepoint) for timepoint in metadata.Timepoint ] # unique directories for each tile @@ -640,7 +641,7 @@ def get_datasets_to_combine(self): dates_times = [re.search(pattern, file_name).group() for file_name in phenix_dirs] # Sort the file names based on the extracted date and time information - sorted_phenix_dirs = [file_name for _, file_name in sorted(zip(dates_times, phenix_dirs))] + sorted_phenix_dirs = [file_name for _, file_name in sorted(zip(dates_times, phenix_dirs, strict=False))] self.phenix_dirs = [f"{input_path}/{phenix_dir}" for phenix_dir in sorted_phenix_dirs] diff --git a/src/scportrait/tools/stitch/_stitch.py b/src/scportrait/tools/stitch/_stitch.py index e49c1843..769902d7 100644 --- a/src/scportrait/tools/stitch/_stitch.py +++ b/src/scportrait/tools/stitch/_stitch.py @@ -41,7 +41,7 @@ def __init__( rescale_range: tuple = (1, 99), channel_order: List[str] = None, reader_type=FilePatternReaderRescale, - orientation: dict = {"flip_x": False, "flip_y": True}, + orientation: dict = None, plot_QC: bool = True, overwrite: bool = False, cache: str = None, @@ -85,6 +85,8 @@ def __init__( cache : str, optional Directory to store temporary files during stitching (default is None). If set to none this directory will be created in the outdir. """ + if orientation is None: + orientation = {"flip_x": False, "flip_y": True} self.input_dir = input_dir self.slidename = slidename self.outdir = outdir @@ -457,7 +459,7 @@ def write_thumbnail(self): ) write_tif(filename, self.thumbnail) - def write_spatialdata(self, scale_factors=[2, 4, 8]): + def write_spatialdata(self, scale_factors=None): """ Write the assembled mosaic as a SpatialData object. @@ -468,6 +470,8 @@ def write_spatialdata(self, scale_factors=[2, 4, 8]): The scale factors are used to generate downsampled versions of the image for faster visualization at lower resolutions. """ + if scale_factors is None: + scale_factors = [2, 4, 8] filepath = os.path.join(self.outdir, f"{self.slidename}.spatialdata") # create spatialdata object @@ -540,10 +544,12 @@ def __init__( channel_order: list[str] = None, overwrite: bool = False, reader_type=FilePatternReaderRescale, - orientation={"flip_x": False, "flip_y": True}, + orientation=None, cache: str = None, threads: int = 20, ) -> None: + if orientation is None: + orientation = {"flip_x": False, "flip_y": True} super().__init__( input_dir, slidename, diff --git a/src/scportrait/tools/stitch/_utils/ashlar_plotting.py b/src/scportrait/tools/stitch/_utils/ashlar_plotting.py index 38f7a64f..4a502b5d 100644 --- a/src/scportrait/tools/stitch/_utils/ashlar_plotting.py +++ b/src/scportrait/tools/stitch/_utils/ashlar_plotting.py @@ -12,8 +12,6 @@ except ImportError: gtGraph = None -nx.graph.Graph - def draw_mosaic_image(ax, aligner, img, **kwargs): if img is None: @@ -133,7 +131,7 @@ def plot_edge_scatter(aligner, outdir, annotate=True): g.ax_joint.set_yscale("log") g.set_axis_labels("error", "shift") if annotate: - for pair, x, y in zip(aligner.neighbors_graph.edges, xdata, ydata): + for pair, x, y in zip(aligner.neighbors_graph.edges, xdata, ydata, strict=False): plt.annotate(str(pair), (x, y), alpha=0.1) plt.tight_layout() diff --git a/src/scportrait/tools/stitch/_utils/filereaders.py b/src/scportrait/tools/stitch/_utils/filereaders.py index bb95fa1c..56cceb68 100644 --- a/src/scportrait/tools/stitch/_utils/filereaders.py +++ b/src/scportrait/tools/stitch/_utils/filereaders.py @@ -102,7 +102,7 @@ def channel_map(self): for id in range(n_channels): channel_names.append(self._metadata.getChannelName(0, id)) - channel_map = dict(zip(list(range(n_channels)), channel_names)) + channel_map = dict(zip(list(range(n_channels)), channel_names, strict=False)) return channel_map diff --git a/src/scportrait/tools/stitch/_utils/filewriters.py b/src/scportrait/tools/stitch/_utils/filewriters.py index dd0417d5..934a0384 100644 --- a/src/scportrait/tools/stitch/_utils/filewriters.py +++ b/src/scportrait/tools/stitch/_utils/filewriters.py @@ -184,10 +184,12 @@ def write_spatialdata( image_path: str, image: np.array, channel_names: List[str] = None, - scale_factors: List[int] = [2, 4, 8], + scale_factors: List[int] = None, overwrite: bool = False, ): # check if the file exists and delete if overwrite is set to True + if scale_factors is None: + scale_factors = [2, 4, 8] if os.path.exists(image_path): if overwrite: shutil.rmtree(image_path) diff --git a/src/scportrait/tools/stitch/_utils/graphs.py b/src/scportrait/tools/stitch/_utils/graphs.py index 81af7ba6..6cb7ade2 100644 --- a/src/scportrait/tools/stitch/_utils/graphs.py +++ b/src/scportrait/tools/stitch/_utils/graphs.py @@ -62,7 +62,7 @@ def nx2gt(nxG): gtG = Graph(edge_list, eprops=[("weight", "float")], directed=nxG.is_directed()) vertices = gtG.get_vertices() - for missing_node in [x for x in node_list if x not in vertices]: + for _missing_node in [x for x in node_list if x not in vertices]: gtG.add_vertex() return gtG diff --git a/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py b/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py index ea50028b..34c3ccff 100644 --- a/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py +++ b/src/scportrait/tools/stitch/_utils/parallelized_ashlar.py @@ -107,7 +107,7 @@ def compute_threshold(self): random_state = np.random.RandomState() for i in range(n): # Limit tries to avoid infinite loop in pathological cases. - for current_try in range(max_tries): + for _current_try in range(max_tries): t1, t2 = random_state.randint(self.metadata.num_images, size=2) o1, o2 = random_state.randint(max_offset, size=2) # Check for non-overlapping strips and abort the retry loop. @@ -144,7 +144,7 @@ def register(t1, t2, offset1, offset2): # prepare arguments for executor args = [] - for (t1, t2), (offset1, offset2) in zip(pairs, offsets): + for (t1, t2), (offset1, offset2) in zip(pairs, offsets, strict=False): arg = (t1, t2, offset1, offset2) args.append(copy.deepcopy(arg)) From 9d253a97e1e527352354919a5a33d04b62a6310a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:25:44 +0200 Subject: [PATCH 09/12] activate rule tidy imports --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ed8f6e20..abbd6869 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ select = [ "I", # isort #"D", # pydocstyle "B", # flake8-bugbear - #"TID", # flake8-tidy-imports + "TID", # flake8-tidy-imports "C4", # flake8-comprehensions "BLE", # flake8-blind-except #"UP", # pyupgrade From 46d79f11d6341a7efed6c09f58e2c15983defa71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:31:46 +0200 Subject: [PATCH 10/12] activate pyupgrade rule to check typing --- pyproject.toml | 2 +- src/scportrait/pipeline/_base.py | 2 +- src/scportrait/pipeline/_utils/sdata_io.py | 7 +++-- .../pipeline/_utils/spatialdata_classes.py | 6 ++--- .../pipeline/_utils/spatialdata_helper.py | 12 ++++----- src/scportrait/pipeline/classification.py | 15 +++++------ .../mask_filtering/filter_segmentation.py | 10 +++---- src/scportrait/pipeline/project.py | 26 +++++++++---------- .../pipeline/segmentation/segmentation.py | 21 +++++++-------- .../pipeline/segmentation/workflows.py | 17 +++++------- src/scportrait/pipeline/selection.py | 7 +++-- src/scportrait/tools/ml/datasets.py | 17 +++++------- src/scportrait/tools/ml/models.py | 14 +++++----- src/scportrait/tools/ml/transforms.py | 10 +++---- src/scportrait/tools/parse/_parse_phenix.py | 10 +++---- src/scportrait/tools/stitch/_stitch.py | 5 ++-- .../tools/stitch/_utils/filereaders.py | 2 +- .../tools/stitch/_utils/filewriters.py | 13 +++++----- .../tools/stitch/_utils/parallelilzation.py | 2 +- tests/processing_test.py | 4 +-- 20 files changed, 91 insertions(+), 111 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index abbd6869..e3c923af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ select = [ "TID", # flake8-tidy-imports "C4", # flake8-comprehensions "BLE", # flake8-blind-except - #"UP", # pyupgrade + "UP", # pyupgrade "RUF100", # Report unused noqa directives "TCH", # Typing imports #"NPY", # Numpy specific rules diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index bc29e98e..30e49d2e 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -11,7 +11,7 @@ import torch -class Logable(object): +class Logable: """ Object which can create log entries. diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 52e7713d..ff51994d 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -1,6 +1,5 @@ import os import shutil -from typing import List, Tuple import datatree import xarray @@ -134,7 +133,7 @@ def _write_segmentation_sdata( segmentation, segmentation_label: str, classes: set = None, - chunks: Tuple[int, int] = (1000, 1000), + chunks: tuple[int, int] = (1000, 1000), overwrite: bool = False, ): transform_original = Identity() @@ -248,7 +247,7 @@ def _load_input_image_to_memmap(self, tmp_dir_abs_path: str, image=None): def _load_seg_to_memmap( self, - seg_name: List[str], + seg_name: list[str], tmp_dir_abs_path: str, ): """ @@ -259,7 +258,7 @@ def _load_seg_to_memmap( Parameters ---------- - seg_name : List[str] + seg_name : list[str] List of segmentation element names that should be loaded found in the sdata object. The segmentation elments need to have the same size. tmp_dir_abs_path : str diff --git a/src/scportrait/pipeline/_utils/spatialdata_classes.py b/src/scportrait/pipeline/_utils/spatialdata_classes.py index da717aa5..f5494d17 100644 --- a/src/scportrait/pipeline/_utils/spatialdata_classes.py +++ b/src/scportrait/pipeline/_utils/spatialdata_classes.py @@ -1,5 +1,5 @@ from functools import singledispatchmethod -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any from dask.array import Array as DaskArray from dask.array import unique as DaskUnique @@ -19,7 +19,7 @@ class spLabels2DModel(Labels2DModel): # add an additional attribute that always contains the unique classes in a labels image - attrs = AttrsSchema({"transform": Transform_s}, {"cell_ids": Set}) + attrs = AttrsSchema({"transform": Transform_s}, {"cell_ids": set}) def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -44,7 +44,7 @@ def _get_cell_ids(data, remove_background=True): return data @singledispatchmethod - def convert(self, data: Union[DataTree, DataArray], classes: set = None) -> Union[DataTree, DataArray]: + def convert(self, data: DataTree | DataArray, classes: set = None) -> DataTree | DataArray: """ """ raise ValueError(f"Unsupported data type: {type(data)}. Please use .convert() from Labels2DModel instead.") diff --git a/src/scportrait/pipeline/_utils/spatialdata_helper.py b/src/scportrait/pipeline/_utils/spatialdata_helper.py index fac0ca56..edc45607 100644 --- a/src/scportrait/pipeline/_utils/spatialdata_helper.py +++ b/src/scportrait/pipeline/_utils/spatialdata_helper.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Set, Tuple, Union - import datatree import numpy as np import pandas as pd @@ -24,7 +22,7 @@ def check_memory(item): return array_size < available_memory -def generate_region_annotation_lookuptable(sdata: SpatialData) -> Dict: +def generate_region_annotation_lookuptable(sdata: SpatialData) -> dict: """Generate a lookup table for the region annotation tables contained in a SpatialData object ordered according to the region they annotate. Parameters @@ -80,7 +78,7 @@ def remap_region_annotation_table(table: TableModel, region_name: str) -> TableM return table -def get_chunk_size(element: Union[datatree.DataTree, xarray.DataArray]) -> Union[Tuple, List[Tuple]]: +def get_chunk_size(element: datatree.DataTree | xarray.DataArray) -> tuple | list[tuple]: """Get the chunk size of the image data. Parameters @@ -148,8 +146,8 @@ def get_chunk_size(element: Union[datatree.DataTree, xarray.DataArray]) -> Union def rechunk_image( - element: Union[datatree.DataTree, xarray.DataArray], chunk_size: Tuple -) -> Union[datatree.DataTree, xarray.DataArray]: + element: datatree.DataTree | xarray.DataArray, chunk_size: tuple +) -> datatree.DataTree | xarray.DataArray: """ Rechunk the image data to the desired chunksize. This is useful for ensuring that the data is chunked in a regular manner. @@ -183,7 +181,7 @@ def rechunk_image( def make_centers_object( centers: np.ndarray, - ids: List, + ids: list, transformation: str, coordinate_system="global", ): diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index 28ebddf7..685f88dc 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -4,7 +4,6 @@ import shutil from contextlib import redirect_stdout from functools import partial as func_partial -from typing import List, Union import numpy as np import pandas as pd @@ -301,8 +300,8 @@ def _load_pretrained_model(self, model_name: str): def _load_model( self, ckpt_path, - hparams_path: Union[str, None] = None, - model_type: Union[str, None] = None, + hparams_path: str | None = None, + model_type: str | None = None, ) -> pl.LightningModule: """Load a model from a checkpoint file and transfer it to the inference device. @@ -371,14 +370,14 @@ def _load_model( def load_model( self, ckpt_path, - hparams_path: Union[str, None] = None, - model_type: Union[str, None] = None, + hparams_path: str | None = None, + model_type: str | None = None, ): model = self._load_model(ckpt_path, hparams_path, model_type) self._assign_model(model) ### Functions regarding dataloading and transforms #### - def configure_transforms(self, selected_transforms: List): + def configure_transforms(self, selected_transforms: list): self.transforms = transforms.Compose(selected_transforms) self.log(f"The following transforms were applied: {self.transforms}") @@ -387,7 +386,7 @@ def generate_dataloader( extraction_dir: str, selected_transforms: transforms.Compose = transforms.Compose([]), size: int = 0, - seed: Union[int, None] = 42, + seed: int | None = 42, dataset_class=HDF5SingleCellDataset, ) -> torch.utils.data.DataLoader: """Create a pytorch dataloader from the provided single-cell image dataset. @@ -1078,7 +1077,7 @@ def _generate_column_names( self, n_masks: int = 2, n_channels: int = 3, - channel_names: Union[List, None] = None, + channel_names: list | None = None, ) -> None: column_names = [] diff --git a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py index 0ec74ee3..10f3e6e7 100644 --- a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py +++ b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py @@ -113,14 +113,14 @@ def call_as_tile(self): try: self.log(f"Beginning filtering on tile in position [{self.window[0]}, {self.window[1]}]") super().__call__(input_image) - except (IOError, ValueError, RuntimeError) as e: + except (OSError, ValueError, RuntimeError) as e: self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) else: print(f"Tile in position [{self.window[0]}, {self.window[1]}] only contained zeroes.") try: super().__call_empty__(input_image) - except (IOError, ValueError, RuntimeError) as e: + except (OSError, ValueError, RuntimeError) as e: self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) @@ -245,7 +245,7 @@ def calculate_tileing_plan(self, mask_size): os.remove(tileing_plan_path) else: self.log("Reading existing tileing plan from file.") - with open(tileing_plan_path, "r") as f: + with open(tileing_plan_path) as f: _tileing_plan = [eval(line) for line in f.readlines()] return _tileing_plan @@ -305,7 +305,7 @@ def execute_tile_list(self, tile_list, n_cpu=None): def f(x): try: x.call_as_tile() - except (IOError, ValueError, RuntimeError) as e: + except (OSError, ValueError, RuntimeError) as e: self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) return x.get_output() @@ -368,7 +368,7 @@ def collect_results(self): output_image = np.zeros((c, y, x), dtype=np.uint16) classes = defaultdict(list) - with open(f"{self.directory}/window.csv", "r") as f: + with open(f"{self.directory}/window.csv") as f: _window_locations = [eval(line.strip()) for line in f.readlines()] self.log(f"Expecting {len(_window_locations)} tiles") diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 49c18276..b0d9806b 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- import os import re import shutil import tempfile import warnings from time import time -from typing import Dict, List, Union import dask.array as darray import datatree @@ -143,13 +141,13 @@ def _load_config_from_file(self, file_path): if not os.path.isfile(file_path): raise ValueError(f"Your config path {file_path} is invalid.") - with open(file_path, "r") as stream: + with open(file_path) as stream: try: self.config = yaml.safe_load(stream) except yaml.YAMLError as exc: print(exc) - def _get_config_file(self, config_path: Union[str, None] = None): + def _get_config_file(self, config_path: str | None = None): # load config file self.config_path = os.path.join(self.project_location, self.DEFAULT_CONFIG_NAME) @@ -487,7 +485,7 @@ def _write_image_sdata( #### Functions for getting elements from sdata object ##### - def _load_seg_to_memmap(self, seg_name: List[str], tmp_dir_abs_path: str): + def _load_seg_to_memmap(self, seg_name: list[str], tmp_dir_abs_path: str): """ Helper function to load segmentation masks from sdata to memory mapped temp arrays for faster access. Loading happens in a chunked manner to avoid memory issues. @@ -630,7 +628,7 @@ def _load_input_image_to_memmap(self, tmp_dir_abs_path: str): return path_input_image #### Functions to load input data #### - def load_input_from_array(self, array: np.ndarray, channel_names: List[str] = None, overwrite=None): + def load_input_from_array(self, array: np.ndarray, channel_names: list[str] = None, overwrite=None): # check if an input image was already loaded if so throw error if overwrite = False # setup overwrite @@ -697,7 +695,7 @@ def load_input_from_tif_files( if crop is None: crop = [(0, -1), (0, -1)] - def extract_unique_parts(paths: List[str]): + def extract_unique_parts(paths: list[str]): """helper function to get unique channel names from filepaths Parameters @@ -970,7 +968,7 @@ def load_input_from_sdata( #### Functions to perform processing #### - def segment(self, overwrite: Union[bool, None] = None): + def segment(self, overwrite: bool | None = None): # check to ensure a method has been assigned if self.segmentation_f is None: raise ValueError("No segmentation method defined") @@ -996,7 +994,7 @@ def segment(self, overwrite: Union[bool, None] = None): self.segmentation_f.overwrite = original_overwrite # reset to original value self.sdata = self.filehandler.get_sdata() # update - def complete_segmentation(self, overwrite: Union[bool, None] = None): + def complete_segmentation(self, overwrite: bool | None = None): # check to ensure a method has been assigned if self.segmentation_f is None: raise ValueError("No segmentation method defined") @@ -1021,7 +1019,7 @@ def complete_segmentation(self, overwrite: Union[bool, None] = None): self._check_sdata_status() self.segmentation_f.overwrite = original_overwrite # reset to original value - def extract(self, partial=False, n_cells=None, overwrite: Union[bool, None] = None): + def extract(self, partial=False, n_cells=None, overwrite: bool | None = None): if self.extraction_f is None: raise ValueError("No extraction method defined") @@ -1043,7 +1041,7 @@ def classify( n_cells=0, data_type="complete", partial_seed=None, - overwrite: Union[bool, None] = None, + overwrite: bool | None = None, ): if self.classification_f is None: raise ValueError("No classification method defined") @@ -1098,10 +1096,10 @@ def classify( def select( self, - cell_sets: List[Dict], - calibration_marker: Union[np.array, None] = None, + cell_sets: list[dict], + calibration_marker: np.array | None = None, segmentation_name: str = "seg_all_nucleus", - name: Union[str, None] = None, + name: str | None = None, ): """ Select specified classes using the defined selection method. diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index f51a2840..ad17c91e 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -6,9 +6,7 @@ import timeit import traceback from multiprocessing import current_process -from typing import List -import datatree import h5py import matplotlib.pyplot as plt import numpy as np @@ -18,7 +16,6 @@ from alphabase.io import tempmmap from dask.array.core import Array as daskArray from PIL import Image -from spatialdata import SpatialData from tqdm.auto import tqdm from scportrait.pipeline._base import ProcessingStep @@ -223,7 +220,7 @@ def _transform_input_image(self, input_image): input_image = input_image.data return input_image - def _save_segmentation(self, labels: np.array, classes: List) -> None: + def _save_segmentation(self, labels: np.array, classes: list) -> None: """Helper function to save the results of a segmentation to file when generating a segmentation of a shard. Args: @@ -306,20 +303,20 @@ def save_map(self, map_name): """ if self.maps[map_name] is None: - self.log("Error saving map {}, map is None".format(map_name)) + self.log(f"Error saving map {map_name}, map is None") else: map_index = list(self.maps.keys()).index(map_name) # check if map contains more than one channel (3, 1024, 1024) vs (1024, 1024) if len(self.maps[map_name].shape) > 2: for i, channel in enumerate(self.maps[map_name]): - channel_name = "{}_{}_{}_map".format(map_index, map_name, i) + channel_name = f"{map_index}_{map_name}_{i}_map" channel_path = os.path.join(self.directory, channel_name) if self.debug and self.PRINT_MAPS_ON_DEBUG: self.save_image(channel, save_name=channel_path) else: - channel_name = "{}_{}_map".format(map_index, map_name) + channel_name = f"{map_index}_{map_name}_map" channel_path = os.path.join(self.directory, channel_name) if self.debug and self.PRINT_MAPS_ON_DEBUG: @@ -414,7 +411,7 @@ def _call_as_shard(self): self.log(f"Segmentation of Shard with the slicing {self.window} finished") - def _save_classes(self, classes: List) -> None: + def _save_classes(self, classes: list) -> None: """Helper function to save classes to a file when generating a segmentation of a shard.""" # define path where classes should be saved filtered_path = os.path.join(self.directory, self.DEFAULT_CLASSES_FILE) @@ -485,7 +482,7 @@ def __init__(self, *args, **kwargs): if not hasattr(self, "method"): raise AttributeError("No Segmentation method defined, please set attribute ``method``") - def _calculate_sharding_plan(self, image_size) -> List: + def _calculate_sharding_plan(self, image_size) -> list: """Calculate the sharding plan for the given input image size.""" _sharding_plan = [] @@ -533,7 +530,7 @@ def _calculate_sharding_plan(self, image_size) -> List: return _sharding_plan - def _get_sharding_plan(self, overwrite, force_read: bool = False) -> List: + def _get_sharding_plan(self, overwrite, force_read: bool = False) -> list: # check if a sharding plan already exists sharding_plan_path = f"{self.directory}/sharding_plan.csv" @@ -544,7 +541,7 @@ def _get_sharding_plan(self, overwrite, force_read: bool = False) -> List: os.remove(sharding_plan_path) else: self.log("Reading existing sharding plan from file.") - with open(sharding_plan_path, "r") as f: + with open(sharding_plan_path) as f: sharding_plan = [eval(line) for line in f.readlines()] return sharding_plan @@ -669,7 +666,7 @@ def _resolve_sharding(self, sharding_plan): ) # check to make sure windows match - with open(f"{local_shard_directory}/window.csv", "r") as f: + with open(f"{local_shard_directory}/window.csv") as f: window_local = eval(f.read()) if window_local != window: diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 5fd4a920..c7ec48f7 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -3,7 +3,6 @@ import sys import time import timeit -from typing import List, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -265,8 +264,8 @@ def _rescale_downsampled_mask(self, mask: np.array, mask_name: str) -> np.array: def _normalize_image( self, input_image: np.array, - lower: Union[float, dict], - upper: Union[float, dict], + lower: float | dict, + upper: float | dict, debug: bool = False, ) -> np.array: # check that both are specified as the same type @@ -363,7 +362,7 @@ def _check_for_size_filtering(self, mask_types=None) -> None: setattr(self, f"{mask_type}_thresholds", thresholds) setattr(self, f"{mask_type}_confidence_interval", confidence_interval) - def _get_params_cellsize_filtering(self, type) -> Tuple[Union[Tuple[float], None], Union[float, None]]: + def _get_params_cellsize_filtering(self, type) -> tuple[tuple[float] | None, float | None]: absolute_filter_status = False if "min_size" in self.config[f"{type}_segmentation"].keys(): @@ -397,7 +396,7 @@ def _get_params_cellsize_filtering(self, type) -> Tuple[Union[Tuple[float], None def _perform_size_filtering( self, mask: np.array, - thresholds: Union[Tuple[float], None], + thresholds: tuple[float] | None, confidence_interval: float, mask_name: str, log: bool = True, @@ -524,7 +523,7 @@ def _perform_mask_matching_filtering( filtering_threshold: float, debug: bool = False, input_image: np.array = None, - ) -> Tuple[np.array, np.array]: + ) -> tuple[np.array, np.array]: """ Match the nuclei and cytosol masks to ensure that the same cells are present in both masks. @@ -603,9 +602,7 @@ def _perform_mask_matching_filtering( self._clear_cache(vars_to_delete=[mask_nuc, mask_cyto, plot_results]) self.log( - "Total time to perform nucleus and cytosol mask matching filtering: {:.2f} seconds".format( - time.time() - start_time - ) + f"Total time to perform nucleus and cytosol mask matching filtering: {time.time() - start_time:.2f} seconds" ) return masks_nucleus, masks_cytosol @@ -1151,7 +1148,7 @@ def _read_cellpose_model(self, modeltype: str, name: str, gpu: str, device) -> m model = models.CellposeModel(pretrained_model=name, gpu=gpu, device=device) return model - def _load_model(self, model_type: str, gpu: str, device) -> Tuple[float, models.Cellpose]: + def _load_model(self, model_type: str, gpu: str, device) -> tuple[float, models.Cellpose]: """ Loads cellpose model diff --git a/src/scportrait/pipeline/selection.py b/src/scportrait/pipeline/selection.py index 1fcaa09c..0afb36a8 100644 --- a/src/scportrait/pipeline/selection.py +++ b/src/scportrait/pipeline/selection.py @@ -1,5 +1,4 @@ import os -from typing import Dict, List, Union import numpy as np from alphabase.io import tempmmap @@ -40,7 +39,7 @@ def _setup_selection(self): savename = name.replace(" ", "_") + ".xml" self.savepath = os.path.join(self.directory, savename) - def _post_processing_cleanup(self, vars_to_delete: Union[List, None] = None): + def _post_processing_cleanup(self, vars_to_delete: list | None = None): if vars_to_delete is not None: self._clear_cache(vars_to_delete=vars_to_delete) @@ -53,9 +52,9 @@ def _post_processing_cleanup(self, vars_to_delete: Union[List, None] = None): def process( self, segmentation_name: str, - cell_sets: List[Dict], + cell_sets: list[dict], calibration_marker: np.array, - name: Union[str, None] = None, + name: str | None = None, ): """ Process function for selecting cells and generating their XML. diff --git a/src/scportrait/tools/ml/datasets.py b/src/scportrait/tools/ml/datasets.py index 24e20745..56a0a09b 100644 --- a/src/scportrait/tools/ml/datasets.py +++ b/src/scportrait/tools/ml/datasets.py @@ -1,9 +1,6 @@ import os from collections.abc import Iterable -# type checking functions -from typing import List, Union - import h5py import numpy as np import torch @@ -85,9 +82,9 @@ def __init__( def _add_hdf_to_index( self, path: str, - index_list: Union[List, None] = None, - label: Union[int, None] = None, - label_column: Union[int, None] = None, + index_list: list | None = None, + label: int | None = None, + label_column: int | None = None, dtype_label_column=float, label_column_transform=None, read_label: bool = False, @@ -167,7 +164,7 @@ def _add_hdf_to_index( def _add_dataset( self, path: str, - current_index_list: List, + current_index_list: list, id: int, read_label_from_dataset: bool, ): @@ -202,7 +199,7 @@ def _scan_directory( self, path: str, levels_left: int, - current_index_list: Union[List, None] = None, + current_index_list: list | None = None, read_label_from_dataset: bool = False, ): """ @@ -300,10 +297,10 @@ def stats(self): # print dataset statistics """Print dataset statistics.""" labels = [el[0] for el in self.data_locator] - print("Total: {}".format(len(labels))) + print(f"Total: {len(labels)}") for label in set(labels): - print("{}: {}".format(label, labels.count(label))) + print(f"{label}: {labels.count(label)}") def __len__(self): """get number of elements contained in the dataset""" diff --git a/src/scportrait/tools/ml/models.py b/src/scportrait/tools/ml/models.py index 4a74f74b..9f188f21 100644 --- a/src/scportrait/tools/ml/models.py +++ b/src/scportrait/tools/ml/models.py @@ -138,7 +138,7 @@ def __init__( num_classes=2, image_size_factor=2, ): - super(VGG1, self).__init__() + super().__init__() # save num_classes for use in making MLP head self.num_classes = num_classes @@ -171,7 +171,7 @@ def __init__( num_classes=2, image_size_factor=2, ): - super(VGG2, self).__init__() + super().__init__() # save num_classes for use in making MLP head self.num_classes = num_classes @@ -203,7 +203,7 @@ def __init__( num_classes=1, image_size_factor=2, ): - super(VGG2_regression, self).__init__() + super().__init__() self.num_classes = num_classes self.image_size_factor = image_size_factor self.norm = nn.BatchNorm2d(in_channels) @@ -275,7 +275,7 @@ def __init__( in_channels=5, out_channels=5, ): - super(CAEBase, self).__init__() + super().__init__() self.norm = nn.BatchNorm2d(in_channels) @@ -341,7 +341,7 @@ class VAEBase(nn.Module): """ def __init__(self, in_channels, out_channels, latent_dim, hidden_dims=None, **kwargs): - super(VAEBase, self).__init__() + super().__init__() self.latent_dim = latent_dim @@ -511,7 +511,7 @@ def __init__( in_channels=5, num_classes=2, ): - super(_VGG1, self).__init__() + super().__init__() self.norm = nn.BatchNorm2d(in_channels) @@ -614,7 +614,7 @@ def __init__( in_channels=5, num_classes=2, ): - super(_VGG2, self).__init__() + super().__init__() self.norm = nn.BatchNorm2d(in_channels) # self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) diff --git a/src/scportrait/tools/ml/transforms.py b/src/scportrait/tools/ml/transforms.py index c66a0dd7..b9769895 100644 --- a/src/scportrait/tools/ml/transforms.py +++ b/src/scportrait/tools/ml/transforms.py @@ -6,7 +6,7 @@ import torchvision.transforms.functional as TF -class RandomRotation(object): +class RandomRotation: """ Randomly rotate input image in 90 degree steps. """ @@ -26,7 +26,7 @@ def __call__(self, tensor): return TF.rotate(tensor, angle) -class GaussianNoise(object): +class GaussianNoise: """ Add gaussian noise to the input image. """ @@ -52,7 +52,7 @@ def __call__(self, tensor): return tensor -class GaussianBlur(object): +class GaussianBlur: """ Apply a gaussian blur to the input image. """ @@ -78,7 +78,7 @@ def __call__(self, tensor): return blur(tensor) -class ChannelReducer(object): +class ChannelReducer: """ can reduce an imaging dataset dataset to 5, 3 or 1 channel 5: nuclei_mask, cell_mask, channel_nucleus, channel_cellmask, channel_of_interest @@ -98,7 +98,7 @@ def __call__(self, tensor): return tensor -class ChannelSelector(object): +class ChannelSelector: """ select the channel used for prediction. """ diff --git a/src/scportrait/tools/parse/_parse_phenix.py b/src/scportrait/tools/parse/_parse_phenix.py index 877ccc5b..e6e2b30d 100644 --- a/src/scportrait/tools/parse/_parse_phenix.py +++ b/src/scportrait/tools/parse/_parse_phenix.py @@ -246,7 +246,7 @@ def generate_new_filenames(self, metadata): else: max_y = metadata.loc[((metadata.Well == well) & (metadata.Row == rows[0]))].Y_pos.max() metadata.loc[(metadata.Well == well) & (metadata.Row == row), "Y_pos"] = ( - metadata.loc[(metadata.Well == well) & (metadata.Row == row), "Y_pos"] + int(max_y) + int(1) + metadata.loc[(metadata.Well == well) & (metadata.Row == row), "Y_pos"] + int(max_y) + 1 ) metadata.loc[(metadata.Well == well) & (metadata.Row == row), "Row"] = rows[0] @@ -257,7 +257,7 @@ def generate_new_filenames(self, metadata): else: max_x = metadata.loc[(metadata.Well == wells[0])].X_pos.max() metadata.loc[(metadata.Well == well), "X_pos"] = ( - metadata.loc[(metadata.Well == well), "X_pos"] + int(max_x) + int(1) + metadata.loc[(metadata.Well == well), "X_pos"] + int(max_x) + 1 ) metadata.loc[(metadata.Well == well), "Well"] = wells[0] @@ -269,9 +269,7 @@ def generate_new_filenames(self, metadata): # generate new file names for i in range(metadata.shape[0]): _row = metadata.loc[i, :] - name = "Timepoint{}_Row{}_Well{}_{}_zstack{}_r{}_c{}.tif".format( - _row.Timepoint, _row.Row, _row.Well, _row.Channel, _row.Zstack, _row.Y_pos, _row.X_pos - ) + name = f"Timepoint{_row.Timepoint}_Row{_row.Row}_Well{_row.Well}_{_row.Channel}_zstack{_row.Zstack}_r{_row.Y_pos}_c{_row.X_pos}.tif" name = name metadata.loc[i, "new_file_name"] = name @@ -363,7 +361,7 @@ def _generate_missing_file_names(x_positions, y_positions, timepoint, row, well, else: # get size of missing images that need to be replaced image = imread(os.path.join(metadata["source"][0], metadata["filename"][0])) - image[:] = int(0) + image[:] = 0 self.black_image = image print(f"The found missing tiles need to be replaced with black images of the size {image.shape}.") diff --git a/src/scportrait/tools/stitch/_stitch.py b/src/scportrait/tools/stitch/_stitch.py index 769902d7..35855730 100644 --- a/src/scportrait/tools/stitch/_stitch.py +++ b/src/scportrait/tools/stitch/_stitch.py @@ -2,7 +2,6 @@ import shutil import sys from concurrent.futures import ThreadPoolExecutor -from typing import List, Union import numpy as np from alphabase.io.tempmmap import ( @@ -37,9 +36,9 @@ def __init__( overlap: float = 0.1, max_shift: float = 30, filter_sigma: int = 0, - do_intensity_rescale: Union[bool, str] = True, + do_intensity_rescale: bool | str = True, rescale_range: tuple = (1, 99), - channel_order: List[str] = None, + channel_order: list[str] = None, reader_type=FilePatternReaderRescale, orientation: dict = None, plot_QC: bool = True, diff --git a/src/scportrait/tools/stitch/_utils/filereaders.py b/src/scportrait/tools/stitch/_utils/filereaders.py index 56cceb68..40c059b2 100644 --- a/src/scportrait/tools/stitch/_utils/filereaders.py +++ b/src/scportrait/tools/stitch/_utils/filereaders.py @@ -27,7 +27,7 @@ def __init__( ): try: super().__init__(path, pattern, overlap, pixel_size=pixel_size) - except (FileNotFoundError, IOError): + except (OSError, FileNotFoundError): print( f"Error: Could not read images with the given pattern {pattern}. Please check the path {path} and pattern." ) diff --git a/src/scportrait/tools/stitch/_utils/filewriters.py b/src/scportrait/tools/stitch/_utils/filewriters.py index 934a0384..6a55ac0f 100644 --- a/src/scportrait/tools/stitch/_utils/filewriters.py +++ b/src/scportrait/tools/stitch/_utils/filewriters.py @@ -1,6 +1,5 @@ import os import shutil -from typing import List, Tuple import numpy as np import zarr @@ -33,12 +32,12 @@ def write_tif(image_path: str, image: np.array, dtype="uint16"): def write_ome_zarr( filepath: str, image: np.array, - channels: List[str], + channels: list[str], slidename: str, - channel_colors: List[str] = None, + channel_colors: list[str] = None, downscaling_size: int = 4, n_downscaling_layers: int = 4, - chunk_size: Tuple[int, int, int] = (1, 1024, 1024), + chunk_size: tuple[int, int, int] = (1, 1024, 1024), overwrite: bool = False, ): """write out an image as an OME-Zarr file compatible with napari @@ -120,7 +119,7 @@ def write_ome_zarr( write_image(image, group=group, axes=axes, storage_options={"chunks": chunk_size}, scaler=scaler) -def write_xml(image_paths: List[str], channels: List[str], slidename: str, outdir: str = None): +def write_xml(image_paths: list[str], channels: list[str], slidename: str, outdir: str = None): """Helper function to generate an XML for import of stitched .tifs into BIAS. Parameters @@ -183,8 +182,8 @@ def write_xml(image_paths: List[str], channels: List[str], slidename: str, outdi def write_spatialdata( image_path: str, image: np.array, - channel_names: List[str] = None, - scale_factors: List[int] = None, + channel_names: list[str] = None, + scale_factors: list[int] = None, overwrite: bool = False, ): # check if the file exists and delete if overwrite is set to True diff --git a/src/scportrait/tools/stitch/_utils/parallelilzation.py b/src/scportrait/tools/stitch/_utils/parallelilzation.py index b3141268..1f1ffab0 100644 --- a/src/scportrait/tools/stitch/_utils/parallelilzation.py +++ b/src/scportrait/tools/stitch/_utils/parallelilzation.py @@ -1,6 +1,6 @@ # helper functions for paralellization +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Callable from tqdm.auto import tqdm diff --git a/tests/processing_test.py b/tests/processing_test.py index 125a6769..059b646c 100644 --- a/tests/processing_test.py +++ b/tests/processing_test.py @@ -299,7 +299,7 @@ def test_plot_image(tmpdir): # Since this function does not return anything, we just check if it produces any exceptions try: plot_image(array, size=(5, 5), save_name=save_name) - except (ValueError, TypeError, IOError) as e: + except (OSError, ValueError, TypeError) as e: pytest.fail(f"plot_image raised exception: {str(e)}") assert os.path.isfile(str(save_name) + ".png") @@ -327,7 +327,7 @@ def test_logable_log(): log_path = os.path.join(temp_dir, logable.DEFAULT_LOG_NAME) assert os.path.isfile(log_path) - with open(log_path, "r") as f: + with open(log_path) as f: log_content = f.read() assert "Testing" in log_content From 2b9f4a7573025364b01236e5ba36534724634a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:35:52 +0200 Subject: [PATCH 11/12] activate numpy specific rules and fix any issues --- pyproject.toml | 2 +- src/scportrait/pipeline/extraction.py | 9 +++++---- tests/processing_test.py | 21 ++++++++++++++------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3c923af..fec996bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ select = [ "UP", # pyupgrade "RUF100", # Report unused noqa directives "TCH", # Typing imports - #"NPY", # Numpy specific rules + "NPY", # Numpy specific rules #"PTH" # Use pathlib ] ignore = [ diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index 7d030781..733d6186 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -306,9 +306,9 @@ def _get_classes_to_extract(self): ) # randomly sample n_cells from the centers - np.random.seed(self.seed) - chosen_ids = np.random.choice(list(range(len(self.centers_cell_ids))), self.n_cells, replace=False) - print(self.centers_cell_ids) + rng = np.random.default_rng(self.seed) + chosen_ids = rng.choice(list(range(len(self.centers_cell_ids))), self.n_cells, replace=False) + self.classes = self.centers_cell_ids[chosen_ids] self.px_centers = self.centers[chosen_ids] else: @@ -601,7 +601,8 @@ def _transfer_tempmmap_to_hdf5(self): n_cells = 100 n_cells_to_visualize = len(keep_index) // n_cells - random_indexes = np.random.choice(keep_index, n_cells_to_visualize, replace=False) + rng = np.random.default_rng() + random_indexes = rng.choice(keep_index, n_cells_to_visualize, replace=False) for index in random_indexes: stack = _tmp_single_cell_data[index] diff --git a/tests/processing_test.py b/tests/processing_test.py index 059b646c..c29bb918 100644 --- a/tests/processing_test.py +++ b/tests/processing_test.py @@ -228,14 +228,16 @@ def test_numba_mask_centroid(): def test_percentile_norm(): - img = np.random.rand(4, 4) + rng = np.random.default_rng() + img = rng.random((4, 4)) norm_img = _percentile_norm(img, 0.1, 0.9) assert np.min(norm_img) == pytest.approx(0) assert np.max(norm_img) == pytest.approx(1) def test_percentile_normalization_C_H_W(): - test_array = np.random.randint(2, size=(3, 100, 100)) + rng = np.random.default_rng() + test_array = rng.integers(2, size=(3, 100, 100)) test_array[:, 10:11, 10:11] = -1 test_array[:, 12:13, 12:13] = 3 @@ -245,7 +247,8 @@ def test_percentile_normalization_C_H_W(): def test_percentile_normalization_H_W(): - test_array = np.random.randint(2, size=(100, 100)) + rng = np.random.default_rng() + test_array = rng.integers(2, size=(100, 100)) test_array[10:11, 10:11] = -1 test_array[12:13, 12:13] = 3 @@ -255,13 +258,15 @@ def test_percentile_normalization_H_W(): def test_rolling_window_mean(): - array = np.random.rand(10, 10) + rng = np.random.default_rng() + array = rng.random((10, 10)) rolling_array = rolling_window_mean(array, size=5, scaling=False) assert np.all(array.shape == rolling_array.shape) def test_MinMax(): - array = np.random.rand(10, 10) + rng = np.random.default_rng() + array = rng.random((10, 10)) normalized_array = MinMax(array) assert np.min(normalized_array) == 0 assert np.max(normalized_array) == 1 @@ -284,7 +289,8 @@ def test_flatten(): def test_visualize_class(): class_ids = [1, 2] seg_map = np.array([[0, 1, 0], [1, 2, 1], [2, 0, 1]]) - background = np.random.random((3, 3)) * 255 + rng = np.random.default_rng() + background = rng.random((3, 3)) * 255 # Since this function does not return anything, we just check if it produces any exceptions try: visualize_class(class_ids, seg_map, background) @@ -293,7 +299,8 @@ def test_visualize_class(): def test_plot_image(tmpdir): - array = np.random.rand(10, 10) + rng = np.random.default_rng() + array = rng.random((10, 10)) save_name = tmpdir.join("test_plot_image") # Since this function does not return anything, we just check if it produces any exceptions From c120cc2662d69dbb44c35eec98bf169eee457c94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:55:09 +0200 Subject: [PATCH 12/12] fix numba error resulting from linting --- src/scportrait/pipeline/_utils/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scportrait/pipeline/_utils/segmentation.py b/src/scportrait/pipeline/_utils/segmentation.py index 8fd8b16e..977e520f 100644 --- a/src/scportrait/pipeline/_utils/segmentation.py +++ b/src/scportrait/pipeline/_utils/segmentation.py @@ -373,7 +373,7 @@ def _return_edge_labels_2d(input_map): .union(set(last_column.flatten())) ) - full_union = {np.uint64(i) for i in full_union} + full_union = set([np.uint64(i) for i in full_union]) # noqa: C403 set comprehensions are not supported by numba full_union.discard(0) return list(full_union)