Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,13 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None):
)
job_id = fl_ctx.get_job_id()

# Use at least the server-required minimum (e.g. for tensor streaming). When the server
# sends MIN_GET_TASK_TIMEOUT we update self.timeout; the caller may still pass a smaller
# config value, so ensure we never use less than the required minimum.
if not timeout:
timeout = self.timeout
else:
timeout = max(timeout, self.timeout)

parent_fqcn = determine_parent_fqcn(self.client_config, fl_ctx)
self.logger.debug(f"pulling task from parent FQCN: {parent_fqcn}")
Expand Down
75 changes: 75 additions & 0 deletions tests/unit_test/app_opt/tensor_stream/timeout_management_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,81 @@ def get_header_side_effect(key, default=None):
for log_message in all_log_messages:
assert "Automatically adjusting" not in log_message

@pytest.mark.parametrize(
"initial_timeout,server_min_timeout,caller_timeout,expected_request_timeout",
[
(5.0, 360.0, 5.0, 360.0), # Caller passes small value, server min is large → use server min
(5.0, 360.0, 30.0, 360.0), # Caller passes moderate value, still below server min → use server min
(5.0, 360.0, 400.0, 400.0), # Caller passes value above server min → use caller value
(5.0, 360.0, 360.0, 360.0), # Caller passes exactly server min → use that
(5.0, 660.0, 600.0, 660.0), # Large server min, caller still below → use server min
],
)
@patch("nvflare.private.fed.client.communicator.new_cell_message")
@patch("nvflare.private.fed.client.communicator.determine_parent_fqcn")
@patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx")
def test_explicit_timeout_enforces_server_minimum(
self,
mock_gen_ctx,
mock_determine_parent,
mock_new_cell_message,
initial_timeout,
server_min_timeout,
caller_timeout,
expected_request_timeout,
):
"""Test that an explicit caller-provided timeout is raised to the server minimum when needed.

When self.timeout has been bumped by a prior MIN_GET_TASK_TIMEOUT from the server,
and the caller passes an explicit (smaller) timeout, pull_task should use
max(caller_timeout, self.timeout) so the server-required minimum is respected.
"""
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode

communicator = Communicator(client_config={"client_name": "test_client"}, timeout=initial_timeout)
communicator.engine = Mock()
communicator.cell = Mock()

# Simulate a prior MIN_GET_TASK_TIMEOUT bump
communicator.timeout = server_min_timeout

mock_determine_parent.return_value = "parent_fqcn"
mock_new_cell_message.return_value = Mock()

# Create a successful response (content doesn't matter for this test)
response_shareable = Shareable()
response_shareable.set_header(ServerCommandKey.TASK_NAME, "train")
response_shareable.set_header(FLContextKey.TASK_ID, "task_123")

mock_task = Mock()

def get_header_side_effect(key, default=None):
return {
MessageHeaderKey.RETURN_CODE: ReturnCode.OK,
MessageHeaderKey.PAYLOAD_LEN: 1024,
}.get(key, default)

mock_task.get_header = Mock(side_effect=get_header_side_effect)
mock_task.payload = response_shareable
communicator.cell.send_request.return_value = mock_task

mock_fl_context = Mock()
mock_fl_context.get_job_id.return_value = "job_123"
mock_fl_context.get_run_abort_signal.return_value = None
mock_fl_context.set_prop = Mock()

communicator.logger = Mock()

# Call pull_task WITH an explicit timeout
communicator.pull_task("project", "token", "ssid", mock_fl_context, timeout=caller_timeout)

# Verify the actual timeout passed to cell.send_request
actual_timeout = communicator.cell.send_request.call_args[1]["timeout"]
assert actual_timeout == expected_request_timeout, (
f"Expected send_request timeout={expected_request_timeout}, got {actual_timeout}. "
f"caller_timeout={caller_timeout}, self.timeout={server_min_timeout}"
)

@patch("nvflare.private.fed.client.communicator.new_cell_message")
@patch("nvflare.private.fed.client.communicator.determine_parent_fqcn")
@patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx")
Expand Down