1717import torch .nn as nn
1818from tensordict import TensorDict
1919from tensordict .nn import TensorDictModule
20- from torchrl .collectors import SyncDataCollector , MultiSyncDataCollector
20+ from torchrl .collectors import MultiSyncDataCollector , SyncDataCollector
2121from torchrl .envs import GymEnv
2222from torchrl .weight_update import (
2323 MultiProcessWeightSyncScheme ,
2727
2828def example_single_collector_multiprocess ():
2929 """Example 1: Single collector with multiprocess scheme."""
30- print ("\n " + "=" * 70 )
30+ print ("\n " + "=" * 70 )
3131 print ("Example 1: Single Collector with Multiprocess Scheme" )
32- print ("=" * 70 )
33-
32+ print ("=" * 70 )
33+
3434 # Create environment and policy
3535 env = GymEnv ("CartPole-v1" )
3636 policy = TensorDictModule (
3737 nn .Linear (
38- env .observation_spec ["observation" ].shape [- 1 ],
39- env .action_spec .shape [- 1 ]
38+ env .observation_spec ["observation" ].shape [- 1 ], env .action_spec .shape [- 1 ]
4039 ),
4140 in_keys = ["observation" ],
4241 out_keys = ["action" ],
4342 )
4443 env .close ()
45-
44+
4645 # Create weight sync scheme
4746 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
48-
47+
4948 print ("Creating collector with multiprocess weight sync..." )
5049 collector = SyncDataCollector (
5150 create_env_fn = lambda : GymEnv ("CartPole-v1" ),
@@ -54,46 +53,45 @@ def example_single_collector_multiprocess():
5453 total_frames = 200 ,
5554 weight_sync_schemes = {"policy" : scheme },
5655 )
57-
56+
5857 # Collect data and update weights periodically
5958 print ("Collecting data..." )
6059 for i , data in enumerate (collector ):
6160 print (f"Iteration { i } : Collected { data .numel ()} transitions" )
62-
61+
6362 # Update policy weights every 2 iterations
6463 if i % 2 == 0 :
6564 new_weights = policy .state_dict ()
6665 collector .update_policy_weights_ (new_weights )
6766 print (" → Updated policy weights" )
68-
67+
6968 if i >= 2 : # Just run a few iterations for demo
7069 break
71-
70+
7271 collector .shutdown ()
7372 print ("✓ Single collector example completed!\n " )
7473
7574
7675def example_multi_collector_shared_memory ():
7776 """Example 2: Multiple collectors with shared memory."""
78- print ("\n " + "=" * 70 )
77+ print ("\n " + "=" * 70 )
7978 print ("Example 2: Multiple Collectors with Shared Memory" )
80- print ("=" * 70 )
81-
79+ print ("=" * 70 )
80+
8281 # Create environment and policy
8382 env = GymEnv ("CartPole-v1" )
8483 policy = TensorDictModule (
8584 nn .Linear (
86- env .observation_spec ["observation" ].shape [- 1 ],
87- env .action_spec .shape [- 1 ]
85+ env .observation_spec ["observation" ].shape [- 1 ], env .action_spec .shape [- 1 ]
8886 ),
8987 in_keys = ["observation" ],
9088 out_keys = ["action" ],
9189 )
9290 env .close ()
93-
91+
9492 # Shared memory is more efficient for frequent updates
9593 scheme = SharedMemWeightSyncScheme (strategy = "tensordict" , auto_register = True )
96-
94+
9795 print ("Creating multi-collector with shared memory..." )
9896 collector = MultiSyncDataCollector (
9997 create_env_fn = [
@@ -106,49 +104,51 @@ def example_multi_collector_shared_memory():
106104 total_frames = 400 ,
107105 weight_sync_schemes = {"policy" : scheme },
108106 )
109-
107+
110108 # Workers automatically see weight updates via shared memory
111109 print ("Collecting data..." )
112110 for i , data in enumerate (collector ):
113111 print (f"Iteration { i } : Collected { data .numel ()} transitions" )
114-
112+
115113 # Update weights frequently (shared memory makes this very fast)
116114 collector .update_policy_weights_ (TensorDict .from_module (policy ))
117115 print (" → Updated policy weights via shared memory" )
118-
116+
119117 if i >= 1 : # Just run a couple iterations for demo
120118 break
121-
119+
122120 collector .shutdown ()
123121 print ("✓ Multi-collector with shared memory example completed!\n " )
124122
125123
126124def main ():
127125 """Run all examples."""
128- print ("\n " + "=" * 70 )
126+ print ("\n " + "=" * 70 )
129127 print ("Weight Synchronization Schemes - Collector Integration Examples" )
130- print ("=" * 70 )
131-
128+ print ("=" * 70 )
129+
132130 # Set multiprocessing start method
133131 import torch .multiprocessing as mp
132+
134133 try :
135- mp .set_start_method (' spawn' )
134+ mp .set_start_method (" spawn" )
136135 except RuntimeError :
137136 pass # Already set
138-
137+
139138 # Run examples
140139 example_single_collector_multiprocess ()
141140 example_multi_collector_shared_memory ()
142-
143- print ("\n " + "=" * 70 )
141+
142+ print ("\n " + "=" * 70 )
144143 print ("All examples completed successfully!" )
145- print ("=" * 70 )
144+ print ("=" * 70 )
146145 print ("\n Key takeaways:" )
147146 print (" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios" )
148- print (" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers" )
149- print ("=" * 70 + "\n " )
147+ print (
148+ " • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers"
149+ )
150+ print ("=" * 70 + "\n " )
150151
151152
152153if __name__ == "__main__" :
153154 main ()
154-
0 commit comments