Skip to content

Commit 107ab58

Browse files
committed
[Feature] Documentation
ghstack-source-id: bb9f0e2 Pull-Request: #3192
1 parent cb706e8 commit 107ab58

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

docs/source/reference/collectors.rst

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
417417
RPCWeightUpdater
418418
DistributedWeightUpdater
419419

420+
Weight Synchronization API
421+
~~~~~~~~~~~~~~~~~~~~~~~~~~
422+
423+
The weight synchronization API provides a simple, modular approach to updating model weights across
424+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
425+
models may need to be synchronized independently.
426+
427+
Overview
428+
^^^^^^^^
429+
430+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
431+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
432+
separation of concerns, where four classes are involved:
433+
434+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
435+
your main entrypoint to configure the weight synchronization.
436+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
437+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
438+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
439+
440+
The following diagram shows the different classes involved in the weight synchronization process:
441+
442+
.. aafig::
443+
:aspect: 60
444+
:scale: 130
445+
:proportional:
446+
447+
INITIALIZATION PHASE
448+
====================
449+
450+
WeightSyncScheme
451+
+------------------+
452+
| |
453+
| Configuration: |
454+
| - strategy |
455+
| - transport_type |
456+
| |
457+
+--------+---------+
458+
|
459+
+------------+-------------+
460+
| |
461+
creates creates
462+
| |
463+
v v
464+
Main Process Worker Process
465+
+--------------+ +---------------+
466+
| WeightSender | | WeightReceiver|
467+
| | | |
468+
| - strategy | | - strategy |
469+
| - transports | | - transport |
470+
| - model_ref | | - model_ref |
471+
| | | |
472+
| Registers: | | Registers: |
473+
| - model | | - model |
474+
| - workers | | - transport |
475+
+--------------+ +---------------+
476+
| |
477+
| Transport Layer |
478+
| +----------------+ |
479+
+-->+ MPTransport |<------+
480+
| | (pipes) | |
481+
| +----------------+ |
482+
| +----------------+ |
483+
+-->+ SharedMemTrans |<------+
484+
| | (shared mem) | |
485+
| +----------------+ |
486+
| +----------------+ |
487+
+-->+ RayTransport |<------+
488+
| (Ray store) |
489+
+----------------+
490+
491+
492+
SYNCHRONIZATION PHASE
493+
=====================
494+
495+
Main Process Worker Process
496+
497+
+-------------------+ +-------------------+
498+
| WeightSender | | WeightReceiver |
499+
| | | |
500+
| 1. Extract | | 4. Poll transport |
501+
| weights from | | for weights |
502+
| model using | | |
503+
| strategy | | |
504+
| | 2. Send via | |
505+
| +-------------+ | Transport | +--------------+ |
506+
| | Strategy | | +------------+ | | Strategy | |
507+
| | extract() | | | | | | apply() | |
508+
| +-------------+ +----+ Transport +-------->+ +--------------+ |
509+
| | | | | | | |
510+
| v | +------------+ | v |
511+
| +-------------+ | | +--------------+ |
512+
| | Model | | | | Model | |
513+
| | (source) | | 3. Ack (optional) | | (dest) | |
514+
| +-------------+ | <-----------------------+ | +--------------+ |
515+
| | | |
516+
+-------------------+ | 5. Apply weights |
517+
| to model using |
518+
| strategy |
519+
+-------------------+
520+
521+
Key Challenges Addressed
522+
^^^^^^^^^^^^^^^^^^^^^^^^^
523+
524+
Modern RL training often involves multiple models that need independent synchronization:
525+
526+
1. **Multiple Models Per Collector**: A collector might need to update:
527+
528+
- The main policy network
529+
- A value network in a Ray actor within the replay buffer
530+
- Models embedded in the environment itself
531+
- Separate world models or auxiliary networks
532+
533+
2. **Different Update Strategies**: Each model may require different synchronization approaches:
534+
535+
- Full state_dict transfer vs. TensorDict-based updates
536+
- Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
537+
- Varied update frequencies
538+
539+
3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to
540+
specific worker indices, requiring a more flexible update mechanism.
541+
542+
Architecture
543+
^^^^^^^^^^^^
544+
545+
The API follows a scheme-based design where users specify synchronization requirements upfront,
546+
and the collector handles the orchestration transparently:
547+
548+
.. aafig::
549+
:aspect: 60
550+
:scale: 130
551+
:proportional:
552+
553+
Main Process Worker Process 1 Worker Process 2
554+
555+
+-----------------+ +---------------+ +---------------+
556+
| Collector | | Collector | | Collector |
557+
| | | | | |
558+
| Models: | | Models: | | Models: |
559+
| +----------+ | | +--------+ | | +--------+ |
560+
| | Policy A | | | |Policy A| | | |Policy A| |
561+
| +----------+ | | +--------+ | | +--------+ |
562+
| +----------+ | | +--------+ | | +--------+ |
563+
| | Model B | | | |Model B| | | |Model B| |
564+
| +----------+ | | +--------+ | | +--------+ |
565+
| | | | | |
566+
| Weight Senders: | | Weight | | Weight |
567+
| +----------+ | | Receivers: | | Receivers: |
568+
| | Sender A +---+------------+->Receiver A | | Receiver A |
569+
| +----------+ | | | | |
570+
| +----------+ | | +--------+ | | +--------+ |
571+
| | Sender B +---+------------+->Receiver B | | Receiver B |
572+
| +----------+ | Pipes | | Pipes | |
573+
+-----------------+ +-------+-------+ +-------+-------+
574+
^ ^ ^
575+
| | |
576+
| update_policy_weights_() | Apply weights |
577+
| | |
578+
+------+-------+ | |
579+
| User Code | | |
580+
| (Training) | | |
581+
+--------------+ +------------------------+
582+
583+
The weight synchronization flow:
584+
585+
1. **Initialization**: User creates ``weight_sync_schemes`` dict mapping model IDs to schemes
586+
2. **Registration**: Collector creates ``WeightSender`` for each model in the main process
587+
3. **Worker Setup**: Each worker creates corresponding ``WeightReceiver`` instances
588+
4. **Synchronization**: Calling ``update_policy_weights_()`` triggers all senders to push weights
589+
5. **Application**: Receivers automatically apply weights to their registered models
590+
591+
Available Classes
592+
^^^^^^^^^^^^^^^^^
593+
594+
**Synchronization Schemes** (User-Facing Configuration):
595+
596+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme`: Base class for schemes
597+
- :class:`~torchrl.weight_update.weight_sync_schemes.MultiProcessWeightSyncScheme`: For multiprocessing with pipes
598+
- :class:`~torchrl.weight_update.weight_sync_schemes.SharedMemWeightSyncScheme`: For shared memory synchronization
599+
- :class:`~torchrl.weight_update.weight_sync_schemes.RayWeightSyncScheme`: For Ray-based distribution
600+
- :class:`~torchrl.weight_update.weight_sync_schemes.NoWeightSyncScheme`: Dummy scheme for no synchronization
601+
602+
**Internal Classes** (Automatically Managed):
603+
604+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender`: Sends weights to all workers for one model
605+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver`: Receives and applies weights in worker
606+
- :class:`~torchrl.weight_update.weight_sync_schemes.TransportBackend`: Communication layer abstraction
607+
608+
Usage Example
609+
^^^^^^^^^^^^^
610+
611+
.. code-block:: python
612+
613+
from torchrl.collectors import MultiSyncDataCollector
614+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
615+
616+
# Define synchronization for multiple models
617+
weight_sync_schemes = {
618+
"policy": MultiProcessWeightSyncScheme(strategy="tensordict"),
619+
"value_net": MultiProcessWeightSyncScheme(strategy="state_dict"),
620+
}
621+
622+
collector = MultiSyncDataCollector(
623+
create_env_fn=[make_env] * 4,
624+
policy=policy,
625+
frames_per_batch=1000,
626+
weight_sync_schemes=weight_sync_schemes, # Pass schemes dict
627+
)
628+
629+
# Single call updates all registered models across all workers
630+
for i, batch in enumerate(collector):
631+
# Training step
632+
loss = train(batch)
633+
634+
# Sync all models with one call
635+
collector.update_policy_weights_(policy)
636+
637+
The collector automatically:
638+
639+
- Creates ``WeightSender`` instances in the main process for each model
640+
- Creates ``WeightReceiver`` instances in each worker process
641+
- Resolves models by ID (e.g., ``"policy"`` → ``collector.policy``)
642+
- Handles transport setup and communication
643+
- Applies weights using the appropriate strategy (state_dict vs tensordict)
644+
645+
API Reference
646+
^^^^^^^^^^^^^
647+
648+
.. currentmodule:: torchrl.weight_update.weight_sync_schemes
649+
650+
.. autosummary::
651+
:toctree: generated/
652+
:template: rl_template.rst
653+
654+
WeightSyncScheme
655+
MultiProcessWeightSyncScheme
656+
SharedMemWeightSyncScheme
657+
RayWeightSyncScheme
658+
NoWeightSyncScheme
659+
WeightSender
660+
WeightReceiver
661+
420662
Collectors and replay buffers interoperability
421663
----------------------------------------------
422664

0 commit comments

Comments
 (0)