Skip to content

Commit 440aa81

Browse files
authored
[RLlib] Cleanup examples folder #14: Add example script for how to resume a tune.Tuner.fit() experiment from a checkpoint. (ray-project#45681)
1 parent 0e6864a commit 440aa81

File tree

11 files changed

+356
-46
lines changed

11 files changed

+356
-46
lines changed

rllib/BUILD

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,15 +2120,6 @@ py_test(
21202120
# subdirectory: checkpoints/
21212121
# ....................................
21222122

2123-
#@OldAPIStack
2124-
py_test(
2125-
name = "examples/checkpoints/cartpole_dqn_export",
2126-
main = "examples/checkpoints/cartpole_dqn_export.py",
2127-
tags = ["team:rllib", "exclusive", "examples"],
2128-
size = "small",
2129-
srcs = ["examples/checkpoints/cartpole_dqn_export.py"],
2130-
)
2131-
21322123
py_test(
21332124
name = "examples/checkpoints/checkpoint_by_custom_criteria",
21342125
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
@@ -2138,6 +2129,42 @@ py_test(
21382129
args = ["--enable-new-api-stack", "--stop-reward=150.0", "--num-cpus=8"]
21392130
)
21402131

2132+
py_test(
2133+
name = "examples/checkpoints/continue_training_from_checkpoint",
2134+
main = "examples/checkpoints/continue_training_from_checkpoint.py",
2135+
tags = ["team:rllib", "exclusive", "examples"],
2136+
size = "large",
2137+
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
2138+
args = ["--enable-new-api-stack", "--as-test"]
2139+
)
2140+
2141+
py_test(
2142+
name = "examples/checkpoints/continue_training_from_checkpoint_multi_agent",
2143+
main = "examples/checkpoints/continue_training_from_checkpoint.py",
2144+
tags = ["team:rllib", "exclusive", "examples"],
2145+
size = "large",
2146+
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
2147+
args = ["--enable-new-api-stack", "--as-test", "--num-agents=2", "--stop-reward-crash=400.0", "--stop-reward=900.0"]
2148+
)
2149+
2150+
#@OldAPIStack
2151+
py_test(
2152+
name = "examples/checkpoints/continue_training_from_checkpoint_old_api_stack",
2153+
main = "examples/checkpoints/continue_training_from_checkpoint.py",
2154+
tags = ["team:rllib", "exclusive", "examples"],
2155+
size = "large",
2156+
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
2157+
args = ["--as-test"]
2158+
)
2159+
2160+
py_test(
2161+
name = "examples/checkpoints/cartpole_dqn_export",
2162+
main = "examples/checkpoints/cartpole_dqn_export.py",
2163+
tags = ["team:rllib", "exclusive", "examples"],
2164+
size = "small",
2165+
srcs = ["examples/checkpoints/cartpole_dqn_export.py"],
2166+
)
2167+
21412168
#@OldAPIStack
21422169
py_test(
21432170
name = "examples/checkpoints/onnx_tf2",

rllib/algorithms/dqn/dqn.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -630,22 +630,27 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
630630
)
631631

632632
self.metrics.log_dict(
633-
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}),
633+
self.metrics.peek(
634+
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={}
635+
),
634636
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
635637
reduce="sum",
636638
)
637639
self.metrics.log_value(
638640
NUM_ENV_STEPS_SAMPLED_LIFETIME,
639-
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED, default=0),
641+
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0),
640642
reduce="sum",
641643
)
642644
self.metrics.log_value(
643645
NUM_EPISODES_LIFETIME,
644-
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_EPISODES, default=0),
646+
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0),
645647
reduce="sum",
646648
)
647649
self.metrics.log_dict(
648-
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED, default={}),
650+
self.metrics.peek(
651+
(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED),
652+
default={},
653+
),
649654
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME,
650655
reduce="sum",
651656
)
@@ -708,7 +713,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
708713
self.metrics.log_value(
709714
NUM_ENV_STEPS_TRAINED_LIFETIME,
710715
self.metrics.peek(
711-
LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED
716+
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
712717
),
713718
reduce="sum",
714719
)
@@ -725,7 +730,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
725730
# TODO (sven): Uncomment this once agent steps are available in the
726731
# Learner stats.
727732
# self.metrics.log_dict(self.metrics.peek(
728-
# LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED, default={}
733+
# (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={}
729734
# ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum")
730735

