-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
manager: add per request timeouts #59
Conversation
2dcd290
to
731fe8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! lgtm. Maybe add python tests for this as well?
54982e2
to
9a0cc4b
Compare
Updated with some unit tests and an integration test + added timeout field to |
torchft/manager_test.py
Outdated
@patch("torchft.manager.ManagerClient", autospec=True) | ||
def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: | ||
manager = self._create_manager(use_async_quorum=False) | ||
client_mock().should_commit = mock_should_commit | ||
|
||
client_mock().quorum.return_value = ( | ||
123, # quorum_id | ||
1, # replica_rank | ||
2, # replica_world | ||
"manager address", | ||
f"localhost:{self.store.port}", | ||
1, # max_step | ||
1, # max_rank | ||
2, # max_world_size | ||
False, # heal | ||
) | ||
|
||
manager.start_quorum(timeout=timedelta(seconds=12)) | ||
self.assertTrue(manager.should_commit(timeout=timedelta(seconds=23))) | ||
|
||
self.assertEqual(client_mock().quorum.call_count, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im not sure if this actually tests that the timeout is passed correctly. Seems to just test that the _clients methods are called at all. Maybe check that _client was called with the expected timeout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check on the kwargs
torchft/manager_integ_test.py
Outdated
def test_quorum_timeout(self) -> None: | ||
with ExitStack() as stack: | ||
lighthouse = Lighthouse( | ||
bind="[::]:0", | ||
min_replicas=2, | ||
) | ||
stack.callback(lighthouse.shutdown) | ||
|
||
store = dist.TCPStore( | ||
host_name="localhost", | ||
port=0, | ||
is_master=True, | ||
wait_for_workers=False, | ||
) | ||
|
||
pg = ProcessGroupGloo() | ||
manager = Manager( | ||
pg=pg, | ||
min_replica_size=2, | ||
load_state_dict=lambda x: None, | ||
state_dict=lambda: None, | ||
store_addr="localhost", | ||
store_port=store.port, | ||
rank=0, | ||
world_size=2, | ||
lighthouse_addr=lighthouse.address(), | ||
port=19530, | ||
use_async_quorum=False, | ||
) | ||
stack.callback(manager.shutdown) | ||
|
||
with self.assertRaisesRegex( | ||
TimeoutError, | ||
"status: Cancelled, message.*Timeout expired", | ||
): | ||
manager.start_quorum(timeout=timedelta(seconds=0.01)) | ||
|
||
with self.assertRaisesRegex( | ||
TimeoutError, | ||
"status: Cancelled, message.*Timeout expired", | ||
): | ||
manager.should_commit(timeout=timedelta(seconds=0.01)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesnt really test that methods now have method specific timeouts though. They could still timeout on the default timeout and pass. Maybe set a different timeout in Manager init from the one used by method and assert the method one was used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added an elapsed time check -- the default is 60s so easy to verify that we're shorter
9a0cc4b
to
03c22d9
Compare
This adds per request timeouts to all the ManagerClient methods. It also forwards the timeout from the manager to the lighthouse for quorum operations.
Test plan:
updated unit tests