Skip to content

Commit ee41800

Browse files
authored
[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alpha-star style) #2. (ray-project#21649)
1 parent 8ebc50f commit ee41800

File tree

12 files changed

+596
-170
lines changed

12 files changed

+596
-170
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import collections
2+
import platform
3+
import random
4+
from typing import Optional
5+
6+
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
7+
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
8+
from ray.rllib.utils.timer import TimerStat
9+
from ray.rllib.utils.typing import PolicyID, SampleBatchType
10+
11+
12+
class MixInMultiAgentReplayBuffer:
13+
"""This buffer adds replayed samples to a stream of new experiences.
14+
15+
- Any newly added batch (`add_batch()`) is immediately returned upon
16+
the next `replay` call (close to on-policy) as well as being moved
17+
into the buffer.
18+
- Additionally, a certain number of old samples is mixed into the
19+
returned sample according to a given "replay ratio".
20+
- If >1 calls to `add_batch()` are made without any `replay()` calls
21+
in between, all newly added batches are returned (plus some older samples
22+
according to the "replay ratio").
23+
24+
Examples:
25+
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
26+
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100,
27+
... replay_ratio=0.66)
28+
>>> buffer.add_batch(<A>)
29+
>>> buffer.add_batch(<B>)
30+
>>> buffer.replay()
31+
... [<A>, <B>, <B>]
32+
>>> buffer.add_batch(<C>)
33+
>>> buffer.replay()
34+
... [<C>, <A>, <B>]
35+
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it
36+
>>> # is the newest sample
37+
38+
>>> buffer.add_batch(<D>)
39+
>>> buffer.replay()
40+
... [<D>, <A>, <C>]
41+
42+
# replay proportion 0.0 -> replay disabled:
43+
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0)
44+
>>> buffer.add_batch(<A>)
45+
>>> buffer.replay()
46+
... [<A>]
47+
>>> buffer.add_batch(<B>)
48+
>>> buffer.replay()
49+
... [<B>]
50+
"""
51+
52+
def __init__(self, capacity: int, replay_ratio: float):
53+
"""Initializes MixInReplay instance.
54+
55+
Args:
56+
capacity (int): Number of batches to store in total.
57+
replay_ratio (float): Ratio of replayed samples in the returned
58+
batches. E.g. a ratio of 0.0 means only return new samples
59+
(no replay), a ratio of 0.5 means always return newest sample
60+
plus one old one (1:1), a ratio of 0.66 means always return
61+
the newest sample plus 2 old (replayed) ones (1:2), etc...
62+
"""
63+
self.capacity = capacity
64+
self.replay_ratio = replay_ratio
65+
self.replay_proportion = None
66+
if self.replay_ratio != 1.0:
67+
self.replay_proportion = self.replay_ratio / (
68+
1.0 - self.replay_ratio)
69+
70+
def new_buffer():
71+
return SimpleReplayBuffer(num_slots=capacity)
72+
73+
self.replay_buffers = collections.defaultdict(new_buffer)
74+
75+
# Metrics.
76+
self.add_batch_timer = TimerStat()
77+
self.replay_timer = TimerStat()
78+
self.update_priorities_timer = TimerStat()
79+
80+
# Added timesteps over lifetime.
81+
self.num_added = 0
82+
83+
# Last added batch(es).
84+
self.last_added_batches = collections.defaultdict(list)
85+
86+
def add_batch(self, batch: SampleBatchType) -> None:
87+
"""Adds a batch to the appropriate policy's replay buffer.
88+
89+
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
90+
it is not a MultiAgentBatch. Subsequently adds the individual policy
91+
batches to the storage.
92+
93+
Args:
94+
batch: The batch to be added.
95+
"""
96+
# Make a copy so the replay buffer doesn't pin plasma memory.
97+
batch = batch.copy()
98+
batch = batch.as_multi_agent()
99+
100+
with self.add_batch_timer:
101+
for policy_id, sample_batch in batch.policy_batches.items():
102+
self.replay_buffers[policy_id].add_batch(sample_batch)
103+
self.last_added_batches[policy_id].append(sample_batch)
104+
self.num_added += batch.count
105+
106+
def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
107+
Optional[SampleBatchType]:
108+
buffer = self.replay_buffers[policy_id]
109+
# Return None, if:
110+
# - Buffer empty or
111+
# - `replay_ratio` < 1.0 (new samples required in returned batch)
112+
# and no new samples to mix with replayed ones.
113+
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
114+
and self.replay_ratio < 1.0):
115+
return None
116+
117+
# Mix buffer's last added batches with older replayed batches.
118+
with self.replay_timer:
119+
output_batches = self.last_added_batches[policy_id]
120+
self.last_added_batches[policy_id] = []
121+
122+
# No replay desired -> Return here.
123+
if self.replay_ratio == 0.0:
124+
return SampleBatch.concat_samples(output_batches)
125+
# Only replay desired -> Return a (replayed) sample from the
126+
# buffer.
127+
elif self.replay_ratio == 1.0:
128+
return buffer.replay()
129+
130+
# Replay ratio = old / [old + new]
131+
# Replay proportion: old / new
132+
num_new = len(output_batches)
133+
replay_proportion = self.replay_proportion
134+
while random.random() < num_new * replay_proportion:
135+
replay_proportion -= 1
136+
output_batches.append(buffer.replay())
137+
return SampleBatch.concat_samples(output_batches)
138+
139+
def get_host(self) -> str:
140+
"""Returns the computer's network name.
141+
142+
Returns:
143+
The computer's networks name or an empty string, if the network
144+
name could not be determined.
145+
"""
146+
return platform.node()

