-
Notifications
You must be signed in to change notification settings - Fork 242
Cherry-pick of #4084 and #4132 #4286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Optional, Set | ||
|
|
||
|
|
||
| def should_ignore_result_error( | ||
| ignore_result_error: Optional[bool], | ||
| client_name: str, | ||
| failed_clients: Set[str], | ||
| num_targets: int, | ||
| min_responses: int, | ||
| ) -> bool: | ||
| """Determine whether a client result error should be ignored or cause a panic. | ||
|
|
||
| This function implements the three-mode error handling policy: | ||
| - None (Dynamic): Ignore errors if min_responses can still be reached, panic otherwise. | ||
| - False (Strict): Never ignore errors, always panic. | ||
| - True (Resilient): Always ignore errors, never panic. | ||
|
|
||
| Note: This function can be safely called multiple times for the same client error. | ||
| The failed_clients set uses idempotent add() operations, so duplicate calls for | ||
| the same client will not affect the remaining count calculation. | ||
|
|
||
| Args: | ||
| ignore_result_error: The error handling mode. | ||
| - None: Dynamic mode - ignore if min_responses still reachable. | ||
| - False: Strict mode - always panic on error. | ||
| - True: Resilient mode - always ignore errors. | ||
| client_name: Name of the client with the error. | ||
| failed_clients: Set of client names that have already failed (will be updated | ||
| in dynamic mode only). | ||
| num_targets: Total number of target clients for the current task. | ||
| min_responses: Minimum number of responses required. | ||
|
|
||
| Returns: | ||
| True if the error should be ignored (no panic needed). | ||
| False if a panic should be triggered. | ||
| """ | ||
| if ignore_result_error is True: | ||
| # Resilient mode - always ignore errors | ||
| return True | ||
| elif ignore_result_error is False: | ||
| # Strict mode - always panic on error | ||
| return False | ||
| else: | ||
| # Dynamic mode (None) - check if min_responses still reachable | ||
| failed_clients.add(client_name) | ||
| remaining_good_clients = num_targets - len(failed_clients) | ||
| return remaining_good_clients >= min_responses | ||
|
|
||
|
|
||
| def get_error_handling_message( | ||
| ignore_result_error: Optional[bool], | ||
| client_name: str, | ||
| error_code: Any, | ||
| current_round: Optional[int], | ||
| controller_name: str, | ||
| failed_clients: Set[str], | ||
| num_targets: int, | ||
| min_responses: int, | ||
| ) -> str: | ||
| """Generate appropriate log message based on error handling mode. | ||
|
|
||
| Args: | ||
| ignore_result_error: The error handling mode (None, False, or True). | ||
| client_name: Name of the client with the error. | ||
| error_code: The return code from the client result (ReturnCode constant or None). | ||
| current_round: Current training round (may be None if not set in result). | ||
| controller_name: Name of the controller class. | ||
| failed_clients: Set of client names that have failed. | ||
| num_targets: Total number of target clients. | ||
| min_responses: Minimum number of responses required. | ||
|
|
||
| Returns: | ||
| Appropriate message string for logging. | ||
| """ | ||
| if ignore_result_error is True: | ||
| return f"Ignore the result from {client_name} at round {current_round}. " f"Result error code: {error_code}" | ||
| elif ignore_result_error is False: | ||
| return ( | ||
| f"Result from {client_name} is bad, error code: {error_code}. " | ||
| f"{controller_name} exiting at round {current_round}." | ||
| ) | ||
| else: | ||
| remaining_good_clients = num_targets - len(failed_clients) | ||
| if remaining_good_clients >= min_responses: | ||
| return ( | ||
| f"Ignore the result from {client_name} at round {current_round}. " | ||
| f"Result error code: {error_code}. " | ||
| f"Remaining good clients ({remaining_good_clients}) >= min_responses ({min_responses})." | ||
| ) | ||
| else: | ||
| return ( | ||
| f"Result from {client_name} is bad, error code: {error_code}. " | ||
| f"Cannot reach min_responses: remaining good clients ({remaining_good_clients}) < min_responses ({min_responses}). " | ||
| f"{controller_name} exiting at round {current_round}." | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import random | ||
| from abc import ABC, abstractmethod | ||
| from typing import Callable, List, Union | ||
| from typing import Callable, List, Optional, Union | ||
|
|
||
| from nvflare.apis.client import Client | ||
| from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey | ||
|
|
@@ -29,6 +28,7 @@ | |
| from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable | ||
| from nvflare.app_common.app_constant import AppConstants | ||
| from nvflare.app_common.app_event_type import AppEventType | ||
| from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error | ||
| from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper | ||
| from nvflare.app_common.utils.fl_model_utils import FLModelUtils | ||
| from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int, check_str | ||
|
|
@@ -39,16 +39,18 @@ class BaseModelController(Controller, FLComponentWrapper, ABC): | |
| def __init__( | ||
| self, | ||
| persistor_id: str = AppConstants.DEFAULT_PERSISTOR_ID, | ||
| ignore_result_error: bool = False, | ||
| ignore_result_error: Optional[bool] = None, | ||
| allow_empty_global_weights: bool = False, | ||
| task_check_period: float = 0.5, | ||
| ): | ||
| """FLModel based controller. | ||
|
|
||
| Args: | ||
| persistor_id (str, optional): ID of the persistor component. Defaults to AppConstants.DEFAULT_PERSISTOR_ID ("persistor"). | ||
| ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. | ||
| Defaults to False. | ||
| ignore_result_error (bool or None, optional): How to handle client result errors. | ||
| - None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise. | ||
| - False: Strict mode - panic on any client error. | ||
| - True: Resilient mode - always ignore client errors. | ||
| allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have | ||
| empty global weights at first round, such that clients start training from scratch without any global info. | ||
| Defaults to False. | ||
|
|
@@ -73,6 +75,12 @@ def __init__( | |
| # model related | ||
| self._results = [] | ||
|
|
||
| # Task context for dynamic ignore_result_error mode (when ignore_result_error=None). | ||
| # These are reset per send_model() call to track error tolerance for the current task. | ||
| self._current_min_responses = 0 # Minimum successful responses needed for this task | ||
| self._current_num_targets = 0 # Total number of clients targeted for this task | ||
| self._current_failed_clients = set() # Set of client names that returned errors in this task | ||
|
|
||
| def start_controller(self, fl_ctx: FLContext) -> None: | ||
| self.fl_ctx = fl_ctx | ||
| self.info("Initializing BaseModelController workflow.") | ||
|
|
@@ -139,6 +147,12 @@ def broadcast_model( | |
| if not blocking and not isinstance(callback, Callable): | ||
| raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) | ||
|
|
||
| # Store task context for dynamic ignore_result_error mode | ||
| num_targets = len(targets) if targets else len(self.engine.get_clients()) | ||
| self._current_min_responses = min_responses if min_responses > 0 else num_targets | ||
| self._current_num_targets = num_targets | ||
| self._current_failed_clients = set() | ||
|
|
||
| self.set_fl_context(data) | ||
|
|
||
| task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) | ||
|
|
@@ -231,9 +245,24 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: | |
| result_model.meta["client_name"] = client_name | ||
|
|
||
| self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) | ||
| self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) | ||
| accepted = self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) | ||
| self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) | ||
|
|
||
| # If result was rejected (error ignored or panic), skip further processing | ||
| if not accepted: | ||
| client_task.result = None | ||
| return | ||
|
|
||
| # Now try to convert result to FLModel | ||
| try: | ||
| result_model = FLModelUtils.from_shareable(result) | ||
| result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA] | ||
| result_model.meta["client_name"] = client_name | ||
|
Comment on lines
+256
to
+260
|
||
| except Exception as e: | ||
| self.warning(f"Failed to convert result from {client_name} to FLModel: {e}") | ||
| client_task.result = None | ||
| return | ||
|
|
||
| callback = client_task.task.get_prop(AppConstants.TASK_PROP_CALLBACK) | ||
| if callback: | ||
| try: | ||
|
|
@@ -245,38 +274,77 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: | |
|
|
||
| # Cleanup task result | ||
| client_task.result = None | ||
|
|
||
| gc.collect() | ||
| # Note: Memory cleanup (gc.collect + malloc_trim) is handled by subclasses | ||
| # via _maybe_cleanup_memory() based on memory_gc_rounds setting | ||
|
|
||
| def process_result_of_unknown_task( | ||
| self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext | ||
| ) -> None: | ||
| if task_name == AppConstants.TASK_TRAIN: | ||
| self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx) | ||
| self.info(f"Result of unknown task {task_name} sent to aggregator.") | ||
| accepted = self._accept_train_result( | ||
| client_name=client.name, result=result, fl_ctx=fl_ctx, is_unknown_task=True | ||
| ) | ||
| if accepted: | ||
| self.info(f"Result of unknown task {task_name} sent to aggregator.") | ||
| else: | ||
| self.error("Ignoring result from unknown task.") | ||
|
|
||
| def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLContext): | ||
| def _accept_train_result( | ||
| self, client_name: str, result: Shareable, fl_ctx: FLContext, is_unknown_task: bool = False | ||
| ) -> bool: | ||
| """Accept or reject a training result based on error handling policy. | ||
|
|
||
| Args: | ||
| client_name: Name of the client that sent the result. | ||
| result: The Shareable result from the client. | ||
| fl_ctx: The FLContext. | ||
| is_unknown_task: Whether this result is from an unknown/late task. | ||
|
|
||
| Returns: | ||
| True if the result was accepted, False if it was rejected (error ignored or panic triggered). | ||
| """ | ||
| self.fl_ctx = fl_ctx | ||
| rc = result.get_return_code() | ||
|
|
||
| current_round = result.get_header(AppConstants.CURRENT_ROUND, None) | ||
|
|
||
| # For unknown/late tasks, always ignore errors (no valid tolerance context) | ||
| # For normal tasks, use the configured ignore_result_error setting | ||
| ignore_result_error = True if is_unknown_task else self._ignore_result_error | ||
|
|
||
| # Use empty set for unknown tasks since we don't have valid tracking context | ||
| failed_clients = set() if is_unknown_task else self._current_failed_clients | ||
| num_targets = 0 if is_unknown_task else self._current_num_targets | ||
| min_responses = 0 if is_unknown_task else self._current_min_responses | ||
|
|
||
| # Raise panic if bad peer context or execution exception. | ||
| if rc and rc != ReturnCode.OK: | ||
| if self._ignore_result_error: | ||
| self.warning( | ||
| f"Ignore the train result from {client_name} at round {current_round}. Train result error code: {rc}", | ||
| ) | ||
| should_ignore = should_ignore_result_error( | ||
| ignore_result_error=ignore_result_error, | ||
| client_name=client_name, | ||
| failed_clients=failed_clients, | ||
| num_targets=num_targets, | ||
| min_responses=min_responses, | ||
| ) | ||
| msg = get_error_handling_message( | ||
| ignore_result_error=ignore_result_error, | ||
| client_name=client_name, | ||
| error_code=rc, | ||
| current_round=current_round, | ||
| controller_name=self.__class__.__name__, | ||
| failed_clients=failed_clients, | ||
| num_targets=num_targets, | ||
| min_responses=min_responses, | ||
| ) | ||
| if should_ignore: | ||
| self.warning(msg) | ||
| return False # Result rejected - error ignored | ||
| else: | ||
| self.panic( | ||
| f"Result from {client_name} is bad, error code: {rc}. " | ||
| f"{self.__class__.__name__} exiting at round {current_round}." | ||
| ) | ||
| return | ||
| self.panic(msg) | ||
| return False # Result rejected - panic triggered | ||
|
|
||
| self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) | ||
| return True # Result accepted | ||
|
|
||
| @abstractmethod | ||
| def run(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
broadcast_model(),num_targets = len(targets) if targets else len(self.engine.get_clients())treats an emptytargets=[]the same astargets=None(all clients). That makes the dynamicignore_result_error=Nonetolerance math incorrect, and it also conflicts with later logic that passestargetsthrough tobroadcast_and_wait. Use an explicittargets is not Nonecheck (and consider validating thattargetsis non-empty if empty is not meaningful).