Skip to content

Commit a6672ba

Browse files
Add __get_attr__ for experimental wrappers for generic solution to optimise extra module imports (#392)
1 parent 24a5518 commit a6672ba

13 files changed

+389
-378
lines changed

gymnasium/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules."""
22
# isort: skip_file
3+
# pyright: reportUnsupportedDunderAll=false
34

45
from gymnasium.core import (
56
Env,
@@ -17,7 +18,9 @@
1718
pprint_registry,
1819
make_vec,
1920
)
20-
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
21+
22+
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
23+
from gymnasium import envs
2124

2225

2326
__all__ = [
@@ -37,6 +40,7 @@
3740
"pprint_registry",
3841
# module folders
3942
"envs",
43+
"experimental",
4044
"spaces",
4145
"utils",
4246
"vector",

gymnasium/experimental/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Root __init__ of the gym experimental wrappers."""
22

33

4-
from gymnasium.experimental import functional
4+
from gymnasium.experimental import functional, wrappers
55
from gymnasium.experimental.functional import FuncEnv
66
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
77
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
@@ -17,4 +17,6 @@
1717
"VectorWrapper",
1818
"SyncVectorEnv",
1919
"AsyncVectorEnv",
20+
# wrappers
21+
"wrappers",
2022
]

gymnasium/experimental/functional_jax_env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import gymnasium as gym
1212
from gymnasium.envs.registration import EnvSpec
1313
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
14-
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
14+
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
1515
from gymnasium.utils import seeding
1616
from gymnasium.vector.utils import batch_space
1717

+80-49
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,11 @@
1-
"""Experimental Wrappers."""
2-
# isort: skip_file
1+
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
2+
# pyright: reportUnsupportedDunderAll=false
33

4-
from gymnasium.experimental.wrappers.lambda_action import (
5-
LambdaActionV0,
6-
ClipActionV0,
7-
RescaleActionV0,
8-
)
9-
from gymnasium.experimental.wrappers.lambda_observations import (
10-
LambdaObservationV0,
11-
FilterObservationV0,
12-
FlattenObservationV0,
13-
GrayscaleObservationV0,
14-
ResizeObservationV0,
15-
ReshapeObservationV0,
16-
RescaleObservationV0,
17-
DtypeObservationV0,
18-
PixelObservationV0,
19-
NormalizeObservationV0,
20-
)
21-
from gymnasium.experimental.wrappers.lambda_reward import (
22-
ClipRewardV0,
23-
LambdaRewardV0,
24-
NormalizeRewardV0,
25-
)
26-
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
27-
from gymnasium.experimental.wrappers.stateful_observation import (
28-
TimeAwareObservationV0,
29-
DelayObservationV0,
30-
FrameStackObservationV0,
31-
)
32-
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
33-
from gymnasium.experimental.wrappers.common import (
34-
PassiveEnvCheckerV0,
35-
OrderEnforcingV0,
36-
AutoresetV0,
37-
RecordEpisodeStatisticsV0,
38-
)
39-
from gymnasium.experimental.wrappers.rendering import (
40-
RenderCollectionV0,
41-
RecordVideoV0,
42-
HumanRenderingV0,
43-
)
4+
import importlib
445

45-
from gymnasium.experimental.wrappers.vector import (
46-
VectorRecordEpisodeStatistics,
47-
VectorListInfo,
48-
)
496

507
__all__ = [
8+
"vector",
519
# --- Observation wrappers ---
5210
"LambdaObservationV0",
5311
"FilterObservationV0",
@@ -82,7 +40,80 @@
8240
"RenderCollectionV0",
8341
"RecordVideoV0",
8442
"HumanRenderingV0",
85-
# --- Vector ---
86-
"VectorRecordEpisodeStatistics",
87-
"VectorListInfo",
43+
# --- Conversion ---
44+
"JaxToNumpyV0",
45+
"JaxToTorchV0",
46+
"NumpyToTorchV0",
8847
]
48+
49+
50+
_wrapper_to_class = {
51+
# lambda_action.py
52+
"LambdaActionV0": "lambda_action",
53+
"ClipActionV0": "lambda_action",
54+
"RescaleActionV0": "lambda_action",
55+
# lambda_observations.py
56+
"LambdaObservationV0": "lambda_observations",
57+
"FilterObservationV0": "lambda_observations",
58+
"FlattenObservationV0": "lambda_observations",
59+
"GrayscaleObservationV0": "lambda_observations",
60+
"ResizeObservationV0": "lambda_observations",
61+
"ReshapeObservationV0": "lambda_observations",
62+
"RescaleObservationV0": "lambda_observations",
63+
"DtypeObservationV0": "lambda_observations",
64+
"PixelObservationV0": "lambda_observations",
65+
"NormalizeObservationV0": "lambda_observations",
66+
# lambda_reward.py
67+
"ClipRewardV0": "lambda_reward",
68+
"LambdaRewardV0": "lambda_reward",
69+
"NormalizeRewardV0": "lambda_reward",
70+
# stateful_action
71+
"StickyActionV0": "stateful_action",
72+
# stateful_observation
73+
"TimeAwareObservationV0": "stateful_observation",
74+
"DelayObservationV0": "stateful_observation",
75+
"FrameStackObservationV0": "stateful_observation",
76+
# atari_preprocessing
77+
"AtariPreprocessingV0": "atari_preprocessing",
78+
# common
79+
"PassiveEnvCheckerV0": "common",
80+
"OrderEnforcingV0": "common",
81+
"AutoresetV0": "common",
82+
"RecordEpisodeStatisticsV0": "common",
83+
# rendering
84+
"RenderCollectionV0": "rendering",
85+
"RecordVideoV0": "rendering",
86+
"HumanRenderingV0": "rendering",
87+
# jax_to_numpy
88+
"JaxToNumpyV0": "jax_to_numpy",
89+
# "jax_to_numpy": "jax_to_numpy",
90+
# "numpy_to_jax": "jax_to_numpy",
91+
# jax_to_torch
92+
"JaxToTorchV0": "jax_to_torch",
93+
# "jax_to_torch": "jax_to_torch",
94+
# "torch_to_jax": "jax_to_torch",
95+
# numpy_to_torch
96+
"NumpyToTorchV0": "numpy_to_torch",
97+
# "torch_to_numpy": "numpy_to_torch",
98+
# "numpy_to_torch": "numpy_to_torch",
99+
}
100+
101+
102+
def __getattr__(name: str):
103+
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
104+
105+
This optimises the loading of gymnasium.
106+
107+
Args:
108+
name: The name of a wrapper to load
109+
110+
Returns:
111+
Wrapper
112+
"""
113+
if name in _wrapper_to_class:
114+
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
115+
module = importlib.import_module(import_stmt)
116+
return getattr(module, name)
117+
# add helpful error message if version number has changed
118+
else:
119+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

gymnasium/experimental/wrappers/atari_preprocessing.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
try:
99
import cv2
1010
except ImportError:
11-
cv2 = None
11+
raise gym.error.DependencyNotInstalled(
12+
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
13+
)
1214

1315

1416
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
@@ -72,10 +74,6 @@ def __init__(
7274
)
7375
gym.Wrapper.__init__(self, env)
7476

75-
if cv2 is None:
76-
raise gym.error.DependencyNotInstalled(
77-
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
78-
)
7977
assert frame_skip > 0
8078
assert screen_size > 0
8179
assert noop_max >= 0

gymnasium/experimental/wrappers/conversion/__init__.py

-1
This file was deleted.

0 commit comments

Comments
 (0)