Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

process_group: add PG init timeouts + automatically assign manager port #60

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from torchft.process_group import ProcessGroup

MANAGER_ADDR_KEY: str = "manager_addr"
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
REPLICA_ID_KEY: str = "replica_id"

T = TypeVar("T")
Expand Down Expand Up @@ -74,6 +74,12 @@ class Manager:
"""
Manager manages the full fault tolerant training loop.

This requires the that the TCPStore specified by the store_addr and
store_port or MASTER_ADDR and MASTER_PORT environment variables to be
started prior to creating this manager. If using a modern version of
torchelastic this will already be the case. Otherwise, it should be started
via torch.distributed.init_process_group prior to creating this manager.

NOTE: when saving periodic checkpoints you must save and restore the
Manager's state_dict as well to avoid synchronization issues.
"""
Expand All @@ -84,7 +90,6 @@ def __init__(
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
min_replica_size: int,
port: int = MANAGER_DEFAULT_PORT,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
rank: Optional[int] = None,
Expand All @@ -94,13 +99,18 @@ def __init__(
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
port: Optional[int] = None,
) -> None:
"""
Args:
load_state_dict: function to load the state dict when recovering
state_dict: function to save the state dict with recovering
min_replica_size: minimum number of replicas on each step
port: if rank==0, the port to run the manager server on
port: if rank==0, the port to run the manager server on.
Port assignment priority:
1. this argument
2. TORCHFT_MANAGER_PORT env var
3. arbitrary port assigned via 0
use_async_quorum: whether to run the quorum asynchronously during the forward pass
timeout: timeout for all operations
rank: the replica group local rank
Expand Down Expand Up @@ -147,6 +157,10 @@ def _manager_state_dict() -> Dict[str, T]:

if rank == 0:
hostname = socket.gethostname()

if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))

addr = f"http://{hostname}:{port}"
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
Expand All @@ -163,7 +177,7 @@ def _manager_state_dict() -> Dict[str, T]:
world_size=world_size,
)

self._store.set(MANAGER_ADDR_KEY, addr)
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
self._store.set(REPLICA_ID_KEY, replica_id)

addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
Expand Down
94 changes: 69 additions & 25 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""

import logging
import queue
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -53,8 +54,23 @@
_FUTURE_EXCEPTION = "fut_exception"


def _get(queue: mp.Queue, timeout: float) -> object:
v = queue.get(timeout=timeout)
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
"""
Gets an item from a queue with a timeout. If the timeout is exceeded then
a TimeoutError is raised.

If an exception is returned from the queue then it is raised.

Args:
q: queue to get from
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()
try:
v = q.get(timeout=timeout)
except queue.Empty as e:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e
if isinstance(v, Exception):
raise v
return v
Expand Down Expand Up @@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
Every time this is called it must be provided with a unique prefixed
store address. I.e. localhost:1234/my/prefix/1

This function will block until the underlying ProcessGroup is created.
If an error occurs this will throw.

Args:
store_addr: address of the store to use
rank: rank of this process
Expand Down Expand Up @@ -187,7 +206,6 @@ def __repr__(self) -> str:


class ProcessGroupWrapper(ProcessGroup):
PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
"""
This is a wrapper around any ProcessGroup with a reconfiguration method.
"""
Expand All @@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

store = create_store_client(store_addr)

# TODO: set global timeout
# pyre-fixme[20]: expects argument options
self._pg = self.PG_CLASS(store, rank, world_size)
self._pg = self._create_pg(store, rank, world_size)

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
raise NotImplementedError("not implemented")

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)
Expand Down Expand Up @@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
This is a reconfigurable version of ProcessGroupGloo.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
super().__init__()
self._timeout = timeout

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)

def getBackendName(self) -> str:
return "torchft-gloo"
Expand All @@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
abort when reconfiguring, we need to ensure this is safe.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-nccl"
Expand Down Expand Up @@ -546,10 +569,9 @@ class ProcessGroupBaby(ProcessGroup):

