-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
process_group: add PG init timeouts + automatically assign manager port #60
Conversation
torchft/manager.py
Outdated
|
||
if port is None: | ||
port_str = os.environ.get(MANAGER_PORT_ENV) | ||
if port_str is None: | ||
port = 0 | ||
else: | ||
port = int(port_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to set the default value to 0
? Would save these few lines and looks cleaner imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks for the suggestion!
64eb94a
to
b1c3692
Compare
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01)) | ||
with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"): | ||
pg.configure(store_addr, 0, 2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would an assert that the time taken to raise the timeout is now much smaller than the default be good here? Would make sure that future code does not break the ability for the wrapper to set the underlying PG creation timeout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to compare to the default -- the message includes the time (in this case 10ms)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok I see. All good then
torchft/process_group_test.py
Outdated
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01)) | ||
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"): | ||
a.configure(store_addr, 0, 2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same suggestion as the gloo one about testing that timeout was changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we have the time in the message I think that's sufficient
torchft/process_group_test.py
Outdated
@@ -178,6 +193,17 @@ def test_baby_gloo(self) -> None: | |||
|
|||
torch.testing.assert_close(at, bt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can assert an exact value here? Would be more a check that all reduce actually does all reduce which might not be the intention here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
torchft/process_group.py
Outdated
try: | ||
# pyre-fixme[20]: expects argument options | ||
pg = cls.PG_CLASS(store, rank, world_size) | ||
except Exception as e: | ||
logger.exception(f"got exception in worker: {e}") | ||
tx.put(e) | ||
return | ||
tx.put(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since PGWrapper now uses _create_pg instead of PG_CLASS instantiation, should we make the code more consistent by having _create_pg
here too?
# fetch the status of the PG init | ||
# if an exception was returned _get will throw | ||
assert _get(rx, self._timeout) is None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If _get
will throw the Exception, why do we need this assertion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to make sure we consume the value -- doesn't really matter what the output is as long as it's consistent
b1c3692
to
d6f256c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! lgtm
This does two things:
Test plan: