From 562a39be8ac97f6bce1efc5e5f10ea91ba8e63fa Mon Sep 17 00:00:00 2001 From: Egor Achkasov Date: Mon, 22 Jul 2024 14:09:10 +0000 Subject: [PATCH] Signal getitem flattened --- hermespy/core/signal_model.py | 55 +++++++++++++++++++++------- tests/unit_tests/core/test_signal.py | 6 +++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/hermespy/core/signal_model.py b/hermespy/core/signal_model.py index 833c357d..07201288 100644 --- a/hermespy/core/signal_model.py +++ b/hermespy/core/signal_model.py @@ -766,7 +766,7 @@ def __parse_slice(s: slice, dim_size: int) -> Tuple[int, int, int]: s1 = s1 if s1 >= 0 else s1 % dim_size return s0, s1, s2 - def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, int, bool]: + def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, int, bool, bool, bool]: """Parse and validate key in __getitem__ and __setitem__. Raises: @@ -785,11 +785,13 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s11 (int): samples stop s12 (int): samples step isboolmask (bool): True if key is a boolean mask, False otherwise. + should_flatten_streams (bool): True if numpy's getitem would flatten the streams (1) dimension with this key + should_flatten_samples (bool): True if numpy's getitem would flatten the samples (2) dimension with this key Note that if isboolmask is True, then all s?? take the following values: (0, self.num_streams, 1, 0, self.num_samples, 1). - Note that if the key references any dimansion with an integer index, + Note that if the key references any dimension with an integer index, then the corresponding result start will be the index, and stop is start+1. For example, if key is 1, then only stream 1 is need. Then s00 is 1 and s01 is 2. Numpy getitem of [1] and [1:2] differ in dimensions. Flattening of the second variant should be considered. @@ -798,6 +800,8 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in self_num_streams = self.num_streams self_num_samples = self.num_samples isboolmask = False + should_flatten_streams = False + should_flatten_samples = False # Key is a tuple of two # ====================================================================== @@ -819,6 +823,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s00 = key[0] % self_num_streams s01 = s00 + 1 s02 = 1 + should_flatten_streams = True else: raise TypeError( f"Expected to get streams index as an integer or a slice, but got {type(key[0])}" @@ -834,10 +839,11 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s10 = key[1] % self_num_samples s11 = s10 + 1 s12 = 1 + should_flatten_samples = True else: - raise TypeError(f"Samples key is of an unsupported type ({type(key[1])})") + raise TypeError(f"Samples key is ofan unsupported type ({type(key[1])})") - return s00, s01, s02, s10, s11, s12, False + return s00, s01, s02, s10, s11, s12, False, should_flatten_streams, should_flatten_samples # ====================================================================== # done Key is a tuple of two @@ -861,6 +867,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s00 = key % self_num_streams s01 = s00 + 1 s02 = 1 + should_flatten_streams = True # Key is a boolean mask or something unsupported else: try: @@ -874,7 +881,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in except ValueError: raise TypeError(f"Unsupported key type {type(key)}") - return s00, s01, s02, s10, s11, s12, isboolmask + return s00, s01, s02, s10, s11, s12, isboolmask, should_flatten_streams, should_flatten_samples def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]: """Find indices of blocks that are affected by the given samples slice. @@ -917,7 +924,7 @@ def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]: return b_start, b_stop - def getitem(self, key: Any = slice(None, None)) -> np.ndarray: + def getitem(self, key: Any = slice(None, None), unflatten : bool = True) -> np.ndarray: """Get specified samples. Works like np.ndarray.__getitem__, but de-sparsifies the signal. @@ -927,6 +934,9 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: a tuple (int, int), (int, slice), (slice, int), (slice, slice) or a boolean mask. Defaults to slice(None, None) (same as [:, :]) + unflatten (bool): + Set to True to ensure the result ndim to be 2 even if only one stream is selected. + Set to False to allow the numpy-like degenerate dimensions reduction. Examples: getitem(slice(None, None)): @@ -934,13 +944,17 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: Warning: can cause memory overflow if used with a sparse signal. getitem(0): Select and de-sparsify the first stream. + Result shape is (1, num_samples) + getitem(0, False): + Same, but allow the numpy flattening. + Result shape is (num_samples,) getitem((slice(None, 2), slice(50, 100))): Select streams 0, 1 and samples 50-99. Same as samples_matrix[:2, 50:100] - Returns: np.ndarray with ndim 2 and dtype np.complex_""" + Returns: np.ndarray with ndim 2 or less and dtype dtype np.complex_""" - s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key) + s00, s01, s02, s10, s11, s12, isboolmask, should_flatten_streams, should_flatten_samples = self._parse_validate_itemkey(key) num_streams = -((s01 - s00) // -s02) num_samples = -((s11 - s10) // -s12) if self.num_samples == 0 or self.num_streams == 0: # if this signal is empty @@ -968,6 +982,9 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: # ^b previous^^gap^^b_start/b_stop^ elif s11 > b.offset: res[:, b.offset - s10 :] = b[s00:s01:s02, :] + # Apply numpy-like flattening + if not unflatten and (should_flatten_streams or should_flatten_samples): + res = res.flatten() return res # assemble the result @@ -1012,7 +1029,13 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: ) if is_streams_step_reversing: return res[::s02, ::s12] - return res[s00:s01:s02, ::s12] + res = res[s00:s01:s02, ::s12] + + # Apply numpy-like flattening + if not unflatten and (should_flatten_streams or should_flatten_samples): + res = res.flatten() + + return res def getstreams(self, streams_key: int | slice | Sequence[int]) -> Signal: """Create a new signal like this, but with only the selected streams. @@ -1693,15 +1716,21 @@ def Create( ) -> DenseSignal: return DenseSignal(samples, sampling_rate, carrier_frequency, noise_power, delay, offsets) - def getitem(self, key: Any = slice(None, None)) -> np.ndarray: + def getitem(self, key: Any = slice(None, None), unflatten: bool = True) -> np.ndarray: """Reroutes the argument to the single block of this model. Refer the numpy.ndarray.__getitem__ documentation. The result is always a 2D ndarray.""" res = self._blocks[0].view(np.ndarray)[key] # de-flatten - if res.ndim == 1: - return res.reshape((1, res.size)) + if unflatten and res.ndim == 1: + streams_flattened, samples_flattened = self._parse_validate_itemkey(key)[-2:] + if streams_flattened and samples_flattened: + return res.reshape(()) + elif streams_flattened: + return res.reshape((1, res.size)) + elif samples_flattened: + return res.reshape((res.size, 1)) return res def __setitem__(self, key: Any, value: Any) -> None: @@ -1932,7 +1961,7 @@ def __from_dense(block: np.ndarray) -> List[SignalBlock]: def __setitem__(self, key: Any, value: Any) -> None: # parse and validate key - s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key) + s00, s01, s02, s10, s11, s12, isboolmask, _, _ = self._parse_validate_itemkey(key) if s02 <= 0 or s12 <= 0: raise NotImplementedError("Only positive steps are implemented") if s12 != 1: diff --git a/tests/unit_tests/core/test_signal.py b/tests/unit_tests/core/test_signal.py index 2a2fcc0b..8ea859b4 100644 --- a/tests/unit_tests/core/test_signal.py +++ b/tests/unit_tests/core/test_signal.py @@ -234,6 +234,12 @@ def test_setgetitem(self) -> None: for key in keys: assert_array_equal(self.samples_dense[key].flatten(), self.signal.getitem(key).flatten()) + # getitem with unflatten=False + assert_array_equal(self.signal.getitem(0, False).shape, + (self.signal.num_samples,)) + assert_array_equal(self.signal.getitem((slice_full, 0), False).shape, + (self.signal.num_streams,)) + # __setitem__ dummy_value = 13.37 + 73.31j dummy_samples_full = np.full((self.num_streams, self.num_samples),