Skip to content

Commit 64eb94a

Browse files
committed
process_group: add PG timeouts + automatically assign manager port
1 parent 2ae42a0 commit 64eb94a

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

torchft/manager.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from torchft.process_group import ProcessGroup
4747

4848
MANAGER_ADDR_KEY: str = "manager_addr"
49-
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
49+
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
5050
REPLICA_ID_KEY: str = "replica_id"
5151

5252
T = TypeVar("T")
@@ -84,7 +84,6 @@ def __init__(
8484
load_state_dict: Callable[[T], None],
8585
state_dict: Callable[[], T],
8686
min_replica_size: int,
87-
port: int = MANAGER_DEFAULT_PORT,
8887
use_async_quorum: bool = True,
8988
timeout: timedelta = timedelta(seconds=60),
9089
rank: Optional[int] = None,
@@ -94,6 +93,7 @@ def __init__(
9493
store_port: Optional[int] = None,
9594
lighthouse_addr: Optional[str] = None,
9695
replica_id: Optional[str] = None,
96+
port: Optional[int] = None,
9797
) -> None:
9898
"""
9999
Args:
@@ -147,6 +147,13 @@ def _manager_state_dict() -> Dict[str, T]:
147147

148148
if rank == 0:
149149
hostname = socket.gethostname()
150+
151+
if port is None:
152+
port_str = os.environ.get(MANAGER_PORT_ENV)
153+
if port_str is None:
154+
port = 0
155+
else:
156+
port = int(port_str)
150157
addr = f"http://{hostname}:{port}"
151158
bind = f"[::]:{port}"
152159
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
@@ -163,7 +170,7 @@ def _manager_state_dict() -> Dict[str, T]:
163170
world_size=world_size,
164171
)
165172

166-
self._store.set(MANAGER_ADDR_KEY, addr)
173+
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
167174
self._store.set(REPLICA_ID_KEY, replica_id)
168175

169176
addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")

torchft/process_group.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def __repr__(self) -> str:
187187

188188

189189
class ProcessGroupWrapper(ProcessGroup):
190-
PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
191190
"""
192191
This is a wrapper around any ProcessGroup with a reconfiguration method.
193192
"""
@@ -209,9 +208,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209208

210209
store = create_store_client(store_addr)
211210

212-
# TODO: set global timeout
213-
# pyre-fixme[20]: expects argument options
214-
self._pg = self.PG_CLASS(store, rank, world_size)
211+
self._pg = self._create_pg(store, rank, world_size)
212+
213+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
214+
raise NotImplementedError("not implemented")
215215

216216
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
217217
return self.parent.allreduce(tensors, opts)
@@ -244,9 +244,12 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244244
This is a reconfigurable version of ProcessGroupGloo.
245245
"""
246246

247-
PG_CLASS: Type[BaseProcessGroup] = (
248-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249-
)
247+
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
248+
super().__init__()
249+
self._timeout = timeout
250+
251+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
252+
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)
250253

251254
def getBackendName(self) -> str:
252255
return "torchft-gloo"
@@ -263,9 +266,12 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263266
abort when reconfiguring, we need to ensure this is safe.
264267
"""
265268

266-
PG_CLASS: Type[BaseProcessGroup] = (
267-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268-
)
269+
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
270+
super().__init__()
271+
self._timeout = timeout
272+
273+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
274+
return BaseProcessGroupNCCL(store, rank, world_size, self._timeout)
269275

270276
def getBackendName(self) -> str:
271277
return "torchft-nccl"

0 commit comments

Comments
 (0)