Skip to content

Commit 806701e

Browse files
authored
[RLlib] New ConnectorV2 API #4: Changes to Learner/LearnerGroup API to allow updating from Episodes. (ray-project#41235)
1 parent 65478d4 commit 806701e

File tree

18 files changed

+611
-301
lines changed

18 files changed

+611
-301
lines changed

doc/source/rllib/package_ref/learner.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ Performing Updates
7272
:nosignatures:
7373
:toctree: doc/
7474

75-
Learner.update
75+
Learner.update_from_batch
76+
Learner.update_from_episodes
7677
Learner._update
7778
Learner.additional_update
7879
Learner.additional_update_for_module

doc/source/rllib/rllib-learner.rst

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,23 @@ Updates
229229

230230
.. testcode::
231231

232-
# This is a blocking update
233-
results = learner_group.update(DUMMY_BATCH)
232+
# This is a blocking update.
233+
results = learner_group.update_from_batch(batch=DUMMY_BATCH)
234234

235235
# This is a non-blocking update. The results are returned in a future
236-
# call to `async_update`
237-
_ = learner_group.async_update(DUMMY_BATCH)
236+
# call to `update_from_batch(..., async_update=True)`
237+
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True)
238238

239239
# Artificially wait for async request to be done to get the results
240-
# in the next call to `LearnerGroup.async_update()`.
240+
# in the next call to
241+
# `LearnerGroup.update_from_batch(..., async_update=True)`.
241242
time.sleep(5)
242-
results = learner_group.async_update(DUMMY_BATCH)
243+
results = learner_group.update_from_batch(
244+
batch=DUMMY_BATCH, async_update=True
245+
)
243246
# `results` is a list of results dict. The items in the list represent the different
244-
# remote results from the different calls to `async_update()`.
247+
# remote results from the different calls to
248+
# `update_from_batch(..., async_update=True)`.
245249
assert len(results) > 0
246250
# Each item is a results dict, already reduced over the n Learner workers.
247251
assert isinstance(results[0], dict), results[0]
@@ -256,8 +260,8 @@ Updates
256260

257261
.. testcode::
258262

259-
# This is a blocking update.
260-
result = learner.update(DUMMY_BATCH)
263+
# This is a blocking update (given a training batch).
264+
result = learner.update_from_batch(batch=DUMMY_BATCH)
261265

262266
# This is an additional non-gradient based update.
263267
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)

rllib/algorithms/algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import functools
66
import gymnasium as gym
77
import importlib
8+
import importlib.metadata
89
import json
910
import logging
1011
import numpy as np
1112
import os
1213
from packaging import version
13-
import importlib.metadata
1414
import re
1515
import tempfile
1616
import time
@@ -1607,7 +1607,7 @@ def training_step(self) -> ResultDict:
16071607
# TODO: (sven) rename MultiGPUOptimizer into something more
16081608
# meaningful.
16091609
if self.config._enable_new_api_stack:
1610-
train_results = self.learner_group.update(train_batch)
1610+
train_results = self.learner_group.update_from_batch(batch=train_batch)
16111611
elif self.config.get("simple_optimizer") is True:
16121612
train_results = train_one_step(self, train_batch)
16131613
else:

rllib/algorithms/appo/tests/test_appo_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_appo_loss(self):
9797
env=algo.workers.local_worker().env
9898
)
9999
learner_group.set_weights(algo.get_weights())
100-
learner_group.update(train_batch.as_multi_agent())
100+
learner_group.update_from_batch(batch=train_batch.as_multi_agent())
101101

102102
algo.stop()
103103

rllib/algorithms/bc/bc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def training_step(self) -> ResultDict:
171171
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
172172

173173
# Updating the policy.
174-
train_results = self.learner_group.update(train_batch)
174+
train_results = self.learner_group.update_from_batch(batch=train_batch)
175175

176176
# Synchronize weights.
177177
# As the results contain for each policy the loss and in addition the

rllib/algorithms/dreamerv3/dreamerv3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,8 @@ def training_step(self) -> ResultDict:
606606
)
607607

