You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/reference/collectors.rst
+242Lines changed: 242 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -417,6 +417,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
417
417
RPCWeightUpdater
418
418
DistributedWeightUpdater
419
419
420
+
Weight Synchronization API
421
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
422
+
423
+
The weight synchronization API provides a simple, modular approach to updating model weights across
424
+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
425
+
models may need to be synchronized independently.
426
+
427
+
Overview
428
+
^^^^^^^^
429
+
430
+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
431
+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
432
+
separation of concerns, where four classes are involved:
433
+
434
+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
435
+
your main entrypoint to configure the weight synchronization.
436
+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
437
+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
438
+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
439
+
440
+
The following diagram shows the different classes involved in the weight synchronization process:
441
+
442
+
.. aafig::
443
+
:aspect: 60
444
+
:scale: 130
445
+
:proportional:
446
+
447
+
INITIALIZATION PHASE
448
+
====================
449
+
450
+
WeightSyncScheme
451
+
+------------------+
452
+
||
453
+
| Configuration: |
454
+
| - strategy |
455
+
| - transport_type |
456
+
||
457
+
+--------+---------+
458
+
|
459
+
+------------+-------------+
460
+
||
461
+
creates creates
462
+
||
463
+
v v
464
+
Main Process Worker Process
465
+
+--------------++---------------+
466
+
| WeightSender || WeightReceiver|
467
+
||||
468
+
| - strategy || - strategy |
469
+
| - transports || - transport |
470
+
| - model_ref || - model_ref |
471
+
||||
472
+
| Registers: || Registers: |
473
+
| - model || - model |
474
+
| - workers || - transport |
475
+
+--------------++---------------+
476
+
||
477
+
| Transport Layer |
478
+
|+----------------+|
479
+
+-->+ MPTransport |<------+
480
+
|| (pipes) ||
481
+
|+----------------+|
482
+
|+----------------+|
483
+
+-->+ SharedMemTrans |<------+
484
+
|| (shared mem) ||
485
+
|+----------------+|
486
+
|+----------------+|
487
+
+-->+ RayTransport |<------+
488
+
| (Ray store) |
489
+
+----------------+
490
+
491
+
492
+
SYNCHRONIZATION PHASE
493
+
=====================
494
+
495
+
Main Process Worker Process
496
+
497
+
+-------------------+ +-------------------+
498
+
|WeightSender | | WeightReceiver |
499
+
|| | |
500
+
|1. Extract | | 4. Poll transport |
501
+
|weights from | | for weights |
502
+
|model using | | |
503
+
|strategy | | |
504
+
|| 2. Send via | |
505
+
|+-------------+ | Transport | +--------------+ |
506
+
|| Strategy | | +------------+ | | Strategy | |
507
+
|| extract() | | | | | | apply() | |
508
+
|+-------------+ +----+ Transport +-------->+ +--------------+ |
0 commit comments