"""

PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
WORK_CLASS: Type[_BabyWork] = _BabyWork

def __init__(self, timeout: float = 60.0) -> None:
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
super().__init__(0, 1)

self._world_size = -1
Expand All @@ -562,7 +584,10 @@ def __init__(self, timeout: float = 60.0) -> None:
self._futures: Dict[int, Future[object]] = {}
self._futures_lock = threading.Lock()

self._timeout = timeout
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

self._timeout: float = timeout

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._p is not None:
Expand All @@ -581,7 +606,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

ctx = mp.get_context("spawn")
self._tx = ctx.Queue()
self._rx = ctx.Queue()
self._rx = rx = ctx.Queue()

# futures need thread to fire callbacks
self._future_queue = ctx.Queue()
Expand All @@ -602,6 +627,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
)
self._p.start()

# fetch the status of the PG init
# if an exception was returned _get will throw
assert _get(rx, self._timeout) is None

Comment on lines +630 to +633
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If _get will throw the Exception, why do we need this assertion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure we consume the value -- doesn't really matter what the output is as long as it's consistent

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
"""
This is a class method to avoid pickling the class.
"""
raise NotImplementedError("not implemented")

@classmethod
def _worker(
cls,
Expand All @@ -615,8 +651,13 @@ def _worker(
try:
store = create_store_client(store_addr)

# pyre-fixme[20]: expects argument options
pg = cls.PG_CLASS(store, rank, world_size)
try:
pg = cls._create_pg(store, rank, world_size)
except Exception as e:
logger.exception(f"got exception in worker: {e}")
tx.put(e)
return
tx.put(None)

work = {}
next_op_id: int = 0
Expand Down Expand Up @@ -737,9 +778,10 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
ProcessGroupBabyNCCL.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-baby-gloo"
Expand All @@ -761,11 +803,13 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
WORK_CLASS = _BabyWorkNCCL

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-baby-nccl"

Expand Down
49 changes: 38 additions & 11 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Any, Dict, Tuple
from unittest import TestCase, skipUnless
from unittest.mock import Mock
Expand Down Expand Up @@ -122,6 +123,16 @@ def test_gloo(self) -> None:
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
m(torch.rand(2, 3))

def test_gloo_timeout(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"):
pg.configure(store_addr, 0, 2)

Comment on lines +132 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would an assert that the time taken to raise the timeout is now much smaller than the default be good here? Would make sure that future code does not break the ability for the wrapper to set the underlying PG creation timeout.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to compare to the default -- the message includes the time (in this case 10ms)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I see. All good then

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.is_available(), "needs CUDA")
def test_nccl(self) -> None:
Expand Down Expand Up @@ -155,28 +166,44 @@ def test_baby_gloo(self) -> None:
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"
store_addr: str = f"localhost:{store.port}/prefix"

def run(rank: int) -> Tuple[torch.Tensor, Work]:
a = ProcessGroupBabyGloo()
a.configure(store_addr, rank, 2)

a = ProcessGroupBabyGloo()
b = ProcessGroupBabyGloo()
self.assertEqual(a.size(), 2)

a.configure(store_addr, 0, 2)
b.configure(store_addr, 1, 2)
at = torch.tensor([rank + 1])

self.assertEqual(a.size(), 2)
a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work

at = torch.tensor([1])
bt = torch.tensor([2])
with ThreadPoolExecutor(max_workers=2) as executor:
a_fut = executor.submit(run, 0)
b_fut = executor.submit(run, 1)

a_work = a.allreduce([at], ReduceOp.SUM)
b_work = b.allreduce([bt], ReduceOp.SUM)
at, a_work = a_fut.result()
bt, b_work = b_fut.result()

a_work.wait()
fut = b_work.get_future()

fut.wait()

torch.testing.assert_close(at, bt)
torch.testing.assert_close(at, torch.tensor([3]))
torch.testing.assert_close(bt, torch.tensor([3]))

def test_baby_gloo_timeout(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
a.configure(store_addr, 0, 2)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
Expand Down
Loading