608608
# Perform the actual update via our learner group.
609-
train_results = self.learner_group.update(
610-
SampleBatch(sample).as_multi_agent(),
609+
train_results = self.learner_group.update_from_batch(
610+
batch=SampleBatch(sample).as_multi_agent(),
611611
reduce_fn=self._reduce_results,
612612
)
613613
self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps

rllib/algorithms/dreamerv3/utils/summaries.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def report_predicted_vs_sampled_obs(
133133
Continues: Compute MSE (sampled vs predicted).
134134
135135
Args:
136-
results: The results dict that was returned by `LearnerGroup.update()`.
136+
results: The results dict that was returned by
137+
`LearnerGroup.update_from_batch()`.
137138
sample: The sampled data (dict) from the replay buffer. Already tf-tensor
138139
converted.
139140
batch_size_B: The batch size (B). This is the number of trajectories sampled

rllib/algorithms/impala/impala.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -946,24 +946,18 @@ def learn_on_processed_samples(self) -> ResultDict:
946946
self.batches_to_place_on_learner.clear()
947947
# If there are no learner workers and learning is directly on the driver
948948
# Then we can't do async updates, so we need to block.
949-
blocking = self.config.num_learner_workers == 0
949+
async_update = self.config.num_learner_workers > 0
950950
results = []
951951
for batch in batches:
952-
if blocking:
953-
result = self.learner_group.update(
954-
batch,
955-
reduce_fn=_reduce_impala_results,
956-
num_iters=self.config.num_sgd_iter,
957-
minibatch_size=self.config.minibatch_size,
958-
)
952+
result = self.learner_group.update_from_batch(
953+
batch=batch,
954+
async_update=async_update,
955+
reduce_fn=_reduce_impala_results,
956+
num_iters=self.config.num_sgd_iter,
957+
minibatch_size=self.config.minibatch_size,
958+
)
959+
if not async_update:
959960
results = [result]
960-
else:
961-
results = self.learner_group.async_update(
962-
batch,
963-
reduce_fn=_reduce_impala_results,
964-
num_iters=self.config.num_sgd_iter,
965-
minibatch_size=self.config.minibatch_size,
966-
)
967961

968962
for r in results:
969963
self._counters[NUM_ENV_STEPS_TRAINED] += r[ALL_MODULES].pop(
@@ -973,14 +967,14 @@ def learn_on_processed_samples(self) -> ResultDict:
973967
NUM_AGENT_STEPS_TRAINED
974968
)
975969

976-
self._counters.update(self.learner_group.get_in_queue_stats())
970+
self._counters.update(self.learner_group.get_stats())
977971
# If there are results, reduce-mean over each individual value and return.
978972
if results:
979973
return tree.map_structure(lambda *x: np.mean(x), *results)
980974

981975
# Nothing on the queue -> Don't send requests to learner group
982-
# or no results ready (from previous `self.learner_group.update()` calls) for
983-
# reducing.
976+
# or no results ready (from previous `self.learner_group.update_from_batch()`
977+
# calls) for reducing.
984978
return {}
985979

986980
def place_processed_samples_on_learner_thread_queue(self) -> None:

rllib/algorithms/impala/tests/test_impala_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_impala_loss(self):
9494
env=algo.workers.local_worker().env
9595
)
9696
learner_group.set_weights(algo.get_weights())
97-
learner_group.update(train_batch.as_multi_agent())
97+
learner_group.update_from_batch(batch=train_batch.as_multi_agent())
9898

9999
algo.stop()
100100

rllib/algorithms/ppo/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ def training_step(self) -> ResultDict:
424424
if self.config._enable_new_api_stack:
425425
# TODO (Kourosh) Clearly define what train_batch_size
426426
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
427-
train_results = self.learner_group.update(
428-
train_batch,
427+
train_results = self.learner_group.update_from_batch(
428+
batch=train_batch,
429429
minibatch_size=self.config.sgd_minibatch_size,
430430
num_iters=self.config.num_sgd_iter,
431431
)

0 commit comments

Comments
 (0)