Skip to content

Commit

Permalink
Signal getitem flattened
Browse files Browse the repository at this point in the history
  • Loading branch information
egor-achkasov authored and adlerjan committed Jul 22, 2024
1 parent 13c2ced commit 562a39b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
55 changes: 42 additions & 13 deletions hermespy/core/signal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def __parse_slice(s: slice, dim_size: int) -> Tuple[int, int, int]:
s1 = s1 if s1 >= 0 else s1 % dim_size
return s0, s1, s2

def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, int, bool]:
def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, int, bool, bool, bool]:
"""Parse and validate key in __getitem__ and __setitem__.
Raises:
Expand All @@ -785,11 +785,13 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
s11 (int): samples stop
s12 (int): samples step
isboolmask (bool): True if key is a boolean mask, False otherwise.
should_flatten_streams (bool): True if numpy's getitem would flatten the streams (1) dimension with this key
should_flatten_samples (bool): True if numpy's getitem would flatten the samples (2) dimension with this key
Note that if isboolmask is True, then all s?? take the following values:
(0, self.num_streams, 1, 0, self.num_samples, 1).
Note that if the key references any dimansion with an integer index,
Note that if the key references any dimension with an integer index,
then the corresponding result start will be the index, and stop is start+1.
For example, if key is 1, then only stream 1 is need. Then s00 is 1 and s01 is 2.
Numpy getitem of [1] and [1:2] differ in dimensions. Flattening of the second variant should be considered.
Expand All @@ -798,6 +800,8 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
self_num_streams = self.num_streams
self_num_samples = self.num_samples
isboolmask = False
should_flatten_streams = False
should_flatten_samples = False

# Key is a tuple of two
# ======================================================================
Expand All @@ -819,6 +823,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
s00 = key[0] % self_num_streams
s01 = s00 + 1
s02 = 1
should_flatten_streams = True
else:
raise TypeError(
f"Expected to get streams index as an integer or a slice, but got {type(key[0])}"
Expand All @@ -834,10 +839,11 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
s10 = key[1] % self_num_samples
s11 = s10 + 1
s12 = 1
should_flatten_samples = True
else:
raise TypeError(f"Samples key is of an unsupported type ({type(key[1])})")
raise TypeError(f"Samples key is ofan unsupported type ({type(key[1])})")

return s00, s01, s02, s10, s11, s12, False
return s00, s01, s02, s10, s11, s12, False, should_flatten_streams, should_flatten_samples
# ======================================================================
# done Key is a tuple of two

Expand All @@ -861,6 +867,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
s00 = key % self_num_streams
s01 = s00 + 1
s02 = 1
should_flatten_streams = True
# Key is a boolean mask or something unsupported
else:
try:
Expand All @@ -874,7 +881,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in
except ValueError:
raise TypeError(f"Unsupported key type {type(key)}")

return s00, s01, s02, s10, s11, s12, isboolmask
return s00, s01, s02, s10, s11, s12, isboolmask, should_flatten_streams, should_flatten_samples

def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]:
"""Find indices of blocks that are affected by the given samples slice.
Expand Down Expand Up @@ -917,7 +924,7 @@ def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]:

return b_start, b_stop

