|
22 | 22 | from torchrl.trainers.algorithms.ddpg import DDPGTrainer |
23 | 23 | from torchrl.trainers.algorithms.dqn import DQNTrainer |
24 | 24 | from torchrl.trainers.algorithms.iql import IQLTrainer |
| 25 | +from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer |
25 | 26 | from torchrl.trainers.algorithms.ppo import PPOTrainer |
26 | 27 | from torchrl.trainers.algorithms.sac import SACTrainer |
27 | 28 | from torchrl.trainers.algorithms.td3 import TD3Trainer |
@@ -218,6 +219,147 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: |
218 | 219 | return trainer |
219 | 220 |
|
220 | 221 |
|
| 222 | +@dataclass |
| 223 | +class OfflineToOnlineTrainerConfig(SACTrainerConfig): |
| 224 | + """Hydra configuration for |
| 225 | + :class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`. |
| 226 | +
|
| 227 | + Every kwarg accepted by ``OfflineToOnlineTrainer.__init__`` is exposed as a |
| 228 | + field here, with SAC network-construction helper fields inherited from |
| 229 | + :class:`SACTrainerConfig`. |
| 230 | + """ |
| 231 | + |
| 232 | + anneal_frames: int | None = None |
| 233 | + |
| 234 | + _target_: str = ( |
| 235 | + "torchrl.trainers.algorithms.configs.trainers." |
| 236 | + "_make_offline_to_online_trainer" |
| 237 | + ) |
| 238 | + |
| 239 | + def __post_init__(self) -> None: |
| 240 | + """Post-initialization hook for offline-to-online trainer configuration.""" |
| 241 | + super().__post_init__() |
| 242 | + |
| 243 | + |
| 244 | +def _make_offline_to_online_trainer(*args, **kwargs) -> OfflineToOnlineTrainer: |
| 245 | + from torchrl.trainers.trainers import Logger |
| 246 | + |
| 247 | + collector = kwargs.pop("collector") |
| 248 | + total_frames = kwargs.pop("total_frames") |
| 249 | + if total_frames is None: |
| 250 | + total_frames = collector.total_frames |
| 251 | + frame_skip = kwargs.pop("frame_skip", 1) |
| 252 | + optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) |
| 253 | + loss_module = kwargs.pop("loss_module") |
| 254 | + optimizer = kwargs.pop("optimizer") |
| 255 | + logger = kwargs.pop("logger") |
| 256 | + clip_grad_norm = kwargs.pop("clip_grad_norm", True) |
| 257 | + clip_norm = kwargs.pop("clip_norm") |
| 258 | + progress_bar = kwargs.pop("progress_bar", True) |
| 259 | + replay_buffer = kwargs.pop("replay_buffer") |
| 260 | + save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) |
| 261 | + log_interval = kwargs.pop("log_interval", 10000) |
| 262 | + save_trainer_file = kwargs.pop("save_trainer_file") |
| 263 | + seed = kwargs.pop("seed") |
| 264 | + actor_network = kwargs.pop("actor_network") |
| 265 | + critic_network = kwargs.pop("critic_network") |
| 266 | + kwargs.pop("create_env_fn") |
| 267 | + target_net_updater = kwargs.pop("target_net_updater") |
| 268 | + async_collection = kwargs.pop("async_collection", False) |
| 269 | + if async_collection: |
| 270 | + raise ValueError( |
| 271 | + "OfflineToOnlineTrainer does not support async_collection." |
| 272 | + ) |
| 273 | + log_timings = kwargs.pop("log_timings", False) |
| 274 | + auto_log_optim_steps = kwargs.pop("auto_log_optim_steps", True) |
| 275 | + batch_size = kwargs.pop("batch_size", None) |
| 276 | + anneal_frames = kwargs.pop("anneal_frames", None) |
| 277 | + enable_logging = kwargs.pop("enable_logging", True) |
| 278 | + log_rewards = kwargs.pop("log_rewards", True) |
| 279 | + log_actions = kwargs.pop("log_actions", True) |
| 280 | + log_observations = kwargs.pop("log_observations", False) |
| 281 | + done_key = _normalize_hydra_key(kwargs.pop("done_key", "done")) |
| 282 | + terminated_key = _normalize_hydra_key(kwargs.pop("terminated_key", "terminated")) |
| 283 | + reward_key = _normalize_hydra_key(kwargs.pop("reward_key", "reward")) |
| 284 | + episode_reward_key = _normalize_hydra_key( |
| 285 | + kwargs.pop("episode_reward_key", "reward_sum") |
| 286 | + ) |
| 287 | + action_key = _normalize_hydra_key(kwargs.pop("action_key", "action")) |
| 288 | + observation_key = _normalize_hydra_key(kwargs.pop("observation_key", "observation")) |
| 289 | + hooks = kwargs.pop("hooks", None) |
| 290 | + |
| 291 | + # Instantiate networks first |
| 292 | + if actor_network is not None and not isinstance(actor_network, torch.nn.Module): |
| 293 | + actor_network = actor_network() |
| 294 | + if critic_network is not None and not isinstance(critic_network, torch.nn.Module): |
| 295 | + critic_network = critic_network() |
| 296 | + |
| 297 | + if not isinstance(collector, BaseCollector): |
| 298 | + collector = collector() |
| 299 | + |
| 300 | + if not isinstance(loss_module, LossModule): |
| 301 | + # then it's a partial config |
| 302 | + loss_module = loss_module( |
| 303 | + actor_network=actor_network, critic_network=critic_network |
| 304 | + ) |
| 305 | + if target_net_updater is not None and not isinstance( |
| 306 | + target_net_updater, TargetNetUpdater |
| 307 | + ): |
| 308 | + # target_net_updater must be a partial taking the loss as input |
| 309 | + target_net_updater = target_net_updater(loss_module) |
| 310 | + if not isinstance(optimizer, torch.optim.Optimizer): |
| 311 | + # then it's a partial config |
| 312 | + optimizer = optimizer(params=loss_module.parameters()) |
| 313 | + |
| 314 | + # Quick instance checks |
| 315 | + if not isinstance(collector, BaseCollector): |
| 316 | + raise ValueError(f"collector must be a BaseCollector, got {type(collector)}") |
| 317 | + if not isinstance(loss_module, LossModule): |
| 318 | + raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") |
| 319 | + if not isinstance(optimizer, torch.optim.Optimizer): |
| 320 | + raise ValueError( |
| 321 | + f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" |
| 322 | + ) |
| 323 | + if not isinstance(logger, Logger) and logger is not None: |
| 324 | + raise ValueError(f"logger must be a Logger, got {type(logger)}") |
| 325 | + |
| 326 | + trainer = OfflineToOnlineTrainer( |
| 327 | + collector=collector, |
| 328 | + total_frames=total_frames, |
| 329 | + frame_skip=frame_skip, |
| 330 | + optim_steps_per_batch=optim_steps_per_batch, |
| 331 | + loss_module=loss_module, |
| 332 | + replay_buffer=replay_buffer, |
| 333 | + anneal_frames=anneal_frames, |
| 334 | + batch_size=batch_size, |
| 335 | + optimizer=optimizer, |
| 336 | + logger=logger, |
| 337 | + clip_grad_norm=clip_grad_norm, |
| 338 | + clip_norm=clip_norm, |
| 339 | + progress_bar=progress_bar, |
| 340 | + seed=seed, |
| 341 | + save_trainer_interval=save_trainer_interval, |
| 342 | + log_interval=log_interval, |
| 343 | + save_trainer_file=save_trainer_file, |
| 344 | + enable_logging=enable_logging, |
| 345 | + log_rewards=log_rewards, |
| 346 | + log_actions=log_actions, |
| 347 | + log_observations=log_observations, |
| 348 | + target_net_updater=target_net_updater, |
| 349 | + async_collection=async_collection, |
| 350 | + log_timings=log_timings, |
| 351 | + auto_log_optim_steps=auto_log_optim_steps, |
| 352 | + done_key=done_key, |
| 353 | + terminated_key=terminated_key, |
| 354 | + reward_key=reward_key, |
| 355 | + episode_reward_key=episode_reward_key, |
| 356 | + action_key=action_key, |
| 357 | + observation_key=observation_key, |
| 358 | + ) |
| 359 | + _register_trainer_hooks(trainer, hooks) |
| 360 | + return trainer |
| 361 | + |
| 362 | + |
221 | 363 | @dataclass |
222 | 364 | class PPOTrainerConfig(TrainerConfig): |
223 | 365 | """Hydra configuration for :class:`~torchrl.trainers.algorithms.PPOTrainer`. |
|
0 commit comments