Skip to content

Commit f262594

Browse files
committed
[Quality] Fix flaky test
ghstack-source-id: 839143f Pull-Request: #3211
1 parent 4424ad7 commit f262594

File tree

3 files changed

+94
-85
lines changed

3 files changed

+94
-85
lines changed

examples/collectors/weight_sync_collectors.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from tensordict import TensorDict
1919
from tensordict.nn import TensorDictModule
20-
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
20+
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
2121
from torchrl.envs import GymEnv
2222
from torchrl.weight_update import (
2323
MultiProcessWeightSyncScheme,
@@ -27,25 +27,24 @@
2727

2828
def example_single_collector_multiprocess():
2929
"""Example 1: Single collector with multiprocess scheme."""
30-
print("\n" + "="*70)
30+
print("\n" + "=" * 70)
3131
print("Example 1: Single Collector with Multiprocess Scheme")
32-
print("="*70)
33-
32+
print("=" * 70)
33+
3434
# Create environment and policy
3535
env = GymEnv("CartPole-v1")
3636
policy = TensorDictModule(
3737
nn.Linear(
38-
env.observation_spec["observation"].shape[-1],
39-
env.action_spec.shape[-1]
38+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
4039
),
4140
in_keys=["observation"],
4241
out_keys=["action"],
4342
)
4443
env.close()
45-
44+
4645
# Create weight sync scheme
4746
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
48-
47+
4948
print("Creating collector with multiprocess weight sync...")
5049
collector = SyncDataCollector(
5150
create_env_fn=lambda: GymEnv("CartPole-v1"),
@@ -54,46 +53,45 @@ def example_single_collector_multiprocess():
5453
total_frames=200,
5554
weight_sync_schemes={"policy": scheme},
5655
)
57-
56+
5857
# Collect data and update weights periodically
5958
print("Collecting data...")
6059
for i, data in enumerate(collector):
6160
print(f"Iteration {i}: Collected {data.numel()} transitions")
62-
61+
6362
# Update policy weights every 2 iterations
6463
if i % 2 == 0:
6564
new_weights = policy.state_dict()
6665
collector.update_policy_weights_(new_weights)
6766
print(" → Updated policy weights")
68-
67+
6968
if i >= 2: # Just run a few iterations for demo
7069
break
71-
70+
7271
collector.shutdown()
7372
print("✓ Single collector example completed!\n")
7473

7574

7675
def example_multi_collector_shared_memory():
7776
"""Example 2: Multiple collectors with shared memory."""
78-
print("\n" + "="*70)
77+
print("\n" + "=" * 70)
7978
print("Example 2: Multiple Collectors with Shared Memory")
80-
print("="*70)
81-
79+
print("=" * 70)
80+
8281
# Create environment and policy
8382
env = GymEnv("CartPole-v1")
8483
policy = TensorDictModule(
8584
nn.Linear(
86-
env.observation_spec["observation"].shape[-1],
87-
env.action_spec.shape[-1]
85+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
8886
),
8987
in_keys=["observation"],
9088
out_keys=["action"],
9189
)
9290
env.close()
93-
91+
9492
# Shared memory is more efficient for frequent updates
9593
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
96-
94+
9795
print("Creating multi-collector with shared memory...")
9896
collector = MultiSyncDataCollector(
9997
create_env_fn=[
@@ -106,49 +104,51 @@ def example_multi_collector_shared_memory():
106104
total_frames=400,
107105
weight_sync_schemes={"policy": scheme},
108106
)
109-
107+
110108
# Workers automatically see weight updates via shared memory
111109
print("Collecting data...")
112110
for i, data in enumerate(collector):
113111
print(f"Iteration {i}: Collected {data.numel()} transitions")
114-
112+
115113
# Update weights frequently (shared memory makes this very fast)
116114
collector.update_policy_weights_(TensorDict.from_module(policy))
117115
print(" → Updated policy weights via shared memory")
118-
116+
119117
if i >= 1: # Just run a couple iterations for demo
120118
break
121-
119+
122120
collector.shutdown()
123121
print("✓ Multi-collector with shared memory example completed!\n")
124122

125123

126124
def main():
127125
"""Run all examples."""
128-
print("\n" + "="*70)
126+
print("\n" + "=" * 70)
129127
print("Weight Synchronization Schemes - Collector Integration Examples")
130-
print("="*70)
131-
128+
print("=" * 70)
129+
132130
# Set multiprocessing start method
133131
import torch.multiprocessing as mp
132+
134133
try:
135-
mp.set_start_method('spawn')
134+
mp.set_start_method("spawn")
136135
except RuntimeError:
137136
pass # Already set
138-
137+
139138
# Run examples
140139
example_single_collector_multiprocess()
141140
example_multi_collector_shared_memory()
142-
143-
print("\n" + "="*70)
141+
142+
print("\n" + "=" * 70)
144143
print("All examples completed successfully!")
145-
print("="*70)
144+
print("=" * 70)
146145
print("\nKey takeaways:")
147146
print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios")
148-
print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers")
149-
print("="*70 + "\n")
147+
print(
148+
" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers"
149+
)
150+
print("=" * 70 + "\n")
150151

151152

152153
if __name__ == "__main__":
153154
main()
154-

0 commit comments

Comments
 (0)