@@ -117,9 +117,257 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117
118118 Policy copy decision tree in Collectors.
119119
120- Weight Synchronization in Distributed Environments
120+ Weight Synchronization using Weight Update Schemes
121121--------------------------------------------------
122122
123+ RL pipelines are typically split in two big computational buckets: training, and inference.
124+ While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125+ synchronize its weights with the inference one.
126+ In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127+ used in both instances. From there, anything can happen:
128+
129+ - In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130+ `DataCollectors ` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
131+ for his instance of the policy.
132+ - In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs
133+ synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond
134+ policy-to-policy weight synchronization strategies.
135+ - In the LLM world, the inference engine and the training one are very different: they will use different libraries,
136+ kernels and calling APIs (e.g., `generate ` vs. `forward `). The weight format can also be drastically different (quantized
137+ vs non-quantized).
138+ This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends.
139+ - One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively
140+ asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141+ is to store the weights on some intermediary server and let the workers fetch them when necessary.
142+
143+ TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144+ transfer:
145+
146+ - A `Sender ` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147+ - A `Receiver ` class that casts the weights to the destination module (policy or other utility module);
148+ - A `Transport ` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149+ - A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
150+
151+ Each of these classes is detailed below.
152+
153+ Usage Examples
154+ ~~~~~~~~~~~~~~
155+
156+ .. note ::
157+ **Runnable versions ** of these examples are available in the repository:
158+
159+ - `examples/collectors/weight_sync_standalone.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_standalone.py >`_: Standalone weight synchronization
160+ - `examples/collectors/weight_sync_collectors.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_collectors.py >`_: Collector integration
161+
162+ Using Weight Update Schemes Independently
163+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164+
165+ Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
166+
167+ .. code-block :: python
168+
169+ import torch
170+ import torch.nn as nn
171+ from torch import multiprocessing as mp
172+ from tensordict import TensorDict
173+ from torchrl.weight_update import (
174+ MultiProcessWeightSyncScheme,
175+ SharedMemWeightSyncScheme,
176+ )
177+
178+ # Create a simple policy
179+ policy = nn.Linear(4 , 2 )
180+
181+ # Example 1: Multiprocess weight synchronization with state_dict
182+ # --------------------------------------------------------------
183+ # On the main process side (trainer):
184+ scheme = MultiProcessWeightSyncScheme(strategy = " state_dict" )
185+ sender = scheme.create_sender()
186+
187+ # Register worker pipes
188+ parent_pipe, child_pipe = mp.Pipe()
189+ sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe)
190+
191+ # Send weights to workers
192+ weights = policy.state_dict()
193+ sender.update_weights(weights)
194+
195+ # On the worker process side:
196+ # receiver = scheme.create_receiver()
197+ # receiver.register_model(policy)
198+ # receiver.register_worker_transport(child_pipe)
199+ # # Receive and apply weights
200+ # result = receiver._transport.receive_weights(timeout=5.0)
201+ # if result is not None:
202+ # model_id, weights = result
203+ # receiver.apply_weights(weights)
204+
205+ # Example 2: Shared memory weight synchronization
206+ # ------------------------------------------------
207+ # Create shared memory scheme with auto-registration
208+ shared_scheme = SharedMemWeightSyncScheme(strategy = " tensordict" , auto_register = True )
209+ shared_sender = shared_scheme.create_sender()
210+
211+ # Register worker pipes for lazy registration
212+ parent_pipe2, child_pipe2 = mp.Pipe()
213+ shared_sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe2)
214+
215+ # Send weights (automatically creates shared buffer on first send)
216+ weights_td = TensorDict.from_module(policy)
217+ shared_sender.update_weights(weights_td)
218+
219+ # Workers automatically see updates via shared memory!
220+
221+ Using Weight Update Schemes with Collectors
222+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223+
224+ Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization
225+ across multiple inference workers:
226+
227+ .. code-block :: python
228+
229+ import torch.nn as nn
230+ from tensordict.nn import TensorDictModule
231+ from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
232+ from torchrl.envs import GymEnv
233+ from torchrl.weight_update import (
234+ MultiProcessWeightSyncScheme,
235+ SharedMemWeightSyncScheme,
236+ )
237+
238+ # Create environment and policy
239+ env = GymEnv(" CartPole-v1" )
240+ policy = TensorDictModule(
241+ nn.Linear(env.observation_spec[" observation" ].shape[- 1 ],
242+ env.action_spec.shape[- 1 ]),
243+ in_keys = [" observation" ],
244+ out_keys = [" action" ],
245+ )
246+
247+ # Example 1: Single collector with multiprocess scheme
248+ # -----------------------------------------------------
249+ scheme = MultiProcessWeightSyncScheme(strategy = " state_dict" )
250+
251+ collector = SyncDataCollector(
252+ create_env_fn = lambda : GymEnv(" CartPole-v1" ),
253+ policy = policy,
254+ frames_per_batch = 64 ,
255+ total_frames = 1000 ,
256+ weight_sync_schemes = {" policy" : scheme},
257+ )
258+
259+ # Collect data and update weights periodically
260+ for i, data in enumerate (collector):
261+ # ... training step with data ...
262+
263+ # Update policy weights every N iterations
264+ if i % 10 == 0 :
265+ new_weights = policy.state_dict()
266+ collector.update_policy_weights_(new_weights)
267+
268+ collector.shutdown()
269+
270+ # Example 2: Multiple collectors with shared memory
271+ # --------------------------------------------------
272+ # Shared memory is more efficient for frequent updates
273+ shared_scheme = SharedMemWeightSyncScheme(strategy = " tensordict" , auto_register = True )
274+
275+ collector = MultiSyncDataCollector(
276+ create_env_fn = [
277+ lambda : GymEnv(" CartPole-v1" ),
278+ lambda : GymEnv(" CartPole-v1" ),
279+ lambda : GymEnv(" CartPole-v1" ),
280+ ],
281+ policy = policy,
282+ frames_per_batch = 192 ,
283+ total_frames = 10000 ,
284+ weight_sync_schemes = {" policy" : shared_scheme},
285+ )
286+
287+ # Workers automatically see weight updates via shared memory
288+ for data in collector:
289+ # ... training ...
290+ collector.update_policy_weights_(TensorDict.from_module(policy))
291+
292+ collector.shutdown()
293+
294+ .. note ::
295+ When using ``SharedMemWeightSyncScheme ``, weight updates are zero-copy and extremely fast since all
296+ processes share the same memory buffers. This is ideal for frequent weight updates but requires all
297+ processes to be on the same machine.
298+
299+ .. note ::
300+ The ``strategy `` parameter determines the weight format: ``"state_dict" `` uses PyTorch's native state
301+ dictionaries, while ``"tensordict" `` uses TensorDict format which can be more efficient for structured
302+ models and supports advanced features like lazy initialization.
303+
304+ Weight Senders
305+ ~~~~~~~~~~~~~~
306+
307+ .. currentmodule :: torchrl.weight_update
308+
309+ .. autosummary ::
310+ :toctree: generated/
311+ :template: rl_template.rst
312+
313+ WeightSender
314+ RayModuleTransformSender
315+
316+ Weight Receivers
317+ ~~~~~~~~~~~~~~~~
318+
319+ .. currentmodule :: torchrl.weight_update
320+
321+ .. autosummary ::
322+ :toctree: generated/
323+ :template: rl_template.rst
324+
325+ WeightReceiver
326+ RayModuleTransformReceiver
327+
328+ Transports
329+ ~~~~~~~~~~
330+
331+ .. currentmodule :: torchrl.weight_update
332+
333+ .. autosummary ::
334+ :toctree: generated/
335+ :template: rl_template.rst
336+
337+ TransportBackend
338+ MPTransport
339+ SharedMemTransport
340+ RayTransport
341+ RayActorTransport
342+ RPCTransport
343+ DistributedTransport
344+
345+ Schemes
346+ ~~~~~~~
347+
348+ .. currentmodule :: torchrl.weight_update
349+
350+ .. autosummary ::
351+ :toctree: generated/
352+ :template: rl_template.rst
353+
354+ WeightSyncScheme
355+ MultiProcessWeightSyncScheme
356+ SharedMemWeightSyncScheme
357+ NoWeightSyncScheme
358+ RayWeightSyncScheme
359+ RayModuleTransformScheme
360+ RPCWeightSyncScheme
361+ DistributedWeightSyncScheme
362+
363+ Legacy: Weight Synchronization in Distributed Environments
364+ ----------------------------------------------------------
365+
366+ .. warning ::
367+ The `WeightUpdater ` is considered legacy as per the 0.11 release and will be deprecated soon.
368+ The Weight update schemes, which provides more flexibility and a better compatibility with heavy
369+ weight transfers (e.g., LLMs) is to be preferred.
370+
123371In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
124372latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
125373mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
0 commit comments