|
| 1 | +import logging |
| 2 | +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set |
| 3 | + |
| 4 | +import ray |
| 5 | +from ray.actor import ActorHandle |
| 6 | +from ray.rllib.utils.annotations import ExperimentalAPI |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +@ExperimentalAPI |
| 12 | +def asynchronous_parallel_requests( |
| 13 | + remote_requests_in_flight: DefaultDict[ActorHandle, Set[ |
| 14 | + ray.ObjectRef]], |
| 15 | + actors: List[ActorHandle], |
| 16 | + ray_wait_timeout_s: Optional[float] = None, |
| 17 | + max_remote_requests_in_flight_per_actor: int = 2, |
| 18 | + remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None, |
| 19 | + remote_args: Optional[List[List[Any]]] = None, |
| 20 | + remote_kwargs: Optional[List[Dict[str, Any]]] = None, |
| 21 | +) -> Dict[ActorHandle, Any]: |
| 22 | + """Runs parallel and asynchronous rollouts on all remote workers. |
| 23 | +
|
| 24 | + May use a timeout (if provided) on `ray.wait()` and returns only those |
| 25 | + samples that could be gathered in the timeout window. Allows a maximum |
| 26 | + of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight |
| 27 | + per remote actor. |
| 28 | +
|
| 29 | + Alternatively to calling `actor.sample.remote()`, the user can provide a |
| 30 | + `remote_fn()`, which will be applied to the actor(s) instead. |
| 31 | +
|
| 32 | + Args: |
| 33 | + remote_requests_in_flight: Dict mapping actor handles to a set of |
| 34 | + their currently-in-flight pending requests (those we expect to |
| 35 | + ray.get results for next). If you have an RLlib Trainer that calls |
| 36 | + this function, you can use its `self.remote_requests_in_flight` |
| 37 | + property here. |
| 38 | + actors: The List of ActorHandles to perform the remote requests on. |
| 39 | + ray_wait_timeout_s: Timeout (in sec) to be used for the underlying |
| 40 | + `ray.wait()` calls. If None (default), never time out (block |
| 41 | + until at least one actor returns something). |
| 42 | + max_remote_requests_in_flight_per_actor: Maximum number of remote |
| 43 | + requests sent to each actor. 2 (default) is probably |
| 44 | + sufficient to avoid idle times between two requests. |
| 45 | + remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of |
| 46 | + `actor.sample.remote()` to generate the requests. |
| 47 | + remote_args: If provided, use this list (per-actor) of lists (call |
| 48 | + args) as *args to be passed to the `remote_fn`. |
| 49 | + E.g.: actors=[A, B], |
| 50 | + remote_args=[[...] <- *args for A, [...] <- *args for B]. |
| 51 | + remote_kwargs: If provided, use this list (per-actor) of dicts |
| 52 | + (kwargs) as **kwargs to be passed to the `remote_fn`. |
| 53 | + E.g.: actors=[A, B], |
| 54 | + remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + A dict mapping actor handles to the results received by sending requests |
| 58 | + to these actors. |
| 59 | + None, if no samples are ready. |
| 60 | +
|
| 61 | + Examples: |
| 62 | + >>> # 2 remote rollout workers (num_workers=2): |
| 63 | + >>> batches = asynchronous_parallel_sample( |
| 64 | + ... trainer.remote_requests_in_flight, |
| 65 | + ... actors=trainer.workers.remote_workers(), |
| 66 | + ... ray_wait_timeout_s=0.1, |
| 67 | + ... remote_fn=lambda w: time.sleep(1) # sleep 1sec |
| 68 | + ... ) |
| 69 | + >>> print(len(batches)) |
| 70 | + ... 2 |
| 71 | + >>> # Expect a timeout to have happened. |
| 72 | + >>> batches[0] is None and batches[1] is None |
| 73 | + ... True |
| 74 | + """ |
| 75 | + |
| 76 | + if remote_args is not None: |
| 77 | + assert len(remote_args) == len(actors) |
| 78 | + if remote_kwargs is not None: |
| 79 | + assert len(remote_kwargs) == len(actors) |
| 80 | + |
| 81 | + # For faster hash lookup. |
| 82 | + actor_set = set(actors) |
| 83 | + |
| 84 | + # Collect all currently pending remote requests into a single set of |
| 85 | + # object refs. |
| 86 | + pending_remotes = set() |
| 87 | + # Also build a map to get the associated actor for each remote request. |
| 88 | + remote_to_actor = {} |
| 89 | + for actor, set_ in remote_requests_in_flight.items(): |
| 90 | + # Only consider those actors' pending requests that are in |
| 91 | + # the given `actors` list. |
| 92 | + if actor in actor_set: |
| 93 | + pending_remotes |= set_ |
| 94 | + for r in set_: |
| 95 | + remote_to_actor[r] = actor |
| 96 | + |
| 97 | + # Add new requests, if possible (if |
| 98 | + # `max_remote_requests_in_flight_per_actor` setting allows it). |
| 99 | + for actor_idx, actor in enumerate(actors): |
| 100 | + # Still room for another request to this actor. |
| 101 | + if len(remote_requests_in_flight[actor]) < \ |
| 102 | + max_remote_requests_in_flight_per_actor: |
| 103 | + if remote_fn is None: |
| 104 | + req = actor.sample.remote() |
| 105 | + else: |
| 106 | + args = remote_args[actor_idx] if remote_args else [] |
| 107 | + kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} |
| 108 | + req = actor.apply.remote(remote_fn, *args, **kwargs) |
| 109 | + # Add to our set to send to ray.wait(). |
| 110 | + pending_remotes.add(req) |
| 111 | + # Keep our mappings properly updated. |
| 112 | + remote_requests_in_flight[actor].add(req) |
| 113 | + remote_to_actor[req] = actor |
| 114 | + |
| 115 | + # There must always be pending remote requests. |
| 116 | + assert len(pending_remotes) > 0 |
| 117 | + pending_remote_list = list(pending_remotes) |
| 118 | + |
| 119 | + # No timeout: Block until at least one result is returned. |
| 120 | + if ray_wait_timeout_s is None: |
| 121 | + # First try to do a `ray.wait` w/o timeout for efficiency. |
| 122 | + ready, _ = ray.wait( |
| 123 | + pending_remote_list, num_returns=len(pending_remotes), timeout=0) |
| 124 | + # Nothing returned and `timeout` is None -> Fall back to a |
| 125 | + # blocking wait to make sure we can return something. |
| 126 | + if not ready: |
| 127 | + ready, _ = ray.wait(pending_remote_list, num_returns=1) |
| 128 | + # Timeout: Do a `ray.wait() call` w/ timeout. |
| 129 | + else: |
| 130 | + ready, _ = ray.wait( |
| 131 | + pending_remote_list, |
| 132 | + num_returns=len(pending_remotes), |
| 133 | + timeout=ray_wait_timeout_s) |
| 134 | + |
| 135 | + # Return empty results if nothing ready after the timeout. |
| 136 | + if not ready: |
| 137 | + return {} |
| 138 | + |
| 139 | + # Remove in-flight records for ready refs. |
| 140 | + for obj_ref in ready: |
| 141 | + remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref) |
| 142 | + |
| 143 | + # Do one ray.get(). |
| 144 | + results = ray.get(ready) |
| 145 | + assert len(ready) == len(results) |
| 146 | + |
| 147 | + # Return mapping from (ready) actors to their results. |
| 148 | + ret = {} |
| 149 | + for obj_ref, result in zip(ready, results): |
| 150 | + ret[remote_to_actor[obj_ref]] = result |
| 151 | + |
| 152 | + return ret |
0 commit comments