Skip to content

Commit ef54ee5

Browse files
authored
[RLlib] Cleanup examples folder #10: Add custom_rl_module.py example script and matching RLModule example class (tiny CNN).. (ray-project#45774)
1 parent 641f0fa commit ef54ee5

File tree

12 files changed

+362
-108
lines changed

12 files changed

+362
-108
lines changed

rllib/BUILD

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,7 +2119,6 @@ py_test(
21192119

21202120
# subdirectory: checkpoints/
21212121
# ....................................
2122-
21232122
py_test(
21242123
name = "examples/checkpoints/checkpoint_by_custom_criteria",
21252124
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
@@ -2283,7 +2282,6 @@ py_test(
22832282

22842283
# subdirectory: curriculum/
22852284
# ....................................
2286-
22872285
py_test(
22882286
name = "examples/curriculum/curriculum_learning",
22892287
main = "examples/curriculum/curriculum_learning.py",
@@ -2295,7 +2293,6 @@ py_test(
22952293

22962294
# subdirectory: debugging/
22972295
# ....................................
2298-
22992296
#@OldAPIStack
23002297
py_test(
23012298
name = "examples/debugging/deterministic_training_torch",
@@ -2308,7 +2305,6 @@ py_test(
23082305

23092306
# subdirectory: envs/
23102307
# ....................................
2311-
23122308
py_test(
23132309
name = "examples/envs/custom_gym_env",
23142310
main = "examples/envs/custom_gym_env.py",
@@ -2449,7 +2445,6 @@ py_test(
24492445

24502446
# subdirectory: gpus/
24512447
# ....................................
2452-
24532448
py_test(
24542449
name = "examples/gpus/fractional_0.5_gpus_per_learner",
24552450
main = "examples/gpus/fractional_gpus_per_learner.py",
@@ -2469,7 +2464,6 @@ py_test(
24692464

24702465
# subdirectory: hierarchical/
24712466
# ....................................
2472-
24732467
#@OldAPIStack
24742468
py_test(
24752469
name = "examples/hierarchical/hierarchical_training_tf",
@@ -2492,7 +2486,6 @@ py_test(
24922486

24932487
# subdirectory: inference/
24942488
# ....................................
2495-
24962489
#@OldAPIStack
24972490
py_test(
24982491
name = "examples/inference/policy_inference_after_training_tf",
@@ -2905,6 +2898,15 @@ py_test(
29052898

29062899
# subdirectory: rl_modules/
29072900
# ....................................
2901+
py_test(
2902+
name = "examples/rl_modules/custom_rl_module",
2903+
main = "examples/rl_modules/custom_rl_module.py",
2904+
tags = ["team:rllib", "examples"],
2905+
size = "medium",
2906+
srcs = ["examples/rl_modules/custom_rl_module.py"],
2907+
args = ["--enable-new-api-stack", "--stop-iters=3"],
2908+
)
2909+
29082910
#@OldAPIStack @HybridAPIStack
29092911
py_test(
29102912
name = "examples/rl_modules/classes/mobilenet_rlm_hybrid_api_stack",

rllib/algorithms/ppo/torch/ppo_torch_rl_module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def setup(self):
2222
super().setup()
2323

2424
# If not an inference-only module (e.g., for evaluation), set up the
25-
# parameter names to be removed or renamed when syncing from the state dict
26-
# when synching.
25+
# parameter names to be removed or renamed when syncing from the state dict.
2726
if not self.inference_only:
2827
# Set the expected and unexpected keys for the inference-only module.
2928
self._set_inference_only_state_dict_keys()

rllib/core/rl_module/rl_module.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import datetime
33
import json
44
import pathlib
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from typing import Mapping, Any, TYPE_CHECKING, Optional, Type, Dict, Union
77

88
import gymnasium as gym
@@ -203,7 +203,7 @@ class RLModuleConfig:
203203

204204
observation_space: gym.Space = None
205205
action_space: gym.Space = None
206-
model_config_dict: Dict[str, Any] = None
206+
model_config_dict: Dict[str, Any] = field(default_factory=dict)
207207
catalog_class: Type["Catalog"] = None
208208

209209
def get_catalog(self) -> "Catalog":
@@ -456,22 +456,23 @@ def setup(self):
456456
457457
This is called automatically during the __init__ method of this class,
458458
therefore, the subclass should call super.__init__() in its constructor. This
459-
abstraction can be used to create any component that your RLModule needs.
459+
abstraction can be used to create any components (e.g. NN layers) that your
460+
RLModule needs.
460461
"""
461462
return None
462463

463464
@OverrideToImplementCustomLogic
464465
def get_train_action_dist_cls(self) -> Type[Distribution]:
465466
"""Returns the action distribution class for this RLModule used for training.
466467
467-
This class is used to create action distributions from outputs of the
468-
forward_train method. If the case that no action distribution class is needed,
468+
This class is used to get the correct action distribution class to be used by
469+
the training components. In case that no action distribution class is needed,
469470
this method can return None.
470471
471472
Note that RLlib's distribution classes all implement the `Distribution`
472473
interface. This requires two special methods: `Distribution.from_logits()` and
473-
`Distribution.to_deterministic()`. See the documentation for `Distribution`
474-
for more detail.
474+
`Distribution.to_deterministic()`. See the documentation of the
475+
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
475476
"""
476477
raise NotImplementedError
477478

@@ -485,8 +486,8 @@ def get_exploration_action_dist_cls(self) -> Type[Distribution]:
485486
486487
Note that RLlib's distribution classes all implement the `Distribution`
487488
interface. This requires two special methods: `Distribution.from_logits()` and
488-
`Distribution.to_deterministic()`. See the documentation for `Distribution`
489-
for more detail.
489+
`Distribution.to_deterministic()`. See the documentation of the
490+
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
490491
"""
491492
raise NotImplementedError
492493

@@ -500,8 +501,8 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]:
500501
501502
Note that RLlib's distribution classes all implement the `Distribution`
502503
interface. This requires two special methods: `Distribution.from_logits()` and
503-
`Distribution.to_deterministic()`. See the documentation for `Distribution`
504-
for more detail.
504+
`Distribution.to_deterministic()`. See the documentation of the
505+
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
505506
"""
506507
raise NotImplementedError
507508

@@ -596,9 +597,7 @@ def output_specs_inference(self) -> SpecType:
596597
a dict that has `action_dist` key and its value is an instance of
597598
`Distribution`.
598599
"""
599-
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
600-
# is what most algos will do.
601-
return {"action_dist": Distribution}
600+
return [Columns.ACTION_DIST_INPUTS]
602601

603602
@OverrideToImplementCustomLogic_CallToSuperRecommended
604603
def output_specs_exploration(self) -> SpecType:
@@ -609,9 +608,7 @@ def output_specs_exploration(self) -> SpecType:
609608
a dict that has `action_dist` key and its value is an instance of
610609
`Distribution`.
611610
"""
612-
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
613-
# is what most algos will do.
614-
return {"action_dist": Distribution}
611+
return [Columns.ACTION_DIST_INPUTS]
615612

616613
def output_specs_train(self) -> SpecType:
617614
"""Returns the output specs of the forward_train method."""

rllib/core/rl_module/torch/torch_rl_module.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,6 @@
2121
torch, nn = try_import_torch()
2222

2323

24-
def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
25-
"""A wrapper that compiles the forward methods of a TorchRLModule."""
26-
27-
# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
28-
# Check if torch framework supports torch.compile.
29-
if (
30-
torch is not None
31-
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
32-
):
33-
raise ValueError("torch.compile is only supported from torch 2.0.0")
34-
35-
compiled_forward_train = torch.compile(
36-
rl_module._forward_train,
37-
backend=compile_config.torch_dynamo_backend,
38-
mode=compile_config.torch_dynamo_mode,
39-
**compile_config.kwargs
40-
)
41-
42-
rl_module._forward_train = compiled_forward_train
43-
44-
compiled_forward_inference = torch.compile(
45-
rl_module._forward_inference,
46-
backend=compile_config.torch_dynamo_backend,
47-
mode=compile_config.torch_dynamo_mode,
48-
**compile_config.kwargs
49-
)
50-
51-
rl_module._forward_inference = compiled_forward_inference
52-
53-
compiled_forward_exploration = torch.compile(
54-
rl_module._forward_exploration,
55-
backend=compile_config.torch_dynamo_backend,
56-
mode=compile_config.torch_dynamo_mode,
57-
**compile_config.kwargs
58-
)
59-
60-
rl_module._forward_exploration = compiled_forward_exploration
61-
62-
return rl_module
63-
64-
6524
class TorchRLModule(nn.Module, RLModule):
6625
"""A base class for RLlib PyTorch RLModules.
6726
@@ -234,3 +193,44 @@ class TorchDDPRLModuleWithTargetNetworksInterface(
234193
@override(RLModuleWithTargetNetworksInterface)
235194
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
236195
return self.module.get_target_network_pairs()
196+
197+
198+
def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
199+
"""A wrapper that compiles the forward methods of a TorchRLModule."""
200+
201+
# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
202+
# Check if torch framework supports torch.compile.
203+
if (
204+
torch is not None
205+
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
206+
):
207+
raise ValueError("torch.compile is only supported from torch 2.0.0")
208+
209+
compiled_forward_train = torch.compile(
210+
rl_module._forward_train,
211+
backend=compile_config.torch_dynamo_backend,
212+
mode=compile_config.torch_dynamo_mode,
213+
**compile_config.kwargs,
214+
)
215+
216+
rl_module._forward_train = compiled_forward_train
217+
218+
compiled_forward_inference = torch.compile(
219+
rl_module._forward_inference,
220+
backend=compile_config.torch_dynamo_backend,
221+
mode=compile_config.torch_dynamo_mode,
222+
**compile_config.kwargs,
223+
)
224+
225+
rl_module._forward_inference = compiled_forward_inference
226+
227+
compiled_forward_exploration = torch.compile(
228+
rl_module._forward_exploration,
229+
backend=compile_config.torch_dynamo_backend,
230+
mode=compile_config.torch_dynamo_mode,
231+
**compile_config.kwargs,
232+
)
233+
234+
rl_module._forward_exploration = compiled_forward_exploration
235+
236+
return rl_module

rllib/env/single_agent_env_runner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,9 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
9191
try:
9292
module_spec: SingleAgentRLModuleSpec = self.config.rl_module_spec
9393
module_spec.observation_space = self._env_to_module.observation_space
94-
# TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should
95-
# actually hold the spaces for a single env, but for boxes the
96-
# shape is (1, 1) which brings a problem with the action dists.
97-
# shape=(1,) is expected.
9894
module_spec.action_space = self.env.envs[0].action_space
99-
module_spec.model_config_dict = self.config.model_config
95+
if module_spec.model_config_dict is None:
96+
module_spec.model_config_dict = self.config.model_config
10097
# Only load a light version of the module, if available. This is useful
10198
# if the the module has target or critic networks not needed in sampling
10299
# or inference.

rllib/examples/rl_modules/action_masking_rlm.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)