Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nvflare/app_common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error

__all__ = ["should_ignore_result_error", "get_error_handling_message"]
109 changes: 109 additions & 0 deletions nvflare/app_common/utils/error_handling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Set


def should_ignore_result_error(
ignore_result_error: Optional[bool],
client_name: str,
failed_clients: Set[str],
num_targets: int,
min_responses: int,
) -> bool:
"""Determine whether a client result error should be ignored or cause a panic.

This function implements the three-mode error handling policy:
- None (Dynamic): Ignore errors if min_responses can still be reached, panic otherwise.
- False (Strict): Never ignore errors, always panic.
- True (Resilient): Always ignore errors, never panic.

Note: This function can be safely called multiple times for the same client error.
The failed_clients set uses idempotent add() operations, so duplicate calls for
the same client will not affect the remaining count calculation.

Args:
ignore_result_error: The error handling mode.
- None: Dynamic mode - ignore if min_responses still reachable.
- False: Strict mode - always panic on error.
- True: Resilient mode - always ignore errors.
client_name: Name of the client with the error.
failed_clients: Set of client names that have already failed (will be updated
in dynamic mode only).
num_targets: Total number of target clients for the current task.
min_responses: Minimum number of responses required.

Returns:
True if the error should be ignored (no panic needed).
False if a panic should be triggered.
"""
if ignore_result_error is True:
# Resilient mode - always ignore errors
return True
elif ignore_result_error is False:
# Strict mode - always panic on error
return False
else:
# Dynamic mode (None) - check if min_responses still reachable
failed_clients.add(client_name)
remaining_good_clients = num_targets - len(failed_clients)
return remaining_good_clients >= min_responses


def get_error_handling_message(
ignore_result_error: Optional[bool],
client_name: str,
error_code: Any,
current_round: Optional[int],
controller_name: str,
failed_clients: Set[str],
num_targets: int,
min_responses: int,
) -> str:
"""Generate appropriate log message based on error handling mode.

Args:
ignore_result_error: The error handling mode (None, False, or True).
client_name: Name of the client with the error.
error_code: The return code from the client result (ReturnCode constant or None).
current_round: Current training round (may be None if not set in result).
controller_name: Name of the controller class.
failed_clients: Set of client names that have failed.
num_targets: Total number of target clients.
min_responses: Minimum number of responses required.

Returns:
Appropriate message string for logging.
"""
if ignore_result_error is True:
return f"Ignore the result from {client_name} at round {current_round}. " f"Result error code: {error_code}"
elif ignore_result_error is False:
return (
f"Result from {client_name} is bad, error code: {error_code}. "
f"{controller_name} exiting at round {current_round}."
)
else:
remaining_good_clients = num_targets - len(failed_clients)
if remaining_good_clients >= min_responses:
return (
f"Ignore the result from {client_name} at round {current_round}. "
f"Result error code: {error_code}. "
f"Remaining good clients ({remaining_good_clients}) >= min_responses ({min_responses})."
)
else:
return (
f"Result from {client_name} is bad, error code: {error_code}. "
f"Cannot reach min_responses: remaining good clients ({remaining_good_clients}) < min_responses ({min_responses}). "
f"{controller_name} exiting at round {current_round}."
)
108 changes: 88 additions & 20 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
from abc import ABC, abstractmethod
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey
Expand All @@ -29,6 +28,7 @@
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error
from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int, check_str
Expand All @@ -39,16 +39,18 @@ class BaseModelController(Controller, FLComponentWrapper, ABC):
def __init__(
self,
persistor_id: str = AppConstants.DEFAULT_PERSISTOR_ID,
ignore_result_error: bool = False,
ignore_result_error: Optional[bool] = None,
allow_empty_global_weights: bool = False,
task_check_period: float = 0.5,
):
"""FLModel based controller.

Args:
persistor_id (str, optional): ID of the persistor component. Defaults to AppConstants.DEFAULT_PERSISTOR_ID ("persistor").
ignore_result_error (bool, optional): whether this controller can proceed if client result has errors.
Defaults to False.
ignore_result_error (bool or None, optional): How to handle client result errors.
- None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise.
- False: Strict mode - panic on any client error.
- True: Resilient mode - always ignore client errors.
allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have
empty global weights at first round, such that clients start training from scratch without any global info.
Defaults to False.
Expand All @@ -73,6 +75,12 @@ def __init__(
# model related
self._results = []

# Task context for dynamic ignore_result_error mode (when ignore_result_error=None).
# These are reset per send_model() call to track error tolerance for the current task.
self._current_min_responses = 0 # Minimum successful responses needed for this task
self._current_num_targets = 0 # Total number of clients targeted for this task
self._current_failed_clients = set() # Set of client names that returned errors in this task

def start_controller(self, fl_ctx: FLContext) -> None:
self.fl_ctx = fl_ctx
self.info("Initializing BaseModelController workflow.")
Expand Down Expand Up @@ -139,6 +147,12 @@ def broadcast_model(
if not blocking and not isinstance(callback, Callable):
raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback)))

# Store task context for dynamic ignore_result_error mode
num_targets = len(targets) if targets else len(self.engine.get_clients())
self._current_min_responses = min_responses if min_responses > 0 else num_targets
Comment on lines +150 to +152
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In broadcast_model(), num_targets = len(targets) if targets else len(self.engine.get_clients()) treats an empty targets=[] the same as targets=None (all clients). That makes the dynamic ignore_result_error=None tolerance math incorrect, and it also conflicts with later logic that passes targets through to broadcast_and_wait. Use an explicit targets is not None check (and consider validating that targets is non-empty if empty is not meaningful).

Copilot uses AI. Check for mistakes.
self._current_num_targets = num_targets
self._current_failed_clients = set()

self.set_fl_context(data)

task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback)
Expand Down Expand Up @@ -231,9 +245,24 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result_model.meta["client_name"] = client_name

self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT)
self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
accepted = self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT)

# If result was rejected (error ignored or panic), skip further processing
if not accepted:
client_task.result = None
return

# Now try to convert result to FLModel
try:
result_model = FLModelUtils.from_shareable(result)
result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA]
result_model.meta["client_name"] = client_name
Comment on lines +256 to +260
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_process_result() converts the raw Shareable to FLModel before calling _accept_train_result(), and then converts it again inside the new try/except block. This duplicates work and (more importantly) the first conversion is unguarded and can raise for errored/invalid results that should have been rejected by _accept_train_result() first. Convert only once, after _accept_train_result() returns True, and keep the conversion inside the try/except.

Copilot uses AI. Check for mistakes.
except Exception as e:
self.warning(f"Failed to convert result from {client_name} to FLModel: {e}")
client_task.result = None
return

callback = client_task.task.get_prop(AppConstants.TASK_PROP_CALLBACK)
if callback:
try:
Expand All @@ -245,38 +274,77 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:

# Cleanup task result
client_task.result = None

gc.collect()
# Note: Memory cleanup (gc.collect + malloc_trim) is handled by subclasses
# via _maybe_cleanup_memory() based on memory_gc_rounds setting

def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
) -> None:
if task_name == AppConstants.TASK_TRAIN:
self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx)
self.info(f"Result of unknown task {task_name} sent to aggregator.")
accepted = self._accept_train_result(
client_name=client.name, result=result, fl_ctx=fl_ctx, is_unknown_task=True
)
if accepted:
self.info(f"Result of unknown task {task_name} sent to aggregator.")
else:
self.error("Ignoring result from unknown task.")

def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLContext):
def _accept_train_result(
self, client_name: str, result: Shareable, fl_ctx: FLContext, is_unknown_task: bool = False
) -> bool:
"""Accept or reject a training result based on error handling policy.