def getitem(self, key: Any = slice(None, None)) -> np.ndarray:
def getitem(self, key: Any = slice(None, None), unflatten : bool = True) -> np.ndarray:
"""Get specified samples.
Works like np.ndarray.__getitem__, but de-sparsifies the signal.
Expand All @@ -927,20 +934,27 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray:
a tuple (int, int), (int, slice), (slice, int), (slice, slice)
or a boolean mask.
Defaults to slice(None, None) (same as [:, :])
unflatten (bool):
Set to True to ensure the result ndim to be 2 even if only one stream is selected.
Set to False to allow the numpy-like degenerate dimensions reduction.
Examples:
getitem(slice(None, None)):
Select all samples from the signal.
Warning: can cause memory overflow if used with a sparse signal.
getitem(0):
Select and de-sparsify the first stream.
Result shape is (1, num_samples)
getitem(0, False):
Same, but allow the numpy flattening.
Result shape is (num_samples,)
getitem((slice(None, 2), slice(50, 100))):
Select streams 0, 1 and samples 50-99.
Same as samples_matrix[:2, 50:100]
Returns: np.ndarray with ndim 2 and dtype np.complex_"""
Returns: np.ndarray with ndim 2 or less and dtype dtype np.complex_"""

s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key)
s00, s01, s02, s10, s11, s12, isboolmask, should_flatten_streams, should_flatten_samples = self._parse_validate_itemkey(key)
num_streams = -((s01 - s00) // -s02)
num_samples = -((s11 - s10) // -s12)
if self.num_samples == 0 or self.num_streams == 0: # if this signal is empty
Expand Down Expand Up @@ -968,6 +982,9 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray:
# ^b previous^^gap^^b_start/b_stop^
elif s11 > b.offset:
res[:, b.offset - s10 :] = b[s00:s01:s02, :]
# Apply numpy-like flattening
if not unflatten and (should_flatten_streams or should_flatten_samples):
res = res.flatten()
return res

# assemble the result
Expand Down Expand Up @@ -1012,7 +1029,13 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray:
)
if is_streams_step_reversing:
return res[::s02, ::s12]
return res[s00:s01:s02, ::s12]
res = res[s00:s01:s02, ::s12]

# Apply numpy-like flattening
if not unflatten and (should_flatten_streams or should_flatten_samples):
res = res.flatten()

return res

def getstreams(self, streams_key: int | slice | Sequence[int]) -> Signal:
"""Create a new signal like this, but with only the selected streams.
Expand Down Expand Up @@ -1693,15 +1716,21 @@ def Create(
) -> DenseSignal:
return DenseSignal(samples, sampling_rate, carrier_frequency, noise_power, delay, offsets)

def getitem(self, key: Any = slice(None, None)) -> np.ndarray:
def getitem(self, key: Any = slice(None, None), unflatten: bool = True) -> np.ndarray:
"""Reroutes the argument to the single block of this model.
Refer the numpy.ndarray.__getitem__ documentation.
The result is always a 2D ndarray."""

res = self._blocks[0].view(np.ndarray)[key]
# de-flatten
if res.ndim == 1:
return res.reshape((1, res.size))
if unflatten and res.ndim == 1:
streams_flattened, samples_flattened = self._parse_validate_itemkey(key)[-2:]
if streams_flattened and samples_flattened:
return res.reshape(())
elif streams_flattened:
return res.reshape((1, res.size))
elif samples_flattened:
return res.reshape((res.size, 1))
return res

def __setitem__(self, key: Any, value: Any) -> None:
Expand Down Expand Up @@ -1932,7 +1961,7 @@ def __from_dense(block: np.ndarray) -> List[SignalBlock]:

def __setitem__(self, key: Any, value: Any) -> None:
# parse and validate key
s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key)
s00, s01, s02, s10, s11, s12, isboolmask, _, _ = self._parse_validate_itemkey(key)
if s02 <= 0 or s12 <= 0:
raise NotImplementedError("Only positive steps are implemented")
if s12 != 1:
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def test_setgetitem(self) -> None:
for key in keys:
assert_array_equal(self.samples_dense[key].flatten(), self.signal.getitem(key).flatten())

# getitem with unflatten=False
assert_array_equal(self.signal.getitem(0, False).shape,
(self.signal.num_samples,))
assert_array_equal(self.signal.getitem((slice_full, 0), False).shape,
(self.signal.num_streams,))

# __setitem__
dummy_value = 13.37 + 73.31j
dummy_samples_full = np.full((self.num_streams, self.num_samples),
Expand Down

0 comments on commit 562a39b

Please sign in to comment.