Skip to content

Commit

Permalink
Rechunk input arrays before deriving
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Jan 24, 2024
1 parent adca725 commit 963eceb
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/iris/aux_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,33 @@
import dask.array as da
import numpy as np

from iris._lazy_data import _optimum_chunksize
from iris.common import CFVariableMixin, CoordMetadata, metadata_manager_factory
import iris.coords
from iris.exceptions import IrisIgnoringBoundsWarning


def rechunk_args(func):
def wrapped(*args):
data = func(*args)
chunks = _optimum_chunksize(
data.chunksize,
shape=data.shape,
dtype=data.dtype,
)
rechunked_args = []
for arg in args:
if isinstance(arg, da.Array):
new_chunks = [
1 if arg.shape[i] == 1 else chunk for i, chunk in enumerate(chunks)
]
arg = arg.rechunk(new_chunks)
rechunked_args.append(arg)
return func(*rechunked_args)

return wrapped


class AuxCoordFactory(CFVariableMixin, metaclass=ABCMeta):
"""Represents a "factory" which can manufacture additional auxiliary coordinate.
Expand Down Expand Up @@ -813,6 +835,7 @@ def dependencies(self):
"surface_air_pressure": self.surface_air_pressure,
}

@rechunk_args
def _derive(self, delta, sigma, surface_air_pressure):
return delta + sigma * surface_air_pressure

Expand Down

0 comments on commit 963eceb

Please sign in to comment.