Skip to content

Commit e03dd6e

Browse files
authored
[RLlib] New ConnectorV3 API #5: PPO runs in single-agent mode in this API stack. (ray-project#42272)
1 parent 88a35bc commit e03dd6e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1140
-605
lines changed

rllib/BUILD

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,6 @@ py_test(
150150
# --------------------------------------------------------------------
151151

152152
# APPO
153-
py_test(
154-
name = "learning_tests_cartpole_appo_no_vtrace",
155-
main = "tests/run_regression_tests.py",
156-
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
157-
size = "medium", # bazel may complain about it being too long sometimes - medium is on purpose as some frameworks take longer
158-
srcs = ["tests/run_regression_tests.py"],
159-
data = ["tuned_examples/appo/cartpole-appo.yaml"],
160-
args = ["--dir=tuned_examples/appo"]
161-
)
162-
163153
py_test(
164154
name = "learning_tests_cartpole_appo_w_rl_modules_and_learner",
165155
main = "tests/run_regression_tests.py",
@@ -177,7 +167,7 @@ py_test(
177167
size = "medium",
178168
srcs = ["tests/run_regression_tests.py"],
179169
data = [
180-
"tuned_examples/appo/cartpole-appo-vtrace-separate-losses.py"
170+
"tuned_examples/appo/cartpole-appo-separate-losses.py"
181171
],
182172
args = ["--dir=tuned_examples/appo"]
183173
)
@@ -208,17 +198,17 @@ py_test(
208198
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
209199
size = "medium",
210200
srcs = ["tests/run_regression_tests.py"],
211-
data = ["tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml"],
201+
data = ["tuned_examples/appo/cartpole-appo-fake-gpus.yaml"],
212202
args = ["--dir=tuned_examples/appo"]
213203
)
214204

215205
py_test(
216-
name = "learning_tests_stateless_cartpole_appo_vtrace",
206+
name = "learning_tests_stateless_cartpole_appo",
217207
main = "tests/run_regression_tests.py",
218208
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
219209
size = "enormous",
220210
srcs = ["tests/run_regression_tests.py"],
221-
data = ["tuned_examples/appo/stateless-cartpole-appo-vtrace.py"],
211+
data = ["tuned_examples/appo/stateless_cartpole_appo.py"],
222212
args = ["--dir=tuned_examples/appo"]
223213
)
224214

@@ -1453,6 +1443,13 @@ py_test(
14531443
srcs = ["utils/exploration/tests/test_explorations.py"]
14541444
)
14551445

1446+
py_test(
1447+
name = "test_value_predictions",
1448+
tags = ["team:rllib", "utils"],
1449+
size = "small",
1450+
srcs = ["utils/postprocessing/tests/test_value_predictions.py"]
1451+
)
1452+
14561453
py_test(
14571454
name = "test_random_encoder",
14581455
tags = ["team:rllib", "utils"],
@@ -1461,7 +1458,7 @@ py_test(
14611458
)
14621459

14631460
py_test(
1464-
name = "utils/tests/test_torch_utils",
1461+
name = "test_torch_utils",
14651462
tags = ["team:rllib", "utils", "gpu"],
14661463
size = "medium",
14671464
srcs = ["utils/tests/test_torch_utils.py"]

rllib/algorithms/algorithm.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Set,
2626
Tuple,
2727
Type,
28+
TYPE_CHECKING,
2829
Union,
2930
)
3031

@@ -46,7 +47,6 @@
4647
collect_metrics,
4748
summarize_episodes,
4849
)
49-
from ray.rllib.evaluation.postprocessing_v2 import postprocess_episodes_to_sample_batch
5050
from ray.rllib.evaluation.worker_set import WorkerSet
5151
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
5252
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
@@ -129,6 +129,8 @@
129129
from ray.util.timer import _Timer
130130
from ray.tune.registry import get_trainable_cls
131131

132+
if TYPE_CHECKING:
133+
from ray.rllib.core.learner.learner_group import LearnerGroup
132134

133135
try:
134136
from ray.rllib.extensions import AlgorithmBase
@@ -449,6 +451,9 @@ def __init__(
449451
# Placeholder for a local replay buffer instance.
450452
self.local_replay_buffer = None
451453

454+
# Placeholder for our LearnerGroup responsible for updating the RLModule(s).
455+
self.learner_group: Optional["LearnerGroup"] = None
456+
452457
# Create a default logger creator if no logger_creator is specified
453458
if logger_creator is None:
454459
# Default logdir prefix containing the agent's name and the
@@ -1410,7 +1415,12 @@ def remote_fn(worker):
14101415
worker.set_weights(
14111416
weights=ray.get(weights_ref), weights_seq_no=weights_seq_no
14121417
)
1413-
episodes = worker.sample(explore=False)
1418+
# By episode: Run always only one episode per remote call.
1419+
# By timesteps: By default EnvRunner runs for the configured number of
1420+
# timesteps (based on `rollout_fragment_length` and `num_envs_per_worker`).
1421+
episodes = worker.sample(
1422+
explore=False, num_episodes=1 if unit == "episodes" else None
1423+
)
14141424
metrics = worker.get_metrics()
14151425
return episodes, metrics, weights_seq_no
14161426

@@ -1449,11 +1459,13 @@ def remote_fn(worker):
14491459
rollout_metrics.extend(metrics)
14501460
i += 1
14511461

1452-
# Convert our list of Episodes to a single SampleBatch.
1453-
batch = postprocess_episodes_to_sample_batch(episodes)
14541462
# Collect steps stats.
1455-
_agent_steps = batch.agent_steps()
1456-
_env_steps = batch.env_steps()
1463+
# TODO (sven): Solve for proper multi-agent env/agent steps counting.
1464+
# Once we have multi-agent support on EnvRunner stack, we can simply do:
1465+
# `len(episode)` for env steps and `episode.num_agent_steps()` for agent
1466+
# steps.
1467+
_agent_steps = sum(len(e) for e in episodes)
1468+
_env_steps = sum(len(e) for e in episodes)
14571469

14581470
# Only complete episodes done by eval workers.
14591471
if unit == "episodes":
@@ -1467,6 +1479,7 @@ def remote_fn(worker):
14671479
)
14681480

14691481
if self.reward_estimators:
1482+
batch = concat_samples([e.get_sample_batch() for e in episodes])
14701483
all_batches.append(batch)
14711484

14721485
agent_steps_this_iter += _agent_steps

rllib/algorithms/algorithm_config.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ def __init__(self, algo_class=None):
363363
self.grad_clip = None
364364
self.grad_clip_by = "global_norm"
365365
self.train_batch_size = 32
366+
# Simple logic for now: If None, use `train_batch_size`.
367+
self.train_batch_size_per_learner = None
366368
# TODO (sven): Unsolved problem with RLModules sometimes requiring settings from
367369
# the main AlgorithmConfig. We should not require the user to provide those
368370
# settings in both, the AlgorithmConfig (as property) AND the model config
@@ -871,6 +873,7 @@ def build_env_to_module_connector(self, env):
871873
return pipeline
872874

873875
def build_module_to_env_connector(self, env):
876+
874877
from ray.rllib.connectors.module_to_env import (
875878
DefaultModuleToEnv,
876879
ModuleToEnvPipeline,
@@ -1333,11 +1336,11 @@ def environment(
13331336
Tuple[value1, value2]: Clip at value1 and value2.
13341337
normalize_actions: If True, RLlib will learn entirely inside a normalized
13351338
action space (0.0 centered with small stddev; only affecting Box
1336-
components). We will unsquash actions (and clip, just in case) to the
1339+
components). RLlib will unsquash actions (and clip, just in case) to the
13371340
bounds of the env's action space before sending actions back to the env.
1338-
clip_actions: If True, RLlib will clip actions according to the env's bounds
1339-
before sending them back to the env.
1340-
TODO: (sven) This option should be deprecated and always be False.
1341+
clip_actions: If True, the RLlib default ModuleToEnv connector will clip
1342+
actions according to the env's bounds (before sending them into the
1343+
`env.step()` call).
13411344
disable_env_checking: If True, disable the environment pre-checking module.
13421345
is_atari: This config can be used to explicitly specify whether the env is
13431346
an Atari env or not. If not specified, RLlib will try to auto-detect
@@ -1678,6 +1681,7 @@ def training(
16781681
grad_clip: Optional[float] = NotProvided,
16791682
grad_clip_by: Optional[str] = NotProvided,
16801683
train_batch_size: Optional[int] = NotProvided,
1684+
train_batch_size_per_learner: Optional[int] = NotProvided,
16811685
model: Optional[dict] = NotProvided,
16821686
optimizer: Optional[dict] = NotProvided,
16831687
max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
@@ -1726,7 +1730,16 @@ def training(
17261730
the shapes of these tensors are).
17271731
grad_clip_by: See `grad_clip` for the effect of this setting on gradient
17281732
clipping. Allowed values are `value`, `norm`, and `global_norm`.
1729-
train_batch_size: Training batch size, if applicable.
1733+
train_batch_size_per_learner: Train batch size per individual Learner
1734+
worker. This setting only applies to the new API stack. The number
1735+
of Learner workers can be set via `config.resources(
1736+
num_learner_workers=...)`. The total effective batch size is then
1737+
`num_learner_workers` x `train_batch_size_per_learner` and can
1738+
be accessed via the property `AlgorithmConfig.total_train_batch_size`.
1739+
train_batch_size: Training batch size, if applicable. When on the new API
1740+
stack, this setting should no longer be used. Instead, use
1741+
`train_batch_size_per_learner` (in combination with
1742+
`num_learner_workers`).
17301743
model: Arguments passed into the policy model. See models/catalog.py for a
17311744
full list of the available model options.
17321745
TODO: Provide ModelConfig objects instead of dicts.
@@ -1766,6 +1779,8 @@ def training(
17661779
"or 'global_norm'!"
17671780
)
17681781
self.grad_clip_by = grad_clip_by
1782+
if train_batch_size_per_learner is not NotProvided:
1783+
self.train_batch_size_per_learner = train_batch_size_per_learner
17691784
if train_batch_size is not NotProvided:
17701785
self.train_batch_size = train_batch_size
17711786
if model is not NotProvided:
@@ -2716,20 +2731,29 @@ def uses_new_env_runners(self):
27162731
self.env_runner_cls, RolloutWorker
27172732
)
27182733

2734+
@property
2735+
def total_train_batch_size(self):
2736+
if self.train_batch_size_per_learner is not None:
2737+
return self.train_batch_size_per_learner * (self.num_learner_workers or 1)
2738+
else:
2739+
return self.train_batch_size
2740+
2741+
# TODO: Make rollout_fragment_length as read-only property and replace the current
2742+
# self.rollout_fragment_length a private variable.
27192743
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
27202744
"""Automatically infers a proper rollout_fragment_length setting if "auto".
27212745
27222746
Uses the simple formula:
2723-
`rollout_fragment_length` = `train_batch_size` /
2747+
`rollout_fragment_length` = `total_train_batch_size` /
27242748
(`num_envs_per_worker` * `num_rollout_workers`)
27252749
27262750
If result is a fraction AND `worker_index` is provided, will make
27272751
those workers add additional timesteps, such that the overall batch size (across
2728-
the workers) will add up to exactly the `train_batch_size`.
2752+
the workers) will add up to exactly the `total_train_batch_size`.
27292753
27302754
Returns:
27312755
The user-provided `rollout_fragment_length` or a computed one (if user
2732-
provided value is "auto"), making sure `train_batch_size` is reached
2756+
provided value is "auto"), making sure `total_train_batch_size` is reached
27332757
exactly in each iteration.
27342758
"""
27352759
if self.rollout_fragment_length == "auto":
@@ -2739,11 +2763,11 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
27392763
# 4 workers, 3 envs per worker, 2500 train batch size:
27402764
# -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496)
27412765
# -> worker 1: 209, workers 2-4: 208
2742-
rollout_fragment_length = self.train_batch_size / (
2766+
rollout_fragment_length = self.total_train_batch_size / (
27432767
self.num_envs_per_worker * (self.num_rollout_workers or 1)
27442768
)
27452769
if int(rollout_fragment_length) != rollout_fragment_length:
2746-
diff = self.train_batch_size - int(
2770+
diff = self.total_train_batch_size - int(
27472771
rollout_fragment_length
27482772
) * self.num_envs_per_worker * (self.num_rollout_workers or 1)
27492773
if (worker_index * self.num_envs_per_worker) <= diff:
@@ -3095,36 +3119,38 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
30953119
30963120
Raises:
30973121
ValueError: If there is a mismatch between user provided
3098-
`rollout_fragment_length` and `train_batch_size`.
3122+
`rollout_fragment_length` and `total_train_batch_size`.
30993123
"""
31003124
if (
31013125
self.rollout_fragment_length != "auto"
31023126
and not self.in_evaluation
3103-
and self.train_batch_size > 0
3127+
and self.total_train_batch_size > 0
31043128
):
31053129
min_batch_size = (
31063130
max(self.num_rollout_workers, 1)
31073131
* self.num_envs_per_worker
31083132
* self.rollout_fragment_length
31093133
)
31103134
batch_size = min_batch_size
3111-
while batch_size < self.train_batch_size:
3135+
while batch_size < self.total_train_batch_size:
31123136
batch_size += min_batch_size
3113-
if (
3114-
batch_size - self.train_batch_size > 0.1 * self.train_batch_size
3115-
or batch_size - min_batch_size - self.train_batch_size
3116-
> (0.1 * self.train_batch_size)
3137+
if batch_size - self.total_train_batch_size > (
3138+
0.1 * self.total_train_batch_size
3139+
) or batch_size - min_batch_size - self.total_train_batch_size > (
3140+
0.1 * self.total_train_batch_size
31173141
):
3118-
suggested_rollout_fragment_length = self.train_batch_size // (
3142+
suggested_rollout_fragment_length = self.total_train_batch_size // (
31193143
self.num_envs_per_worker * (self.num_rollout_workers or 1)
31203144
)
31213145
raise ValueError(
3122-
f"Your desired `train_batch_size` ({self.train_batch_size}) or a "
3123-
"value 10% off of that cannot be achieved with your other "
3146+
"Your desired `total_train_batch_size` "
3147+
f"({self.total_train_batch_size}={self.num_learner_workers} "
3148+
f"learners x {self.train_batch_size_per_learner}) "
3149+
"or a value 10% off of that cannot be achieved with your other "
31243150
f"settings (num_rollout_workers={self.num_rollout_workers}; "
31253151
f"num_envs_per_worker={self.num_envs_per_worker}; "
31263152
f"rollout_fragment_length={self.rollout_fragment_length})! "
3127-
"Try setting `rollout_fragment_length` to 'auto' OR "
3153+
"Try setting `rollout_fragment_length` to 'auto' OR to a value of "
31283154
f"{suggested_rollout_fragment_length}."
31293155
)
31303156

@@ -3580,8 +3606,7 @@ def _validate_evaluation_settings(self):
35803606
"""Checks, whether evaluation related settings make sense."""
35813607
if (
35823608
self.evaluation_interval
3583-
and self.env_runner_cls is not None
3584-
and not issubclass(self.env_runner_cls, RolloutWorker)
3609+
and self.uses_new_env_runners
35853610
and not self.enable_async_evaluation
35863611
):
35873612
raise ValueError(

rllib/algorithms/appo/tests/test_appo.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,6 @@ def test_appo_compilation(self):
2626
num_iterations = 2
2727

2828
for _ in framework_iterator(config):
29-
print("w/o v-trace")
30-
config.vtrace = False
31-
algo = config.build(env="CartPole-v1")
32-
for i in range(num_iterations):
33-
results = algo.train()
34-
print(results)
35-
check_train_results(results)
36-
37-
check_compute_single_action(algo)
38-
algo.stop()
39-
40-
print("w/ v-trace")
41-
config.vtrace = True
4229
algo = config.build(env="CartPole-v1")
4330
for i in range(num_iterations):
4431
results = algo.train()

rllib/algorithms/appo/tf/appo_tf_learner.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
OLD_ACTION_DIST_LOGITS_KEY,
99
)
1010
from ray.rllib.algorithms.appo.appo_learner import AppoLearner
11+
from ray.rllib.algorithms.impala.tf.impala_tf_learner import ImpalaTfLearner
1112
from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2
1213
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
13-
from ray.rllib.core.learner.tf.tf_learner import TfLearner
1414
from ray.rllib.utils.annotations import override
1515
from ray.rllib.utils.framework import try_import_tf
1616
from ray.rllib.utils.nested_dict import NestedDict
@@ -19,10 +19,10 @@
1919
_, tf, _ = try_import_tf()
2020

2121

22-
class APPOTfLearner(AppoLearner, TfLearner):
22+
class APPOTfLearner(AppoLearner, ImpalaTfLearner):
2323
"""Implements APPO loss / update logic on top of ImpalaTfLearner."""
2424

25-
@override(TfLearner)
25+
@override(ImpalaTfLearner)
2626
def compute_loss_for_module(
2727
self,
2828
*,
@@ -72,12 +72,15 @@ def compute_loss_for_module(
7272
trajectory_len=rollout_frag_or_episode_len,
7373
recurrent_seq_len=recurrent_seq_len,
7474
)
75-
bootstrap_values_time_major = make_time_major(
76-
batch[SampleBatch.VALUES_BOOTSTRAPPED],
77-
trajectory_len=rollout_frag_or_episode_len,
78-
recurrent_seq_len=recurrent_seq_len,
79-
)
80-
bootstrap_value = bootstrap_values_time_major[-1]
75+
if self.config.uses_new_env_runners:
76+
bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED]
77+
else:
78+
bootstrap_values_time_major = make_time_major(
79+
batch[SampleBatch.VALUES_BOOTSTRAPPED],
80+
trajectory_len=rollout_frag_or_episode_len,
81+
recurrent_seq_len=recurrent_seq_len,
82+
)
83+
bootstrap_values = bootstrap_values_time_major[-1]
8184

8285
# The discount factor that is used should be gamma except for timesteps where
8386
# the episode is terminated. In that case, the discount factor should be 0.
@@ -100,7 +103,7 @@ def compute_loss_for_module(
100103
discounts=discounts_time_major,
101104
rewards=rewards_time_major,
102105
values=values_time_major,
103-
bootstrap_value=bootstrap_value,
106+
bootstrap_values=bootstrap_values,
104107
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
105108
clip_rho_threshold=config.vtrace_clip_rho_threshold,
106109
)

0 commit comments

Comments
 (0)