diff --git a/docs/programming_guide/timeouts.rst b/docs/programming_guide/timeouts.rst index a9018e241e..23c99c42a7 100644 --- a/docs/programming_guide/timeouts.rst +++ b/docs/programming_guide/timeouts.rst @@ -2519,6 +2519,45 @@ application.conf Settings # Shutdown end_run_readiness_timeout = 10.0 + # Server startup/dead-job safety flags + strict_start_job_reply_check = false + sync_client_jobs_require_previous_report = true + + +.. _server_startup_dead_job_safety_flags: + +Server Startup and Dead-Job Safety Flags +---------------------------------------- + +These ``application.conf`` flags are server-side safety controls used during job startup +and client heartbeat synchronization: + +.. list-table:: + :header-rows: 1 + :widths: 36 12 52 + + * - Parameter + - Default + - Purpose + * - strict_start_job_reply_check + - false + - Enables strict START_JOB reply validation (detects missing/timeout replies and non-OK return codes). + * - sync_client_jobs_require_previous_report + - true + - Requires a prior positive heartbeat report before treating "missing job on client" as a dead-job signal. + +Recommended usage: + +- ``strict_start_job_reply_check`` defaults to ``false`` for backward compatibility. + Enable it (``true``) for large-scale or hierarchical deployments where startup timeouts + are expected and you want the server to proceed with the subset of clients that responded, + rather than failing the entire job. With ``false``, a timed-out reply is treated as a + silent success, which can mask startup problems. +- Keep ``sync_client_jobs_require_previous_report=true`` (default) to prevent false + dead-job reports during startup races and transient heartbeat delays. +- Set ``sync_client_jobs_require_previous_report=false`` only to restore legacy behavior + where the first missing-job heartbeat immediately triggers dead-job detection. + Admin Client Session (Python API) --------------------------------- diff --git a/docs/user_guide/timeout_troubleshooting.rst b/docs/user_guide/timeout_troubleshooting.rst index bdd7fb486d..c08bc37376 100644 --- a/docs/user_guide/timeout_troubleshooting.rst +++ b/docs/user_guide/timeout_troubleshooting.rst @@ -214,6 +214,17 @@ Via Configuration Files get_task_timeout = 300.0 submit_task_result_timeout = 300.0 + # Server startup/dead-job safety flags + strict_start_job_reply_check = false + sync_client_jobs_require_previous_report = true + +Server-side safety flags guidance (see :ref:`server_startup_dead_job_safety_flags` for full details): + +- ``strict_start_job_reply_check`` (default ``false``): keep default for backward-compatible startup behavior; + set to ``true`` to enforce stricter START_JOB reply checks. +- ``sync_client_jobs_require_previous_report`` (default ``true``): keep enabled to avoid false dead-job reports + caused by transient startup or sync races. + **comm_config.json** (system-level, in startup kit): .. code-block:: json diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 018f03179a..1debdc841a 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -535,6 +535,12 @@ class ConfigVarName: # server: wait this long since job schedule time before starting to check dead/disconnected clients DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" + # server: require all start-job replies to be non-timeout and OK before considering the run started + STRICT_START_JOB_REPLY_CHECK = "strict_start_job_reply_check" + + # server: require prior positive job observation before reporting "missing job on client" as dead-job + SYNC_CLIENT_JOBS_REQUIRE_PREVIOUS_REPORT = "sync_client_jobs_require_previous_report" + # customized nvflare decomposers module name DECOMPOSER_MODULE = "nvflare_decomposers" diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index efe35222b7..99e7fd93bb 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -320,7 +320,15 @@ def get_job_clients(self, fl_ctx: FLContext): """ job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if not isinstance(job_meta, dict): + raise RuntimeError(f"invalid job meta type: expected dict but got {type(job_meta)}") + job_clients = job_meta.get(JobMetaKey.JOB_CLIENTS) + if job_clients is None: + raise RuntimeError(f"missing {JobMetaKey.JOB_CLIENTS} in job meta") + if not isinstance(job_clients, list): + raise RuntimeError(f"invalid {JobMetaKey.JOB_CLIENTS} type: expected list but got {type(job_clients)}") + self.all_clients = [from_dict(d) for d in job_clients] for c in self.all_clients: self.name_to_clients[c.name] = c diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index 917a34b3af..7052bed0fc 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -32,7 +32,7 @@ from nvflare.fuel.hci.server.hci import AdminServer from nvflare.fuel.hci.server.login import LoginModule, SessionManager from nvflare.fuel.sec.audit import Auditor, AuditService -from nvflare.private.admin_defs import Message +from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader from nvflare.private.fed.server.message_send import ClientReply, send_requests @@ -77,7 +77,27 @@ def __init__(self, client, req: Message): self.req = req -def check_client_replies(replies: List[ClientReply], client_sites: List[str], command: str): +def check_client_replies( + replies: List[ClientReply], client_sites: List[str], command: str, strict: bool = False +) -> List[str]: + """Check client replies for errors. + + Args: + replies: list of client replies + client_sites: list of expected client names + command: command description for error messages + strict: if True, detect timed-out clients (reply=None) and return them as a list + rather than raising. Explicit errors (non-OK return code or error body) + always raise regardless of this flag. + + Returns: + List of client names whose reply was None (timed out). Only populated when + strict=True; always empty when strict=False. + + Raises: + RuntimeError: if no replies were received, reply count mismatches, structurally + missing replies (strict mode), or any client returned an explicit error. + """ display_sites = ", ".join(client_sites) if not replies: raise RuntimeError(f"Failed to {command} to the clients {display_sites}: no replies.") @@ -85,12 +105,42 @@ def check_client_replies(replies: List[ClientReply], client_sites: List[str], co raise RuntimeError(f"Failed to {command} to the clients {display_sites}: not enough replies.") error_msg = "" - for r, client_name in zip(replies, client_sites): - if r.reply and ERROR_MSG_PREFIX in r.reply.body: - error_msg += f"\t{client_name}: {r.reply.body}\n" - if error_msg != "": + timed_out_clients = [] + replies_by_client = {r.client_name: r for r in replies} + + if strict: + missing_clients = [c for c in client_sites if c not in replies_by_client] + if missing_clients: + raise RuntimeError( + f"Failed to {command} to the clients {display_sites}: missing replies from {missing_clients}." + ) + + for client_name in client_sites: + r = replies_by_client[client_name] + if not r.reply: + # Timeout: record and continue — caller decides whether to exclude or abort. + timed_out_clients.append(client_name) + continue + + return_code = r.reply.get_header(MsgHeader.RETURN_CODE, ReturnCode.OK) + if return_code != ReturnCode.OK: + detail = r.reply.body if r.reply.body else f"return code {return_code}" + error_msg += f"\t{client_name}: {detail}\n" + continue + + if isinstance(r.reply.body, str) and r.reply.body.startswith(ERROR_MSG_PREFIX): + error_msg += f"\t{client_name}: {r.reply.body}\n" + else: + for client_name in client_sites: + r = replies_by_client.get(client_name) + if r and r.reply and isinstance(r.reply.body, str) and r.reply.body.startswith(ERROR_MSG_PREFIX): + error_msg += f"\t{client_name}: {r.reply.body}\n" + + if error_msg: raise RuntimeError(f"Failed to {command} to the following clients: \n{error_msg}") + return timed_out_clients + class FedAdminServer(AdminServer): def __init__( diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index acb9604b43..a6d237114f 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -343,6 +343,13 @@ def __init__( self.name_to_reg = {} self.cred_keeper = CredKeeper() + # Tracks per-job which client tokens have been positively observed running the job. + # Keyed by job_id -> set of client tokens. Used by _sync_client_jobs() to require + # a prior positive heartbeat before classifying a client's missing job as "dead". + # Entries are cleaned up as soon as the job is no longer in run_processes. + self._job_reported_clients: Dict[str, set] = {} + self._job_reported_clients_lock = threading.Lock() + # these are used when the server sends a message to itself. self.my_own_auth_client_name = "server" self.my_own_token = "server" @@ -795,21 +802,63 @@ def client_heartbeat(self, request: Message) -> Message: def _sync_client_jobs(self, request, client_token): # jobs that are running on client but not on server need to be aborted! client_jobs = request.get_header(CellMessageHeaderKeys.JOB_IDS) - server_jobs = self.engine.run_processes.keys() - jobs_need_abort = list(set(client_jobs).difference(server_jobs)) - - # also check jobs that are running on server but not on the client - jobs_on_server_but_not_on_client = list(set(server_jobs).difference(client_jobs)) - if jobs_on_server_but_not_on_client: - # notify all the participating clients these jobs are not running on server anymore - for job_id in jobs_on_server_but_not_on_client: - job_info = self.engine.run_processes[job_id] + if not isinstance(client_jobs, (list, tuple, set)): + client_jobs = [] + + client_jobs = set(client_jobs) + server_jobs = set(self.engine.run_processes.keys()) + jobs_need_abort = list(client_jobs.difference(server_jobs)) + + require_previous_report = ConfigService.get_bool_var( + name=ConfigVarName.SYNC_CLIENT_JOBS_REQUIRE_PREVIOUS_REPORT, + conf=SystemConfigs.APPLICATION_CONF, + default=True, + ) + + with self._job_reported_clients_lock: + # Remove stale tracking entries for jobs that are no longer running. + for stale_job_id in list(self._job_reported_clients.keys()): + if stale_job_id not in server_jobs: + del self._job_reported_clients[stale_job_id] + + # Record jobs that this client has reported at least once. + # If require_previous_report is enabled, we only treat "missing job on client" + # as dead-job after first positive observation. + for job_id in server_jobs.intersection(client_jobs): + job_info = self.engine.run_processes.get(job_id) + if not job_info: + continue + participating_clients = job_info.get(RunProcessKey.PARTICIPANTS, None) - if participating_clients: + if not participating_clients or client_token not in participating_clients: + continue + + self._job_reported_clients.setdefault(job_id, set()).add(client_token) + + # Also check jobs that are running on server but not on the client. + jobs_on_server_but_not_on_client = list(server_jobs.difference(client_jobs)) + dead_job_notifications = [] + if jobs_on_server_but_not_on_client: + for job_id in jobs_on_server_but_not_on_client: + job_info = self.engine.run_processes.get(job_id) + if not job_info: + continue + + participating_clients = job_info.get(RunProcessKey.PARTICIPANTS, None) + if not participating_clients: + continue + # this is a dict: token => nvflare.apis.client.Client client = participating_clients.get(client_token, None) - if client: - self._notify_dead_job(client, job_id, "missing job on client") + if not client: + continue + + reported_clients = self._job_reported_clients.get(job_id, set()) + if (not require_previous_report) or (client_token in reported_clients): + dead_job_notifications.append((client, job_id)) + + for client, job_id in dead_job_notifications: + self._notify_dead_job(client, job_id, "missing job on client") return jobs_need_abort diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index 46c3a03885..01cb4cb140 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -21,12 +21,21 @@ from nvflare.apis.client import Client from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SiteType, SystemComponents +from nvflare.apis.fl_constant import ( + AdminCommandNames, + ConfigVarName, + FLContextKey, + RunProcessKey, + SiteType, + SystemComponents, + SystemConfigs, +) from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import ALL_SITES, Job, JobMetaKey, RunStatus from nvflare.apis.job_scheduler_spec import DispatchInfo from nvflare.apis.workspace import Workspace from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.config_service import ConfigService from nvflare.lighter.utils import verify_folder_signature from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode from nvflare.private.defs import RequestHeader, TrainingTopic @@ -214,7 +223,12 @@ def _deploy_job(self, job: Job, sites: dict, fl_ctx: FLContext) -> Tuple[str, li else: deploy_detail.append(f"{client_name}: OK") else: - deploy_detail.append(f"{client_name}: unknown") + # No reply means the client timed out during deployment. + # Count this as a failure so the min_sites / required_sites check + # can decide whether to abort, rather than silently treating a + # timed-out client as successfully deployed. + failed_clients.append(client_name) + deploy_detail.append(f"{client_name}: no reply (deployment timeout)") # see whether any of the failed clients are required if failed_clients: @@ -248,15 +262,63 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo # job_clients is a dict of: token => Client assert isinstance(job_clients, dict) participating_clients = [c.to_dict() for c in job_clients.values()] + # start_client_job serializes job.meta into request headers; make sure + # JOB_CLIENTS is available before client startup. job.meta[JobMetaKey.JOB_CLIENTS] = participating_clients err = engine.start_app_on_server(fl_ctx, job=job, job_clients=job_clients) if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") replies = engine.start_client_job(job, client_sites, fl_ctx) - client_sites_names = list(client_sites.keys()) - check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") - display_sites = ",".join(client_sites_names) + all_client_sites = list(client_sites.keys()) + active_client_sites = list(all_client_sites) + strict_start_reply_check = ConfigService.get_bool_var( + name=ConfigVarName.STRICT_START_JOB_REPLY_CHECK, + conf=SystemConfigs.APPLICATION_CONF, + default=False, + ) + timed_out = check_client_replies( + replies=replies, + client_sites=all_client_sites, + command=f"start job ({job_id})", + strict=strict_start_reply_check, + ) + if timed_out: + active_count = len(all_client_sites) - len(timed_out) + + # A required site timing out is fatal regardless of min_sites, same as deploy phase. + if job.required_sites: + for c in timed_out: + if c in job.required_sites: + raise RuntimeError(f"start job ({job_id}): required client {c} timed out") + + if job.min_sites and active_count < job.min_sites: + raise RuntimeError( + f"start job ({job_id}): {len(timed_out)} client(s) timed out and remaining " + f"{active_count} < min_sites {job.min_sites}: {timed_out}" + ) + self.log_warning( + fl_ctx, + f"start job ({job_id}): {len(timed_out)} client(s) timed out at start-job: {timed_out}; " + f"{active_count} of {len(all_client_sites)} clients started successfully.", + ) + active_client_sites = [c for c in all_client_sites if c not in timed_out] + + if not strict_start_reply_check: + # In non-strict mode, check_client_replies() does not return timed-out clients. + # Build active clients directly from actual replies so JOB_CLIENTS stays accurate. + replies_by_client = {r.client_name: r for r in replies} + active_client_sites = [] + for client_name in all_client_sites: + client_reply = replies_by_client.get(client_name) + if client_reply and client_reply.reply: + active_client_sites.append(client_name) + + # Set metadata once, after any timeout exclusion, so it always reflects active participants. + active_sites = set(active_client_sites) + participating_clients = [c.to_dict() for c in job_clients.values() if c.name in active_sites] + job.meta[JobMetaKey.JOB_CLIENTS] = participating_clients + display_sites = ",".join(active_client_sites) self.log_info(fl_ctx, f"Started run: {job_id} for clients: {display_sites}") self.fire_event(EventType.JOB_STARTED, fl_ctx) diff --git a/tests/unit_test/private/fed/client/__init__.py b/tests/unit_test/private/fed/client/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit_test/private/fed/client/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/private/fed/client/client_run_manager_test.py b/tests/unit_test/private/fed/client/client_run_manager_test.py new file mode 100644 index 0000000000..fe54c19fc3 --- /dev/null +++ b/tests/unit_test/private/fed/client/client_run_manager_test.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nvflare.apis.job_def import JobMetaKey +from nvflare.private.fed.client.client_run_manager import ClientRunManager + + +class _DummyRunManager: + def __init__(self): + self.all_clients = None + self.name_to_clients = {} + + +def test_get_job_clients_raises_if_job_clients_missing(): + run_manager = _DummyRunManager() + fl_ctx = MagicMock() + fl_ctx.get_prop.return_value = {} + + with pytest.raises(RuntimeError, match=f"missing {JobMetaKey.JOB_CLIENTS}"): + ClientRunManager.get_job_clients(run_manager, fl_ctx) + + +def test_get_job_clients_raises_if_job_clients_not_list(): + run_manager = _DummyRunManager() + fl_ctx = MagicMock() + fl_ctx.get_prop.return_value = {JobMetaKey.JOB_CLIENTS: "bad"} + + with pytest.raises(RuntimeError, match=f"invalid {JobMetaKey.JOB_CLIENTS} type"): + ClientRunManager.get_job_clients(run_manager, fl_ctx) + + +def test_get_job_clients_accepts_empty_list(): + run_manager = _DummyRunManager() + fl_ctx = MagicMock() + fl_ctx.get_prop.return_value = {JobMetaKey.JOB_CLIENTS: []} + + ClientRunManager.get_job_clients(run_manager, fl_ctx) + + assert run_manager.all_clients == [] + fl_ctx.set_prop.assert_called_once() diff --git a/tests/unit_test/private/fed/server/admin_test.py b/tests/unit_test/private/fed/server/admin_test.py new file mode 100644 index 0000000000..a4a5956000 --- /dev/null +++ b/tests/unit_test/private/fed/server/admin_test.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode +from nvflare.private.fed.server.admin import check_client_replies +from nvflare.private.fed.server.message_send import ClientReply + + +def _make_client_reply(client_name: str, return_code=ReturnCode.OK, body="ok"): + req = Message(topic="req", body="") + reply = Message(topic="reply", body=body) + reply.set_header(MsgHeader.RETURN_CODE, return_code) + return ClientReply(client_token=f"token-{client_name}", client_name=client_name, req=req, reply=reply) + + +def _make_timeout_reply(client_name: str): + """Simulate a client that did not respond (reply=None).""" + return ClientReply( + client_token=f"token-{client_name}", client_name=client_name, req=Message(topic="req", body=""), reply=None + ) + + +# --------------------------------------------------------------------------- +# Legacy (non-strict) mode +# --------------------------------------------------------------------------- + + +def test_check_client_replies_legacy_allows_timeout_reply(): + """Non-strict mode silently accepts a timeout reply.""" + replies = [_make_timeout_reply("C1")] + + result = check_client_replies(replies=replies, client_sites=["C1"], command="start", strict=False) + + assert result == [] + + +def test_check_client_replies_legacy_uses_dict_lookup_not_zip(): + """Non-strict mode uses name-keyed lookup; reply order does not matter.""" + # Replies in reverse order of client_sites — old zip() would give wrong names. + replies = [_make_client_reply("C2"), _make_client_reply("C1")] + + result = check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=False) + + assert result == [] + + +# --------------------------------------------------------------------------- +# Strict mode — timeouts +# --------------------------------------------------------------------------- + + +def test_check_client_replies_strict_returns_timed_out_clients(): + """In strict mode a timeout reply is returned as a timed-out client, NOT raised.""" + replies = [_make_timeout_reply("C1")] + + timed_out = check_client_replies(replies=replies, client_sites=["C1"], command="start", strict=True) + + assert timed_out == ["C1"] + + +def test_check_client_replies_strict_returns_only_timed_out_clients(): + """Mixed: one OK, one timeout — only the timed-out client is returned.""" + replies = [_make_client_reply("C1"), _make_timeout_reply("C2")] + + timed_out = check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=True) + + assert timed_out == ["C2"] + + +def test_check_client_replies_strict_no_timeouts_returns_empty(): + """All clients responded successfully — returns empty list.""" + replies = [_make_client_reply("C1"), _make_client_reply("C2")] + + result = check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=True) + + assert result == [] + + +# --------------------------------------------------------------------------- +# Strict mode — explicit errors always raise +# --------------------------------------------------------------------------- + + +def test_check_client_replies_strict_raises_for_non_ok_return_code(): + replies = [_make_client_reply("C1", return_code=ReturnCode.ERROR, body="start failed")] + + with pytest.raises(RuntimeError, match="start failed"): + check_client_replies(replies=replies, client_sites=["C1"], command="start", strict=True) + + +def test_check_client_replies_strict_raises_for_missing_client_reply(): + """Structurally missing entry (client not in replies dict at all) always raises.""" + replies = [_make_client_reply("C1"), _make_client_reply("CX")] + + with pytest.raises(RuntimeError, match=r"missing replies from \["): + check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=True) + + +def test_check_client_replies_strict_raises_but_not_for_timeout_when_mixed(): + """If one client has explicit error and another times out, explicit error raises.""" + replies = [_make_client_reply("C1", return_code=ReturnCode.ERROR, body="err"), _make_timeout_reply("C2")] + + with pytest.raises(RuntimeError, match="err"): + check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=True) + + +# --------------------------------------------------------------------------- +# Strict mode — reply ordering +# --------------------------------------------------------------------------- + + +def test_check_client_replies_strict_allows_reordered_success_replies(): + replies = [_make_client_reply("C2"), _make_client_reply("C1")] + + check_client_replies(replies=replies, client_sites=["C1", "C2"], command="start", strict=True) + + +# --------------------------------------------------------------------------- +# Non-strict mode — ERROR_MSG_PREFIX detection +# --------------------------------------------------------------------------- + + +def test_check_client_replies_legacy_raises_when_body_starts_with_error_prefix(): + """Non-strict mode raises when reply body starts with ERROR_MSG_PREFIX.""" + from nvflare.private.defs import ERROR_MSG_PREFIX + + replies = [_make_client_reply("C1", body=f"{ERROR_MSG_PREFIX}: something went wrong")] + + with pytest.raises(RuntimeError, match="something went wrong"): + check_client_replies(replies=replies, client_sites=["C1"], command="start", strict=False) + + +def test_check_client_replies_legacy_does_not_raise_when_prefix_not_at_start(): + """Non-strict mode uses startswith — a body containing the prefix mid-string is NOT an error.""" + from nvflare.private.defs import ERROR_MSG_PREFIX + + replies = [_make_client_reply("C1", body=f"info: see {ERROR_MSG_PREFIX} for details")] + + result = check_client_replies(replies=replies, client_sites=["C1"], command="start", strict=False) + + assert result == [] diff --git a/tests/unit_test/private/fed/server/fed_server_test.py b/tests/unit_test/private/fed/server/fed_server_test.py index 235cfac0c9..c33a27937f 100644 --- a/tests/unit_test/private/fed/server/fed_server_test.py +++ b/tests/unit_test/private/fed/server/fed_server_test.py @@ -16,6 +16,7 @@ import pytest +from nvflare.apis.fl_constant import RunProcessKey from nvflare.apis.shareable import Shareable from nvflare.private.defs import CellMessageHeaderKeys, new_cell_message from nvflare.private.fed.server.fed_server import FederatedServer @@ -25,8 +26,7 @@ class TestFederatedServer: @pytest.mark.parametrize("server_state, expected", [(HotState(), ["extra_job"]), (ColdState(), [])]) def test_heart_beat_abort_jobs(self, server_state, expected): - with patch("nvflare.private.fed.server.fed_server.ServerEngine") as mock_engine: - + with patch("nvflare.private.fed.server.fed_server.ServerEngine"): server = FederatedServer( project_name="project_name", min_num_clients=1, @@ -53,3 +53,170 @@ def test_heart_beat_abort_jobs(self, server_state, expected): result = server.client_heartbeat(request) assert result.get_header(CellMessageHeaderKeys.ABORT_JOBS, []) == expected + + def test_sync_client_jobs_legacy_reports_missing_immediately(self): + with ( + patch("nvflare.private.fed.server.fed_server.ServerEngine"), + patch("nvflare.private.fed.server.fed_server.ConfigService.get_bool_var", return_value=False), + ): + server = FederatedServer( + project_name="project_name", + min_num_clients=1, + max_num_clients=10, + cmd_modules=None, + heart_beat_timeout=600, + args=MagicMock(), + secure_train=False, + snapshot_persistor=MagicMock(), + overseer_agent=MagicMock(), + ) + + token = "token-1" + client = MagicMock() + client.name = "C1" + server.engine.run_processes = {"job1": {RunProcessKey.PARTICIPANTS: {token: client}}} + server.engine.notify_dead_job = MagicMock() + + no_job_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: []}, Shareable()) + server._sync_client_jobs(no_job_request, token) + + server.engine.notify_dead_job.assert_called_once_with("job1", "C1", "missing job on client") + + def test_sync_client_jobs_reports_missing_only_after_prior_seen_when_enabled(self): + with ( + patch("nvflare.private.fed.server.fed_server.ServerEngine"), + patch("nvflare.private.fed.server.fed_server.ConfigService.get_bool_var", return_value=True), + ): + server = FederatedServer( + project_name="project_name", + min_num_clients=1, + max_num_clients=10, + cmd_modules=None, + heart_beat_timeout=600, + args=MagicMock(), + secure_train=False, + snapshot_persistor=MagicMock(), + overseer_agent=MagicMock(), + ) + + token = "token-1" + client = MagicMock() + client.name = "C1" + server.engine.run_processes = {"job1": {RunProcessKey.PARTICIPANTS: {token: client}}} + server.engine.notify_dead_job = MagicMock() + + no_job_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: []}, Shareable()) + server._sync_client_jobs(no_job_request, token) + server.engine.notify_dead_job.assert_not_called() + + job_present_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: ["job1"]}, Shareable()) + server._sync_client_jobs(job_present_request, token) + server.engine.notify_dead_job.assert_not_called() + + server._sync_client_jobs(no_job_request, token) + server.engine.notify_dead_job.assert_called_once_with("job1", "C1", "missing job on client") + + def test_sync_client_jobs_default_requires_prior_report(self): + """Default behaviour (require_previous_report=True) must not fire on the + first missing-job heartbeat — no config override needed.""" + with ( + patch("nvflare.private.fed.server.fed_server.ServerEngine"), + patch("nvflare.private.fed.server.fed_server.ConfigService.get_bool_var", return_value=True), + ): + server = FederatedServer( + project_name="project_name", + min_num_clients=1, + max_num_clients=10, + cmd_modules=None, + heart_beat_timeout=600, + args=MagicMock(), + secure_train=False, + snapshot_persistor=MagicMock(), + overseer_agent=MagicMock(), + ) + + token = "token-1" + client = MagicMock() + client.name = "C1" + server.engine.run_processes = {"job1": {RunProcessKey.PARTICIPANTS: {token: client}}} + server.engine.notify_dead_job = MagicMock() + + # First heartbeat: client says it has no job1 — should NOT fire yet + no_job_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: []}, Shareable()) + server._sync_client_jobs(no_job_request, token) + server.engine.notify_dead_job.assert_not_called() + + def test_sync_client_jobs_tracking_in_server_attr_not_job_info(self): + """Positive observations must be recorded in server._job_reported_clients, + NOT injected into the job_info dict.""" + with ( + patch("nvflare.private.fed.server.fed_server.ServerEngine"), + patch("nvflare.private.fed.server.fed_server.ConfigService.get_bool_var", return_value=True), + ): + server = FederatedServer( + project_name="project_name", + min_num_clients=1, + max_num_clients=10, + cmd_modules=None, + heart_beat_timeout=600, + args=MagicMock(), + secure_train=False, + snapshot_persistor=MagicMock(), + overseer_agent=MagicMock(), + ) + + token = "token-1" + client = MagicMock() + client.name = "C1" + job_info = {RunProcessKey.PARTICIPANTS: {token: client}} + server.engine.run_processes = {"job1": job_info} + server.engine.notify_dead_job = MagicMock() + + # Positive observation heartbeat + job_present_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: ["job1"]}, Shareable()) + server._sync_client_jobs(job_present_request, token) + + # Token recorded in server attribute + assert "job1" in server._job_reported_clients + assert token in server._job_reported_clients["job1"] + + # NOT injected into job_info dict + assert "_reported_clients" not in job_info + + def test_sync_client_jobs_cleans_up_stale_job_tracking(self): + """When a job is removed from run_processes the corresponding tracking + entry in _job_reported_clients must be purged on the next sync call.""" + with ( + patch("nvflare.private.fed.server.fed_server.ServerEngine"), + patch("nvflare.private.fed.server.fed_server.ConfigService.get_bool_var", return_value=True), + ): + server = FederatedServer( + project_name="project_name", + min_num_clients=1, + max_num_clients=10, + cmd_modules=None, + heart_beat_timeout=600, + args=MagicMock(), + secure_train=False, + snapshot_persistor=MagicMock(), + overseer_agent=MagicMock(), + ) + + token = "token-1" + client = MagicMock() + client.name = "C1" + server.engine.run_processes = {"job1": {RunProcessKey.PARTICIPANTS: {token: client}}} + server.engine.notify_dead_job = MagicMock() + + # Positive observation — entry created in _job_reported_clients + job_present = new_cell_message({CellMessageHeaderKeys.JOB_IDS: ["job1"]}, Shareable()) + server._sync_client_jobs(job_present, token) + assert "job1" in server._job_reported_clients + + # Job finishes — removed from run_processes + server.engine.run_processes = {} + + # Next sync call for any client should purge the stale entry + other_request = new_cell_message({CellMessageHeaderKeys.JOB_IDS: []}, Shareable()) + server._sync_client_jobs(other_request, token) + assert "job1" not in server._job_reported_clients diff --git a/tests/unit_test/private/fed/server/job_runner_deploy_test.py b/tests/unit_test/private/fed/server/job_runner_deploy_test.py new file mode 100644 index 0000000000..77841d1743 --- /dev/null +++ b/tests/unit_test/private/fed/server/job_runner_deploy_test.py @@ -0,0 +1,319 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for JobRunner._deploy_job() — focusing on timeout/failure classification +and the min_sites / required_sites abort logic. + +The test infrastructure stubs out all engine/fl_ctx interaction so that only +_deploy_job()'s own logic is exercised.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from nvflare.apis.client import Client +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.job_def import Job +from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode +from nvflare.private.fed.server.job_runner import JobRunner + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ok_reply(): + """Simulate a successful deployment ACK.""" + msg = Message(topic="reply", body="ok") + msg.set_header(MsgHeader.RETURN_CODE, ReturnCode.OK) + return msg + + +def _error_reply(body="deploy failed"): + """Simulate an explicit error ACK.""" + msg = Message(topic="reply", body=body) + msg.set_header(MsgHeader.RETURN_CODE, ReturnCode.ERROR) + return msg + + +def _build_fl_ctx(token_to_reply: dict, job_id="job-1", min_sites=None, required_sites=None): + """Build a minimal fl_ctx / engine mock for _deploy_job(). + + Args: + token_to_reply: mapping of client_token -> Message|None + None simulates a deployment timeout for that client. + min_sites: job.min_sites value + required_sites: job.required_sites list (or None) + + Returns: + (runner, fl_ctx, job, client_sites) + """ + runner = JobRunner(workspace_root="/tmp") + runner.log_info = MagicMock() + runner.log_warning = MagicMock() + runner.fire_event = MagicMock() + + # Build client objects matching token_to_reply keys + client_objects = [] + sites = {} + for i, token in enumerate(token_to_reply): + client_name = f"site-{i + 1}" + c = MagicMock(spec=Client) + c.token = token + c.name = client_name + client_objects.append(c) + sites[client_name] = MagicMock() + + # Engine + engine = MagicMock() + engine.validate_targets.return_value = (client_objects, []) + engine.get_clients.return_value = client_objects + + # AdminServer + admin_server = MagicMock() + admin_server.timeout = 10.0 + admin_server.send_requests_and_get_reply_dict.return_value = token_to_reply + engine.server.admin_server = admin_server + + # fl_ctx + fl_ctx = MagicMock() + fl_ctx.get_engine.return_value = engine + deploy_detail = [] + fl_ctx.get_prop.return_value = deploy_detail + fl_ctx.set_prop.side_effect = lambda key, val: deploy_detail.__class__ # no-op for other props + + # Job + job = MagicMock(spec=Job) + job.job_id = job_id + job.meta = {} + job.min_sites = min_sites + job.required_sites = required_sites or [] + + # Simulate a single app deployment to all client sites + deployment = {"app": list(sites.keys())} + job.get_deployment.return_value = deployment + job.get_application.return_value = b"app_data" + + return runner, fl_ctx, engine, job, sites + + +# --------------------------------------------------------------------------- +# Deployment timeout classified as failure +# --------------------------------------------------------------------------- + +_DEPLOY_PATCHES = [ + "nvflare.private.fed.server.job_runner.Workspace", + "nvflare.private.fed.server.job_runner.AppDeployer", + "nvflare.private.fed.server.job_runner.verify_folder_signature", +] + + +def _run_deploy(runner, job, sites, fl_ctx, *, extra_patches=None): + """Run _deploy_job with the standard set of external dependencies patched out.""" + patches = list(_DEPLOY_PATCHES) + if extra_patches: + patches.extend(extra_patches) + with patch.object(runner, "_make_deploy_message", return_value=MagicMock()): + with patch(patches[0]), patch(patches[1]), patch(patches[2], return_value=True): + return runner._deploy_job(job, sites, fl_ctx) + + +class TestDeployJobTimeoutClassification: + def test_timeout_reply_counted_as_failed_client(self): + """A client that returns None (timeout) must appear in failed_clients.""" + token_to_reply = {"token-1": _ok_reply(), "token-2": None} + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=1) + + _, failed = _run_deploy(runner, job, sites, fl_ctx) + + assert "site-2" in failed + + def test_ok_reply_not_in_failed_clients(self): + """A client that returns OK must not appear in failed_clients.""" + token_to_reply = {"token-1": _ok_reply()} + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=1) + + _, failed = _run_deploy(runner, job, sites, fl_ctx) + + assert failed == [] + + def test_explicit_error_reply_counted_as_failed_client(self): + """An explicit error reply (non-OK return code) must appear in failed_clients.""" + token_to_reply = {"token-1": _ok_reply(), "token-2": _error_reply("disk full")} + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=1) + + _, failed = _run_deploy(runner, job, sites, fl_ctx) + + assert "site-2" in failed + + def test_timeout_recorded_in_deploy_detail(self): + """Timed-out clients must produce a 'deployment timeout' entry, not 'unknown'.""" + token_to_reply = {"token-1": None} + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=0) + + # Capture the deploy_detail list set on fl_ctx + captured = {} + + def capture_set_prop(key, val, **kw): + captured[key] = val + + fl_ctx.set_prop.side_effect = capture_set_prop + + _run_deploy(runner, job, sites, fl_ctx) + + detail = captured.get(FLContextKey.JOB_DEPLOY_DETAIL, []) + assert any( + "deployment timeout" in entry for entry in detail + ), f"Expected 'deployment timeout' in deploy_detail but got: {detail}" + assert not any("unknown" in entry for entry in detail), f"Old 'unknown' label should not appear; got: {detail}" + + def test_mixed_outcomes_all_correctly_classified(self): + """OK + error + timeout in one batch: only error and timeout end up in failed_clients.""" + token_to_reply = { + "token-1": _ok_reply(), + "token-2": _error_reply("out of memory"), + "token-3": None, # timeout + } + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=1) + + _, failed = _run_deploy(runner, job, sites, fl_ctx) + + assert "site-1" not in failed # OK → not failed + assert "site-2" in failed # explicit error → failed + assert "site-3" in failed # timeout → failed + + +# --------------------------------------------------------------------------- +# min_sites logic with timeouts treated as failures +# --------------------------------------------------------------------------- + + +class TestDeployJobMinSites: + def test_timeout_does_not_abort_when_within_min_sites(self): + """One timeout but two OK; min_sites=2 → 2 ok ≥ 2 → proceed.""" + token_to_reply = { + "token-1": _ok_reply(), + "token-2": None, + "token-3": _ok_reply(), + } + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=2) + + job_id, failed = _run_deploy(runner, job, sites, fl_ctx) + + assert "site-2" in failed + assert job_id == "job-1" + + def test_timeout_aborts_when_below_min_sites(self): + """All clients time out; min_sites=2 → 0 ok < 2 → RuntimeError.""" + token_to_reply = {"token-1": None, "token-2": None} + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=2) + + with pytest.raises(RuntimeError, match="deploy failure"): + _run_deploy(runner, job, sites, fl_ctx) + + def test_timeout_aborts_below_min_sites_mixed(self): + """One OK but two fail (1 error + 1 timeout); min_sites=2 → 1 ok < 2 → abort.""" + token_to_reply = { + "token-1": _ok_reply(), + "token-2": None, + "token-3": _error_reply("refused"), + } + runner, fl_ctx, engine, job, sites = _build_fl_ctx(token_to_reply, min_sites=2) + + with pytest.raises(RuntimeError, match="deploy failure"): + _run_deploy(runner, job, sites, fl_ctx) + + +# --------------------------------------------------------------------------- +# Full startup sequence integration-style test +# --------------------------------------------------------------------------- + + +class TestDeployAndStartIntegration: + """Verify the full deploy → start sequence correctly handles timeouts at both phases.""" + + @patch("nvflare.private.fed.server.job_runner.check_client_replies") + @patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) + def test_deploy_timeout_excluded_from_start_run(self, mock_get_bool, mock_check_replies): + """Clients that time out at deployment are excluded from _start_run's client_sites + so the start-job phase never sees them.""" + mock_check_replies.return_value = [] # all start-job replies OK + + runner = JobRunner(workspace_root="/tmp") + runner.log_info = MagicMock() + runner.log_warning = MagicMock() + runner.fire_event = MagicMock() + + # Two clients: site-1 OK, site-2 deployment timeout + client1 = MagicMock(spec=Client) + client1.token = "token-1" + client1.name = "site-1" + client1.to_dict.return_value = {"name": "site-1"} + + client2 = MagicMock(spec=Client) + client2.token = "token-2" + client2.name = "site-2" + + engine = MagicMock() + engine.validate_targets.return_value = ([client1, client2], []) + engine.get_job_clients.return_value = {"token-1": client1} + engine.start_app_on_server.return_value = "" + engine.start_client_job.return_value = [MagicMock()] + + admin_server = MagicMock() + admin_server.timeout = 10.0 + admin_server.send_requests_and_get_reply_dict.return_value = { + "token-1": _ok_reply(), + "token-2": None, # deployment timeout + } + engine.server.admin_server = admin_server + + fl_ctx = MagicMock() + fl_ctx.get_engine.return_value = engine + deploy_detail = [] + fl_ctx.get_prop.return_value = deploy_detail + + job = MagicMock(spec=Job) + job.job_id = "job-e2e" + job.meta = {} + job.min_sites = 1 + job.required_sites = [] + job.get_deployment.return_value = {"app": ["site-1", "site-2"]} + job.get_application.return_value = b"app_data" + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + + with ( + patch.object(runner, "_make_deploy_message", return_value=MagicMock()), + patch("nvflare.private.fed.server.job_runner.Workspace"), + patch("nvflare.private.fed.server.job_runner.AppDeployer"), + patch("nvflare.private.fed.server.job_runner.verify_folder_signature", return_value=True), + ): + job_id, failed = runner._deploy_job(job, client_sites, fl_ctx) + + # site-2 must be in failed (deployment timeout) + assert "site-2" in failed + # site-1 must not be in failed + assert "site-1" not in failed + + # In the real run() loop, deployable_clients = client_sites - failed_clients. + # Verify _start_run actually uses only deployable clients. + deployable = {k: v for k, v in client_sites.items() if k not in failed} + assert "site-1" in deployable + assert "site-2" not in deployable + + runner._start_run(job_id=job_id, job=job, client_sites=deployable, fl_ctx=fl_ctx) + + engine.start_client_job.assert_called_once_with(job, deployable, fl_ctx) + assert mock_check_replies.call_args.kwargs["client_sites"] == ["site-1"] diff --git a/tests/unit_test/private/fed/server/job_runner_test.py b/tests/unit_test/private/fed/server/job_runner_test.py new file mode 100644 index 0000000000..479c832f17 --- /dev/null +++ b/tests/unit_test/private/fed/server/job_runner_test.py @@ -0,0 +1,303 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from nvflare.apis.job_def import JobMetaKey +from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode +from nvflare.private.fed.server.job_runner import JobRunner +from nvflare.private.fed.server.message_send import ClientReply + + +def _make_runner_inputs(num_clients=1): + runner = JobRunner(workspace_root="/tmp") + runner.log_info = MagicMock() + runner.log_warning = MagicMock() + runner.fire_event = MagicMock() + + fl_ctx = MagicMock() + engine = MagicMock() + fl_ctx.get_engine.return_value = engine + + client_obj = MagicMock() + client_obj.to_dict.return_value = {"name": "site-1"} + engine.get_job_clients.return_value = {"token-1": client_obj} + engine.start_app_on_server.return_value = "" + engine.start_client_job.return_value = [MagicMock()] + + job = MagicMock() + job.job_id = "job-1" + job.meta = {} + job.min_sites = 0 # no minimum by default + job.required_sites = None # no required sites by default + + client_sites = {"site-1": MagicMock()} + return runner, fl_ctx, engine, job, client_sites + + +# --------------------------------------------------------------------------- +# strict flag wiring +# --------------------------------------------------------------------------- + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=False) +def test_start_run_passes_strict_false_when_flag_disabled(mock_get_bool, mock_check_replies): + mock_check_replies.return_value = [] # no timeouts + runner, fl_ctx, _engine, job, client_sites = _make_runner_inputs() + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + mock_get_bool.assert_called_once() + mock_check_replies.assert_called_once() + assert mock_check_replies.call_args.kwargs["strict"] is False + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_passes_strict_true_when_flag_enabled(mock_get_bool, mock_check_replies): + mock_check_replies.return_value = [] # no timeouts + runner, fl_ctx, _engine, job, client_sites = _make_runner_inputs() + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + mock_get_bool.assert_called_once() + mock_check_replies.assert_called_once() + assert mock_check_replies.call_args.kwargs["strict"] is True + + +# --------------------------------------------------------------------------- +# timeout exclusion in _start_run +# --------------------------------------------------------------------------- + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_proceeds_when_timed_out_clients_within_min_sites(mock_get_bool, mock_check_replies): + """When some clients time out but active count >= min_sites, job proceeds with a warning.""" + mock_check_replies.return_value = ["site-2"] # site-2 timed out + runner, fl_ctx, _engine, job, client_sites = _make_runner_inputs() + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 1 # require at least 1; site-1 is still active + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + runner.log_warning.assert_called_once() + warning_msg = runner.log_warning.call_args[0][1] + assert "site-2" in warning_msg + assert "timed out" in warning_msg + + +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=False) +def test_start_run_non_strict_excludes_timed_out_clients_from_meta(mock_get_bool): + """Even when strict checking is disabled, JOB_CLIENTS should include only active clients.""" + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + + ok_reply = Message(topic="reply", body="ok") + ok_reply.set_header(MsgHeader.RETURN_CODE, ReturnCode.OK) + req1 = Message(topic="req", body="") + req2 = Message(topic="req", body="") + engine.start_client_job.return_value = [ + ClientReply(client_token="token-site-1", client_name="site-1", req=req1, reply=ok_reply), + ClientReply(client_token="token-site-2", client_name="site-2", req=req2, reply=None), + ] + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert job.meta[JobMetaKey.JOB_CLIENTS] == [{"name": "site-1"}] + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_raises_when_timed_out_clients_breach_min_sites(mock_get_bool, mock_check_replies): + """When timeouts cause active count to fall below min_sites, _start_run raises.""" + mock_check_replies.return_value = ["site-1", "site-2"] # both timed out + runner, fl_ctx, _engine, job, client_sites = _make_runner_inputs() + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 2 # need at least 2; 0 active after timeouts + + with pytest.raises(RuntimeError, match="min_sites"): + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_updates_job_clients_meta_after_timeout_exclusion(mock_get_bool, mock_check_replies): + mock_check_replies.return_value = ["site-2"] + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 1 + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert job.meta[JobMetaKey.JOB_CLIENTS] == [{"name": "site-1"}] + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_keeps_job_clients_meta_when_no_timeouts(mock_get_bool, mock_check_replies): + mock_check_replies.return_value = [] + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert job.meta[JobMetaKey.JOB_CLIENTS] == [{"name": "site-1"}, {"name": "site-2"}] + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_sets_job_clients_meta_before_start_client_job(mock_get_bool, mock_check_replies): + mock_check_replies.return_value = [] + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + + seen_job_clients_meta = {} + + def _start_client_job_side_effect(passed_job, passed_client_sites, passed_fl_ctx): + seen_job_clients_meta["value"] = passed_job.meta.get(JobMetaKey.JOB_CLIENTS) + return [MagicMock()] + + engine.start_client_job.side_effect = _start_client_job_side_effect + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert seen_job_clients_meta["value"] == [{"name": "site-1"}, {"name": "site-2"}] + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_raises_when_required_site_times_out(mock_get_bool, mock_check_replies): + """A timed-out required site must abort the job even if active_count >= min_sites.""" + mock_check_replies.return_value = ["site-2"] # site-2 timed out + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 1 # still satisfied after site-2 drops out + job.required_sites = ["site-2"] # but site-2 is required + + with pytest.raises(RuntimeError, match="required client site-2 timed out"): + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + +@patch("nvflare.private.fed.server.job_runner.check_client_replies") +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_proceeds_when_non_required_site_times_out(mock_get_bool, mock_check_replies): + """A timed-out non-required site proceeds normally when min_sites is still satisfied.""" + mock_check_replies.return_value = ["site-2"] # site-2 timed out but is not required + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 1 + job.required_sites = ["site-1"] # site-1 is required, site-2 is not + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert job.meta[JobMetaKey.JOB_CLIENTS] == [{"name": "site-1"}] + runner.log_warning.assert_called_once() + + +@patch("nvflare.private.fed.server.job_runner.ConfigService.get_bool_var", return_value=True) +def test_start_run_integration_real_reply_check_updates_meta(mock_get_bool): + """Integration-style check: _start_run + real check_client_replies timeout path.""" + runner, fl_ctx, engine, job, _client_sites = _make_runner_inputs() + + site1 = MagicMock() + site1.name = "site-1" + site1.to_dict.return_value = {"name": "site-1"} + + site2 = MagicMock() + site2.name = "site-2" + site2.to_dict.return_value = {"name": "site-2"} + + engine.get_job_clients.return_value = {"token-1": site1, "token-2": site2} + + ok_reply = Message(topic="reply", body="ok") + ok_reply.set_header(MsgHeader.RETURN_CODE, ReturnCode.OK) + req1 = Message(topic="req", body="") + req2 = Message(topic="req", body="") + engine.start_client_job.return_value = [ + ClientReply(client_token="token-site-1", client_name="site-1", req=req1, reply=ok_reply), + ClientReply(client_token="token-site-2", client_name="site-2", req=req2, reply=None), + ] + + client_sites = {"site-1": MagicMock(), "site-2": MagicMock()} + job.min_sites = 1 + + runner._start_run(job_id=job.job_id, job=job, client_sites=client_sites, fl_ctx=fl_ctx) + + assert job.meta[JobMetaKey.JOB_CLIENTS] == [{"name": "site-1"}] + runner.log_warning.assert_called_once() + warning_msg = runner.log_warning.call_args[0][1] + assert "site-2" in warning_msg + assert "timed out" in warning_msg