diff --git a/.github/workflows/ReceivePR.yml b/.github/workflows/ReceivePR.yml index 89d058ad31..7749056b8a 100644 --- a/.github/workflows/ReceivePR.yml +++ b/.github/workflows/ReceivePR.yml @@ -27,6 +27,7 @@ jobs: - name: Test run: | pip install .[dev] + pip install xarray pre-commit run --all-files python -m unittest diff --git a/.github/workflows/benchmark_main.yml b/.github/workflows/benchmark_main.yml new file mode 100644 index 0000000000..c226d36086 --- /dev/null +++ b/.github/workflows/benchmark_main.yml @@ -0,0 +1,46 @@ +name: Benchmark main and save +on: + push: + branches: + - main + +jobs: + benchmark-main: + name: Benchmark main and save + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup MPI + uses: mpi4py/setup-mpi@v1 + - name: Use Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: 3.10.11 # Perun only supports 3.8 and ahead + architecture: x64 + - name: Test + run: | + pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html + pip install xarray + pip install .[cb] + PERUN_RUN_ID=N4 mpirun -n 4 python benchmarks/cb/main.py + jq -s flatten bench_data/*.json > bench_data/all_benchmarks.json + - name: Save benchmark result and update gh-pages-chart + if: ${{github.ref == 'refs/heads/main'}} + uses: benchmark-action/github-action-benchmark@v1 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + # Benchmark action input and output + tool: 'customSmallerIsBetter' + output-file-path: bench_data/all_benchmarks.json + # external-data-json-path: ./cache/benchmark-data.json + # Alert configuration + fail-on-alert: false # Don't fail on main branch + comment-on-alert: true + # Save benchmarks from the main branch + save-data-file: true + # Pages configuration + auto-push: true + gh-pages-branch: gh-pages + benchmark-data-dir-path: dev/bench + # Upload the updated cache file for the next job by actions/cache diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml new file mode 100644 index 0000000000..db2271982f --- /dev/null +++ b/.github/workflows/benchmark_pr.yml @@ -0,0 +1,46 @@ +name: Benchmark PR +on: + pull_request: + types: [opened, synchronize, reopened, labeled] + branches: [main] + +jobs: + benchmark-pr: + name: Benchmark PR + if: contains(github.event.pull_request.labels.*.name, 'benchmark PR') + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup MPI + uses: mpi4py/setup-mpi@v1 + - name: Use Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: 3.10.11 # Perun only supports 3.8 and ahead + architecture: x64 + - name: Test + run: | + pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html + pip install xarray + pip install .[cb] + PERUN_RUN_ID=N4 mpirun -n 4 python benchmarks/cb/main.py + jq -s flatten bench_data/*.json > bench_data/all_benchmarks.json + - name: Compare benchmark result + if: ${{github.ref != 'refs/heads/main'}} + uses: benchmark-action/github-action-benchmark@v1 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + # Benchmark action input and output + tool: 'customSmallerIsBetter' + output-file-path: bench_data/all_benchmarks.json + # external-data-json-path: ./cache/benchmark-data.json + # Alert configuration + fail-on-alert: true + comment-on-alert: true + # Ignore results from non main branches. + save-data-file: false + # Pages configuration + auto-push: false + gh-pages-branch: gh-pages + benchmark-data-dir-path: dev/bench diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4da49b8e33..db4f47b98e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -52,5 +52,6 @@ jobs: pip install pytest pip install ${{ matrix.pytorch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install ${{ matrix.install-options }} + pip install xarray mpirun -n 3 pytest heat/ mpirun -n 4 pytest heat/ diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7ae8d9db86..aca95db058 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -9,7 +9,7 @@ from typing import Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional -from .communication import MPI +from .communication import MPI, sanitize_comm, Communication from .dndarray import DNDarray from . import arithmetics @@ -21,6 +21,7 @@ from . import tiling from . import types from . import _operations +from . import devices __all__ = [ "balance", @@ -37,6 +38,7 @@ "flip", "fliplr", "flipud", + "from_numpy", "hsplit", "hstack", "moveaxis", @@ -1141,6 +1143,34 @@ def flipud(a: DNDarray) -> DNDarray: return flip(a, 0) +def from_numpy( + x: np.ndarray, + split: Optional[int] = None, + device: Optional[Union[str, devices.Device]] = None, + comm: Optional[Communication] = None, +) -> DNDarray: + """ + Creates DNDarray from given NumPy Array. The data type is determined by the data type of the Numpy Array. + Split-dimension, device and communicator can be prescribed as usual. + Inverse of :meth:`DNDarray.numpy()`. + """ + dtype = types.canonical_heat_type(x.dtype) + device = devices.sanitize_device(device) + comm = sanitize_comm(comm) + xht = DNDarray( + torch.from_numpy(x).to(device.torch_device), + x.shape, + dtype=dtype, + split=None, + device=device, + comm=comm, + balanced=True, + ) + if split is not None: + xht.resplit_(split) + return xht + + def hsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]: """ Split array into multiple sub-DNDarrays along the 2nd axis (horizontally/column-wise). diff --git a/heat/dxarray/__init__.py b/heat/dxarray/__init__.py new file mode 100644 index 0000000000..73644d9223 --- /dev/null +++ b/heat/dxarray/__init__.py @@ -0,0 +1,7 @@ +""" +import into heat.dxarray namespace +""" + +from .dxarray import * +from .dxarray_sanitation import * +from .dxarray_manipulations import * diff --git a/heat/dxarray/dxarray.py b/heat/dxarray/dxarray.py new file mode 100644 index 0000000000..3f685ca559 --- /dev/null +++ b/heat/dxarray/dxarray.py @@ -0,0 +1,461 @@ +""" +Implements a distributed counterpart of xarray built on top of Heats DNDarray class +""" + +import torch +import heat as ht +import xarray as xr +from xarray import DataArray +from typing import Union + +# imports of "dxarray_..."-dependencies at the end to avoid cyclic dependence + +__all__ = ["dim_name_to_idx", "dim_idx_to_name", "DXarray", "from_xarray"] + +# Auxiliary functions + + +def dim_name_to_idx(dims: list, names: Union[str, tuple, list, None]) -> Union[int, tuple, list]: + """ + Converts a string "names" (or tuple of strings) referring to dimensions stored in "dims" to the corresponding numeric index (tuple of indices) of these dimensions. + Inverse of :func:`dim_idx_to_name`. + """ + if names is None: + return None + elif isinstance(names, str): + return dims.index(names) + elif isinstance(names, tuple): + names_list = list(names) + return tuple([dims.index(name) for name in names_list]) + elif isinstance(names, list): + return [dims.index(name) for name in names] + else: + raise TypeError("Input names must be None, string, list of strings, or tuple of strings.") + + +def dim_idx_to_name(dims: list, idxs: Union[int, tuple, list, None]) -> Union[str, tuple, list]: + """ + Converts an numeric index "idxs" (or tuple of such indices) referring to the dimensions stored in "dims" to the corresponding name string (or tuple of name strings). + Inverse of :func:`dim_name_to_idx`. + """ + if idxs is None: + return None + elif isinstance(idxs, int): + return dims[idxs] + elif isinstance(idxs, tuple): + idxs_list = list(idxs) + return tuple([dims[idx] for idx in idxs_list]) + elif isinstance(idxs, list): + return [dims[idx] for idx in idxs] + else: + raise TypeError("Input idxs must be None, int, list of ints, or tuple of ints.") + + +class DXarray: + """ + Distributed counterpart of xarray. + + Parameters + -------------- + values: DNDarray + data entries of the DXarray + dims: list + names of the dimensions of the DXarray + coords: dictionary + coordinates + entries of the dictionary have the form `dim`:`coords_of_dim` for each `dim` in `dims`, + where `coords_of_dim` can either be a list of coordinate labels ("logical coordinates") or an + DXarray of same shape as the original one, also split along the same split axis ("physical coordinates"). + split: Union[int,None] + dimension along which the DXarray is split (analogous to split dimension of DNDarray) + + Notes + --------------- + Some attributes of DNDarray are not included in DXarray, e.g., gshape, lshape, larray etc., and need to be accessed by + DXarray.values.gshape etc. + This is in order to avoid confusion, because a DXarray is built of possibly several DNDarrays which could cause confusion + to which gshape etc. a global attribute DXarray.gshape could refer to. + Currently, it is checked whether values and coords are on the same `device`; in principle, this is unnecessary. + """ + + def __init__( + self, + values: ht.DNDarray, + dims: Union[list, None] = None, + coords: Union[dict, None] = None, + name: Union[str, None] = None, + attrs: dict = {}, + ): + """ + Constructor for DXarray class + """ + # Check compatibility of the input arguments + dxarray_sanitation.check_compatibility_values_dims_coords(values, dims, coords) + dxarray_sanitation.check_name(name) + dxarray_sanitation.check_attrs(attrs) + + # after the checks, set the directly given attributes... + + self.__values = values + self.__name = name + self.__attrs = attrs + self.__coords = coords + self.__device = values.device + self.__comm = values.comm + + # if no names are provided, introduce generic names "dim_N", N = 0,1,... + if dims is None: + dims = ["dim_%d" % k for k in range(self.__values.ndim)] + + self.__dims = dims + + # ... and determine those not directly given: + # since we are in the DXarray class, split dimension is given by a string + self.__split = dim_idx_to_name(dims, values.split) + + # determine dimensions with and without coordinates + if coords is not None: + dims_with_coords = sum( + [list(it[0]) if isinstance(it[0], tuple) else [it[0]] for it in coords.items()], [] + ) + else: + dims_with_coords = [] + dims_without_coords = [dim for dim in dims if dim not in dims_with_coords] + + self.__dims_with_coords = dims_with_coords + self.__dims_without_coords = dims_without_coords + + # check if all appearing DNDarrays are balanced: as a result, the DXarray is balanced if and only if all DNDarrays are balanced + self.__balanced = dxarray_sanitation.check_if_balanced( + self.__values, self.__coords, force_check=False + ) + + """ + Attribute getters and setters for the DXarray class + """ + + @property + def values(self) -> ht.DNDarray: + """ + Get values from DXarray + """ + return self.__values + + @property + def dims(self) -> list: + """ + Get dims from DXarray + """ + return self.__dims + + @property + def coords(self) -> dict: + """ + Get coords from DXarray + """ + return self.__coords + + @property + def split(self) -> Union[str, None]: + """ + Get split dimension from DXarray + """ + return self.__split + + @property + def device(self) -> ht.Device: + """ + Get device from DXarray + """ + return self.__device + + @property + def comm(self) -> ht.Communication: + """ + Get communicator from DXarray + """ + return self.__comm + + @property + def name(self) -> str: + """ + Get name from DXarray + """ + return self.__name + + @property + def attrs(self) -> dict: + """ + Get attributes from DXarray + """ + return self.__attrs + + @property + def dims_with_coords(self) -> list: + """ + Get list of dims with coordinates from DXarray + """ + return self.__dims_with_coords + + @property + def dims_without_coords(self) -> list: + """ + Get list of dims without coordinates from DXarray + """ + return self.__dims_without_coords + + @property + def balanced(self) -> bool: + """ + Get the attributed `balanced` of DXarray. + Does not check whether the current value of this attribute is consistent! + (This can be ensured by calling :meth:`DXarray.is_balanced(force_check=True)` first.) + """ + return self.__balanced + + @values.setter + def values(self, newvalues: ht.DNDarray): + """ + Set value array of DXarray + """ + dxarray_sanitation.check_compatibility_values_dims_coords( + newvalues, self.__dims, self.__coords + ) + self.__values = newvalues + + @coords.setter + def coors(self, newcoords: Union[dict, None]): + """ + Set coordinates of DXarray + """ + dxarray_sanitation.check_compatibility_values_dims_coords( + self.__values, self.__dims, newcoords + ) + self.__coords = newcoords + + @name.setter + def name(self, newname: Union[str, None]): + """ + Set name of DXarray + """ + dxarray_sanitation.check_name(newname) + self.__name = newname + + @attrs.setter + def attrs(self, newattrs: Union[dict, None]): + """ + Set attributes of DXarray + """ + dxarray_sanitation.check_attrs(newattrs) + self.__attrs = newattrs + + """ + Private methods of DXarray class + """ + + def __dim_name_to_idx( + self, names: Union[str, tuple, list, None] + ) -> Union[str, tuple, list, None]: + """ + Converts a string (or tuple of strings) referring to dimensions of the DXarray to the corresponding numeric index (tuple of indices) of these dimensions. + Inverse of :meth:`__dim_idx_to_name`. + """ + return dim_name_to_idx(self.__dims, names) + + def __dim_idx_to_name( + self, idxs: Union[int, tuple, list, None] + ) -> Union[int, tuple, list, None]: + """ + Converts an numeric index (or tuple of such indices) referring to the dimensions of the DXarray to the corresponding name string (or tuple of name strings). + Inverse of :meth:`__dim_name_to_idx`. + """ + return dim_idx_to_name(self.__dims, idxs) + + def __repr__(self) -> str: + """ + Representation of DXarray as string. Required for printing. + """ + if self.__name is not None: + print_name = self.__name + else: + print_name = "" + print_values = self.__values.__repr__() + print_dimensions = ", ".join(self.__dims) + if self.__split is not None: + print_split = self.__split + else: + print_split = "None (no splitted)" + if self.__coords is not None: + print_coords = "\n".join( + [it[0].__repr__() + ": \t" + it[1].__repr__() for it in self.__coords.items()] + ) + print_coords = 'Coordinates of "' + print_name + '": ' + print_coords + else: + print_coords = "" + print_attributes = "\n".join( + ["\t" + it[0].__repr__() + ": \t" + it[1].__repr__() for it in self.__attrs.items()] + ) + if len(self.__dims_without_coords) != 0: + print_coordinates_without_dims = "".join( + [ + 'The remaining coordinates of "', + print_name, + '", ', + ", ".join(self.__dims_without_coords), + ", do not have coordinates. \n", + ] + ) + else: + print_coordinates_without_dims = "" + if self.__comm.rank == 0: + return "".join( + [ + 'DXarray with name "', + print_name, + '"\n', + 'Dimensions of "', + print_name, + '": ', + print_dimensions, + "\n", + 'Split dimension of "', + print_name, + '": ', + print_split, + "\n", + 'Values of "', + print_name, + '": ', + print_values, + "\n", + print_coords, + "\n", + print_coordinates_without_dims, + 'Attributes of "', + print_name, + '":', + print_attributes, + "\n\n", + ] + ) + else: + return "" + + """ + Public Methods of DXarray + """ + + def is_balanced(self, force_check: bool = False) -> bool: + """ + Checks if DXarray is balanced. If `force_check = False` (default), the current value of the + attribute `balanced` is returned unless this current value is None (i.e. no information on + no information available); only in the latter case, or if `force_check = True`, the value + of the attribute `balanced` is updated before being returned. + + """ + if self.__balanced is None or force_check: + self.__balanced = dxarray_sanitation.check_if_balanced( + self.__values, self.__coords, force_check=True + ) + return self.__balanced + + def resplit_(self, dim: Union[str, None] = None): + """ + In-place option for resplitting a :class:`DXarray`. + """ + if dim is not None and dim not in self.__dims: + raise ValueError( + "Input `dim` in resplit_ must be either None or a dimension of the underlying DXarray." + ) + # early out if nothing is to do + if self.__split == dim: + return self + else: + # resplit the value array accordingly + self.__values.resplit_(self.__dim_name_to_idx(dim)) + if self.__coords is not None: + for item in self.__coords.items(): + if isinstance(item[0], str) and item[0] == dim: + item[1].resplit_(0) + elif isinstance(item[0], tuple) and dim in item[0]: + item[1].resplit_(dim) + self.__split = dim + return self + + def balance_(self): + """ + In-place option for balancing a :class:`DXarray`. + """ + if self.is_balanced(force_check=True): + return self + else: + self.__values.balance_() + if self.__coords is not None: + for item in self.__coords.items(): + item[1].balance_() + self.__balanced = True + return self + + def xarray(self): + """ + Convert given DXarray (possibly distributed over some processes) to a non-distributed xarray (:class:`xarray.DataArray`) on all processes. + """ + non_dist_copy = self.resplit_(None) + if non_dist_copy.coords is None: + xarray_coords = None + else: + xarray_coords = { + item[0]: item[1].cpu().numpy() + if isinstance(item[1], ht.DNDarray) + else item[1].xarray() + for item in non_dist_copy.coords.items() + } + xarray = DataArray( + non_dist_copy.values.cpu().numpy(), + dims=non_dist_copy.dims, + coords=xarray_coords, + name=non_dist_copy.name, + attrs=non_dist_copy.attrs, + ) + del non_dist_copy + return xarray + + +def from_xarray( + xarray: xr.DataArray, + split: Union[str, None] = None, + device: ht.Device = None, + comm: ht.Communication = None, +) -> DXarray: + """ + Generates a DXarray from a given xarray (:class:`xarray.DataArray`) + """ + coords_dict = { + item[0]: ht.from_numpy(item[1].values, device=device, comm=comm) + if len(item[0]) == 1 + else DXarray( + ht.from_numpy(item[1].values, device=device, comm=comm), + dims=list(item[0]), + coords=None, + name=item[1].name.__str__(), + attrs=item[1].attrs, + ) + for item in xarray.coords.items() + } + dxarray = DXarray( + ht.from_numpy(xarray.values, device=device, comm=comm), + dims=list(xarray.dims), + coords=coords_dict, + name=xarray.name, + attrs=xarray.attrs, + ) + if split is not None: + if split not in dxarray.dims: + raise ValueError('split dimension "', split, '" is not a dimension of input array.') + else: + dxarray.resplit_(split) + return dxarray + + +from . import dxarray_sanitation +from . import dxarray_manipulations diff --git a/heat/dxarray/dxarray_manipulations.py b/heat/dxarray/dxarray_manipulations.py new file mode 100644 index 0000000000..d105862cd7 --- /dev/null +++ b/heat/dxarray/dxarray_manipulations.py @@ -0,0 +1,8 @@ +""" +Manipulation routines for the DXarray class +""" + +import torch +import heat as ht + +from .dxarray import DXarray diff --git a/heat/dxarray/dxarray_operations.py b/heat/dxarray/dxarray_operations.py new file mode 100644 index 0000000000..0a4e89b15c --- /dev/null +++ b/heat/dxarray/dxarray_operations.py @@ -0,0 +1,8 @@ +""" +Operations on Dxarray objects +""" + +import torch +import heat as ht + +from .dxarray import DXarray diff --git a/heat/dxarray/dxarray_sanitation.py b/heat/dxarray/dxarray_sanitation.py new file mode 100644 index 0000000000..cc66461226 --- /dev/null +++ b/heat/dxarray/dxarray_sanitation.py @@ -0,0 +1,198 @@ +""" +Validation/Sanitation routines for the DXarray class +""" + +import torch +import heat as ht +from typing import Any, Union + +from .dxarray import DXarray, dim_name_to_idx, dim_idx_to_name + + +def check_compatibility_values_dims_coords( + values: ht.DNDarray, dims: Union[list, None], coords: Union[dict, None] +): + """ + Checks whether input values, dims, and coords are valid and compatible inputs for a DXarray + """ + if not isinstance(values, ht.DNDarray): + raise TypeError("Input `values` must be a DNDarray, but is ", type(values), ".") + if not (isinstance(dims, list) or dims is None): + raise TypeError("Input `dims` must be a list or None, but is ", type(dims), ".") + if not (isinstance(coords, dict) or coords is None): + raise TypeError("Input `coords` must be a dictionary or None, but is ", type(coords), ".") + + # check if entries of dims are unique + if dims is not None: + if len(set(dims)) != len(dims): + raise ValueError("Entries of `dims` must be unique.") + + # check if names of dims are given (and whether their number fits the number of dims of the values array) + if dims is not None: + if len(dims) != values.ndim: + raise ValueError( + "Number of dimension names in `dims` (=%d) must be equal to number of dimensions of `values` array (=%d)." + % (len(dims), values.ndim) + ) + + # check consistency of the coordinates provided + if coords is not None: + # go through all entries in the dictionary coords + for coord_item in coords.items(): + coord_item_dims = coord_item[0] + coord_item_coords = coord_item[1] + # first case: "classical" coordinates for a single dimension, sometimes referred to "logical coordinates" + if isinstance(coord_item_dims, str): + # here, the coordinates must be given by a one-dimensional DNDarray... + if not isinstance(coord_item_coords, ht.DNDarray): + raise TypeError( + "Coordinate arrays (i.e. entries of `coords`) for single dimension must be DNDarray. Here, type ", + type(coord_item_coords), + " is given for dimension ", + coord_item_dims, + ".", + ) + if not coord_item_coords.ndim == 1: + raise ValueError( + "Coordinate arrays for a single dimension must have dimension 1, but coordinate array for dimension ", + coord_item_dims, + " has dimension %d." % coord_item_coords.ndim, + ) + # ... with matching device and communicator, ... + if not coord_item_coords.device == values.device: + raise RuntimeError( + "Device of coordinate array for dimension ", + coord_item_dims, + "does not coincide with device for `values`.", + ) + if not coord_item_coords.comm == values.comm: + raise RuntimeError( + "Communicator of coordinate array for dimension ", + coord_item_dims, + "does not coincide with device for `values`.", + ) + # ... correct shape, and ... + if not ( + coord_item_coords.gshape[0] + == values.gshape[dim_name_to_idx(dims, coord_item_dims)] + ): + raise ValueError( + "Size of `values` in dimension ", + coord_item_dims, + " does not coincide with size of coordinate array in this dimension.", + ) + # ... that is split if and only if the coordinates refer to the split dimension of the DXarray + if coord_item_dims == dim_idx_to_name(dims, values.split): + if coord_item_coords.split != 0: + raise ValueError( + "`values` array is split along dimension ", + coord_item_dims, + ", but cooresponding coordinate array is not split along this dimension.", + ) + else: + if coord_item_coords.split is not None: + raise ValueError( + "`values` array is not split along dimension ", + coord_item_dims, + ", but cooresponding coordinate array is split along this dimension.", + ) + # second case: "physical coordinates" - two or more dimensions are "merged" together and equipped with a coordinate array + # that cannot be expressed as meshgrid of 1d coordinate arrays + elif isinstance(coord_item_dims, tuple): + # now, the coordinates must be given as a DXarray... + if not isinstance(coord_item_coords, DXarray): + raise TypeError( + "Coordinate arrays (i.e. entries of `coords`) must be DXarrays. Here, type ", + type(coord_item_coords), + " is given for dimensions ", + coord_item_dims, + ".", + ) + # ... with matching dimension names, ... + if coord_item_coords.dims != list(coord_item_dims): + raise ValueError( + "Dimension names of coordinate-DXarray and the corresponding dimension names in `coords` must be equal." + ) + # ... shape, ... + if not ( + torch.tensor(coord_item_coords.values.gshape) + == torch.tensor(values.gshape)[dim_name_to_idx(dims, list(coord_item_dims))] + ).all(): + raise ValueError( + "Size of `values` in dimensions ", + coord_item_dims, + " does not coincide with size of coordinate array in these dimensions.", + ) + # ... device and communicator, ... + if not coord_item_coords.device == values.device: + raise RuntimeError( + "Device of coordinate array for dimensions ", + coord_item_dims, + "does not coincide with device for `values`.", + ) + if not coord_item_coords.comm == values.comm: + raise RuntimeError( + "Communicator of coordinate array for dimensions ", + coord_item_dims, + "does not coincide with device for `values`.", + ) + # ... and split dimension. + if dim_idx_to_name(dims, values.split) in coord_item_dims: + if not coord_item_coords.split == dim_idx_to_name(dims, values.split): + raise ValueError( + "`values` array is split along dimension ", + coord_item_dims, + ", but cooresponding coordinate array is not split along ", + coord_item_coords.split, + ".", + ) + else: + if coord_item_coords.split is not None: + raise ValueError( + "`values` array is not split along dimensions ", + coord_item_dims, + ", but cooresponding coordinate array is split.", + ) + + +def check_name(name: Any): + """ + Checks whether input is appropriate for attribute `name` of `DXarray` + """ + if not (isinstance(name, str) or name is None): + raise TypeError("`name` must be a string or None, but is ", type(name), ".") + + +def check_attrs(attrs: Any): + """ + Checks whether input is appropriate for attributed `attrs` of `DXarray`. + """ + if not (isinstance(attrs, dict) or attrs is None): + raise TypeError("`attrs` must be a dictionary or None, but is ", type(attrs), ".") + + +def check_if_balanced(values: ht.DNDarray, coords: Union[dict, None], force_check: bool = False): + """ + Checks if a DXarray with values and coords is balanced, i.e., equally distributed on each process + A DXarray is balanced if and only if all underlying DNDarrays are balanced. + force_check allows to force a check on balancedness of the underlying DNDarrays. + """ + if not force_check: + if values.balanced is None or values.balanced is False or coords is None: + return values.balanced + else: + coords_balanced = [coord_item[1].balanced for coord_item in coords.items()] + if None in coords_balanced: + return None + else: + balanced = values.balanced and all(coords_balanced) + return balanced + else: + values_balanced = values.is_balanced(force_check=True) + if values_balanced is False or coords is None: + return values_balanced + else: + coords_balanced = [ + coord_item[1].is_balanced(force_check=True) for coord_item in coords.items() + ] + return values_balanced and all(coords_balanced) diff --git a/heat/dxarray/tests/__init__.py b/heat/dxarray/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/heat/dxarray/tests/test_dxarray.py b/heat/dxarray/tests/test_dxarray.py new file mode 100644 index 0000000000..4c93b3abf0 --- /dev/null +++ b/heat/dxarray/tests/test_dxarray.py @@ -0,0 +1,340 @@ +import torch +import os +import unittest +import heat as ht +import numpy as np +import xarray as xr +from mpi4py import MPI + +from heat.core.tests.test_suites.basic_test import TestCase + +nprocs = MPI.COMM_WORLD.Get_size() + + +class TestHelpers(TestCase): + def test_dim_name_idx_conversion(self): + dims = ["x", "y", "z-axis", "time", None] + for names in ["z-axis", ("time", "x"), ["x", "y"]]: + idxs = ht.dxarray.dim_name_to_idx(dims, names) + # check for correct types (str, tuple or list) + self.assertTrue( + type(idxs) is type(names) or (isinstance(names, str) and isinstance(idxs, int)) + ) + # check if dim_name_to_idx and dim_idx_to_name are inverse to each other + names_back = ht.dxarray.dim_idx_to_name(dims, idxs) + self.assertEqual(names_back, names) + # check if TypeError is raised for wrong input types + names = 3.14 + with self.assertRaises(TypeError): + ht.dxarray.dim_name_to_idx(dims, names) + with self.assertRaises(TypeError): + ht.dxarray.dim_idx_to_name(dims, names) + + +class TestDXarray(TestCase): + def test_constructor_and_attributes(self): + m = 2 + n = 3 * nprocs + k = 10 + ell = 2 + + # test constructor in a case that should work and also test if all attributes of the DXarray are set correctly + # here we include a dimension ("no_measurements") without coordinates and two dimensions ("x", "y") with physical instead of logical coordinates + xy = ht.random.rand(m, n, split=1) + t = ht.linspace(-1, 1, k, split=None) + attrs_xy = {"units_xy": "meters"} + xy_coords = ht.dxarray.DXarray( + xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + data = ht.random.randn(m, n, k, ell, split=1) + name = "mytestarray" + attrs = { + "units time": "seconds", + "measured data": "something really random and meaningless", + } + dims = ["x", "y", "time", "no_measurements"] + coords = {("x", "y"): xy_coords, "time": t} + + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=name, attrs=attrs) + + self.assertEqual(dxarray.name, name) + self.assertEqual(dxarray.attrs, attrs) + self.assertEqual(dxarray.dims, dims) + self.assertEqual(dxarray.coords, coords) + self.assertTrue(ht.allclose(dxarray.values, data)) + self.assertEqual(dxarray.device, data.device) + self.assertEqual(dxarray.comm, data.comm) + self.assertEqual(dxarray.dims_with_coords, ["x", "y", "time"]) + self.assertEqual(dxarray.dims_without_coords, ["no_measurements"]) + self.assertEqual(dxarray.split, "y") + self.assertEqual(dxarray.balanced, True) + + # test print + print(dxarray) + + # special case that dim names have to bet set automatically and that there are no coords at all + dxarray = ht.dxarray.DXarray(data) + dims = ["dim_0", "dim_1", "dim_2", "dim_3"] + self.assertEqual(dxarray.dims, dims) + self.assertEqual(dxarray.dims_with_coords, []) + self.assertEqual(dxarray.dims_without_coords, dims) + self.assertEqual(dxarray.split, "dim_1") + self.assertEqual(dxarray.balanced, True) + + # test print + print(dxarray) + + def test_sanity_checks(self): + m = 2 + n = 3 * nprocs + k = 5 * nprocs + ell = 2 + + # here comes the "correct" data + xy = ht.random.rand(m, n, split=1) + t = ht.linspace(-1, 1, k, split=None) + attrs_xy = {"units_xy": "meters"} + xy_coords = ht.dxarray.DXarray( + xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + data = ht.random.randn(m, n, k, ell, split=1) + name = "mytestarray" + attrs = { + "units time": "seconds", + "measured data": "something really random and meaningless", + } + dims = ["x", "y", "time", "no_measurements"] + coords = {("x", "y"): xy_coords, "time": t} + + # wrong data type for name + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=3.14, attrs=attrs) + + # wrong data type for attrs + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=name, attrs=3.14) + + # wrong data type for value + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray(3.14, dims=dims, coords=coords, name=name, attrs=attrs) + + # wrong data type for dims + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray(data, dims=3.14, coords=coords, name=name, attrs=attrs) + + # wrong data type for coords + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=3.14, name=name, attrs=attrs) + + # length of dims and number of dimensions of value array do not match + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=["x", "y", "time"], coords=coords, name=name, attrs=attrs + ) + + # entries of dims are not unique + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=["x", "y", "x", "no_measurements"], coords=coords, name=name, attrs=attrs + ) + + # coordinate array for single dimension is not a DNDarray + wrong_coords = {("x", "y"): xy_coords, "time": 3.14} + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # coordinate array for single dimension has wrong dimensionality + wrong_coords = {("x", "y"): xy_coords, "time": ht.ones((k, 2))} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # device of a coordinate array does not coincide with device of value array + # TBD - how to test this? + + # communicator of a coordinate array does not coincide with communicator of value array + # TBD - how to test this? + + # size of value array in a dimension does not coincide with size of coordinate array in this dimension + wrong_coords = {("x", "y"): xy_coords, "time": ht.ones(nprocs * k + 1)} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # value array is split along a dimension, but cooresponding coordinate array is not split along this dimension + wrong_data = ht.resplit(data, 2) + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + wrong_data, dims=dims, coords=coords, name=name, attrs=attrs + ) + + # value array is not split along a dimension, but cooresponding coordinate array is split along this dimension + wrong_coords = {("x", "y"): xy_coords, "time": ht.resplit(t, 0)} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # coordinate array in the case of "physical coordinates" is not a DXarray + wrong_coords = {("x", "y"): 3.14, "time": t} + with self.assertRaises(TypeError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # dimension names in coordinate DXarray in the case of "physical coordinates" do not coincide with dimension names of value array + wrong_coords_xy = ht.dxarray.DXarray( + xy, dims=["xx", "yy"], attrs=attrs_xy, name="coordinates of space" + ) + wrong_coords = {("x", "y"): wrong_coords_xy, "time": t} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # size of values for physical coordinates does not coincide with size of the respective coordinate array + wrong_xy = ht.random.rand(m + 1, n, split=1) + wrong_coords_xy = ht.dxarray.DXarray( + wrong_xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + wrong_coords = {("x", "y"): wrong_coords_xy, "time": t} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + + # communicator of coordinate array for physical coordinates does not coincide with communicator of value array + # TBD - how to test this? + + # device of coordinate array for physical coordinates does not coincide with device of value array + # TBD - how to test this? + + # coordinate array for physical coordinates is not split along the split dimension of the value array (two cases) + wrong_data = ht.random.randn(m, n, k, ell) + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + wrong_data, dims=dims, coords=coords, name=name, attrs=attrs + ) + wrong_xy = ht.random.rand(m, n) + wrong_coords_xy = ht.dxarray.DXarray( + wrong_xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + wrong_coords = {("x", "y"): wrong_coords_xy, "time": t} + with self.assertRaises(ValueError): + dxarray = ht.dxarray.DXarray( + data, dims=dims, coords=wrong_coords, name=name, attrs=attrs + ) + dxarray *= 1 + + def test_balanced_and_balancing(self): + m = 2 + n = 5 * nprocs + k = 2 + ell = 2 + + # create a highly unbalanced array for the values but not for the coordinates + xy = ht.random.rand(m, n, split=1) + xy = xy[:, 4:] + xy.balance_() + t = ht.linspace(-1, 1, k, split=None) + attrs_xy = {"units_xy": "meters"} + xy_coords = ht.dxarray.DXarray( + xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + data = ht.random.randn(m, n, k, ell, split=1) + data = data[:, 4:, :, :] + name = "mytestarray" + attrs = { + "units time": "seconds", + "measured data": "something really random and meaningless", + } + dims = ["x", "y", "time", "no_measurements"] + coords = {("x", "y"): xy_coords, "time": t} + + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=name, attrs=attrs) + + # balancedness-status is first unknown, then known as false (if explicitly checked) and finally known as false after this check + self.assertEqual(dxarray.balanced, None) + self.assertEqual(dxarray.is_balanced(), False) + self.assertEqual(dxarray.balanced, False) + + # rebalancing should work + dxarray.balance_() + self.assertEqual(dxarray.balanced, True) + self.assertEqual(dxarray.is_balanced(force_check=True), True) + + # rebalanced array should be equal to original one + self.assertTrue(ht.allclose(dxarray.values, data)) + self.assertEqual(dxarray.dims, dims) + self.assertEqual(dxarray.dims_with_coords, ["x", "y", "time"]) + self.assertEqual(dxarray.dims_without_coords, ["no_measurements"]) + self.assertEqual(dxarray.name, name) + self.assertEqual(dxarray.attrs, attrs) + # TBD: check for equality of coordinate arrays + + def test_resplit_(self): + m = 2 * nprocs + n = 3 * nprocs + k = 5 * nprocs + ell = 2 + + xy = ht.random.rand(m, n, split=1) + t = ht.linspace(-1, 1, k, split=None) + attrs_xy = {"units_xy": "meters"} + xy_coords = ht.dxarray.DXarray( + xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + data = ht.random.randn(m, n, k, ell, split=1) + name = "mytestarray" + attrs = { + "units time": "seconds", + "measured data": "something really random and meaningless", + } + dims = ["x", "y", "time", "no_measurements"] + coords = {("x", "y"): xy_coords, "time": t} + + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=name, attrs=attrs) + for newsplit in ["x", "time", None, "y"]: + dxarray.resplit_(newsplit) + self.assertEqual(dxarray.split, newsplit) + self.assertTrue(ht.allclose(dxarray.values, data)) + self.assertEqual(dxarray.dims, dims) + self.assertEqual(dxarray.dims_with_coords, ["x", "y", "time"]) + self.assertEqual(dxarray.dims_without_coords, ["no_measurements"]) + self.assertEqual(dxarray.name, name) + self.assertEqual(dxarray.attrs, attrs) + # TBD: check for equality of coordinate arrays + + def test_to_and_from_xarray(self): + m = 2 + n = 3 * nprocs + k = 10 + ell = 2 + + # test constructor in a case that should work and also test if all attributes of the DXarray are set correctly + # here we include a dimension ("no_measurements") without coordinates and two dimensions ("x", "y") with physical instead of logical coordinates + xy = ht.random.rand(m, n, split=1) + t = ht.linspace(-1, 1, k, split=None) + attrs_xy = {"units_xy": "meters"} + xy_coords = ht.dxarray.DXarray( + xy, dims=["x", "y"], attrs=attrs_xy, name="coordinates of space" + ) + data = ht.random.randn(m, n, k, ell, split=1) + name = "mytestarray" + attrs = { + "units time": "seconds", + "measured data": "something really random and meaningless", + } + dims = ["x", "y", "time", "no_measurements"] + coords = {("x", "y"): xy_coords, "time": t} + + dxarray = ht.dxarray.DXarray(data, dims=dims, coords=coords, name=name, attrs=attrs) + + xarray = dxarray.xarray() + print(xarray) + # TBD convert back and check for equality (or the other way round?) + # dxarray_from_xarray = ht.dxarray.from_xarray(xarray,split=dxarray.split,device=dxarray.device)