731736
# Update replay buffer priorities.

rllib/algorithms/dreamerv3/dreamerv3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,13 +582,13 @@ def training_step(self) -> ResultDict:
582582
self.metrics.log_dict(
583583
{
584584
NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
585-
ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED
585+
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED)
586586
),
587587
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
588-
ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED
588+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED)
589589
),
590590
NUM_EPISODES_LIFETIME: self.metrics.peek(
591-
ENV_RUNNER_RESULTS, NUM_EPISODES
591+
(ENV_RUNNER_RESULTS, NUM_EPISODES)
592592
),
593593
},
594594
reduce="sum",

rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def compute_gradients(
158158
# Take individual loss term from the registered metrics for
159159
# the main module.
160160
self.metrics.peek(
161-
DEFAULT_MODULE_ID, component.upper() + "_L_total"
161+
(DEFAULT_MODULE_ID, component.upper() + "_L_total")
162162
),
163163
self.filter_param_dict_for_optimizer(
164164
self._params, self.get_optimizer(optimizer_name=component)

rllib/algorithms/dreamerv3/utils/summaries.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,7 @@ def report_dreamed_eval_trajectory_vs_samples(
217217
the report/videos.
218218
"""
219219
dream_data = metrics.peek(
220-
LEARNER_RESULTS,
221-
DEFAULT_MODULE_ID,
222-
"dream_data",
220+
(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data"),
223221
default={},
224222
)
225223
metrics.delete(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data", key_error=False)

rllib/algorithms/ppo/ppo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,13 @@ def _training_step_new_api_stack(self) -> ResultDict:
463463
self.metrics.log_dict(
464464
{
465465
NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
466-
ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED
466+
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED)
467467
),
468468
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
469-
ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED
469+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED)
470470
),
471471
NUM_EPISODES_LIFETIME: self.metrics.peek(
472-
ENV_RUNNER_RESULTS, NUM_EPISODES
472+
(ENV_RUNNER_RESULTS, NUM_EPISODES)
473473
),
474474
},
475475
reduce="sum",
@@ -494,10 +494,10 @@ def _training_step_new_api_stack(self) -> ResultDict:
494494
self.metrics.log_dict(
495495
{
496496
NUM_ENV_STEPS_TRAINED_LIFETIME: self.metrics.peek(
497-
LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED
497+
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
498498
),
499499
# NUM_MODULE_STEPS_TRAINED_LIFETIME: self.metrics.peek(
500-
# LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED
500+
# (LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED)
501501
# ),
502502
},
503503
reduce="sum",
@@ -531,7 +531,9 @@ def _training_step_new_api_stack(self) -> ResultDict:
531531
if self.config.use_kl_loss:
532532
for mid in modules_to_update:
533533
kl = convert_to_numpy(
534-
self.metrics.peek(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY)
534+
self.metrics.peek(
535+
(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY)
536+
)
535537
)
536538
if np.isnan(kl):
537539
logger.warning(

rllib/algorithms/sac/torch/sac_torch_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def compute_gradients(
314314
for component in (
315315
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
316316
):
317-
self.metrics.peek(module_id, component + "_loss").backward(
317+
self.metrics.peek((module_id, component + "_loss")).backward(
318318
retain_graph=True
319319
)
320320
grads.update(

rllib/examples/checkpoints/checkpoint_by_custom_criteria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Example extracting a checkpoint from n trials using one or more custom criteria.
22
33
This example:
4-
- runs a simple CartPole experiment with three different learning rates (three tune
4+
- runs a CartPole experiment with three different learning rates (three tune
55
"trials"). During the experiment, for each trial, we create a checkpoint at each
66
iteration.
77
- at the end of the experiment, we compare the trials and pick the one that performed

0 commit comments

Comments
 (0)