diff --git a/.github/workflows/linux-test.yml b/.github/workflows/linux-test.yml
index 7ee262d..91ec66a 100644
--- a/.github/workflows/linux-test.yml
+++ b/.github/workflows/linux-test.yml
@@ -21,4 +21,4 @@ jobs:
pip install -e .
pip install torch
pip install pytest
- pytest -v -rw test/test_replay_buffer.py
+ # pytest -v -rw test/test_*.py
diff --git a/README.md b/README.md
index 18a79fd..5ed680c 100644
--- a/README.md
+++ b/README.md
@@ -240,57 +240,6 @@ The Neural Blocks module also has functions that can generate single modules, re
-### `from wingman.replay_buffer import FlatReplayBuffer as ReplayBuffer`
-
-This is a replay buffer designed around Torch's Dataloader class for reinforcement learning projects.
-This allows easy bootstrapping of the Dataloader's excellent shuffling and pre-batching capabilities.
-In addition, all the data is stored as a numpy array in a contiguous block of memory, allowing very fast retrieval.
-ReplayBuffer also doesn't put any limits on tuple length per transition; some people prefer to store $\{S, A, R, S'\}$, some prefer to store $\{S, A, R, S', A'\}$ - ReplayBuffer doesn't care!
-The length of the tuple can be as long or as short as you want, as long as every tuple fed in is the same length and each element of the tuple is the same shape.
-There is no need to predefine the shape of the inputs that you want to put in the ReplayBuffer, it automatically infers the shape and computes memory usage upon the first tuple stored.
-The basic usage of the ReplayBuffer is as follows:
-
-```python
-import torch
-
-from wingman.replay_buffer import FlatReplayBuffer as ReplayBuffer
-
-# we define the replay buffer to be able to store 1000 tuples of information
-memory = ReplayBuffer(mem_size=1000)
-
-# get the first observation from the environment
-next_obs = env.reset()
-
-# iterate until the environment is complete
-while env.done is False:
- # rollover the observation
- obs = next_obs
-
- # get an action from the policy
- act = policy(obs)
-
- # sample a new transition
- next_obs, rew, done, next_lbl = env.step(act)
-
- # store stuff in the replay buffer
- memory.push((obs, act, rew, next_obs, done))
-
-# perform training using the buffer
-dataloader = torch.utils.data.DataLoader(
- memory, batch_size=32, shuffle=True, drop_last=False
-)
-
-# easily treat the replay buffer as an iterable that we can iterate through
-for batch_num, stuff in enumerate(dataloader):
- observations = gpuize(stuff[0], "cuda:0")
- actions = gpuize(stuff[1], "cuda:0")
- rewards = gpuize(stuff[2], "cuda:0")
- next_states = gpuize(stuff[3], "cuda:0")
- dones = gpuize(stuff[4], "cuda:0")
-```
-
-
-
### `from wingman import gpuize, cpuize`
diff --git a/pyproject.toml b/pyproject.toml
index f454259..ddbc87e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "jj_wingman"
-version = "0.19.3"
+version = "0.20.0"
authors = [
{ name="Jet", email="taijunjet@hotmail.com" },
]
diff --git a/test/test_replay_buffer.py b/test/test_replay_buffer.py
deleted file mode 100644
index 1cb8b30..0000000
--- a/test/test_replay_buffer.py
+++ /dev/null
@@ -1,178 +0,0 @@
-"""Tests the replay buffer module."""
-
-from __future__ import annotations
-
-from copy import deepcopy
-from itertools import product
-from pprint import pformat
-from typing import Literal
-
-import pytest
-import torch
-from utils import (
- are_equivalent,
- create_memory,
- create_shapes,
- element_to_bulk_dim_swap,
- generate_random_dict_data,
- generate_random_flat_data,
-)
-
-# define the test configurations
-_random_rollovers = [True, False]
-_modes = ["numpy", "torch"]
-_devices = [torch.device("cpu")]
-if torch.cuda.is_available():
- _devices.append(torch.device("cuda:0"))
-_store_on_devices = [True, False]
-_use_dict = [False, True]
-ALL_CONFIGURATIONS = list(
- product(
- _random_rollovers,
- _modes,
- _devices,
- _store_on_devices,
- _use_dict,
- )
-)
-
-
-@pytest.mark.parametrize(
- "random_rollover, mode, device, store_on_device, use_dict",
- ALL_CONFIGURATIONS,
-)
-def test_bulk(
- random_rollover: bool,
- mode: Literal["numpy", "torch"],
- device: torch.device,
- store_on_device: bool,
- use_dict: bool,
-):
- """Tests repeatedly bulking the buffer and whether it rollovers correctly."""
- bulk_size = 7
- mem_size = 11
- shapes = create_shapes(use_dict=use_dict, bulk_size=bulk_size)
- memory = create_memory(
- mem_size=mem_size,
- mode=mode,
- device=device,
- store_on_device=store_on_device,
- random_rollover=random_rollover,
- use_dict=use_dict,
- )
-
- for iteration in range(10):
- # try to stuff:
- # a) (bulk_size, 3) array
- # b) (bulk_size,) array
- data = []
- for shape in shapes:
- if isinstance(shape, (list, tuple)):
- data.append(
- generate_random_flat_data(shape=(bulk_size, *shape), mode=mode)
- )
- elif isinstance(shape, dict):
- data.append(generate_random_dict_data(shapes=shape, mode=mode))
- else:
- raise ValueError
- memory.push(data, bulk=True)
-
- # reverse the data
- # on insertion we had [element_dim, bulk_dim, *data_shapes]
- # on comparison we want [bulk_dim, element_dim, *data_shapes]
- serialized_data = element_to_bulk_dim_swap(
- element_first_data=data,
- bulk_size=bulk_size,
- )
-
- # if random rollover and we're more than full, different matching technique
- if random_rollover and memory.is_full:
- num_matches = 0
- # match according to meshgrid
- for item1 in serialized_data:
- for item2 in memory:
- num_matches += int(are_equivalent(item1, item2))
-
- assert (
- num_matches == bulk_size
- ), f"""Expected {bulk_size} matches inside the memory, got {num_matches}."""
-
- continue
-
- for step in range(bulk_size):
- item1 = serialized_data[step]
- item2 = memory[(iteration * bulk_size + step) % mem_size]
- assert are_equivalent(
- item1, item2
- ), f"""Something went wrong with rollover at iteration {iteration},
- step {step}, expected \n{pformat(item1)}, got \n{pformat(item2)}."""
-
-
-@pytest.mark.parametrize(
- "random_rollover, mode, device, store_on_device, use_dict",
- ALL_CONFIGURATIONS,
-)
-def test_non_bulk(
- random_rollover: bool,
- mode: Literal["numpy", "torch"],
- device: torch.device,
- store_on_device: bool,
- use_dict: bool,
-):
- """Tests the replay buffer generically."""
- mem_size = 11
- shapes = create_shapes(use_dict=use_dict)
- memory = create_memory(
- mem_size=mem_size,
- mode=mode,
- device=device,
- store_on_device=store_on_device,
- random_rollover=random_rollover,
- use_dict=use_dict,
- )
-
- previous_data = []
- for iteration in range(20):
- current_data = []
- for shape in shapes:
- if isinstance(shape, (list, tuple)):
- current_data.append(generate_random_flat_data(shape=shape, mode=mode))
- elif isinstance(shape, dict):
- current_data.append(generate_random_dict_data(shapes=shape, mode=mode))
- else:
- raise ValueError
- memory.push(current_data)
-
- # if random rollover and we're more than full, different matching method
- if random_rollover and memory.is_full:
- num_current_matches = 0
- num_previous_matches = 0
- for item in memory:
- num_current_matches += int(are_equivalent(item, current_data))
- num_previous_matches += int(are_equivalent(item, previous_data))
-
- assert (
- num_current_matches == 1
- ), f"""Expected 1 match for current_data, got {num_current_matches}."""
- assert (
- num_previous_matches <= 1
- ), f"""Expected 1 or 0 match for previous_data, got {num_previous_matches}."""
-
- continue
-
- # check the current data
- output = memory.__getitem__(iteration % mem_size)
- assert are_equivalent(
- output, current_data
- ), f"""Something went wrong with rollover at iteration {iteration},
- expected \n{pformat(current_data)}, got \n{pformat(output)}."""
-
- # check the previous data
- if iteration > 0:
- output = memory[(iteration - 1) % mem_size]
- assert are_equivalent(
- output, previous_data
- ), f"""Something went wrong with rollover at iteration {iteration},
- expected \n{pformat(previous_data)}, got \n{pformat(output)}."""
-
- previous_data = deepcopy(current_data)
diff --git a/test/test_replay_buffer_utils.py b/test/test_replay_buffer_utils.py
deleted file mode 100644
index 9d17645..0000000
--- a/test/test_replay_buffer_utils.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""Tests the replay buffer utilities."""
-
-from __future__ import annotations
-
-import numpy as np
-import pytest
-from utils import are_equivalent
-
-from wingman.replay_buffer.utils import listed_dict_to_dicted_list
-
-
-@pytest.mark.parametrize(
- "stack",
- [True, False],
-)
-def test_dicted_list_to_listed_dict(stack: bool) -> None:
- """Tests the dicted_list_to_listed_dict function."""
- # the input
- listed_dict = [
- {"a": {"x": 1, "y": 2}, "b": [3, 4]},
- {"a": {"x": 5, "y": 6}, "b": [7, 8]},
- {"a": {"x": 9, "y": 10}, "b": [11, 12]},
- ]
-
- # the target output
- first = np.asarray([1, 5, 9])
- second = np.asarray([2, 6, 10])
- third = np.asarray(
- [
- [3, 4],
- [7, 8],
- [11, 12],
- ]
- )
- if stack:
- first = np.expand_dims(first, axis=0)
- second = np.expand_dims(second, axis=0)
- third = np.expand_dims(third, axis=0)
- target_dicted_list = {
- "a": {
- "x": first,
- "y": second,
- },
- "b": third,
- }
-
- # convert and check
- created_dicted_list = listed_dict_to_dicted_list(listed_dict, stack=stack)
-
- assert are_equivalent(
- target_dicted_list, created_dicted_list
- ), f"Expected {target_dicted_list=} to equal {created_dicted_list=}."
diff --git a/test/utils.py b/test/utils.py
deleted file mode 100644
index ebeb317..0000000
--- a/test/utils.py
+++ /dev/null
@@ -1,269 +0,0 @@
-"""Utilities used during testing."""
-
-from __future__ import annotations
-
-from typing import Any, Literal
-
-import numpy as np
-import torch
-
-from wingman.replay_buffer import FlatReplayBuffer
-from wingman.replay_buffer.core import ReplayBuffer
-from wingman.replay_buffer.wrappers.dict_wrapper import DictReplayBufferWrapper
-
-
-def create_shapes(
- use_dict: bool, bulk_size: int = 0
-) -> list[tuple[int, ...] | dict[str, Any]]:
- """create_shapes.
-
- Args:
- ----
- use_dict (bool): use_dict
- bulk_size (int): bulk_size
-
- Returns:
- -------
- list[tuple[int, ...] | dict[str, Any]]:
-
- """
- if bulk_size:
- bulk_shape = (bulk_size,)
- else:
- bulk_shape = ()
-
- if use_dict:
- return [
- (*bulk_shape, 3, 3),
- (
- *bulk_shape,
- 3,
- ),
- (*bulk_shape,),
- {
- "a": (*bulk_shape, 4, 3),
- "b": (*bulk_shape,),
- "c": {
- "d": (*bulk_shape, 11, 2),
- },
- },
- {
- "e": (*bulk_shape, 3, 2),
- },
- (*bulk_shape, 4),
- ]
- else:
- return [
- (*bulk_shape, 3, 3),
- (*bulk_shape, 3),
- (*bulk_shape,),
- ]
-
-
-def create_memory(
- mem_size: int,
- mode: Literal["numpy", "torch"],
- device: torch.device,
- store_on_device: bool,
- random_rollover: bool,
- use_dict: bool,
-) -> ReplayBuffer:
- """create_memory.
-
- Args:
- ----
- mem_size (int): mem_size
- mode (Literal["numpy", "torch"]): mode
- device (torch.device): device
- store_on_device (bool): store_on_device
- random_rollover (bool): random_rollover
- use_dict (bool): use_dict
-
- Returns:
- -------
- ReplayBuffer:
-
- """
- memory = FlatReplayBuffer(
- mem_size=mem_size,
- mode=mode,
- device=device,
- store_on_device=store_on_device,
- random_rollover=random_rollover,
- )
-
- if use_dict:
- memory = DictReplayBufferWrapper(
- replay_buffer=memory,
- )
-
- return memory
-
-
-def generate_random_flat_data(
- shape: tuple[int, ...], mode: Literal["numpy", "torch"]
-) -> np.ndarray | torch.Tensor:
- """Generates random data given a shapes specification.
-
- Args:
- ----
- shape (tuple[int, ...]): shape
- mode (Literal["numpy", "torch"]): mode
-
- Returns:
- -------
- np.ndarray | torch.Tensor:
-
- """
- if mode == "numpy":
- return np.asarray(np.random.randn(*shape))
- elif mode == "torch":
- if len(shape) == 0:
- return torch.randn(())
- else:
- return torch.randn(*shape)
- else:
- raise ValueError("Unknown mode.")
-
-
-def generate_random_dict_data(
- shapes: dict[str, Any], mode: Literal["numpy", "torch"]
-) -> dict[str, Any]:
- """Generates a random dictionary of data given a shapes specification.
-
- Args:
- ----
- shapes (dict[str, Any]): shapes
- mode (Literal["numpy", "torch"]): mode
-
- Returns:
- -------
- dict[str, Any]:
-
- """
- data = dict()
- for key, val in shapes.items():
- if isinstance(val, dict):
- data[key] = generate_random_dict_data(shapes=val, mode=mode)
- else:
- data[key] = generate_random_flat_data(shape=val, mode=mode)
-
- return data
-
-
-def _dict_element_to_bulk_dim_swap(
- data_dict: dict[str, Any],
- bulk_size: int,
-) -> list[dict[str, Any]]:
- """Given a nested dictionary where each leaf is an n-long array, returns an n-long sequence where each item is the same nested dictionary structure.
-
- Args:
- ----
- data_dict (dict[str, Any]): data_dict
- bulk_size (int): bulk_size
-
- Returns:
- -------
- list[dict[str, Any]]:
-
- """
- bulk_first_dicts: list[dict[str, Any]] = [dict() for _ in range(bulk_size)]
- for key, value in data_dict.items():
- if isinstance(value, dict):
- for i, element in enumerate(
- _dict_element_to_bulk_dim_swap(data_dict=value, bulk_size=bulk_size)
- ):
- bulk_first_dicts[i][key] = element
- else:
- for i, element in enumerate(value):
- bulk_first_dicts[i][key] = element
-
- return bulk_first_dicts
-
-
-def element_to_bulk_dim_swap(
- element_first_data: list[Any],
- bulk_size: int,
-) -> list[Any]:
- """Given a tuple of elements, each with `bulk_size` items, returns a `bulk_size` sequence, with each item being the size of the tuple.
-
- Args:
- ----
- element_first_data (list[Any]): element_first_data
- bulk_size (int): bulk_size
-
- Returns:
- -------
- list[Any]:
-
- """
- bulk_first_data = [[] for _ in range(bulk_size)]
- for element in element_first_data:
- # if not a dictionary, can do a plain axis extract
- if not isinstance(element, dict):
- for i in range(bulk_size):
- bulk_first_data[i].append(element[i])
-
- # if it's a dictionary, then we need to unpack the dictionary into each item
- else:
- for i, dict_element in enumerate(
- _dict_element_to_bulk_dim_swap(
- data_dict=element,
- bulk_size=bulk_size,
- )
- ):
- bulk_first_data[i].append(dict_element)
-
- return bulk_first_data
-
-
-def _cast(array: np.ndarray | torch.Tensor | float | int) -> np.ndarray:
- """_cast.
-
- Args:
- ----
- array (np.ndarray | torch.Tensor | float | int): array
-
- Returns:
- -------
- np.ndarray:
-
- """
- if isinstance(array, np.ndarray):
- return array
- elif isinstance(array, torch.Tensor):
- return array.cpu().numpy() # pyright: ignore[reportAttributeAccessIssue]
- else:
- return np.asarray(array)
-
-
-def are_equivalent(
- item1: Any,
- item2: Any,
-):
- """Check if two pieces of data are equivalent.
-
- Args:
- ----
- item1 (Any): item1
- item2 (Any): item2
-
- """
- # comparison for array-able types
- if isinstance(item1, (int, float, bool, torch.Tensor, np.ndarray)) or item1 is None:
- return np.isclose(_cast(item1), _cast(item2)).all()
-
- # comparison for lists and tuples
- if isinstance(item1, (list, tuple)) and isinstance(item2, (list, tuple)):
- return len(item1) == len(item2) and all(
- are_equivalent(d1, d2) for d1, d2 in zip(item1, item2)
- )
-
- # comparison for dictionaries
- if isinstance(item1, dict) and isinstance(item2, dict):
- return item1.keys() == item2.keys() and all(
- are_equivalent(item1[key], item2[key]) for key in item1.keys()
- )
-
- # non of the checks passed
- return False
diff --git a/wingman/__init__.py b/wingman/__init__.py
index 6212a69..495cbe0 100644
--- a/wingman/__init__.py
+++ b/wingman/__init__.py
@@ -12,7 +12,6 @@
warnings.warn(
"Could not import torch, "
- "this is not bundled as part of Wingman and has to be installed manually, "
- "as a result, `NeuralBlocks` and `ReplayBuffer` are unavailable.",
+ "this is not bundled as part of Wingman and has to be installed manually.",
category=RuntimeWarning,
)
diff --git a/wingman/exceptions.py b/wingman/exceptions.py
index 3857654..720f1ae 100644
--- a/wingman/exceptions.py
+++ b/wingman/exceptions.py
@@ -33,19 +33,3 @@ def __init__(self, message: str = ""):
message = cstr(message, "FAIL")
super().__init__(message)
self.message = message
-
-
-class ReplayBufferException(Exception):
- """ReplayBufferException."""
-
- def __init__(self, message: str = ""):
- """__init__.
-
- Args:
- ----
- message (str): the message
-
- """
- message = cstr(message, "FAIL")
- super().__init__(message)
- self.message = message
diff --git a/wingman/replay_buffer/__init__.py b/wingman/replay_buffer/__init__.py
deleted file mode 100644
index 8200318..0000000
--- a/wingman/replay_buffer/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-"""Replay Buffer module."""
-
-from .core import ReplayBuffer as ReplayBuffer
-from .flat_replay_buffer import FlatReplayBuffer as FlatReplayBuffer
diff --git a/wingman/replay_buffer/core.py b/wingman/replay_buffer/core.py
deleted file mode 100644
index 3308527..0000000
--- a/wingman/replay_buffer/core.py
+++ /dev/null
@@ -1,307 +0,0 @@
-"""Base replay buffer class."""
-
-from __future__ import annotations
-
-import warnings
-from abc import abstractmethod
-from typing import Any, Generator, Sequence
-
-from prefetch_generator import prefetch
-
-
-class ReplayBuffer:
- """Base replay buffer class."""
-
- def __init__(self, mem_size: int):
- """__init__.
-
- Args:
- ----
- mem_size (int): mem_size
-
- """
- self.mem_size = mem_size
- self.count = 0
- self.memory = []
-
- def __len__(self) -> int:
- """The number of memory items this replay buffer is holding.
-
- Returns
- -------
- int:
-
- """
- return min(self.mem_size, self.count)
-
- def __repr__(self) -> str:
- """Printouts parameters of this replay buffer.
-
- Returns
- -------
- str:
-
- """
- return f"""ReplayBuffer of size {self.mem_size} with {len(self.memory)} elements. \n
- A brief view of the memory: \n
- {self.memory}
- """
-
- @property
- def is_full(self) -> bool:
- """Whether or not the replay buffer has reached capacity.
-
- Returns
- -------
- bool: whether the buffer is full
-
- """
- return self.count >= self.mem_size
-
- def merge(self, other: ReplayBuffer) -> None:
- """Merges another replay buffer into this replay buffer via the `push` method.
-
- Args:
- ----
- other (ReplayBuffer): other
-
- Returns:
- -------
- None:
-
- """
- self.push(
- [m[: min(other.mem_size, other.count)] for m in other.memory],
- bulk=True,
- )
-
- @prefetch(max_prefetch=1)
- def iter_sample(
- self, batch_size: int, num_iter: int
- ) -> Generator[Sequence[Any], None, None]:
- """iter_sample.
-
- Args:
- ----
- batch_size (int): batch_size
- num_iter (int): num_iter
-
- Returns:
- -------
- Generator[Sequence[np.ndarray | torch.Tensor], None, None]:
-
- """
- for _ in range(num_iter):
- yield (self.sample(batch_size=batch_size))
-
- @abstractmethod
- def sample(self, batch_size: int) -> Sequence[Any]:
- """sample.
-
- Args:
- ----
- batch_size (int): batch_size
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def push(
- self,
- data: Sequence[Any],
- bulk: bool = False,
- ) -> None:
- """push.
-
- Args:
- ----
- data (Sequence[Any]): data
- bulk (bool): bulk
-
- Returns:
- -------
- None:
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def __getitem__(self, idx: int) -> Sequence[Any]:
- """__getitem__.
-
- Args:
- ----
- idx (int): idx
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- raise NotImplementedError
-
-
-class ReplayBufferWrapper(ReplayBuffer):
- """ReplayBufferWrapper."""
-
- def __init__(self, base_buffer: ReplayBuffer):
- """__init__.
-
- Args:
- ----
- base_buffer (ReplayBuffer): base_buffer
-
- """
- self.base_buffer = base_buffer
- self.mem_size = base_buffer.mem_size
-
- @property
- def count(self) -> int:
- """The number of transitions that's been through this buffer."""
- return self.base_buffer.count
-
- @property
- def memory(self) -> list[Any]:
- """The core memory of this buffer."""
- warnings.warn(
- "Accessing the core of `ReplayBufferWrapper` returns the "
- "memory of the base buffer, not the wrapped buffer",
- category=RuntimeWarning,
- )
- return self.base_buffer.memory
-
- def merge(self, other: ReplayBuffer) -> None:
- """Merges another replay buffer into this replay buffer via the `push` method.
-
- Args:
- ----
- other (ReplayBuffer): other
-
- Returns:
- -------
- None:
-
- """
- # can't merge when the other one has 0 items
- assert other.count > 0
-
- # if no count, we need to manually push one item first to build the index
- if self.count == 0:
- self.push(other[0])
-
- if other.mem_size == 1:
- return
-
- self.base_buffer.push(
- [m[1 : min(other.mem_size, other.count)] for m in other.memory],
- bulk=True,
- )
- else:
- self.base_buffer.push(
- [m[: min(other.mem_size, other.count)] for m in other.memory],
- bulk=True,
- )
-
-
- def __len__(self) -> int:
- """The number of memory items this replay buffer is holding."""
- return len(self.base_buffer)
-
- def __repr__(self) -> str:
- """Printouts parameters of this replay buffer."""
- return f"""ReplayBuffer of size {self.mem_size} with {len(self.memory)} elements. \n
- A brief view of the memory: \n
- {self.base_buffer}
- """
-
- def __getitem__(self, idx: int) -> Sequence[Any]:
- """__getitem__.
-
- Args:
- ----
- idx (int): idx
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- return self.wrap_data(self.base_buffer[idx])
-
- def push(
- self,
- data: Sequence[Any],
- bulk: bool = False,
- ) -> None:
- """push.
-
- Args:
- ----
- data (Sequence[Any]): data
- bulk (bool): bulk
-
- Returns:
- -------
- None:
-
- """
- self.base_buffer.push(
- data=self.unwrap_data(
- wrapped_data=data,
- bulk=bulk,
- ),
- bulk=bulk,
- )
-
- def sample(self, batch_size: int) -> Sequence[Any]:
- """sample.
-
- Args:
- ----
- batch_size (int): batch_size
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- return self.wrap_data(self.base_buffer.sample(batch_size=batch_size))
-
- @abstractmethod
- def unwrap_data(self, wrapped_data: Sequence[Any], bulk: bool) -> Sequence[Any]:
- """Unwraps data from the underlying data into an unwrapped format.
-
- This is called when packing the data into the `base_buffer`.
-
- Args:
- ----
- wrapped_data (Sequence[Any]): wrapped_data
- bulk (bool): bulk
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def wrap_data(self, unwrapped_data: Sequence[Any]) -> Sequence[Any]:
- """Wraps data from the underlying data into a wrapped format.
-
- This is called when sampling data from `base_buffer`.
-
- Args:
- ----
- unwrapped_data (Sequence[Any]): unwrapped_data
-
- Returns:
- -------
- Sequence[Any]:
-
- """
- raise NotImplementedError
diff --git a/wingman/replay_buffer/flat_replay_buffer.py b/wingman/replay_buffer/flat_replay_buffer.py
deleted file mode 100644
index e113ecd..0000000
--- a/wingman/replay_buffer/flat_replay_buffer.py
+++ /dev/null
@@ -1,267 +0,0 @@
-"""Replay buffer implementation with push, automatic overflow, and automatic torch dataset functionality."""
-
-from __future__ import annotations
-
-from enum import Enum
-from typing import Literal, Sequence
-
-import numpy as np
-
-from wingman.exceptions import ReplayBufferException
-from wingman.replay_buffer.core import ReplayBuffer
-
-try:
- import torch
-except ImportError as e:
- raise ImportError(
- "Could not import torch, this is not bundled as part of Wingman and has to be installed manually."
- ) from e
-
-from wingman.print_utils import cstr, wm_print
-
-
-class _Mode(Enum):
- """_Mode."""
-
- TORCH = 1
- NUMPY = 2
-
-
-class FlatReplayBuffer(ReplayBuffer):
- """Replay Buffer implementation of a Torch or Numpy dataset."""
-
- def __init__(
- self,
- mem_size: int,
- mode: Literal["numpy", "torch"] = "numpy",
- device: torch.device = torch.device("cpu"),
- store_on_device: bool = False,
- random_rollover: bool = False,
- ):
- """__init__.
-
- Args:
- ----
- mem_size (int): number of transitions the replay buffer aims to hold
- mode (Literal["numpy", "torch"]): Whether to store data as "torch" or "numpy".
- device (torch.device): The target device that data will be retrieved to if "torch".
- store_on_device (bool): Whether to store the entire replay on the specified device, otherwise stored on CPU.
- random_rollover (bool): whether to rollover the data in the replay buffer once full or to randomly insert
-
- """
- super().__init__(mem_size=mem_size)
-
- # store the device
- self.device = device
- self.storage_device = self.device if store_on_device else torch.device("cpu")
-
- # random rollover
- self.random_rollover = random_rollover
-
- # store the mode
- if mode == "numpy":
- self.mode = _Mode.NUMPY
- self.mode_type = np.ndarray
- self.mode_caller = np
- self.mode_dtype = np.float32
- elif mode == "torch":
- self.mode = _Mode.TORCH
- self.mode_type = torch.Tensor
- self.mode_caller = torch
- self.mode_dtype = torch.float32
- else:
- raise ReplayBufferException(
- f"Unknown mode {mode}. Only `'numpy'` and `'torch'` are allowed."
- )
-
- def __getitem__(self, idx: int) -> list[np.ndarray | torch.Tensor]:
- """__getitem__.
-
- Args:
- ----
- idx (int): idx
-
- Returns:
- -------
- list[np.ndarray | torch.Tensor]:
-
- """
- return list(d[idx] for d in self.memory)
-
- def _format_data(
- self, thing: np.ndarray | torch.Tensor | float | int | bool, bulk: bool
- ) -> np.ndarray | torch.Tensor:
- """_format_data.
-
- Args:
- ----
- thing (np.ndarray | torch.Tensor | float | int | bool): thing
- bulk (bool): bulk
-
- Returns:
- -------
- np.ndarray | torch.Tensor:
-
- """
- if self.mode == _Mode.NUMPY:
- # cast to the right dtype
- data = np.asarray(
- thing,
- dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
- )
-
- # dim check
- if bulk and len(data.shape) < 1:
- data = np.expand_dims(data, axis=-1)
- elif self.mode == _Mode.TORCH:
- # cast to the right dtype, store on CPU intentionally
- data = torch.asarray(
- thing,
- device=self.storage_device,
- dtype=self.mode_dtype, # pyright: ignore[reportArgumentType]
- )
- data.requires_grad_(False)
-
- # dim check
- if bulk and len(data.shape) < 1:
- data = data.unsqueeze(-1)
- else:
- raise ReplayBufferException(
- f"Unknown mode {self.mode}. Only `'numpy'` and `'torch'` are allowed."
- )
-
- return data
-
- def push(
- self,
- data: Sequence[torch.Tensor | np.ndarray | float | int | bool],
- bulk: bool = False,
- ) -> None:
- """Adds transition tuples into the replay buffer.
-
- The data must be either:
- - an n-long tuple of a single transition
- - an n-long tuple of m transitions, ie: a list of [m, ...] np arrays with the `bulk` flag set to True
-
- Args:
- ----
- data (Sequence[torch.Tensor | np.ndarray | float | int | bool]): data
- bulk (bool): whether to bulk add stuff into the replay buffer
-
- """
- # cast to dtype and conditionally add batch dim
- array_data = [self._format_data(item, bulk=True) for item in data]
-
- # if nothing in the array, we can safely return
- # this can occur for example, when we do `memory.push([np.array([])])`
- if all([len(item) == 0 for item in array_data]):
- return
-
- if not bulk:
- bulk_size = 1
- else:
- bulk_size = data[0].shape[0] # pyright: ignore
- # assert all items have same length
- if not all([len(item) == bulk_size for item in array_data]): # pyright: ignore[reportArgumentType]
- raise ReplayBufferException(
- "All things in data must have same len for the first dimension for bulk data. "
- f"Received data with {[len(item) for item in array_data]} items respectively.",
- )
-
- # assert on memory lengths
- if self.mem_size < bulk_size:
- raise ReplayBufferException(
- f"Bulk size ({bulk_size}) should be less than or equal to memory size ({self.mem_size}).",
- )
-
- # instantiate the memory if it does not exist
- if self.count == 0:
- self.memory = []
- if not bulk:
- self.memory.extend(
- [
- self.mode_caller.zeros(
- (self.mem_size, *item.shape),
- dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
- )
- for item in array_data
- ]
- )
- else:
- self.memory.extend(
- [
- self.mode_caller.zeros(
- (self.mem_size, *item.shape[1:]),
- dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
- )
- for item in array_data
- ]
- )
-
- # move everything to the storage device if torch
- if self.mode == _Mode.TORCH:
- self.memory = [array.to(self.storage_device) for array in self.memory]
-
- mem_size_bytes = sum([d.nbytes for d in self.memory])
- wm_print(
- cstr(f"Replay Buffer Size: {mem_size_bytes / 1e9} gigabytes.", "OKCYAN")
- )
-
- # assert that the number of lists in memory is same as data to push
- if len(array_data) != len(self.memory):
- raise ReplayBufferException(
- f"Data length not similar to memory buffer length. Replay buffer has {len(self.memory)} items. "
- f"But received {len(array_data)} items.",
- )
-
- # indexing for memory positions
- start = self.count % self.mem_size
- stop = min(start + bulk_size, self.mem_size)
- rollover = -(self.mem_size - start - bulk_size)
- if self.random_rollover:
- if not self.is_full:
- idx_front = np.arange(start, stop)
- idx_back = np.random.choice(
- start,
- size=np.maximum(rollover, 0),
- replace=False,
- )
- else:
- idx_front = np.random.choice(
- self.mem_size,
- size=bulk_size,
- replace=False,
- )
- idx_back = np.array([], dtype=np.int64)
- else:
- idx_front = np.arange(start, stop)
- idx_back = np.arange(0, rollover)
- idx = np.concatenate((idx_front, idx_back), axis=0)
-
- # put things in memory
- for memory, item in zip(self.memory, array_data):
- memory[idx] = item
-
- self.count += bulk_size
-
- def sample(self, batch_size: int) -> list[np.ndarray | torch.Tensor]:
- """sample.
-
- Args:
- ----
- batch_size (int): batch_size
-
- Returns:
- -------
- list[np.ndarray | torch.Tensor]:
-
- """
- idx = np.random.randint(
- 0,
- len(self),
- size=np.minimum(len(self), batch_size),
- )
- if self.mode == _Mode.TORCH:
- return [item[idx].to(self.device) for item in self.memory]
- else:
- return [item[idx] for item in self.memory]
diff --git a/wingman/replay_buffer/utils/__init__.py b/wingman/replay_buffer/utils/__init__.py
deleted file mode 100644
index dd4b378..0000000
--- a/wingman/replay_buffer/utils/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""Utilities for helping with replay buffers."""
-
-from .listed_dict_to_dicted_list import (
- listed_dict_to_dicted_list as listed_dict_to_dicted_list,
-)
diff --git a/wingman/replay_buffer/utils/listed_dict_to_dicted_list.py b/wingman/replay_buffer/utils/listed_dict_to_dicted_list.py
deleted file mode 100644
index e536515..0000000
--- a/wingman/replay_buffer/utils/listed_dict_to_dicted_list.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Helper for converting a list of nested dictionaries to a nested dictionary of lists."""
-
-from __future__ import annotations
-
-import functools
-from typing import Any, Generator
-
-import numpy as np
-import torch
-
-
-def _iter_nested_keys(base_dict: dict[str, Any]) -> Generator[list[str], None, None]:
- """Given a nested dictionary, yields a list of keys to each element.
-
- Args:
- ----
- base_dict (dict[str, Any]): base_dict
-
- Returns:
- -------
- Generator[list[str], None, None]:
-
- """
- for key, value in base_dict.items():
- if isinstance(value, dict):
- for sub_keys in _iter_nested_keys(value):
- yield [key, *sub_keys]
- else:
- yield [key]
-
-
-def listed_dict_to_dicted_list(
- list_dict: list[dict[str, Any]], stack: bool
-) -> dict[str, Any]:
- """Given a list of nested dicts, returns a nested dict of lists.
-
- Args:
- ----
- list_dict (list[dict[str, Any]]): list_dict
- stack (bool): stack
-
- Returns:
- -------
- dict[str, Any]:
-
- """
- result = {}
- for key_list in _iter_nested_keys(list_dict[0]):
- # for each element in the expected dictionary
- # scaffold to that point
- ptr = result
- for key in key_list[:-1]:
- ptr = ptr.setdefault(key, {})
-
- # this goes through the main list of dicts
- # and collects elements at the position determined by the key_list
- dicted_list: list = [
- functools.reduce(lambda d, k: d[k], key_list, dict_item)
- for dict_item in list_dict
- ]
-
- # we can't call `concatenate` on a list of non-np.ndarray or non-torch.Tensor items
- if isinstance(dicted_list[0], np.ndarray) and len(dicted_list[0].shape) > 0:
- if stack:
- ptr[key_list[-1]] = np.stack(dicted_list, axis=0)
- else:
- ptr[key_list[-1]] = np.concatenate(dicted_list, axis=0)
- elif isinstance(dicted_list[0], torch.Tensor):
- if stack:
- ptr[key_list[-1]] = torch.stack(dicted_list, dim=0)
- else:
- ptr[key_list[-1]] = torch.concatenate(dicted_list, dim=0)
- else:
- if stack:
- ptr[key_list[-1]] = np.expand_dims(
- np.asarray(dicted_list),
- axis=0,
- )
- else:
- ptr[key_list[-1]] = np.asarray(dicted_list)
- return result
diff --git a/wingman/replay_buffer/wrappers/__init__.py b/wingman/replay_buffer/wrappers/__init__.py
deleted file mode 100644
index f5dec4e..0000000
--- a/wingman/replay_buffer/wrappers/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Replay buffer wrappers."""
-
-from .dict_wrapper import DictReplayBufferWrapper as DictReplayBufferWrapper
diff --git a/wingman/replay_buffer/wrappers/dict_wrapper.py b/wingman/replay_buffer/wrappers/dict_wrapper.py
deleted file mode 100644
index 26998fe..0000000
--- a/wingman/replay_buffer/wrappers/dict_wrapper.py
+++ /dev/null
@@ -1,324 +0,0 @@
-"""Wrapper to convert a FlatReplayBuffer into one that accepts nested dicts."""
-
-from __future__ import annotations
-
-from typing import Any, Mapping, Sequence, Union
-
-import numpy as np
-
-from wingman.exceptions import ReplayBufferException
-from wingman.replay_buffer.core import ReplayBufferWrapper
-from wingman.replay_buffer.flat_replay_buffer import FlatReplayBuffer
-
-try:
- import torch
-except ImportError as e:
- raise ImportError(
- "Could not import torch, this is not bundled as part of Wingman and has to be installed manually."
- ) from e
-
-_NestedDict = Mapping[str, Union[int, "_NestedDict"]]
-
-
-class DictReplayBufferWrapper(ReplayBufferWrapper):
- """Replay Buffer Wrapper that allows the underlying replay buffer to take in nested dicts."""
-
- def __init__(self, replay_buffer: FlatReplayBuffer) -> None:
- """__init__.
-
- If bulk adding items, this expects dictionary items to be a dictionary of lists, NOT a list of dictionaries.
-
- Args:
- ----
- self:
- replay_buffer (FlatReplayBuffer): replay_buffer
-
- Returns:
- -------
- None:
-
- """
- super().__init__(replay_buffer)
-
- # this is a list where:
- # - if the element is an integer, specifies the location in the unwrapped flat data that the wrapped element in this index should be
- # - if it's a dictionary, specifies that the data at this index in the wrapped data should be a dictionary, where
- # - if the value is an int, is the location of this data in the flat dict
- # - if the value is a dict, means that this data holds a nested dict
- self.mapping: list[int | _NestedDict] = []
- self.total_elements = 0
-
- @staticmethod
- def _recursive_unpack_dict_mapping(
- data_dict: dict[str, np.ndarray | torch.Tensor],
- start_idx: int,
- ) -> tuple[_NestedDict, int]:
- """Recursively unpacks a dictionary into a mapping.
-
- Args:
- ----
- data_dict (dict[str, np.ndarray | torch.Tensor]): data_dict
- start_idx (int): start_idx
-
- Returns:
- -------
- tuple[_NestedDict, int]:
-
- """
- mapping: _NestedDict = dict()
- idx = start_idx
-
- for key, value in data_dict.items():
- if isinstance(value, dict):
- mapping[key], idx = (
- DictReplayBufferWrapper._recursive_unpack_dict_mapping(
- value, start_idx=idx
- )
- )
- else:
- mapping[key] = idx
- idx += 1
-
- return mapping, idx
-
- @staticmethod
- def _generate_mapping(
- wrapped_data: Sequence[
- dict[str, np.ndarray | torch.Tensor]
- | np.ndarray
- | torch.Tensor
- | float
- | int
- | bool
- ],
- ) -> tuple[list[int | _NestedDict], int]:
- """Generates a mapping from wrapped data.
-
- For example:
- data = [
- 32,
- 65,
- {
- "a": 5,
- "b": {
- "c": 6,
- "d": 7,
- "e": {
- "f": 8
- }
- },
- "g": {
- "h": 9
- }
- },
- 100,
- ]
-
- becomes:
-
- [
- 0,
- 1,
- {
- 'a': 2,
- 'b': {
- 'c': 3,
- 'd': 4,
- 'e': {'f': 5}
- },
- 'g': {'h': 6}
- },
- 7
- ]
-
- Args:
- ----
- wrapped_data (Sequence[dict[str, np.ndarray | torch.Tensor] | np.ndarray | torch.Tensor | float | int | bool]): wrapped_data
-
- Returns:
- -------
- tuple[list[int | _NestedDict], int]:
-
- """
- mapping: list[int | _NestedDict] = []
- idx = 0
-
- for item in wrapped_data:
- if isinstance(item, dict):
- dict_mapping, idx = (
- DictReplayBufferWrapper._recursive_unpack_dict_mapping(
- item, start_idx=idx
- )
- )
- mapping.append(dict_mapping)
- else:
- mapping.append(idx)
- idx += 1
-
- return mapping, idx
-
- @staticmethod
- def _recursive_unpack_dict_data(
- data_dict: dict[str, np.ndarray | torch.Tensor],
- unwrapped_data_target: list[Any],
- mapping: _NestedDict,
- ) -> list[Any]:
- """Recursively unpacks dictionary data into a sequence of items that FlatReplayBuffer can use.
-
- Args:
- ----
- data_dict (dict[str, np.ndarray | torch.Tensor]): data_dict
- unwrapped_data_target (list[Any]): unwrapped_data_target
- mapping (_NestedDict): mapping
-
- Returns:
- -------
- list[Any]:
-
- """
- for key, value in data_dict.items():
- if isinstance((idx_map := mapping[key]), int):
- unwrapped_data_target[idx_map] = value
-
- elif isinstance((idx_map := mapping[key]), dict):
- if not isinstance(value, dict):
- raise ReplayBufferException(
- "Something went wrong with data unwrapping.\n"
- f"Expected a dictionary for key {key} within the data, but got {type(value)}."
- )
-
- unwrapped_data_target = (
- DictReplayBufferWrapper._recursive_unpack_dict_data(
- data_dict=value,
- unwrapped_data_target=unwrapped_data_target,
- mapping=idx_map,
- )
- )
- else:
- raise ValueError("Not supposed to be here")
-
- return unwrapped_data_target
-
- def unwrap_data(
- self,
- wrapped_data: Sequence[
- dict[str, Any] | np.ndarray | torch.Tensor | float | int | bool
- ],
- bulk: bool,
- ) -> Sequence[np.ndarray | torch.Tensor | float | int | bool]:
- """Unwraps dictionary data into a sequence of items that FlatReplayBuffer can use.
-
- If bulk adding items, this expects dictionary items to be a dictionary of lists, NOT a list of dictionaries.
-
- Args:
- ----
- wrapped_data (Sequence[dict[str, Any] | np.ndarray | torch.Tensor | float | int | bool]): wrapped_data
- bulk (bool): bulk
-
- Returns:
- -------
- Sequence[np.ndarray | torch.Tensor | float | int | bool]:
-
- """
- if len(self) == 0:
- self.mapping, self.total_elements = self._generate_mapping(
- wrapped_data=wrapped_data
- )
-
- if len(self.mapping) != len(wrapped_data):
- raise ReplayBufferException(
- "Something went wrong with data unwrapping.\n"
- f"Expected `wrapped_data` to have {len(self.mapping)} items, but got {len(wrapped_data)}."
- )
-
- # holder for the unwrapped data
- unwrapped_data: list[Any] = [None] * self.total_elements
-
- for i, (mapping, data_item) in enumerate(zip(self.mapping, wrapped_data)):
- # if the mapping says it's an int, then just set the data without nesting
- if isinstance(mapping, int):
- unwrapped_data[mapping] = data_item
-
- # if it's a dict, then we need to recursively unpack
- elif isinstance(mapping, dict):
- if not isinstance(data_item, dict):
- raise ReplayBufferException(
- "Something went wrong with data unwrapping.\n"
- f"Expected `wrapped_data` at element {i} to be a dict, but got {type(data_item)}."
- )
-
- unwrapped_data = self._recursive_unpack_dict_data(
- data_dict=data_item,
- unwrapped_data_target=unwrapped_data,
- mapping=mapping,
- )
-
- return unwrapped_data
-
- @staticmethod
- def _recursive_pack_dict_data(
- unwrapped_data: Sequence[Any],
- mapping: _NestedDict,
- ) -> dict[str, Any]:
- """Packs back a sequence of items into a dictionary structure according to a mapping.
-
- Args:
- ----
- unwrapped_data (Sequence[Any]): unwrapped_data
- mapping (_NestedDict): mapping
-
- Returns:
- -------
- dict[str, Any]:
-
- """
- data_dict = dict()
- for key, idx_map in mapping.items():
- if isinstance(idx_map, int):
- data_dict[key] = unwrapped_data[idx_map]
- elif isinstance(idx_map, dict):
- data_dict[key] = DictReplayBufferWrapper._recursive_pack_dict_data(
- unwrapped_data=unwrapped_data,
- mapping=idx_map,
- )
- else:
- raise ValueError("Not supposed to be here")
-
- return data_dict
-
- def wrap_data(
- self, unwrapped_data: Sequence[np.ndarray | torch.Tensor | float | int | bool]
- ) -> Sequence[
- dict[str, np.ndarray | torch.Tensor]
- | np.ndarray
- | torch.Tensor
- | float
- | int
- | bool
- ]:
- """Converts a sequence of items into a dictionary structure that is similar to that used during `push`.
-
- Args:
- ----
- self:
- unwrapped_data (Sequence[np.ndarray | torch.Tensor | float | int | bool]): unwrapped_data
-
- Returns:
- -------
- Sequence[dict[str, np.ndarray | torch.Tensor] | np.ndarray | torch.Tensor | float | int | bool]:
-
- """
- wrapped_data: list[Any] = [None] * len(self.mapping)
-
- for i, idx_map in enumerate(self.mapping):
- if isinstance(idx_map, int):
- wrapped_data[i] = unwrapped_data[idx_map]
- elif isinstance(idx_map, dict):
- wrapped_data[i] = DictReplayBufferWrapper._recursive_pack_dict_data(
- unwrapped_data=unwrapped_data,
- mapping=idx_map,
- )
- else:
- raise ValueError("Not supposed to be here")
-
- return wrapped_data