Skip to content

Commit 54fd748

Browse files
committed
[BugFix] Fix schemes and refactor collectors to make them readable
ghstack-source-id: 8345a75 Pull-Request: #3226
1 parent 7652000 commit 54fd748

File tree

2 files changed

+451
-228
lines changed

2 files changed

+451
-228
lines changed

test/test_collector.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
RandomPolicy,
8080
)
8181
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule
82-
from torchrl.weight_update import SharedMemWeightSyncScheme
82+
from torchrl.weight_update import (
83+
MultiProcessWeightSyncScheme,
84+
SharedMemWeightSyncScheme,
85+
)
8386

8487
if os.getenv("PYTORCH_TEST_FBCODE"):
8588
IS_FB = True
@@ -1485,12 +1488,12 @@ def env_fn(seed):
14851488

14861489
@pytest.mark.parametrize("use_async", [False, True])
14871490
@pytest.mark.parametrize("cudagraph", [False, True])
1491+
@pytest.mark.parametrize(
1492+
"weight_sync_scheme",
1493+
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
1494+
)
14881495
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1489-
def test_update_weights(self, use_async, cudagraph):
1490-
from torchrl.weight_update.weight_sync_schemes import (
1491-
MultiProcessWeightSyncScheme,
1492-
)
1493-
1496+
def test_update_weights(self, use_async, cudagraph, weight_sync_scheme):
14941497
def create_env():
14951498
return ContinuousActionVecMockEnv()
14961499

@@ -1503,6 +1506,9 @@ def create_env():
15031506
collector_class = (
15041507
MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
15051508
)
1509+
kwargs = {}
1510+
if weight_sync_scheme is not None:
1511+
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
15061512
collector = collector_class(
15071513
[create_env] * 3,
15081514
policy=policy,
@@ -1511,7 +1517,7 @@ def create_env():
15111517
frames_per_batch=20,
15121518
cat_results="stack",
15131519
cudagraph_policy=cudagraph,
1514-
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
1520+
**kwargs,
15151521
)
15161522
assert "policy" in collector._weight_senders, collector._weight_senders.keys()
15171523
try:
@@ -2857,23 +2863,28 @@ def forward(self, td):
28572863
# ["cuda:0", "cuda"],
28582864
],
28592865
)
2860-
def test_param_sync(self, give_weights, collector, policy_device, env_device):
2861-
from torchrl.weight_update.weight_sync_schemes import (
2862-
MultiProcessWeightSyncScheme,
2863-
)
2864-
2866+
@pytest.mark.parametrize(
2867+
"weight_sync_scheme",
2868+
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
2869+
)
2870+
def test_param_sync(
2871+
self, give_weights, collector, policy_device, env_device, weight_sync_scheme
2872+
):
28652873
policy = TestUpdateParams.Policy().to(policy_device)
28662874

28672875
env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device))
28682876
device = env().device
28692877
env = [env]
2878+
kwargs = {}
2879+
if weight_sync_scheme is not None:
2880+
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
28702881
col = collector(
28712882
env,
28722883
policy,
28732884
device=device,
28742885
total_frames=200,
28752886
frames_per_batch=10,
2876-
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
2887+
**kwargs,
28772888
)
28782889
try:
28792890
for i, data in enumerate(col):
@@ -2918,13 +2929,13 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
29182929
# ["cuda:0", "cuda"],
29192930
],
29202931
)
2932+
@pytest.mark.parametrize(
2933+
"weight_sync_scheme",
2934+
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
2935+
)
29212936
def test_param_sync_mixed_device(
2922-
self, give_weights, collector, policy_device, env_device
2937+
self, give_weights, collector, policy_device, env_device, weight_sync_scheme
29232938
):
2924-
from torchrl.weight_update.weight_sync_schemes import (
2925-
MultiProcessWeightSyncScheme,
2926-
)
2927-
29282939
with torch.device("cpu"):
29292940
policy = TestUpdateParams.Policy()
29302941
policy.param = nn.Parameter(policy.param.data.to(policy_device))
@@ -2933,13 +2944,16 @@ def test_param_sync_mixed_device(
29332944
env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device))
29342945
device = env().device
29352946
env = [env]
2947+
kwargs = {}
2948+
if weight_sync_scheme is not None:
2949+
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
29362950
col = collector(
29372951
env,
29382952
policy,
29392953
device=device,
29402954
total_frames=200,
29412955
frames_per_batch=10,
2942-
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
2956+
**kwargs,
29432957
)
29442958
try:
29452959
for i, data in enumerate(col):
@@ -3851,7 +3865,7 @@ def test_weight_update(self, weight_updater):
38513865
if weight_updater == "scheme_shared":
38523866
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
38533867
elif weight_updater == "scheme_pipe":
3854-
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
3868+
kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}}
38553869
elif weight_updater == "weight_updater":
38563870
kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)}
38573871
else:
@@ -3870,14 +3884,16 @@ def test_weight_update(self, weight_updater):
38703884
**kwargs,
38713885
)
38723886

3873-
collector.update_policy_weights_()
3887+
# When using policy_factory, must pass weights explicitly
3888+
collector.update_policy_weights_(policy_weights)
38743889
try:
38753890
for i, data in enumerate(collector):
38763891
if i == 2:
38773892
assert (data["action"] != 0).any()
38783893
# zero the policy
38793894
policy_weights.data.zero_()
3880-
collector.update_policy_weights_()
3895+
# When using policy_factory, must pass weights explicitly
3896+
collector.update_policy_weights_(policy_weights)
38813897
elif i == 3:
38823898
assert (data["action"] == 0).all(), data["action"]
38833899
break
@@ -3973,11 +3989,11 @@ def test_start_multi(self, total_frames, cls):
39733989
@pytest.mark.parametrize(
39743990
"cls", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
39753991
)
3976-
def test_start_update_policy(self, total_frames, cls):
3977-
from torchrl.weight_update.weight_sync_schemes import (
3978-
MultiProcessWeightSyncScheme,
3979-
)
3980-
3992+
@pytest.mark.parametrize(
3993+
"weight_sync_scheme",
3994+
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
3995+
)
3996+
def test_start_update_policy(self, total_frames, cls, weight_sync_scheme):
39813997
rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000))
39823998
env = CountingEnv()
39833999
m = nn.Linear(env.observation_spec["observation"].shape[-1], 1)
@@ -3998,8 +4014,8 @@ def test_start_update_policy(self, total_frames, cls):
39984014

39994015
# Add weight sync schemes for multi-process collectors
40004016
kwargs = {}
4001-
if cls != SyncDataCollector:
4002-
kwargs["weight_sync_schemes"] = {"policy": MultiProcessWeightSyncScheme()}
4017+
if cls != SyncDataCollector and weight_sync_scheme is not None:
4018+
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
40034019

40044020
collector = cls(
40054021
env,

0 commit comments

Comments
 (0)