7979 RandomPolicy ,
8080)
8181from torchrl .modules import Actor , OrnsteinUhlenbeckProcessModule , SafeModule
82- from torchrl .weight_update import SharedMemWeightSyncScheme
82+ from torchrl .weight_update import (
83+ MultiProcessWeightSyncScheme ,
84+ SharedMemWeightSyncScheme ,
85+ )
8386
8487if os .getenv ("PYTORCH_TEST_FBCODE" ):
8588 IS_FB = True
@@ -1485,12 +1488,12 @@ def env_fn(seed):
14851488
14861489 @pytest .mark .parametrize ("use_async" , [False , True ])
14871490 @pytest .mark .parametrize ("cudagraph" , [False , True ])
1491+ @pytest .mark .parametrize (
1492+ "weight_sync_scheme" ,
1493+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
1494+ )
14881495 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1489- def test_update_weights (self , use_async , cudagraph ):
1490- from torchrl .weight_update .weight_sync_schemes import (
1491- MultiProcessWeightSyncScheme ,
1492- )
1493-
1496+ def test_update_weights (self , use_async , cudagraph , weight_sync_scheme ):
14941497 def create_env ():
14951498 return ContinuousActionVecMockEnv ()
14961499
@@ -1503,6 +1506,9 @@ def create_env():
15031506 collector_class = (
15041507 MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
15051508 )
1509+ kwargs = {}
1510+ if weight_sync_scheme is not None :
1511+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
15061512 collector = collector_class (
15071513 [create_env ] * 3 ,
15081514 policy = policy ,
@@ -1511,7 +1517,7 @@ def create_env():
15111517 frames_per_batch = 20 ,
15121518 cat_results = "stack" ,
15131519 cudagraph_policy = cudagraph ,
1514- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
1520+ ** kwargs ,
15151521 )
15161522 assert "policy" in collector ._weight_senders , collector ._weight_senders .keys ()
15171523 try :
@@ -2857,23 +2863,28 @@ def forward(self, td):
28572863 # ["cuda:0", "cuda"],
28582864 ],
28592865 )
2860- def test_param_sync (self , give_weights , collector , policy_device , env_device ):
2861- from torchrl .weight_update .weight_sync_schemes import (
2862- MultiProcessWeightSyncScheme ,
2863- )
2864-
2866+ @pytest .mark .parametrize (
2867+ "weight_sync_scheme" ,
2868+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
2869+ )
2870+ def test_param_sync (
2871+ self , give_weights , collector , policy_device , env_device , weight_sync_scheme
2872+ ):
28652873 policy = TestUpdateParams .Policy ().to (policy_device )
28662874
28672875 env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
28682876 device = env ().device
28692877 env = [env ]
2878+ kwargs = {}
2879+ if weight_sync_scheme is not None :
2880+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
28702881 col = collector (
28712882 env ,
28722883 policy ,
28732884 device = device ,
28742885 total_frames = 200 ,
28752886 frames_per_batch = 10 ,
2876- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
2887+ ** kwargs ,
28772888 )
28782889 try :
28792890 for i , data in enumerate (col ):
@@ -2918,13 +2929,13 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
29182929 # ["cuda:0", "cuda"],
29192930 ],
29202931 )
2932+ @pytest .mark .parametrize (
2933+ "weight_sync_scheme" ,
2934+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
2935+ )
29212936 def test_param_sync_mixed_device (
2922- self , give_weights , collector , policy_device , env_device
2937+ self , give_weights , collector , policy_device , env_device , weight_sync_scheme
29232938 ):
2924- from torchrl .weight_update .weight_sync_schemes import (
2925- MultiProcessWeightSyncScheme ,
2926- )
2927-
29282939 with torch .device ("cpu" ):
29292940 policy = TestUpdateParams .Policy ()
29302941 policy .param = nn .Parameter (policy .param .data .to (policy_device ))
@@ -2933,13 +2944,16 @@ def test_param_sync_mixed_device(
29332944 env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
29342945 device = env ().device
29352946 env = [env ]
2947+ kwargs = {}
2948+ if weight_sync_scheme is not None :
2949+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
29362950 col = collector (
29372951 env ,
29382952 policy ,
29392953 device = device ,
29402954 total_frames = 200 ,
29412955 frames_per_batch = 10 ,
2942- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
2956+ ** kwargs ,
29432957 )
29442958 try :
29452959 for i , data in enumerate (col ):
@@ -3851,7 +3865,7 @@ def test_weight_update(self, weight_updater):
38513865 if weight_updater == "scheme_shared" :
38523866 kwargs = {"weight_sync_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
38533867 elif weight_updater == "scheme_pipe" :
3854- kwargs = {"weight_sync_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
3868+ kwargs = {"weight_sync_schemes" : {"policy" : MultiProcessWeightSyncScheme ()}}
38553869 elif weight_updater == "weight_updater" :
38563870 kwargs = {"weight_updater" : self .MPSWeightUpdaterBase (policy_weights , 2 )}
38573871 else :
@@ -3973,11 +3987,11 @@ def test_start_multi(self, total_frames, cls):
39733987 @pytest .mark .parametrize (
39743988 "cls" , [SyncDataCollector , MultiaSyncDataCollector , MultiSyncDataCollector ]
39753989 )
3976- def test_start_update_policy ( self , total_frames , cls ):
3977- from torchrl . weight_update . weight_sync_schemes import (
3978- MultiProcessWeightSyncScheme ,
3979- )
3980-
3990+ @ pytest . mark . parametrize (
3991+ "weight_sync_scheme" ,
3992+ [ None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ] ,
3993+ )
3994+ def test_start_update_policy ( self , total_frames , cls , weight_sync_scheme ):
39813995 rb = ReplayBuffer (storage = LazyMemmapStorage (max_size = 1000 ))
39823996 env = CountingEnv ()
39833997 m = nn .Linear (env .observation_spec ["observation" ].shape [- 1 ], 1 )
@@ -3998,8 +4012,8 @@ def test_start_update_policy(self, total_frames, cls):
39984012
39994013 # Add weight sync schemes for multi-process collectors
40004014 kwargs = {}
4001- if cls != SyncDataCollector :
4002- kwargs ["weight_sync_schemes" ] = {"policy" : MultiProcessWeightSyncScheme ()}
4015+ if cls != SyncDataCollector and weight_sync_scheme is not None :
4016+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
40034017
40044018 collector = cls (
40054019 env ,
0 commit comments