From 963eceb684ceb02baff3edf9050442c7be6fb2b8 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Wed, 24 Jan 2024 16:26:11 +0100 Subject: [PATCH] Rechunk input arrays before deriving --- lib/iris/aux_factory.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/lib/iris/aux_factory.py b/lib/iris/aux_factory.py index f447537b7d..e3e24b4a90 100644 --- a/lib/iris/aux_factory.py +++ b/lib/iris/aux_factory.py @@ -11,11 +11,33 @@ import dask.array as da import numpy as np +from iris._lazy_data import _optimum_chunksize from iris.common import CFVariableMixin, CoordMetadata, metadata_manager_factory import iris.coords from iris.exceptions import IrisIgnoringBoundsWarning +def rechunk_args(func): + def wrapped(*args): + data = func(*args) + chunks = _optimum_chunksize( + data.chunksize, + shape=data.shape, + dtype=data.dtype, + ) + rechunked_args = [] + for arg in args: + if isinstance(arg, da.Array): + new_chunks = [ + 1 if arg.shape[i] == 1 else chunk for i, chunk in enumerate(chunks) + ] + arg = arg.rechunk(new_chunks) + rechunked_args.append(arg) + return func(*rechunked_args) + + return wrapped + + class AuxCoordFactory(CFVariableMixin, metaclass=ABCMeta): """Represents a "factory" which can manufacture additional auxiliary coordinate. @@ -813,6 +835,7 @@ def dependencies(self): "surface_air_pressure": self.surface_air_pressure, } + @rechunk_args def _derive(self, delta, sigma, surface_air_pressure): return delta + sigma * surface_air_pressure