-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
process_group: add PG init timeouts + automatically assign manager port #60
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import os | ||
from concurrent.futures import ThreadPoolExecutor | ||
from datetime import timedelta | ||
from typing import Any, Dict, Tuple | ||
from unittest import TestCase, skipUnless | ||
from unittest.mock import Mock | ||
|
@@ -122,6 +123,16 @@ def test_gloo(self) -> None: | |
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) | ||
m(torch.rand(2, 3)) | ||
|
||
def test_gloo_timeout(self) -> None: | ||
store = TCPStore( | ||
host_name="localhost", port=0, is_master=True, wait_for_workers=False | ||
) | ||
|
||
store_addr = f"localhost:{store.port}/prefix" | ||
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01)) | ||
with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"): | ||
pg.configure(store_addr, 0, 2) | ||
|
||
Comment on lines
+132
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would an assert that the time taken to raise the timeout is now much smaller than the default be good here? Would make sure that future code does not break the ability for the wrapper to set the underlying PG creation timeout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to compare to the default -- the message includes the time (in this case 10ms) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok I see. All good then |
||
# pyre-fixme[56]: Pyre was not able to infer the type of argument | ||
@skipUnless(torch.cuda.is_available(), "needs CUDA") | ||
def test_nccl(self) -> None: | ||
|
@@ -155,28 +166,44 @@ def test_baby_gloo(self) -> None: | |
host_name="localhost", port=0, is_master=True, wait_for_workers=False | ||
) | ||
|
||
store_addr = f"localhost:{store.port}/prefix" | ||
store_addr: str = f"localhost:{store.port}/prefix" | ||
|
||
def run(rank: int) -> Tuple[torch.Tensor, Work]: | ||
a = ProcessGroupBabyGloo() | ||
a.configure(store_addr, rank, 2) | ||
|
||
a = ProcessGroupBabyGloo() | ||
b = ProcessGroupBabyGloo() | ||
self.assertEqual(a.size(), 2) | ||
|
||
a.configure(store_addr, 0, 2) | ||
b.configure(store_addr, 1, 2) | ||
at = torch.tensor([rank + 1]) | ||
|
||
self.assertEqual(a.size(), 2) | ||
a_work = a.allreduce([at], ReduceOp.SUM) | ||
return at, a_work | ||
|
||
at = torch.tensor([1]) | ||
bt = torch.tensor([2]) | ||
with ThreadPoolExecutor(max_workers=2) as executor: | ||
a_fut = executor.submit(run, 0) | ||
b_fut = executor.submit(run, 1) | ||
|
||
a_work = a.allreduce([at], ReduceOp.SUM) | ||
b_work = b.allreduce([bt], ReduceOp.SUM) | ||
at, a_work = a_fut.result() | ||
bt, b_work = b_fut.result() | ||
|
||
a_work.wait() | ||
fut = b_work.get_future() | ||
|
||
fut.wait() | ||
|
||
torch.testing.assert_close(at, bt) | ||
torch.testing.assert_close(at, torch.tensor([3])) | ||
torch.testing.assert_close(bt, torch.tensor([3])) | ||
|
||
def test_baby_gloo_timeout(self) -> None: | ||
store = TCPStore( | ||
host_name="localhost", port=0, is_master=True, wait_for_workers=False | ||
) | ||
|
||
store_addr = f"localhost:{store.port}/prefix" | ||
|
||
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01)) | ||
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"): | ||
a.configure(store_addr, 0, 2) | ||
|
||
def test_dummy(self) -> None: | ||
pg = ProcessGroupDummy(0, 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
_get
will throw the Exception, why do we need this assertion?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to make sure we consume the value -- doesn't really matter what the output is as long as it's consistent