File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212import torch
1313from torch import multiprocessing as mp
1414
15- from torchrl ._utils import WEIGHT_SYNC_TIMEOUT
1615from torchrl .envs .utils import ExplorationType
1716
1817try :
@@ -40,6 +39,9 @@ def cudagraph_mark_step_begin():
4039_TIMEOUT = 1.0
4140INSTANTIATE_TIMEOUT = 20
4241_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
42+ # Timeout for weight synchronization during collector init.
43+ # Increase this when using many collectors across different CUDA devices.
44+ WEIGHT_SYNC_TIMEOUT = float (os .environ .get ("TORCHRL_WEIGHT_SYNC_TIMEOUT" , 120.0 ))
4345# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
4446_MAX_IDLE_COUNT = int (os .environ .get ("MAX_IDLE_COUNT" , torch .iinfo (torch .int64 ).max ))
4547
Original file line number Diff line number Diff line change 1010
1111from torch import multiprocessing as mp , nn
1212
13- from torchrl ._utils import logger as torchrl_logger , WEIGHT_SYNC_TIMEOUT
13+ from torchrl ._utils import logger as torchrl_logger
14+ from torchrl .collectors ._constants import WEIGHT_SYNC_TIMEOUT
1415
1516from torchrl .weight_update .utils import _resolve_model
1617from torchrl .weight_update .weight_sync_schemes import (
You can’t perform that action at this time.
0 commit comments