Skip to content

Commit

Permalink
(fix): fix ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Oct 14, 2024
1 parent 861e60d commit 57a4e6c
Show file tree
Hide file tree
Showing 16 changed files with 61 additions and 71 deletions.
21 changes: 12 additions & 9 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import MutableMapping, Sequence
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar, Union
from typing import TYPE_CHECKING, Generic, TypeVar

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -33,10 +33,10 @@
from .raw import Raw


OneDIdx = Union[Sequence[int], Sequence[bool], slice]
OneDIdx = Sequence[int] | Sequence[bool] | slice
TwoDIdx = tuple[OneDIdx, OneDIdx]
# TODO: pd.DataFrame only allowed in AxisArrays?
Value = Union[pd.DataFrame, spmatrix, np.ndarray]
Value = pd.DataFrame | spmatrix | np.ndarray

P = TypeVar("P", bound="AlignedMappingBase")
"""Parent mapping an AlignedView is based on."""
Expand Down Expand Up @@ -376,9 +376,14 @@ class PairwiseArraysView(AlignedView[PairwiseArraysBase, OneDIdx], PairwiseArray
PairwiseArraysBase._actual_class = PairwiseArrays


AlignedMapping = Union[
AxisArrays, AxisArraysView, Layers, LayersView, PairwiseArrays, PairwiseArraysView
]
AlignedMapping = (
AxisArrays
| AxisArraysView
| Layers
| LayersView
| PairwiseArrays
| PairwiseArraysView
)
T = TypeVar("T", bound=AlignedMapping)
"""Pair of types to be aligned."""

Expand Down Expand Up @@ -408,9 +413,7 @@ def fget(self) -> Callable[[], None]:

def fake(): ...

fake.__annotations__ = {
"return": Union[self.cls._actual_class, self.cls._view_class]
}
fake.__annotations__ = {"return": self.cls._actual_class | self.cls._view_class}
return fake

def __get__(self, obj: None | AnnData, objtype: type | None = None) -> T:
Expand Down
14 changes: 6 additions & 8 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,12 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index):
"that is, you cannot make a view of a view."
)
self._is_view = True
if isinstance(oidx, (int, np.integer)):
if isinstance(oidx, int | np.integer):
if not (-adata_ref.n_obs <= oidx < adata_ref.n_obs):
raise IndexError(f"Observation index `{oidx}` is out of range.")
oidx += adata_ref.n_obs * (oidx < 0)
oidx = slice(oidx, oidx + 1, 1)
if isinstance(vidx, (int, np.integer)):
if isinstance(vidx, int | np.integer):
if not (-adata_ref.n_vars <= vidx < adata_ref.n_vars):
raise IndexError(f"Variable index `{vidx}` is out of range.")
vidx += adata_ref.n_vars * (vidx < 0)
Expand Down Expand Up @@ -406,7 +406,7 @@ def _init_as_actual(
# as in readwrite.read_10x_h5
if X.dtype != np.dtype(dtype):
X = X.astype(dtype)
elif isinstance(X, (ZarrArray, DaskArray)):
elif isinstance(X, ZarrArray | DaskArray):
X = X.astype(dtype)
else: # is np.ndarray or a subclass, convert to true np.ndarray
X = np.asarray(X, dtype)
Expand Down Expand Up @@ -763,16 +763,14 @@ def _prep_dim_index(self, value, attr: str) -> pd.Index:
raise ValueError(
f"Length of passed value for {attr}_names is {len(value)}, but this AnnData has shape: {self.shape}"
)
if isinstance(value, pd.Index) and not isinstance(
value.name, (str, type(None))
):
if isinstance(value, pd.Index) and not isinstance(value.name, str | type(None)):
raise ValueError(
f"AnnData expects .{attr}.index.name to be a string or None, "
f"but you passed a name of type {type(value.name).__name__!r}"
)
else:
value = pd.Index(value)
if not isinstance(value.name, (str, type(None))):
if not isinstance(value.name, str | type(None)):
value.name = None
if (
len(value) > 0
Expand Down Expand Up @@ -1976,7 +1974,7 @@ def chunk_X(
if isinstance(select, int):
select = select if select < self.n_obs else self.n_obs
choice = np.random.choice(self.n_obs, select, replace)
elif isinstance(select, (np.ndarray, Sequence)):
elif isinstance(select, np.ndarray | Sequence):
choice = np.asarray(select)
else:
raise ValueError("select should be int or array")
Expand Down
10 changes: 5 additions & 5 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,25 @@ def name_idx(i):
stop = None if stop is None else stop + 1
step = indexer.step
return slice(start, stop, step)
elif isinstance(indexer, (np.integer, int)):
elif isinstance(indexer, np.integer | int):
return indexer
elif isinstance(indexer, str):
return index.get_loc(indexer) # int
elif isinstance(
indexer, (Sequence, np.ndarray, pd.Index, spmatrix, np.matrix, SpArray)
indexer, Sequence | np.ndarray | pd.Index | spmatrix | np.matrix | SpArray
):
if hasattr(indexer, "shape") and (
(indexer.shape == (index.shape[0], 1))
or (indexer.shape == (1, index.shape[0]))
):
if isinstance(indexer, (spmatrix, SpArray)):
if isinstance(indexer, spmatrix | SpArray):
indexer = indexer.toarray()
indexer = np.ravel(indexer)
if not isinstance(indexer, (np.ndarray, pd.Index)):
if not isinstance(indexer, np.ndarray | pd.Index):
indexer = np.array(indexer)
if len(indexer) == 0:
indexer = indexer.astype(int)
if issubclass(indexer.dtype.type, (np.integer, np.floating)):
if issubclass(indexer.dtype.type, np.integer | np.floating):
return indexer # Might not work for range indexes
elif issubclass(indexer.dtype.type, np.bool_):
if indexer.shape != index.shape:
Expand Down
14 changes: 7 additions & 7 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def equal_sparse(a, b) -> bool:

xp = array_api_compat.array_namespace(a.data)

if isinstance(b, (CupySparseMatrix, sparse.spmatrix, SpArray)):
if isinstance(b, CupySparseMatrix | sparse.spmatrix | SpArray):
if isinstance(a, CupySparseMatrix):
# Comparison broken for CSC matrices
# https://github.com/cupy/cupy/issues/7757
Expand Down Expand Up @@ -206,7 +206,7 @@ def equal_awkward(a, b) -> bool:


def as_sparse(x, use_sparse_array=False):
if not isinstance(x, (sparse.spmatrix, SpArray)):
if not isinstance(x, sparse.spmatrix | SpArray):
if CAN_USE_SPARSE_ARRAY and use_sparse_array:
return sparse.csr_array(x)
return sparse.csr_matrix(x)
Expand Down Expand Up @@ -536,7 +536,7 @@ def apply(self, el, *, axis, fill_value=None):
return el
if isinstance(el, pd.DataFrame):
return self._apply_to_df(el, axis=axis, fill_value=fill_value)
elif isinstance(el, (sparse.spmatrix, SpArray, CupySparseMatrix)):
elif isinstance(el, sparse.spmatrix | SpArray | CupySparseMatrix):
return self._apply_to_sparse(el, axis=axis, fill_value=fill_value)
elif isinstance(el, AwkArray):
return self._apply_to_awkward(el, axis=axis, fill_value=fill_value)
Expand Down Expand Up @@ -723,7 +723,7 @@ def default_fill_value(els):
This is largely due to backwards compat, and might not be the ideal solution.
"""
if any(isinstance(el, (sparse.spmatrix, SpArray)) for el in els):
if any(isinstance(el, sparse.spmatrix | SpArray) for el in els):
return 0
else:
return np.nan
Expand Down Expand Up @@ -794,7 +794,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
import cupyx.scipy.sparse as cpsparse

if not all(
isinstance(a, (CupySparseMatrix, CupyArray)) or 0 in a.shape for a in arrays
isinstance(a, CupySparseMatrix | CupyArray) or 0 in a.shape for a in arrays
):
raise NotImplementedError(
"Cannot concatenate a cupy array with other array types."
Expand All @@ -821,7 +821,7 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
],
axis=axis,
)
elif any(isinstance(a, (sparse.spmatrix, SpArray)) for a in arrays):
elif any(isinstance(a, sparse.spmatrix | SpArray) for a in arrays):
sparse_stack = (sparse.vstack, sparse.hstack)[axis]
use_sparse_array = any(issubclass(type(a), SpArray) for a in arrays)
return sparse_stack(
Expand Down Expand Up @@ -980,7 +980,7 @@ def concat_pairwise_mapping(
els = [
m.get(k, sparse_class((s, s), dtype=bool)) for m, s in zip(mappings, shapes)
]
if all(isinstance(el, (CupySparseMatrix, CupyArray)) for el in els):
if all(isinstance(el, CupySparseMatrix | CupyArray) for el in els):
result[k] = _cp_block_diag(els, format="csr")
elif all(isinstance(el, DaskArray) for el in els):
result[k] = _dask_block_diag(els)
Expand Down
8 changes: 4 additions & 4 deletions src/anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
# construct manually
if adata.isbacked == (X is None):
# Move from GPU to CPU since it's large and not always used
if isinstance(X, (CupyArray, CupySparseMatrix)):
if isinstance(X, CupyArray | CupySparseMatrix):
self._X = X.get()
else:
self._X = X
Expand All @@ -51,7 +51,7 @@ def __init__(
self.varm = varm
elif X is None: # construct from adata
# Move from GPU to CPU since it's large and not always used
if isinstance(adata.X, (CupyArray, CupySparseMatrix)):
if isinstance(adata.X, CupyArray | CupySparseMatrix):
self._X = adata.X.get()
else:
self._X = adata.X.copy()
Expand Down Expand Up @@ -124,9 +124,9 @@ def __getitem__(self, index):
oidx, vidx = self._normalize_indices(index)

# To preserve two dimensional shape
if isinstance(vidx, (int, np.integer)):
if isinstance(vidx, int | np.integer):
vidx = slice(vidx, vidx + 1, 1)
if isinstance(oidx, (int, np.integer)):
if isinstance(oidx, int | np.integer):
oidx = slice(oidx, oidx + 1, 1)

if not self._adata.isbacked:
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _get_group_format(group: GroupStorageType) -> str:
def is_sparse_indexing_overridden(format: Literal["csr", "csc"], row, col):
major_indexer, minor_indexer = (row, col) if format == "csr" else (col, row)
return isinstance(minor_indexer, slice) and (
(isinstance(major_indexer, (int, np.integer)))
(isinstance(major_indexer, int | np.integer))
or (isinstance(major_indexer, slice))
or (isinstance(major_indexer, np.ndarray) and major_indexer.ndim == 1)
)
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/h5ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def write_h5ad(
f.attrs.setdefault("encoding-version", "0.1.0")

if "X" in as_dense and isinstance(
adata.X, (sparse.spmatrix, BaseCompressedSparseDataset)
adata.X, sparse.spmatrix | BaseCompressedSparseDataset
):
write_sparse_as_dense(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
elif not (adata.isbacked and Path(adata.filename) == Path(filepath)):
# If adata.isbacked, X should already be up to date
write_elem(f, "X", adata.X, dataset_kwargs=dataset_kwargs)
if "raw/X" in as_dense and isinstance(
adata.raw.X, (sparse.spmatrix, BaseCompressedSparseDataset)
adata.raw.X, sparse.spmatrix | BaseCompressedSparseDataset
):
write_sparse_as_dense(
f, "raw/X", adata.raw.X, dataset_kwargs=dataset_kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/_io/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def read_text(
dtype
Numpy data type.
"""
if not isinstance(filename, (PathLike, str, bytes)):
if not isinstance(filename, PathLike | str | bytes):
return _read_text(filename, delimiter, first_column_names, dtype)

filename = Path(filename)
Expand Down
7 changes: 4 additions & 3 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from __future__ import annotations

from functools import wraps
from itertools import pairwise
from typing import TYPE_CHECKING, cast
from warnings import warn

import h5py
from packaging.version import Version

from .._core.sparse_dataset import BaseCompressedSparseDataset
from ..compat import add_note, pairwise
from ..compat import add_note

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal, Union
from typing import Literal

from .._types import StorageType
from ..compat import H5Group, ZarrGroup

Storage = Union[StorageType, BaseCompressedSparseDataset]
Storage = StorageType | BaseCompressedSparseDataset

# For allowing h5py v3
# https://github.com/scverse/anndata/issues/442
Expand Down
22 changes: 5 additions & 17 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from importlib.util import find_spec
from inspect import Parameter, signature
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, Union
from typing import TYPE_CHECKING, TypeVar
from warnings import warn

import h5py
Expand Down Expand Up @@ -46,8 +46,8 @@ class Empty:
pass


Index1D = Union[slice, int, str, np.int64, np.ndarray]
Index = Union[Index1D, tuple[Index1D, Index1D], scipy.sparse.spmatrix, SpArray]
Index1D = slice | int | str | np.int64 | np.ndarray
Index = Index1D | tuple[Index1D, Index1D] | scipy.sparse.spmatrix | SpArray
H5Group = h5py.Group
H5Array = h5py.Dataset
H5File = h5py.File
Expand Down Expand Up @@ -75,18 +75,6 @@ def __exit__(self, *_exc_info) -> None:
os.chdir(self._old_cwd.pop())


if sys.version_info >= (3, 10):
from itertools import pairwise
else:

def pairwise(iterable):
from itertools import tee

a, b = tee(iterable)
next(b, None)
return zip(a, b)


#############################
# Optional deps
#############################
Expand Down Expand Up @@ -319,7 +307,7 @@ def _clean_uns(adata: AnnData): # noqa: F821
continue
name = cats_name.replace("_categories", "")
# fix categories with a single category
if isinstance(cats, (str, int)):
if isinstance(cats, str | int):
cats = [cats]
for ann in [adata.obs, adata.var]:
if name not in ann:
Expand All @@ -344,7 +332,7 @@ def _move_adj_mtx(d):
for k in ("distances", "connectivities"):
if (
(k in n)
and isinstance(n[k], (scipy.sparse.spmatrix, np.ndarray))
and isinstance(n[k], scipy.sparse.spmatrix | np.ndarray)
and len(n[k].shape) == 2
):
warn(
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _write_concat_sequence(
)
write_elem(output_group, output_path, df)
elif all(
isinstance(a, (pd.DataFrame, BaseCompressedSparseDataset, H5Array, ZarrArray))
isinstance(a, pd.DataFrame | BaseCompressedSparseDataset | H5Array | ZarrArray)
for a in arrays
):
_write_concat_arrays(
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/experimental/multi_files/_anncollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from collections.abc import Callable, Mapping
from functools import reduce
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -584,7 +584,7 @@ def attrs_keys(self):


DictCallable = dict[str, Callable]
ConvertType = Union[Callable, dict[str, Callable | DictCallable]]
ConvertType = Callable | dict[str, Callable | DictCallable]


class AnnCollection(_ConcatViewMixin, _IterateViewMixin):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_backed_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def test_backed_raw_subset(tmp_path, array_type, subset_func, subset_func2):
var_idx = subset_func2(mem_adata.var_names)
if (
array_type is asarray
and isinstance(obs_idx, (list, np.ndarray, sparse.spmatrix, SpArray))
and isinstance(var_idx, (list, np.ndarray, sparse.spmatrix, SpArray))
and isinstance(obs_idx, list | np.ndarray | sparse.spmatrix | SpArray)
and isinstance(var_idx, list | np.ndarray | sparse.spmatrix | SpArray)
):
pytest.xfail(
"Fancy indexing does not work with multiple arrays on a h5py.Dataset"
Expand Down
Loading

0 comments on commit 57a4e6c

Please sign in to comment.