rllib/execution/buffers/multi_agent_replay_buffer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import collections
22
import platform
3-
from typing import Any, Dict
3+
from typing import Any, Dict, Optional
44

55
import numpy as np
66
import ray
@@ -13,7 +13,7 @@
1313
from ray.rllib.utils import deprecation_warning
1414
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
1515
from ray.rllib.utils.timer import TimerStat
16-
from ray.rllib.utils.typing import SampleBatchType
16+
from ray.rllib.utils.typing import PolicyID, SampleBatchType
1717
from ray.util.iter import ParallelIteratorWorker
1818

1919

@@ -195,7 +195,7 @@ def add_batch(self, batch: SampleBatchType) -> None:
195195
time_slice, weight=weight)
196196
self.num_added += batch.count
197197

198-
def replay(self) -> SampleBatchType:
198+
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
199199
"""If this buffer was given a fake batch, return it, otherwise return
200200
a MultiAgentBatch with samples.
201201
"""
@@ -211,8 +211,13 @@ def replay(self) -> SampleBatchType:
211211
# Lockstep mode: Sample from all policies at the same time an
212212
# equal amount of steps.
213213
if self.replay_mode == "lockstep":
214+
assert policy_id is None, \
215+
"`policy_id` specifier not allowed in `locksetp` mode!"
214216
return self.replay_buffers[_ALL_POLICIES].sample(
215217
self.replay_batch_size, beta=self.prioritized_replay_beta)
218+
elif policy_id is not None:
219+
return self.replay_buffers[policy_id].sample(
220+
self.replay_batch_size, beta=self.prioritized_replay_beta)
216221
else:
217222
samples = {}
218223
for policy_id, replay_buffer in self.replay_buffers.items():

rllib/execution/buffers/replay_buffer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,25 @@ def add(self, item: SampleBatchType, weight: float) -> None:
132132

133133
@DeveloperAPI
134134
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
135-
"""Sample a batch of experiences.
135+
"""Sample a batch of size `num_items` from this buffer.
136+
137+
If less than `num_items` records are in this buffer, some samples in
138+
the results may be repeated to fulfil the batch size (`num_items`)
139+
request.
136140
137141
Args:
138142
num_items: Number of items to sample from this buffer.
139-
beta: This is ignored (only used by prioritized replay buffers).
143+
beta: The prioritized replay beta value. Only relevant if this
144+
ReplayBuffer is a PrioritizedReplayBuffer.
140145
141146
Returns:
142147
Concatenated batch of items.
143148
"""
144-
idxes = [
145-
random.randint(0,
146-
len(self._storage) - 1) for _ in range(num_items)
147-
]
149+
# If we don't have any samples yet in this buffer, return None.
150+
if len(self) == 0:
151+
return None
152+
153+
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
148154
sample = self._encode_sample(idxes)
149155
# Update our timesteps counters.
150156
self._num_timesteps_sampled += len(sample)
@@ -282,6 +288,10 @@ def sample(self, num_items: int, beta: float) -> SampleBatchType:
282288
"batch_indexes" fields denoting IS of each sampled
283289
transition and original idxes in buffer of sampled experiences.
284290
"""
291+
# If we don't have any samples yet in this buffer, return None.
292+
if len(self) == 0:
293+
return None
294+
285295
assert beta >= 0.0
286296

