Skip to content

Commit 583f6a7

Browse files
authored
Refer to numpy instead of jax [for vars and docs] in vector NumpyToTorch (#1319)
1 parent 9ff8bf4 commit 583f6a7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Diff for: gymnasium/wrappers/vector/numpy_to_torch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, env: VectorEnv, device: Device | None = None):
4242
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
4343
4444
Args:
45-
env: The Jax-based vector environment to wrap
45+
env: The NumPy-based vector environment to wrap
4646
device: The device the torch Tensors should be moved to
4747
"""
4848
super().__init__(env)
@@ -60,8 +60,8 @@ def step(
6060
Returns:
6161
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
6262
"""
63-
jax_action = torch_to_numpy(actions)
64-
obs, reward, terminated, truncated, info = self.env.step(jax_action)
63+
numpy_action = torch_to_numpy(actions)
64+
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
6565

6666
return (
6767
numpy_to_torch(obs, self.device),
@@ -81,7 +81,7 @@ def reset(
8181
8282
Args:
8383
seed: The seed for resetting the environment
84-
options: The options for resetting the environment, these are converted to jax arrays.
84+
options: The options for resetting the environment, these are converted to NumPy arrays.
8585
8686
Returns:
8787
PyTorch-based observations and info

0 commit comments

Comments
 (0)