Args:
client_name: Name of the client that sent the result.
result: The Shareable result from the client.
fl_ctx: The FLContext.
is_unknown_task: Whether this result is from an unknown/late task.

Returns:
True if the result was accepted, False if it was rejected (error ignored or panic triggered).
"""
self.fl_ctx = fl_ctx
rc = result.get_return_code()

current_round = result.get_header(AppConstants.CURRENT_ROUND, None)

# For unknown/late tasks, always ignore errors (no valid tolerance context)
# For normal tasks, use the configured ignore_result_error setting
ignore_result_error = True if is_unknown_task else self._ignore_result_error

# Use empty set for unknown tasks since we don't have valid tracking context
failed_clients = set() if is_unknown_task else self._current_failed_clients
num_targets = 0 if is_unknown_task else self._current_num_targets
min_responses = 0 if is_unknown_task else self._current_min_responses

# Raise panic if bad peer context or execution exception.
if rc and rc != ReturnCode.OK:
if self._ignore_result_error:
self.warning(
f"Ignore the train result from {client_name} at round {current_round}. Train result error code: {rc}",
)
should_ignore = should_ignore_result_error(
ignore_result_error=ignore_result_error,
client_name=client_name,
failed_clients=failed_clients,
num_targets=num_targets,
min_responses=min_responses,
)
msg = get_error_handling_message(
ignore_result_error=ignore_result_error,
client_name=client_name,
error_code=rc,
current_round=current_round,
controller_name=self.__class__.__name__,
failed_clients=failed_clients,
num_targets=num_targets,
min_responses=min_responses,
)
if should_ignore:
self.warning(msg)
return False # Result rejected - error ignored
else:
self.panic(
f"Result from {client_name} is bad, error code: {rc}. "
f"{self.__class__.__name__} exiting at round {current_round}."
)
return
self.panic(msg)
return False # Result rejected - panic triggered

self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False)
return True # Result accepted

@abstractmethod
def run(self):
Expand Down
6 changes: 4 additions & 2 deletions nvflare/app_common/workflows/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ class Scaffold(BaseFedAvg):
num_clients (int, optional): The number of clients. Defaults to 3.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
ignore_result_error (bool, optional): whether this controller can proceed if client result has errors.
Defaults to False.
ignore_result_error (bool or None, optional): How to handle client result errors.
- None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise.
- False: Strict mode - panic on any client error.
- True: Resilient mode - always ignore client errors.
allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have
empty global weights at first round, such that clients start training from scratch without any global info.
Defaults to False.
Expand Down
Loading
Loading