diff --git a/src/snaphu/io/_raster.py b/src/snaphu/io/_raster.py index 567dd61..13aa7ed 100644 --- a/src/snaphu/io/_raster.py +++ b/src/snaphu/io/_raster.py @@ -267,8 +267,15 @@ def close(self) -> None: def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def] self.close() - def __array__(self) -> np.ndarray: - return self.dataset.read(self.band) + def __array__( + self, dtype: DTypeLike | None = None, copy: bool | None = None + ) -> np.ndarray: + if not copy and (copy is not None): + errmsg = "unable to avoid copy while creating an array as requested" + raise ValueError(errmsg) + + data = self.dataset.read(self.band) + return data if (dtype is None) else data.astype(dtype) def _window_from_slices( self, key: slice | tuple[slice, ...] diff --git a/test/io/test_raster.py b/test/io/test_raster.py index 468d5cb..50c9bb9 100644 --- a/test/io/test_raster.py +++ b/test/io/test_raster.py @@ -23,6 +23,11 @@ def has_rasterio() -> bool: return importlib.util.find_spec("rasterio") is not None +def numpy_version() -> np.lib.NumpyVersion: + """Get the version identifier of the imported NumPy package.""" + return np.lib.NumpyVersion(np.__version__) + + @contextmanager def make_geotiff_raster( fp: str | os.PathLike[str], @@ -226,6 +231,21 @@ def test_arraylike(self, geotiff_raster: snaphu.io.Raster): assert arr.shape == geotiff_raster.shape assert arr.dtype == geotiff_raster.dtype + @pytest.mark.parametrize("dtype", [np.int32, np.float64]) + def test_asarray_dtype(self, geotiff_raster: snaphu.io.Raster, dtype: DTypeLike): + arr = np.asarray(geotiff_raster, dtype=dtype) + assert arr.dtype == dtype + + @pytest.mark.skipif(numpy_version() < "2.0.0", reason="requires numpy>=2") + def test_asarray_copy(self, geotiff_raster: snaphu.io.Raster): + # Check that these statements don't raise exceptions. + np.asarray(geotiff_raster, copy=True) + np.asarray(geotiff_raster, copy=None) + + regex = "^unable to avoid copy while creating an array as requested$" + with pytest.raises(ValueError, match=regex): + np.asarray(geotiff_raster, copy=False) + def test_setitem_getitem_roundtrip(self, geotiff_raster: snaphu.io.Raster): data = np.arange(20, dtype=np.int32).reshape(4, 5) idx = np.s_[100:104, 200:205]