diff --git a/docs/src/whatsnew/latest.rst b/docs/src/whatsnew/latest.rst index dfc1019683..7af6e708c7 100644 --- a/docs/src/whatsnew/latest.rst +++ b/docs/src/whatsnew/latest.rst @@ -104,6 +104,8 @@ This document explains the changes made to Iris for this release lazy data from file. This will also speed up coordinate comparison. (:pull:`5610`) +#. `@bouweandela`_ added the option to specify the Dask chunks of the target + array in :func:`iris.util.broadcast_to_shape`. (:pull:`5620`) 🔥 Deprecations =============== diff --git a/lib/iris/tests/unit/util/test_broadcast_to_shape.py b/lib/iris/tests/unit/util/test_broadcast_to_shape.py index 6e32d6389d..a5d571a527 100644 --- a/lib/iris/tests/unit/util/test_broadcast_to_shape.py +++ b/lib/iris/tests/unit/util/test_broadcast_to_shape.py @@ -74,6 +74,31 @@ def test_lazy_masked(self, mocked_compute): for j in range(4): self.assertMaskedArrayEqual(b[i, :, j, :].compute().T, m.compute()) + @mock.patch.object(dask.base, "compute", wraps=dask.base.compute) + def test_lazy_chunks(self, mocked_compute): + # chunks can be specified along with the target shape and are only used + # along new dimensions or on dimensions that have size 1 in the source + # array. + m = da.ma.masked_array( + data=[[1, 2, 3, 4, 5]], + mask=[[0, 1, 0, 0, 0]], + ).rechunk((1, 2)) + b = broadcast_to_shape( + m, + dim_map=(1, 2), + shape=(3, 4, 5), + chunks=( + 1, # used because target is new dim + 2, # used because input size 1 + 3, # not used because broadcast does not rechunk + ), + ) + mocked_compute.assert_not_called() + for i in range(3): + for j in range(4): + self.assertMaskedArrayEqual(b[i, j, :].compute(), m[0].compute()) + assert b.chunks == ((1, 1, 1), (2, 2), (2, 2, 1)) + def test_masked_degenerate(self): # masked arrays can have degenerate masks too a = np.random.random([2, 3]) diff --git a/lib/iris/util.py b/lib/iris/util.py index 9ae1ceb919..ba99c7a985 100644 --- a/lib/iris/util.py +++ b/lib/iris/util.py @@ -25,7 +25,7 @@ import iris.exceptions -def broadcast_to_shape(array, shape, dim_map): +def broadcast_to_shape(array, shape, dim_map, chunks=None): """Broadcast an array to a given shape. Each dimension of the array must correspond to a dimension in the @@ -46,6 +46,13 @@ def broadcast_to_shape(array, shape, dim_map): the index in *shape* which the dimension of *array* corresponds to, so the first element of *dim_map* gives the index of *shape* that corresponds to the first dimension of *array* etc. + chunks : :class:`tuple`, optional + If the source array is a :class:`dask.array.Array` and a value is + provided, then the result will use these chunks instead of the same + chunks as the source array. Setting chunks explicitly as part of + broadcast_to_shape is more efficient than rechunking afterwards. The + values provided here will only be used along dimensions that are new on + the result or have size 1 on the source array. Examples -------- @@ -68,13 +75,25 @@ def broadcast_to_shape(array, shape, dim_map): See more at :doc:`/userguide/real_and_lazy_data`. """ + if isinstance(array, da.Array): + if chunks is not None: + chunks = list(chunks) + for src_idx, tgt_idx in enumerate(dim_map): + # Only use the specified chunks along new dimensions or on + # dimensions that have size 1 in the source array. + if array.shape[src_idx] != 1: + chunks[tgt_idx] = array.chunks[src_idx] + broadcast = functools.partial(da.broadcast_to, shape=shape, chunks=chunks) + else: + broadcast = functools.partial(np.broadcast_to, shape=shape) + n_orig_dims = len(array.shape) n_new_dims = len(shape) - n_orig_dims array = array.reshape(array.shape + (1,) * n_new_dims) # Get dims in required order. array = np.moveaxis(array, range(n_orig_dims), dim_map) - new_array = np.broadcast_to(array, shape) + new_array = broadcast(array) if ma.isMA(array): # broadcast_to strips masks so we need to handle them explicitly. @@ -82,13 +101,13 @@ def broadcast_to_shape(array, shape, dim_map): if mask is ma.nomask: new_mask = ma.nomask else: - new_mask = np.broadcast_to(mask, shape) + new_mask = broadcast(mask) new_array = ma.array(new_array, mask=new_mask) elif is_lazy_masked_data(array): # broadcast_to strips masks so we need to handle them explicitly. mask = da.ma.getmaskarray(array) - new_mask = da.broadcast_to(mask, shape) + new_mask = broadcast(mask) new_array = da.ma.masked_array(new_array, new_mask) return new_array