287297
idxes = self._sample_proportional(num_items)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import logging
2+
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set
3+
4+
import ray
5+
from ray.actor import ActorHandle
6+
from ray.rllib.utils.annotations import ExperimentalAPI
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@ExperimentalAPI
12+
def asynchronous_parallel_requests(
13+
remote_requests_in_flight: DefaultDict[ActorHandle, Set[
14+
ray.ObjectRef]],
15+
actors: List[ActorHandle],
16+
ray_wait_timeout_s: Optional[float] = None,
17+
max_remote_requests_in_flight_per_actor: int = 2,
18+
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
19+
remote_args: Optional[List[List[Any]]] = None,
20+
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
21+
) -> Dict[ActorHandle, Any]:
22+
"""Runs parallel and asynchronous rollouts on all remote workers.
23+
24+
May use a timeout (if provided) on `ray.wait()` and returns only those
25+
samples that could be gathered in the timeout window. Allows a maximum
26+
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
27+
per remote actor.
28+
29+
Alternatively to calling `actor.sample.remote()`, the user can provide a
30+
`remote_fn()`, which will be applied to the actor(s) instead.
31+
32+
Args:
33+
remote_requests_in_flight: Dict mapping actor handles to a set of
34+
their currently-in-flight pending requests (those we expect to
35+
ray.get results for next). If you have an RLlib Trainer that calls
36+
this function, you can use its `self.remote_requests_in_flight`
37+
property here.
38+
actors: The List of ActorHandles to perform the remote requests on.
39+
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
40+
`ray.wait()` calls. If None (default), never time out (block
41+
until at least one actor returns something).
42+
max_remote_requests_in_flight_per_actor: Maximum number of remote
43+
requests sent to each actor. 2 (default) is probably
44+
sufficient to avoid idle times between two requests.
45+
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
46+
`actor.sample.remote()` to generate the requests.
47+
remote_args: If provided, use this list (per-actor) of lists (call
48+
args) as *args to be passed to the `remote_fn`.
49+
E.g.: actors=[A, B],
50+
remote_args=[[...] <- *args for A, [...] <- *args for B].
51+
remote_kwargs: If provided, use this list (per-actor) of dicts
52+
(kwargs) as **kwargs to be passed to the `remote_fn`.
53+
E.g.: actors=[A, B],
54+
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
55+
56+
Returns:
57+
A dict mapping actor handles to the results received by sending requests
58+
to these actors.
59+
None, if no samples are ready.
60+
61+
Examples:
62+
>>> # 2 remote rollout workers (num_workers=2):
63+
>>> batches = asynchronous_parallel_sample(
64+
... trainer.remote_requests_in_flight,
65+
... actors=trainer.workers.remote_workers(),
66+
... ray_wait_timeout_s=0.1,
67+
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
68+
... )
69+
>>> print(len(batches))
70+
... 2
71+
>>> # Expect a timeout to have happened.
72+
>>> batches[0] is None and batches[1] is None
73+
... True
74+
"""
75+
76+
if remote_args is not None:
77+
assert len(remote_args) == len(actors)
78+
if remote_kwargs is not None:
79+
assert len(remote_kwargs) == len(actors)
80+
81+
# For faster hash lookup.
82+
actor_set = set(actors)
83+
84+
# Collect all currently pending remote requests into a single set of
85+
# object refs.
86+
pending_remotes = set()
87+
# Also build a map to get the associated actor for each remote request.
88+
remote_to_actor = {}
89+
for actor, set_ in remote_requests_in_flight.items():
90+
# Only consider those actors' pending requests that are in
91+
# the given `actors` list.
92+
if actor in actor_set:
93+
pending_remotes |= set_
94+
for r in set_:
95+
remote_to_actor[r] = actor
96+
97+
# Add new requests, if possible (if
98+
# `max_remote_requests_in_flight_per_actor` setting allows it).
99+
for actor_idx, actor in enumerate(actors):
100+
# Still room for another request to this actor.
101+
if len(remote_requests_in_flight[actor]) < \
102+
max_remote_requests_in_flight_per_actor:
103+
if remote_fn is None:
104+
req = actor.sample.remote()
105+
else:
106+
args = remote_args[actor_idx] if remote_args else []
107+
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
108+
req = actor.apply.remote(remote_fn, *args, **kwargs)
109+
# Add to our set to send to ray.wait().
110+
pending_remotes.add(req)
111+
# Keep our mappings properly updated.
112+
remote_requests_in_flight[actor].add(req)
113+
remote_to_actor[req] = actor
114+
115+
# There must always be pending remote requests.
116+
assert len(pending_remotes) > 0
117+
pending_remote_list = list(pending_remotes)
118+
119+
# No timeout: Block until at least one result is returned.
120+
if ray_wait_timeout_s is None:
121+
# First try to do a `ray.wait` w/o timeout for efficiency.
122+
ready, _ = ray.wait(
123+
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
124+
# Nothing returned and `timeout` is None -> Fall back to a
125+
# blocking wait to make sure we can return something.
126+
if not ready:
127+
ready, _ = ray.wait(pending_remote_list, num_returns=1)
128+
# Timeout: Do a `ray.wait() call` w/ timeout.
129+
else:
130+
ready, _ = ray.wait(
131+
pending_remote_list,
132+
num_returns=len(pending_remotes),
133+
timeout=ray_wait_timeout_s)
134+
135+
# Return empty results if nothing ready after the timeout.
136+
if not ready:
137+
return {}
138+
139+
# Remove in-flight records for ready refs.
140+
for obj_ref in ready:
141+
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)
142+
143+
# Do one ray.get().
144+
results = ray.get(ready)
145+
assert len(ready) == len(results)
146+
147+
# Return mapping from (ready) actors to their results.
148+
ret = {}
149+
for obj_ref, result in zip(ready, results):
150+
ret[remote_to_actor[obj_ref]] = result
151+
152+
return ret

0 commit comments

Comments
 (0)