Skip to content

Commit a6808c1

Browse files
committed
Update
[ghstack-poisoned]
1 parent f6738c7 commit a6808c1

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

torchrl/collectors/_constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
from torch import multiprocessing as mp
1414

15-
from torchrl._utils import WEIGHT_SYNC_TIMEOUT
1615
from torchrl.envs.utils import ExplorationType
1716

1817
try:
@@ -40,6 +39,9 @@ def cudagraph_mark_step_begin():
4039
_TIMEOUT = 1.0
4140
INSTANTIATE_TIMEOUT = 20
4241
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
42+
# Timeout for weight synchronization during collector init.
43+
# Increase this when using many collectors across different CUDA devices.
44+
WEIGHT_SYNC_TIMEOUT = float(os.environ.get("TORCHRL_WEIGHT_SYNC_TIMEOUT", 120.0))
4345
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
4446
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))
4547

torchrl/weight_update/_shared.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from torch import multiprocessing as mp, nn
1212

13-
from torchrl._utils import logger as torchrl_logger, WEIGHT_SYNC_TIMEOUT
13+
from torchrl._utils import logger as torchrl_logger
14+
from torchrl.collectors._constants import WEIGHT_SYNC_TIMEOUT
1415

1516
from torchrl.weight_update.utils import _resolve_model
1617
from torchrl.weight_update.weight_sync_schemes import (

0 commit comments

Comments
 (0)