Skip to content

Commit

Permalink
Fix io.Raster.__array__() for NumPy>=2 (#90)
Browse files Browse the repository at this point in the history
The new NumPy major version makes some interface changes to the
`__array__` protocol: https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword

The `__array__` method now must take `dtype=None` and `copy=None`
parameters. This change is backwards-compatible with older NumPy
versions.

`copy=False` should always fail for `Raster`, since converting a
`Raster` to an array always involves creating a new array (rather than
creating a view of an existing array).

I would've preferred to make both parameters keyword-only, but the
`dtype` parameter seems to be passed positionally by `np.asarray`, at
least with my current NumPy version (v2.1.1), so it seems best to just
allow both parameters to be positional.
  • Loading branch information
gmgunter authored Sep 25, 2024
1 parent 3769e67 commit e0a30f7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/snaphu/io/_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,15 @@ def close(self) -> None:
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
self.close()

def __array__(self) -> np.ndarray:
return self.dataset.read(self.band)
def __array__(
self, dtype: DTypeLike | None = None, copy: bool | None = None
) -> np.ndarray:
if not copy and (copy is not None):
errmsg = "unable to avoid copy while creating an array as requested"
raise ValueError(errmsg)

data = self.dataset.read(self.band)
return data if (dtype is None) else data.astype(dtype)

def _window_from_slices(
self, key: slice | tuple[slice, ...]
Expand Down
20 changes: 20 additions & 0 deletions test/io/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def has_rasterio() -> bool:
return importlib.util.find_spec("rasterio") is not None


def numpy_version() -> np.lib.NumpyVersion:
"""Get the version identifier of the imported NumPy package."""
return np.lib.NumpyVersion(np.__version__)


@contextmanager
def make_geotiff_raster(
fp: str | os.PathLike[str],
Expand Down Expand Up @@ -226,6 +231,21 @@ def test_arraylike(self, geotiff_raster: snaphu.io.Raster):
assert arr.shape == geotiff_raster.shape
assert arr.dtype == geotiff_raster.dtype

@pytest.mark.parametrize("dtype", [np.int32, np.float64])
def test_asarray_dtype(self, geotiff_raster: snaphu.io.Raster, dtype: DTypeLike):
arr = np.asarray(geotiff_raster, dtype=dtype)
assert arr.dtype == dtype

@pytest.mark.skipif(numpy_version() < "2.0.0", reason="requires numpy>=2")
def test_asarray_copy(self, geotiff_raster: snaphu.io.Raster):
# Check that these statements don't raise exceptions.
np.asarray(geotiff_raster, copy=True)
np.asarray(geotiff_raster, copy=None)

regex = "^unable to avoid copy while creating an array as requested$"
with pytest.raises(ValueError, match=regex):
np.asarray(geotiff_raster, copy=False)

def test_setitem_getitem_roundtrip(self, geotiff_raster: snaphu.io.Raster):
data = np.arange(20, dtype=np.int32).reshape(4, 5)
idx = np.s_[100:104, 200:205]
Expand Down

0 comments on commit e0a30f7

Please sign in to comment.