diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 1387f9ca9c..49565ae927 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -368,8 +368,13 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): ) job_id = fl_ctx.get_job_id() + # Use at least the server-required minimum (e.g. for tensor streaming). When the server + # sends MIN_GET_TASK_TIMEOUT we update self.timeout; the caller may still pass a smaller + # config value, so ensure we never use less than the required minimum. if not timeout: timeout = self.timeout + else: + timeout = max(timeout, self.timeout) parent_fqcn = determine_parent_fqcn(self.client_config, fl_ctx) self.logger.debug(f"pulling task from parent FQCN: {parent_fqcn}") diff --git a/tests/unit_test/app_opt/tensor_stream/timeout_management_test.py b/tests/unit_test/app_opt/tensor_stream/timeout_management_test.py index a27c99fd4a..fc1205cf09 100644 --- a/tests/unit_test/app_opt/tensor_stream/timeout_management_test.py +++ b/tests/unit_test/app_opt/tensor_stream/timeout_management_test.py @@ -267,6 +267,81 @@ def get_header_side_effect(key, default=None): for log_message in all_log_messages: assert "Automatically adjusting" not in log_message + @pytest.mark.parametrize( + "initial_timeout,server_min_timeout,caller_timeout,expected_request_timeout", + [ + (5.0, 360.0, 5.0, 360.0), # Caller passes small value, server min is large → use server min + (5.0, 360.0, 30.0, 360.0), # Caller passes moderate value, still below server min → use server min + (5.0, 360.0, 400.0, 400.0), # Caller passes value above server min → use caller value + (5.0, 360.0, 360.0, 360.0), # Caller passes exactly server min → use that + (5.0, 660.0, 600.0, 660.0), # Large server min, caller still below → use server min + ], + ) + @patch("nvflare.private.fed.client.communicator.new_cell_message") + @patch("nvflare.private.fed.client.communicator.determine_parent_fqcn") + @patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx") + def test_explicit_timeout_enforces_server_minimum( + self, + mock_gen_ctx, + mock_determine_parent, + mock_new_cell_message, + initial_timeout, + server_min_timeout, + caller_timeout, + expected_request_timeout, + ): + """Test that an explicit caller-provided timeout is raised to the server minimum when needed. + + When self.timeout has been bumped by a prior MIN_GET_TASK_TIMEOUT from the server, + and the caller passes an explicit (smaller) timeout, pull_task should use + max(caller_timeout, self.timeout) so the server-required minimum is respected. + """ + from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode + + communicator = Communicator(client_config={"client_name": "test_client"}, timeout=initial_timeout) + communicator.engine = Mock() + communicator.cell = Mock() + + # Simulate a prior MIN_GET_TASK_TIMEOUT bump + communicator.timeout = server_min_timeout + + mock_determine_parent.return_value = "parent_fqcn" + mock_new_cell_message.return_value = Mock() + + # Create a successful response (content doesn't matter for this test) + response_shareable = Shareable() + response_shareable.set_header(ServerCommandKey.TASK_NAME, "train") + response_shareable.set_header(FLContextKey.TASK_ID, "task_123") + + mock_task = Mock() + + def get_header_side_effect(key, default=None): + return { + MessageHeaderKey.RETURN_CODE: ReturnCode.OK, + MessageHeaderKey.PAYLOAD_LEN: 1024, + }.get(key, default) + + mock_task.get_header = Mock(side_effect=get_header_side_effect) + mock_task.payload = response_shareable + communicator.cell.send_request.return_value = mock_task + + mock_fl_context = Mock() + mock_fl_context.get_job_id.return_value = "job_123" + mock_fl_context.get_run_abort_signal.return_value = None + mock_fl_context.set_prop = Mock() + + communicator.logger = Mock() + + # Call pull_task WITH an explicit timeout + communicator.pull_task("project", "token", "ssid", mock_fl_context, timeout=caller_timeout) + + # Verify the actual timeout passed to cell.send_request + actual_timeout = communicator.cell.send_request.call_args[1]["timeout"] + assert actual_timeout == expected_request_timeout, ( + f"Expected send_request timeout={expected_request_timeout}, got {actual_timeout}. " + f"caller_timeout={caller_timeout}, self.timeout={server_min_timeout}" + ) + @patch("nvflare.private.fed.client.communicator.new_cell_message") @patch("nvflare.private.fed.client.communicator.determine_parent_fqcn") @patch("nvflare.private.fed.client.communicator.gen_new_peer_ctx")