Skip to content

Commit

Permalink
documentation: replace "Jax" with "NumPy" and make capitalization uni…
Browse files Browse the repository at this point in the history
…form (#1206)
  • Loading branch information
enjoh authored Oct 12, 2024
1 parent f3fb8a5 commit d571ed6
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions gymnasium/wrappers/numpy_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ def torch_to_numpy(value: Any) -> Any:

@torch_to_numpy.register(numbers.Number)
def _number_to_numpy(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a numpy array."""
"""Convert a python number (int, float, complex) to a NumPy array."""
return np.array(value)


@torch_to_numpy.register(torch.Tensor)
def _torch_to_numpy(value: torch.Tensor) -> Any:
"""Convert a torch.Tensor to a numpy array."""
"""Convert a torch.Tensor to a NumPy array."""
return value.numpy(force=True)


@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})


@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
Expand All @@ -66,7 +66,7 @@ def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:

@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax Array into a PyTorch Tensor."""
"""Converts a NumPy Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
Expand All @@ -75,7 +75,7 @@ def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
@numpy_to_torch.register(numbers.Number)
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
"""Converts a Jax Array into a PyTorch Tensor."""
"""Converts a NumPy Array into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
Expand All @@ -87,15 +87,15 @@ def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Te
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})


@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
env: The NumPy-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
Expand Down

0 comments on commit d571ed6

Please sign in to comment.