diff --git a/nvflare/app_common/utils/__init__.py b/nvflare/app_common/utils/__init__.py index 2db92b2574..6d654ca3b8 100644 --- a/nvflare/app_common/utils/__init__.py +++ b/nvflare/app_common/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error + +__all__ = ["should_ignore_result_error", "get_error_handling_message"] diff --git a/nvflare/app_common/utils/error_handling_utils.py b/nvflare/app_common/utils/error_handling_utils.py new file mode 100644 index 0000000000..8de19b8eca --- /dev/null +++ b/nvflare/app_common/utils/error_handling_utils.py @@ -0,0 +1,109 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Set + + +def should_ignore_result_error( + ignore_result_error: Optional[bool], + client_name: str, + failed_clients: Set[str], + num_targets: int, + min_responses: int, +) -> bool: + """Determine whether a client result error should be ignored or cause a panic. + + This function implements the three-mode error handling policy: + - None (Dynamic): Ignore errors if min_responses can still be reached, panic otherwise. + - False (Strict): Never ignore errors, always panic. + - True (Resilient): Always ignore errors, never panic. + + Note: This function can be safely called multiple times for the same client error. + The failed_clients set uses idempotent add() operations, so duplicate calls for + the same client will not affect the remaining count calculation. + + Args: + ignore_result_error: The error handling mode. + - None: Dynamic mode - ignore if min_responses still reachable. + - False: Strict mode - always panic on error. + - True: Resilient mode - always ignore errors. + client_name: Name of the client with the error. + failed_clients: Set of client names that have already failed (will be updated + in dynamic mode only). + num_targets: Total number of target clients for the current task. + min_responses: Minimum number of responses required. + + Returns: + True if the error should be ignored (no panic needed). + False if a panic should be triggered. + """ + if ignore_result_error is True: + # Resilient mode - always ignore errors + return True + elif ignore_result_error is False: + # Strict mode - always panic on error + return False + else: + # Dynamic mode (None) - check if min_responses still reachable + failed_clients.add(client_name) + remaining_good_clients = num_targets - len(failed_clients) + return remaining_good_clients >= min_responses + + +def get_error_handling_message( + ignore_result_error: Optional[bool], + client_name: str, + error_code: Any, + current_round: Optional[int], + controller_name: str, + failed_clients: Set[str], + num_targets: int, + min_responses: int, +) -> str: + """Generate appropriate log message based on error handling mode. + + Args: + ignore_result_error: The error handling mode (None, False, or True). + client_name: Name of the client with the error. + error_code: The return code from the client result (ReturnCode constant or None). + current_round: Current training round (may be None if not set in result). + controller_name: Name of the controller class. + failed_clients: Set of client names that have failed. + num_targets: Total number of target clients. + min_responses: Minimum number of responses required. + + Returns: + Appropriate message string for logging. + """ + if ignore_result_error is True: + return f"Ignore the result from {client_name} at round {current_round}. " f"Result error code: {error_code}" + elif ignore_result_error is False: + return ( + f"Result from {client_name} is bad, error code: {error_code}. " + f"{controller_name} exiting at round {current_round}." + ) + else: + remaining_good_clients = num_targets - len(failed_clients) + if remaining_good_clients >= min_responses: + return ( + f"Ignore the result from {client_name} at round {current_round}. " + f"Result error code: {error_code}. " + f"Remaining good clients ({remaining_good_clients}) >= min_responses ({min_responses})." + ) + else: + return ( + f"Result from {client_name} is bad, error code: {error_code}. " + f"Cannot reach min_responses: remaining good clients ({remaining_good_clients}) < min_responses ({min_responses}). " + f"{controller_name} exiting at round {current_round}." + ) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index e137717d90..f8ed2e2b5e 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import random from abc import ABC, abstractmethod -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey @@ -29,6 +28,7 @@ from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper from nvflare.app_common.utils.fl_model_utils import FLModelUtils from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int, check_str @@ -39,7 +39,7 @@ class BaseModelController(Controller, FLComponentWrapper, ABC): def __init__( self, persistor_id: str = AppConstants.DEFAULT_PERSISTOR_ID, - ignore_result_error: bool = False, + ignore_result_error: Optional[bool] = None, allow_empty_global_weights: bool = False, task_check_period: float = 0.5, ): @@ -47,8 +47,10 @@ def __init__( Args: persistor_id (str, optional): ID of the persistor component. Defaults to AppConstants.DEFAULT_PERSISTOR_ID ("persistor"). - ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. - Defaults to False. + ignore_result_error (bool or None, optional): How to handle client result errors. + - None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise. + - False: Strict mode - panic on any client error. + - True: Resilient mode - always ignore client errors. allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False. @@ -73,6 +75,12 @@ def __init__( # model related self._results = [] + # Task context for dynamic ignore_result_error mode (when ignore_result_error=None). + # These are reset per send_model() call to track error tolerance for the current task. + self._current_min_responses = 0 # Minimum successful responses needed for this task + self._current_num_targets = 0 # Total number of clients targeted for this task + self._current_failed_clients = set() # Set of client names that returned errors in this task + def start_controller(self, fl_ctx: FLContext) -> None: self.fl_ctx = fl_ctx self.info("Initializing BaseModelController workflow.") @@ -139,6 +147,12 @@ def broadcast_model( if not blocking and not isinstance(callback, Callable): raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) + # Store task context for dynamic ignore_result_error mode + num_targets = len(targets) if targets else len(self.engine.get_clients()) + self._current_min_responses = min_responses if min_responses > 0 else num_targets + self._current_num_targets = num_targets + self._current_failed_clients = set() + self.set_fl_context(data) task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) @@ -231,9 +245,24 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result_model.meta["client_name"] = client_name self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) - self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) + accepted = self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) + # If result was rejected (error ignored or panic), skip further processing + if not accepted: + client_task.result = None + return + + # Now try to convert result to FLModel + try: + result_model = FLModelUtils.from_shareable(result) + result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA] + result_model.meta["client_name"] = client_name + except Exception as e: + self.warning(f"Failed to convert result from {client_name} to FLModel: {e}") + client_task.result = None + return + callback = client_task.task.get_prop(AppConstants.TASK_PROP_CALLBACK) if callback: try: @@ -245,38 +274,77 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: # Cleanup task result client_task.result = None - - gc.collect() + # Note: Memory cleanup (gc.collect + malloc_trim) is handled by subclasses + # via _maybe_cleanup_memory() based on memory_gc_rounds setting def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ) -> None: if task_name == AppConstants.TASK_TRAIN: - self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx) - self.info(f"Result of unknown task {task_name} sent to aggregator.") + accepted = self._accept_train_result( + client_name=client.name, result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + if accepted: + self.info(f"Result of unknown task {task_name} sent to aggregator.") else: self.error("Ignoring result from unknown task.") - def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLContext): + def _accept_train_result( + self, client_name: str, result: Shareable, fl_ctx: FLContext, is_unknown_task: bool = False + ) -> bool: + """Accept or reject a training result based on error handling policy. + + Args: + client_name: Name of the client that sent the result. + result: The Shareable result from the client. + fl_ctx: The FLContext. + is_unknown_task: Whether this result is from an unknown/late task. + + Returns: + True if the result was accepted, False if it was rejected (error ignored or panic triggered). + """ self.fl_ctx = fl_ctx rc = result.get_return_code() current_round = result.get_header(AppConstants.CURRENT_ROUND, None) + # For unknown/late tasks, always ignore errors (no valid tolerance context) + # For normal tasks, use the configured ignore_result_error setting + ignore_result_error = True if is_unknown_task else self._ignore_result_error + + # Use empty set for unknown tasks since we don't have valid tracking context + failed_clients = set() if is_unknown_task else self._current_failed_clients + num_targets = 0 if is_unknown_task else self._current_num_targets + min_responses = 0 if is_unknown_task else self._current_min_responses + # Raise panic if bad peer context or execution exception. if rc and rc != ReturnCode.OK: - if self._ignore_result_error: - self.warning( - f"Ignore the train result from {client_name} at round {current_round}. Train result error code: {rc}", - ) + should_ignore = should_ignore_result_error( + ignore_result_error=ignore_result_error, + client_name=client_name, + failed_clients=failed_clients, + num_targets=num_targets, + min_responses=min_responses, + ) + msg = get_error_handling_message( + ignore_result_error=ignore_result_error, + client_name=client_name, + error_code=rc, + current_round=current_round, + controller_name=self.__class__.__name__, + failed_clients=failed_clients, + num_targets=num_targets, + min_responses=min_responses, + ) + if should_ignore: + self.warning(msg) + return False # Result rejected - error ignored else: - self.panic( - f"Result from {client_name} is bad, error code: {rc}. " - f"{self.__class__.__name__} exiting at round {current_round}." - ) - return + self.panic(msg) + return False # Result rejected - panic triggered self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) + return True # Result accepted @abstractmethod def run(self): diff --git a/nvflare/app_common/workflows/scaffold.py b/nvflare/app_common/workflows/scaffold.py index a4a26c72f8..092a280cf2 100644 --- a/nvflare/app_common/workflows/scaffold.py +++ b/nvflare/app_common/workflows/scaffold.py @@ -38,8 +38,10 @@ class Scaffold(BaseFedAvg): num_clients (int, optional): The number of clients. Defaults to 3. num_rounds (int, optional): The total number of training rounds. Defaults to 5. persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". - ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. - Defaults to False. + ignore_result_error (bool or None, optional): How to handle client result errors. + - None: Dynamic mode (default) - ignore errors if min_responses still reachable, panic otherwise. + - False: Strict mode - panic on any client error. + - True: Resilient mode - always ignore client errors. allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False. diff --git a/nvflare/app_common/workflows/scatter_and_gather.py b/nvflare/app_common/workflows/scatter_and_gather.py index d0fdfabe8e..3da3281a97 100644 --- a/nvflare/app_common/workflows/scatter_and_gather.py +++ b/nvflare/app_common/workflows/scatter_and_gather.py @@ -25,6 +25,7 @@ from nvflare.app_common.abstract.shareable_generator import ShareableGenerator from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error from nvflare.fuel.utils.memory_utils import cleanup_memory from nvflare.fuel.utils.validation_utils import check_non_negative_int from nvflare.security.logging import secure_format_exception @@ -43,7 +44,7 @@ def __init__( shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, - ignore_result_error: bool = False, + ignore_result_error: bool = None, allow_empty_global_weights: bool = False, task_check_period: float = 0.5, persist_every_n_rounds: int = 1, @@ -71,8 +72,10 @@ def __init__( shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator". train_task_name (str, optional): Name of the train task. Defaults to "train". train_timeout (int, optional): Time to wait for clients to do local training. - ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. - Defaults to False. + ignore_result_error (bool or None, optional): How to handle client result errors. + - None: Dynamic mode (default) - ignore errors if min_clients still reachable, panic otherwise. + - False: Strict mode - panic on any client error. + - True: Resilient mode - always ignore client errors. allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False. @@ -143,6 +146,10 @@ def __init__( self._global_weights = make_model_learnable({}, {}) self._current_round = None + # Track failed clients for dynamic ignore_result_error mode + self._current_failed_clients = set() + self._current_num_targets = 0 + def _maybe_cleanup_memory(self): """Perform memory cleanup if configured (every N rounds based on memory_gc_rounds).""" if self._current_round is None: @@ -253,6 +260,10 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: result_received_cb=self._process_train_result, ) + # Reset tracking for dynamic ignore_result_error mode + self._current_failed_clients = set() + self._current_num_targets = len(self._engine.get_clients()) + self.broadcast_and_wait( task=train_task, min_responses=self._min_clients, @@ -351,30 +362,63 @@ def process_result_of_unknown_task( self, client: Client, task_name, client_task_id, result: Shareable, fl_ctx: FLContext ) -> None: if self._phase == AppConstants.PHASE_TRAIN and task_name == self.train_task_name: - self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx) - self.log_info(fl_ctx, f"Result of unknown task {task_name} sent to aggregator.") + accepted = self._accept_train_result( + client_name=client.name, result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + if accepted: + self.log_info(fl_ctx, f"Result of unknown task {task_name} sent to aggregator.") else: self.log_error(fl_ctx, "Ignoring result from unknown task.") - def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLContext) -> bool: + def _accept_train_result( + self, client_name: str, result: Shareable, fl_ctx: FLContext, is_unknown_task: bool = False + ) -> bool: + """Accept or reject a training result based on error handling policy. + Args: + client_name: Name of the client that sent the result. + result: The Shareable result from the client. + fl_ctx: The FLContext. + is_unknown_task: Whether this result is from an unknown/late task. + + Returns: + True if the result was accepted, False if it was rejected (error ignored or panic triggered). + """ rc = result.get_return_code() + # For unknown/late tasks, always ignore errors (no valid tolerance context) + # For normal tasks, use the configured ignore_result_error setting + ignore_result_error_mode = True if is_unknown_task else self.ignore_result_error + + # Use empty set for unknown tasks since we don't have valid tracking context + failed_clients = set() if is_unknown_task else self._current_failed_clients + num_targets = 0 if is_unknown_task else self._current_num_targets + min_responses = 0 if is_unknown_task else self._min_clients + # Raise errors if bad peer context or execution exception. if rc and rc != ReturnCode.OK: - if self.ignore_result_error: - self.log_warning( - fl_ctx, - f"Ignore the train result from {client_name} at round {self._current_round}. Train result error code: {rc}", - ) - return False + should_ignore = should_ignore_result_error( + ignore_result_error=ignore_result_error_mode, + client_name=client_name, + failed_clients=failed_clients, + num_targets=num_targets, + min_responses=min_responses, + ) + msg = get_error_handling_message( + ignore_result_error=ignore_result_error_mode, + client_name=client_name, + error_code=rc, + current_round=self._current_round, + controller_name=self.__class__.__name__, + failed_clients=failed_clients, + num_targets=num_targets, + min_responses=min_responses, + ) + if should_ignore: + self.log_warning(fl_ctx, msg) else: - self.system_panic( - f"Result from {client_name} is bad, error code: {rc}. " - f"{self.__class__.__name__} exiting at round {self._current_round}.", - fl_ctx=fl_ctx, - ) - return False + self.system_panic(msg, fl_ctx=fl_ctx) + return False fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True) fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) diff --git a/nvflare/app_common/workflows/scatter_and_gather_scaffold.py b/nvflare/app_common/workflows/scatter_and_gather_scaffold.py index 52ca07eb9c..3fa8225ccc 100644 --- a/nvflare/app_common/workflows/scatter_and_gather_scaffold.py +++ b/nvflare/app_common/workflows/scatter_and_gather_scaffold.py @@ -40,7 +40,7 @@ def __init__( shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, train_task_name=AppConstants.TASK_TRAIN, train_timeout: int = 0, - ignore_result_error: bool = False, + ignore_result_error: bool = None, task_check_period: float = 0.5, persist_every_n_rounds: int = 1, snapshot_every_n_rounds: int = 1, @@ -65,8 +65,10 @@ def __init__( shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator". train_task_name (str, optional): Name of the train task. Defaults to "train". train_timeout (int, optional): Time to wait for clients to do local training. - ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. - Defaults to False. + ignore_result_error (bool or None, optional): How to handle client result errors. + - None: Dynamic mode (default) - ignore errors if min_clients still reachable, panic otherwise. + - False: Strict mode - panic on any client error. + - True: Resilient mode - always ignore client errors. task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5. persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. If n is 0 then no persist. diff --git a/nvflare/recipe/poc_env.py b/nvflare/recipe/poc_env.py index ac6a67fadf..9c7d282ebb 100644 --- a/nvflare/recipe/poc_env.py +++ b/nvflare/recipe/poc_env.py @@ -14,11 +14,13 @@ import os import shutil +import threading import time from typing import Optional from pydantic import BaseModel, conint, model_validator +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.job_config.api import FedJob from nvflare.recipe.spec import ExecEnv from nvflare.recipe.utils import _collect_non_local_scripts @@ -104,6 +106,7 @@ def __init__( extra: extra env info. """ super().__init__(extra) + self.logger = get_obj_logger(self) v = _PocEnvValidator( num_clients=num_clients, @@ -124,15 +127,19 @@ def __init__( self.docker_image = v.docker_image self.username = v.username self._session_manager = None # Lazy initialization + self._session_manager_lock = threading.Lock() - def deploy(self, job: FedJob): + def deploy(self, job: FedJob) -> str: """Deploy a FedJob to the POC environment. Args: job (FedJob): The FedJob to deploy. Returns: - str: Job ID or deployment result. + str: Job ID. + + Raises: + ValueError: If scripts do not exist locally. """ # Validate scripts exist locally for POC non_local_scripts = _collect_non_local_scripts(job) @@ -143,9 +150,9 @@ def deploy(self, job: FedJob): ) if self._check_poc_running(): - self.stop(clean_poc=True) + self.stop(clean_up=True) - print("Preparing and starting fresh POC services...") + self.logger.info("Preparing and starting fresh POC services...") prepare_poc_provision( clients=self.clients or [], # Empty list if None, let prepare_clients generate number_of_clients=self.num_clients, @@ -162,7 +169,7 @@ def deploy(self, job: FedJob): excluded=[self.username], services_list=[], ) - print("POC services started successfully") + self.logger.info("POC services started successfully") # Give services time to start up time.sleep(SERVICE_START_TIMEOUT) @@ -171,9 +178,14 @@ def deploy(self, job: FedJob): return self._get_session_manager().submit_job(job) def _check_poc_running(self) -> bool: + """Check if POC services are currently running. + + Returns: + bool: True if POC is running, False otherwise. + """ try: project_config, service_config = setup_service_config(self.poc_workspace) - except Exception as e: + except Exception: # POC workspace is not initialized yet, so we don't need to stop and clean it return False @@ -182,16 +194,26 @@ def _check_poc_running(self) -> bool: return True - def stop(self, clean_poc: bool = False): + def stop(self, clean_up: bool = False) -> None: """Try to stop and clean existing POC. + This method is idempotent - safe to call multiple times. + Args: - clean_poc (bool, optional): Whether to clean the POC workspace. Defaults to False. + clean_up (bool, optional): Whether to clean the POC workspace. Defaults to False. """ - project_config, service_config = setup_service_config(self.poc_workspace) + # Check if already stopped (idempotent) + if not self._check_poc_running(): + # POC already stopped or workspace doesn't exist + if clean_up and os.path.exists(self.poc_workspace): + self.logger.info(f"Removing POC workspace: {self.poc_workspace}") + shutil.rmtree(self.poc_workspace, ignore_errors=True) + self._session_manager = None # Clear stale session manager + return try: - print("Stopping existing POC services...") + project_config, service_config = setup_service_config(self.poc_workspace) + self.logger.info("Stopping existing POC services...") _stop_poc( poc_workspace=self.poc_workspace, excluded=[self.username], # Exclude admin console (consistent with start) @@ -200,31 +222,60 @@ def stop(self, clean_poc: bool = False): count = 0 poc_running = True while count < STOP_POC_TIMEOUT: - if not is_poc_running(self.poc_workspace, service_config, project_config): + try: + if not is_poc_running(self.poc_workspace, service_config, project_config): + poc_running = False + break + except Exception: poc_running = False break time.sleep(1) count += 1 - if clean_poc: + if clean_up: if poc_running: - print( - f"Warning: POC still running after {STOP_POC_TIMEOUT} seconds, cannot clean workspace. Skipping cleanup." + self.logger.warning( + f"POC still running after {STOP_POC_TIMEOUT} seconds, cannot clean workspace. Skipping cleanup." ) else: - _clean_poc(self.poc_workspace) + try: + _clean_poc(self.poc_workspace) + except Exception as e: + self.logger.debug(f"Failed to clean POC: {e}") except Exception as e: - print(f"Warning: Failed to stop and clean existing POC: {e}") - print(f"Removing POC workspace: {self.poc_workspace}") - shutil.rmtree(self.poc_workspace, ignore_errors=True) + self.logger.warning(f"Failed to stop and clean existing POC: {e}") + finally: + self._session_manager = None # Clear stale session manager def get_job_status(self, job_id: str) -> Optional[str]: + """Get the status of a job. + + Args: + job_id: The job ID to check status for. + + Returns: + Optional[str]: The status of the job, or None if not available. + """ return self._get_session_manager().get_job_status(job_id) def abort_job(self, job_id: str) -> None: + """Abort a running job. + + Args: + job_id: The job ID to abort. + """ self._get_session_manager().abort_job(job_id) def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]: + """Get the result workspace of a job. + + Args: + job_id: The job ID to get results for. + timeout: The timeout for the job to complete. Defaults to 0.0 (no timeout). + + Returns: + Optional[str]: The result workspace path if job completed, None otherwise. + """ return self._get_session_manager().get_job_result(job_id, timeout) def _get_admin_startup_kit_path(self) -> str: @@ -248,15 +299,16 @@ def _get_admin_startup_kit_path(self) -> str: return admin_dir except Exception as e: - raise RuntimeError(f"Failed to locate admin startup kit: {e}") - - def _get_session_manager(self): - """Get or create SessionManager with lazy initialization.""" - if self._session_manager is None: - session_params = { - "username": self.username, - "startup_kit_location": self._get_admin_startup_kit_path(), - "timeout": self.get_extra_prop("login_timeout", 10), - } - self._session_manager = SessionManager(session_params) - return self._session_manager + raise RuntimeError(f"Failed to locate admin startup kit: {e}") from e + + def _get_session_manager(self) -> SessionManager: + """Get or create SessionManager with lazy initialization (thread-safe).""" + with self._session_manager_lock: + if self._session_manager is None: + session_params = { + "username": self.username, + "startup_kit_location": self._get_admin_startup_kit_path(), + "timeout": self.get_extra_prop("login_timeout", 10), + } + self._session_manager = SessionManager(session_params) + return self._session_manager diff --git a/nvflare/recipe/run.py b/nvflare/recipe/run.py index 4f32f76076..8374c55340 100644 --- a/nvflare/recipe/run.py +++ b/nvflare/recipe/run.py @@ -12,39 +12,116 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import Optional +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.recipe.spec import ExecEnv class Run: + """Represents a running or completed job execution. + + Provides methods to get job status, results, and abort the job. + Caches status and result after the execution environment is stopped. + + This class is thread-safe. All state-changing operations are protected by a lock. + """ + def __init__(self, exec_env: ExecEnv, job_id: str): + """Initialize a Run instance. + + Args: + exec_env: The execution environment managing this job. + job_id: The unique identifier for the job. + + Raises: + ValueError: If exec_env is None or job_id is empty. + """ + if exec_env is None: + raise ValueError("exec_env cannot be None") + if not job_id or not isinstance(job_id, str): + raise ValueError("job_id must be a non-empty string") + self.exec_env = exec_env self.job_id = job_id + self._lock = threading.Lock() + self._stopped = False + self._cached_status: Optional[str] = None + self._cached_result: Optional[str] = None + self.logger = get_obj_logger(self) def get_job_id(self) -> str: + """Get the job ID. + + Returns: + str: The job ID. + """ return self.job_id def get_status(self) -> Optional[str]: """Get the status of the run. Returns: - Optional[str]: The status of the run, or None if called in a simulation environment. + Optional[str]: The status of the run, or None if not available or on error. """ - return self.exec_env.get_job_status(self.job_id) + with self._lock: + if self._stopped: + return self._cached_status + try: + return self.exec_env.get_job_status(self.job_id) + except Exception as e: + self.logger.warning(f"Failed to get job status: {e}") + return None def get_result(self, timeout: float = 0.0) -> Optional[str]: """Get the result workspace of the run. + Waits for job to complete, caches status, then stops execution environment. + Args: - timeout (float, optional): The timeout for the job to complete. - Defaults to 0.0, means never timeout. + timeout (float, optional): Timeout for job completion. Defaults to 0.0 (no timeout). Returns: - Optional[str]: The result workspace path if job completed, None if still running or stopped early. + Optional[str]: Result workspace path, or None if job not finished or on error. """ - return self.exec_env.get_job_result(self.job_id, timeout=timeout) + with self._lock: + if self._stopped: + return self._cached_result + + result = None + try: + result = self.exec_env.get_job_result(self.job_id, timeout=timeout) + self._cached_result = result + except Exception as e: + self.logger.warning(f"Failed to get job result: {e}") + self._cached_result = None + + try: + self._cached_status = self.exec_env.get_job_status(self.job_id) + except Exception as e: + self.logger.warning(f"Failed to get job status: {e}") + self._cached_status = None - def abort(self): - """Abort the running job.""" - self.exec_env.abort_job(self.job_id) + try: + self.exec_env.stop(clean_up=True) + except Exception as e: + self.logger.warning(f"Failed to stop execution environment: {e}") + finally: + self._stopped = True + + return result + + def abort(self) -> None: + """Abort the running job. + + This is a no-op if the execution environment has already been stopped + (e.g., after get_result() was called). Errors are logged but not raised. + """ + with self._lock: + if self._stopped: + return + try: + self.exec_env.abort_job(self.job_id) + except Exception as e: + self.logger.warning(f"Failed to abort job: {e}") diff --git a/nvflare/recipe/spec.py b/nvflare/recipe/spec.py index 52b290a56f..af2aef8165 100644 --- a/nvflare/recipe/spec.py +++ b/nvflare/recipe/spec.py @@ -94,6 +94,19 @@ def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]: """ pass + def stop(self, clean_up: bool = False) -> None: + """Stop the execution environment and optionally clean up resources. + + This method is called after job execution to ensure proper cleanup. + Default implementation is a no-op. Override in subclasses that need cleanup. + + Args: + clean_up: If True, remove workspace and temporary files after stopping. + If False, only stop running processes but preserve workspace. + Defaults to False. + """ + pass + class Recipe(ABC): diff --git a/tests/unit_test/app_common/utils/error_handling_utils_test.py b/tests/unit_test/app_common/utils/error_handling_utils_test.py new file mode 100644 index 0000000000..ddf0f4fa18 --- /dev/null +++ b/tests/unit_test/app_common/utils/error_handling_utils_test.py @@ -0,0 +1,540 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_common.utils.error_handling_utils import get_error_handling_message, should_ignore_result_error + + +class TestShouldIgnoreResultError: + """Test should_ignore_result_error utility function.""" + + def test_true_mode_always_ignores(self): + """Test ignore_result_error=True always returns True (ignore error).""" + failed_clients = set() + result = should_ignore_result_error( + ignore_result_error=True, + client_name="site-1", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert result is True + # Client should not be added to failed_clients in True mode + assert "site-1" not in failed_clients + + def test_false_mode_always_panics(self): + """Test ignore_result_error=False always returns False (panic).""" + failed_clients = set() + result = should_ignore_result_error( + ignore_result_error=False, + client_name="site-1", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert result is False + # Client should not be added to failed_clients in False mode + assert "site-1" not in failed_clients + + def test_dynamic_mode_ignores_when_min_responses_reachable(self): + """Test ignore_result_error=None ignores error when min_responses still reachable.""" + failed_clients = set() + # 5 targets, 3 min_responses, 1 failure -> 4 remaining >= 3 -> ignore + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-1", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert result is True + assert "site-1" in failed_clients + + def test_dynamic_mode_panics_when_min_responses_not_reachable(self): + """Test ignore_result_error=None panics when min_responses not reachable.""" + failed_clients = {"site-1", "site-2"} # 2 already failed + # 5 targets, 3 min_responses, 3 failures (including new one) -> 2 remaining < 3 -> panic + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-3", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert result is False + assert "site-3" in failed_clients + + def test_dynamic_mode_exact_threshold(self): + """Test ignore_result_error=None at exact threshold boundary.""" + failed_clients = {"site-1"} # 1 already failed + # 5 targets, 3 min_responses, 2 failures -> 3 remaining == 3 -> ignore (just enough) + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-2", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert result is True + assert "site-2" in failed_clients + + def test_dynamic_mode_all_must_succeed(self): + """Test dynamic mode when min_responses equals num_targets (all must succeed).""" + failed_clients = set() + # 3 targets, 3 min_responses -> any failure means panic + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-1", + failed_clients=failed_clients, + num_targets=3, + min_responses=3, + ) + assert result is False + assert "site-1" in failed_clients + + def test_dynamic_mode_one_must_succeed(self): + """Test dynamic mode when min_responses is 1.""" + failed_clients = {"site-1", "site-2"} # 2 already failed + # 3 targets, 1 min_responses, 3 failures -> 0 remaining < 1 -> panic + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-3", + failed_clients=failed_clients, + num_targets=3, + min_responses=1, + ) + assert result is False + assert "site-3" in failed_clients + + def test_dynamic_mode_can_lose_all_but_one(self): + """Test dynamic mode allows losing all but min_responses clients.""" + failed_clients = {"site-1"} # 1 already failed + # 3 targets, 1 min_responses, 2 failures -> 1 remaining >= 1 -> ignore + result = should_ignore_result_error( + ignore_result_error=None, + client_name="site-2", + failed_clients=failed_clients, + num_targets=3, + min_responses=1, + ) + assert result is True + assert "site-2" in failed_clients + + +class TestGetErrorHandlingMessage: + """Test get_error_handling_message utility function.""" + + def test_true_mode_message(self): + """Test message for ignore_result_error=True.""" + failed_clients = set() + msg = get_error_handling_message( + ignore_result_error=True, + client_name="site-1", + error_code="EXECUTION_EXCEPTION", + current_round=5, + controller_name="FedAvg", + failed_clients=failed_clients, + num_targets=3, + min_responses=2, + ) + assert "Ignore the result from site-1" in msg + assert "round 5" in msg + assert "EXECUTION_EXCEPTION" in msg + + def test_false_mode_message(self): + """Test message for ignore_result_error=False.""" + failed_clients = set() + msg = get_error_handling_message( + ignore_result_error=False, + client_name="site-1", + error_code="TASK_ABORTED", + current_round=3, + controller_name="ScatterAndGather", + failed_clients=failed_clients, + num_targets=3, + min_responses=2, + ) + assert "Result from site-1 is bad" in msg + assert "TASK_ABORTED" in msg + assert "ScatterAndGather exiting" in msg + assert "round 3" in msg + + def test_dynamic_mode_ignore_message(self): + """Test message for ignore_result_error=None when ignoring.""" + failed_clients = {"site-1"} # 1 already failed, will add site-2 + failed_clients.add("site-2") # Simulate what should_ignore_result_error does + msg = get_error_handling_message( + ignore_result_error=None, + client_name="site-2", + error_code="EXECUTION_EXCEPTION", + current_round=2, + controller_name="FedAvg", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert "Ignore the result from site-2" in msg + assert "Remaining good clients (3) >= min_responses (3)" in msg + + def test_dynamic_mode_panic_message(self): + """Test message for ignore_result_error=None when panicking.""" + failed_clients = {"site-1", "site-2", "site-3"} # 3 already failed + msg = get_error_handling_message( + ignore_result_error=None, + client_name="site-3", + error_code="EXECUTION_EXCEPTION", + current_round=1, + controller_name="FedAvg", + failed_clients=failed_clients, + num_targets=5, + min_responses=3, + ) + assert "Result from site-3 is bad" in msg + assert "Cannot reach min_responses" in msg + assert "remaining good clients (2) < min_responses (3)" in msg + + +class TestIgnoreResultErrorIntegration: + """Integration tests for ignore_result_error behavior in controllers.""" + + def test_base_model_controller_default_is_none(self): + """Test BaseModelController defaults to ignore_result_error=None.""" + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg() + assert controller._ignore_result_error is None + + def test_base_model_controller_accepts_true(self): + """Test BaseModelController accepts ignore_result_error=True.""" + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg(ignore_result_error=True) + assert controller._ignore_result_error is True + + def test_base_model_controller_accepts_false(self): + """Test BaseModelController accepts ignore_result_error=False.""" + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg(ignore_result_error=False) + assert controller._ignore_result_error is False + + def test_scatter_and_gather_default_is_none(self): + """Test ScatterAndGather defaults to ignore_result_error=None.""" + from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + controller = ScatterAndGather() + assert controller.ignore_result_error is None + + def test_scatter_and_gather_accepts_all_modes(self): + """Test ScatterAndGather accepts all three modes.""" + from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + controller_none = ScatterAndGather(ignore_result_error=None) + assert controller_none.ignore_result_error is None + + controller_true = ScatterAndGather(ignore_result_error=True) + assert controller_true.ignore_result_error is True + + controller_false = ScatterAndGather(ignore_result_error=False) + assert controller_false.ignore_result_error is False + + +class TestAcceptTrainResultErrorHandling: + """Test _accept_train_result error handling behavior in BaseModelController.""" + + def _create_mock_result(self, return_code): + """Create a mock Shareable result with given return code.""" + from nvflare.apis.shareable import Shareable + + result = Shareable() + result.set_return_code(return_code) + return result + + def _create_mock_fl_ctx(self): + """Create a mock FLContext.""" + from unittest.mock import MagicMock + + fl_ctx = MagicMock() + fl_ctx.set_prop = MagicMock() + return fl_ctx + + def test_normal_task_with_ok_result_sets_context(self): + """Test that normal task with OK result sets the result in context and returns True.""" + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.app_constant import AppConstants + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg() + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._current_min_responses = 3 + + result = self._create_mock_result(ReturnCode.OK) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result(client_name="site-1", result=result, fl_ctx=fl_ctx) + + assert accepted is True + fl_ctx.set_prop.assert_called_once_with(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) + + def test_normal_task_with_error_and_ignore_does_not_set_context(self): + """Test that normal task with error when should_ignore=True returns False without setting context.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg(ignore_result_error=True) # Always ignore + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._current_min_responses = 3 + controller.warning = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result(client_name="site-1", result=result, fl_ctx=fl_ctx) + + assert accepted is False + # Should NOT set the errored result in context + fl_ctx.set_prop.assert_not_called() + # Should log a warning + controller.warning.assert_called_once() + + def test_normal_task_with_error_and_panic_does_not_set_context(self): + """Test that normal task with error when should_ignore=False returns False and panics.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg(ignore_result_error=False) # Always panic + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._current_min_responses = 3 + controller.panic = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result(client_name="site-1", result=result, fl_ctx=fl_ctx) + + assert accepted is False + # Should NOT set the errored result in context + fl_ctx.set_prop.assert_not_called() + # Should panic + controller.panic.assert_called_once() + + def test_unknown_task_always_ignores_errors(self): + """Test that unknown/late tasks always ignore errors and return False regardless of controller setting.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.fedavg import FedAvg + + # Even with ignore_result_error=False (strict mode), unknown tasks should ignore + controller = FedAvg(ignore_result_error=False) + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._current_min_responses = 3 + controller.warning = MagicMock() + controller.panic = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result( + client_name="site-1", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + assert accepted is False + # Should warn, NOT panic (because is_unknown_task=True forces ignore) + controller.warning.assert_called_once() + controller.panic.assert_not_called() + # Should NOT set the errored result in context + fl_ctx.set_prop.assert_not_called() + + def test_unknown_task_uses_empty_tracking_context(self): + """Test that unknown tasks use empty tracking context to avoid stale data.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg(ignore_result_error=None) # Dynamic mode + # Stale tracking data from previous task + controller._current_failed_clients = {"site-1", "site-2"} + controller._current_num_targets = 3 + controller._current_min_responses = 3 + controller.warning = MagicMock() + controller.panic = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + # With is_unknown_task=True, should use empty/zero context + # which means ignore_result_error=True (set in the method) + accepted = controller._accept_train_result( + client_name="site-3", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + assert accepted is False + # Should warn (because unknown task forces ignore_result_error=True) + controller.warning.assert_called_once() + controller.panic.assert_not_called() + + # Verify stale _current_failed_clients was NOT modified + assert "site-3" not in controller._current_failed_clients + + def test_unknown_task_with_ok_result_sets_context(self): + """Test that unknown task with OK result returns True and sets result in context.""" + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.app_constant import AppConstants + from nvflare.app_common.workflows.fedavg import FedAvg + + controller = FedAvg() + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._current_min_responses = 3 + + result = self._create_mock_result(ReturnCode.OK) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result( + client_name="site-1", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + assert accepted is True + # OK result should still be set in context + fl_ctx.set_prop.assert_called_once_with(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) + + +class TestScatterAndGatherUnknownTaskHandling: + """Test ScatterAndGather's handling of unknown/late task results.""" + + def _create_mock_result(self, return_code): + """Helper to create a mock Shareable result.""" + from unittest.mock import MagicMock + + result = MagicMock() + result.get_return_code.return_value = return_code + result.get_header.return_value = 1 # current_round + return result + + def _create_mock_fl_ctx(self): + """Helper to create a mock FLContext.""" + from unittest.mock import MagicMock + + fl_ctx = MagicMock() + fl_ctx.get_peer_context.return_value = None # Required for log_info to work properly + fl_ctx.get_identity_name.return_value = "test_identity" + fl_ctx.get_job_id.return_value = "test_job" + fl_ctx.get_prop.return_value = None + return fl_ctx + + def _create_mock_aggregator(self): + """Helper to create a mock aggregator.""" + from unittest.mock import MagicMock + + aggregator = MagicMock() + aggregator.accept.return_value = True + return aggregator + + def test_unknown_task_forces_ignore_error_mode(self): + """Test that unknown tasks force ignore_result_error=True regardless of controller setting.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + # Controller set to panic mode (False) + controller = ScatterAndGather(min_clients=3, ignore_result_error=False) + controller._current_failed_clients = set() + controller._current_num_targets = 5 + controller._min_clients = 3 + controller._current_round = 1 + controller.log_warning = MagicMock() + controller.system_panic = MagicMock() + controller.aggregator = self._create_mock_aggregator() + controller.fire_event = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + # With is_unknown_task=True, should warn (not panic) even though ignore_result_error=False + accepted = controller._accept_train_result( + client_name="site-1", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + assert accepted is False + # Should warn (because unknown task forces ignore_result_error=True) + controller.log_warning.assert_called_once() + controller.system_panic.assert_not_called() + + def test_unknown_task_uses_empty_tracking_context(self): + """Test that unknown tasks use empty tracking context to avoid stale data.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + controller = ScatterAndGather(min_clients=3, ignore_result_error=None) + # Stale tracking data from previous task + controller._current_failed_clients = {"site-1", "site-2"} + controller._current_num_targets = 3 + controller._min_clients = 3 + controller._current_round = 1 + controller.log_warning = MagicMock() + controller.system_panic = MagicMock() + controller.aggregator = self._create_mock_aggregator() + controller.fire_event = MagicMock() + + result = self._create_mock_result(ReturnCode.EXECUTION_EXCEPTION) + fl_ctx = self._create_mock_fl_ctx() + + # With is_unknown_task=True, should use empty/zero context + # which means ignore_result_error=True (set in the method) + accepted = controller._accept_train_result( + client_name="site-3", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + assert accepted is False + # Should warn (because unknown task forces ignore_result_error=True) + controller.log_warning.assert_called_once() + controller.system_panic.assert_not_called() + + def test_unknown_task_with_ok_result_accepted(self): + """Test that unknown tasks with OK result are accepted.""" + from unittest.mock import MagicMock + + from nvflare.apis.fl_constant import ReturnCode + from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + controller = ScatterAndGather(min_clients=3) + controller._current_failed_clients = {"site-1"} + controller._current_num_targets = 3 + controller._min_clients = 3 + controller._current_round = 1 + controller.aggregator = self._create_mock_aggregator() + controller.fire_event = MagicMock() + + result = self._create_mock_result(ReturnCode.OK) + fl_ctx = self._create_mock_fl_ctx() + + accepted = controller._accept_train_result( + client_name="site-2", result=result, fl_ctx=fl_ctx, is_unknown_task=True + ) + + # OK result should be accepted and sent to aggregator + assert accepted is True + controller.aggregator.accept.assert_called_once_with(result, fl_ctx) diff --git a/tests/unit_test/app_common/workflow/fedavg_test.py b/tests/unit_test/app_common/workflow/fedavg_test.py index 0cb3517f49..4529fd3cca 100644 --- a/tests/unit_test/app_common/workflow/fedavg_test.py +++ b/tests/unit_test/app_common/workflow/fedavg_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest.mock import MagicMock, patch from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, Task @@ -839,6 +839,9 @@ def test_process_result_sets_current_round_on_fl_ctx(self): def test_broadcast_model_does_not_fire_round_started(self): controller = FedAvg(num_clients=1) controller.fl_ctx = FLContext() + mock_engine = MagicMock() + mock_engine.get_clients.return_value = ["site-1"] + controller.engine = mock_engine model = FLModel(params={"w": 1.0}, current_round=2) with ( diff --git a/tests/unit_test/recipe/poc_env_test.py b/tests/unit_test/recipe/poc_env_test.py index b710a7e1e4..9221231efb 100644 --- a/tests/unit_test/recipe/poc_env_test.py +++ b/tests/unit_test/recipe/poc_env_test.py @@ -123,17 +123,18 @@ def test_get_admin_startup_kit_path_not_found(mock_setup, mock_get_prod_dir, moc @patch("nvflare.recipe.poc_env._stop_poc") @patch("nvflare.recipe.poc_env._clean_poc") @patch("nvflare.recipe.poc_env.is_poc_running") -@patch("nvflare.recipe.poc_env.shutil.rmtree") -def test_stop_poc(mock_rmtree, mock_is_running, mock_clean_poc, mock_stop_poc, mock_setup): +def test_stop_poc(mock_is_running, mock_clean_poc, mock_stop_poc, mock_setup): """Test stop and clean POC functionality.""" mock_setup.return_value = ({"name": "test"}, {"server": "server"}) - mock_is_running.return_value = False # POC stops successfully + # Mock is_poc_running to return True initially (POC is running), + # then False (POC stops successfully after _stop_poc is called) + mock_is_running.side_effect = [True, False] env = PocEnv() - env.stop(clean_poc=True) + env.stop(clean_up=True) mock_stop_poc.assert_called_once_with( poc_workspace=env.poc_workspace, excluded=["admin@nvidia.com"], services_list=[] ) + # _clean_poc handles workspace removal internally via shutil.rmtree mock_clean_poc.assert_called_once_with(env.poc_workspace) - mock_rmtree.assert_called_once_with(env.poc_workspace, ignore_errors=True) diff --git a/tests/unit_test/recipe/run_test.py b/tests/unit_test/recipe/run_test.py index c55e59cbc8..3c6669965e 100644 --- a/tests/unit_test/recipe/run_test.py +++ b/tests/unit_test/recipe/run_test.py @@ -15,11 +15,13 @@ import tempfile from unittest.mock import MagicMock, patch -from nvflare.recipe import PocEnv, ProdEnv, Run, SimEnv +import pytest + +from nvflare.recipe.run import Run class TestRunClass: - """Test the refactored Run class.""" + """Test the Run class.""" def setup_method(self): """Set up test fixtures.""" @@ -31,13 +33,18 @@ def test_initialization(self): """Test Run initialization.""" assert self.run.exec_env == self.mock_env assert self.run.job_id == self.job_id + assert self.run._stopped is False + assert self.run._cached_status is None + assert self.run._cached_result is None + assert self.run.logger is not None + assert self.run._lock is not None # Thread safety lock def test_get_job_id(self): """Test get_job_id method.""" assert self.run.get_job_id() == self.job_id def test_get_status_delegates_to_env(self): - """Test that get_status delegates to exec_env.""" + """Test that get_status delegates to exec_env when not stopped.""" self.mock_env.get_job_status.return_value = "RUNNING" result = self.run.get_status() @@ -54,56 +61,213 @@ def test_get_status_returns_none_for_sim(self): assert result is None self.mock_env.get_job_status.assert_called_once_with(self.job_id) - def test_get_result_delegates_to_env(self): - """Test that get_result delegates to exec_env.""" - self.mock_env.get_job_result.return_value = "/path/to/result" + def test_get_result_waits_caches_and_stops(self): + """Test that get_result waits for job, caches status, and stops env.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" result = self.run.get_result(timeout=30.0) - assert result == "/path/to/result" + assert result == "/tmp/workspace/test_job_123" self.mock_env.get_job_result.assert_called_once_with(self.job_id, timeout=30.0) + self.mock_env.get_job_status.assert_called_once_with(self.job_id) + self.mock_env.stop.assert_called_once_with(clean_up=True) + assert self.run._stopped is True + assert self.run._cached_status == "FINISHED" + assert self.run._cached_result == "/tmp/workspace/test_job_123" def test_get_result_default_timeout(self): """Test get_result with default timeout.""" - self.mock_env.get_job_result.return_value = "/path/to/result" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" result = self.run.get_result() - assert result == "/path/to/result" + assert result == "/tmp/workspace/test_job_123" self.mock_env.get_job_result.assert_called_once_with(self.job_id, timeout=0.0) + def test_get_status_returns_cached_after_stopped(self): + """Test that get_status returns cached value after get_result is called.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" + + # Call get_result first - this stops POC and caches status + self.run.get_result() + + # Reset mock to verify get_status doesn't call exec_env again + self.mock_env.get_job_status.reset_mock() + + # get_status should return cached value + status = self.run.get_status() + assert status == "FINISHED" + self.mock_env.get_job_status.assert_not_called() + + def test_get_result_returns_cached_after_stopped(self): + """Test that get_result returns cached value when called again.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" + + # First call + result1 = self.run.get_result() + + # Reset mocks + self.mock_env.get_job_result.reset_mock() + self.mock_env.stop.reset_mock() + + # Second call should return cached result + result2 = self.run.get_result() + assert result2 == "/tmp/workspace/test_job_123" + self.mock_env.get_job_result.assert_not_called() + self.mock_env.stop.assert_not_called() + + def test_get_result_stops_even_on_result_exception(self): + """Test that stop is called even if get_job_result raises exception.""" + self.mock_env.get_job_result.side_effect = RuntimeError("Connection failed") + self.mock_env.get_job_status.return_value = "FINISHED" + + result = self.run.get_result() + + # Exception is caught and logged, result is None + assert result is None + self.mock_env.stop.assert_called_once_with(clean_up=True) + assert self.run._stopped is True + assert self.run._cached_status == "FINISHED" + assert self.run._cached_result is None + + def test_get_result_sets_stopped_even_on_stop_exception(self): + """Test that _stopped is set to True even if stop() raises exception.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" + self.mock_env.stop.side_effect = RuntimeError("Stop failed") + + result = self.run.get_result() + + # Result is still returned, _stopped is True despite stop() failing + assert result == "/tmp/workspace/test_job_123" + assert self.run._stopped is True + assert self.run._cached_status == "FINISHED" + assert self.run._cached_result == "/tmp/workspace/test_job_123" + + def test_get_result_preserves_result_on_status_exception(self): + """Test that result is preserved even if get_job_status raises exception.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.side_effect = Exception("Status error") + + result = self.run.get_result() + + # Result is still returned and cached, only status cache is None + assert result == "/tmp/workspace/test_job_123" + assert self.run._cached_status is None + assert self.run._cached_result == "/tmp/workspace/test_job_123" + assert self.run._stopped is True + def test_abort_delegates_to_env(self): - """Test that abort delegates to exec_env.""" + """Test that abort delegates to exec_env when not stopped.""" + self.run.abort() + + self.mock_env.abort_job.assert_called_once_with(self.job_id) + + def test_abort_does_nothing_after_stopped(self): + """Test that abort does nothing after get_result has been called.""" + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" + + # Stop via get_result + self.run.get_result() + + # Reset mock + self.mock_env.abort_job.reset_mock() + + # Abort should not call exec_env.abort_job + self.run.abort() + self.mock_env.abort_job.assert_not_called() + + def test_get_status_before_get_result_does_not_stop(self): + """Test that get_status does not stop POC.""" + self.mock_env.get_job_status.return_value = "RUNNING" + + # Call get_status multiple times + self.run.get_status() + self.run.get_status() + self.run.get_status() + + # stop should never be called + self.mock_env.stop.assert_not_called() + assert self.run._stopped is False + + def test_init_with_none_exec_env_raises(self): + """Test that Run raises ValueError when exec_env is None.""" + with pytest.raises(ValueError, match="exec_env cannot be None"): + Run(exec_env=None, job_id="test_job") + + def test_init_with_empty_job_id_raises(self): + """Test that Run raises ValueError when job_id is empty.""" + with pytest.raises(ValueError, match="job_id must be a non-empty string"): + Run(exec_env=self.mock_env, job_id="") + + def test_init_with_none_job_id_raises(self): + """Test that Run raises ValueError when job_id is None.""" + with pytest.raises(ValueError, match="job_id must be a non-empty string"): + Run(exec_env=self.mock_env, job_id=None) + + def test_get_status_handles_exception(self): + """Test that get_status returns None and logs warning on exception.""" + self.mock_env.get_job_status.side_effect = RuntimeError("Connection failed") + + result = self.run.get_status() + + assert result is None + self.mock_env.get_job_status.assert_called_once_with(self.job_id) + + def test_abort_handles_exception(self): + """Test that abort logs warning on exception but doesn't raise.""" + self.mock_env.abort_job.side_effect = RuntimeError("Abort failed") + + # Should not raise self.run.abort() self.mock_env.abort_job.assert_called_once_with(self.job_id) - def test_run_with_different_environments(self): - """Test Run works with different environment types.""" - # Test with simulation environment (returns None for status) - sim_env = MagicMock() - sim_env.get_job_status.return_value = None - sim_env.get_job_result.return_value = "/sim/workspace/job_123" + def test_concurrent_get_result_calls(self): + """Test that concurrent get_result calls are handled safely.""" + import threading + + self.mock_env.get_job_result.return_value = "/tmp/workspace/test_job_123" + self.mock_env.get_job_status.return_value = "FINISHED" - sim_run = Run(exec_env=sim_env, job_id="sim_job") - assert sim_run.get_status() is None - assert sim_run.get_result() == "/sim/workspace/job_123" + results = [] + errors = [] - # Test with production environment - prod_env = MagicMock() - prod_env.get_job_status.return_value = "COMPLETED" - prod_env.get_job_result.return_value = "/prod/downloads/job_result" + def call_get_result(): + try: + result = self.run.get_result() + results.append(result) + except Exception as e: + errors.append(e) - prod_run = Run(exec_env=prod_env, job_id="prod_job") - assert prod_run.get_status() == "COMPLETED" - assert prod_run.get_result() == "/prod/downloads/job_result" + # Launch multiple threads calling get_result concurrently + threads = [threading.Thread(target=call_get_result) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + # All calls should succeed without errors + assert len(errors) == 0 + # All results should be the same (cached or fresh) + assert all(r == "/tmp/workspace/test_job_123" for r in results) + # stop should only be called once (first call stops, others return cached) + assert self.mock_env.stop.call_count == 1 + +@pytest.mark.skip(reason="Integration tests require full environment setup") class TestRunIntegration: """Integration tests for Run with actual environment classes.""" def test_run_with_sim_env(self): """Test Run with actual SimEnv.""" + from nvflare.recipe import SimEnv + sim_env = SimEnv(num_clients=2, workspace_root="/tmp/test_sim") run = Run(exec_env=sim_env, job_id="test_job") @@ -116,6 +280,7 @@ def test_run_with_sim_env(self): def test_run_with_poc_env(self): """Test Run with actual PocEnv (mocked dependencies).""" + from nvflare.recipe import PocEnv with patch("nvflare.recipe.poc_env.get_poc_workspace", return_value="/tmp/poc"): poc_env = PocEnv(num_clients=2) @@ -128,6 +293,8 @@ def test_run_with_poc_env(self): def test_run_with_prod_env(self): """Test Run with actual ProdEnv (mocked dependencies).""" + from nvflare.recipe import ProdEnv + with tempfile.TemporaryDirectory() as temp_dir: prod_env = ProdEnv(startup_kit_location=temp_dir) run = Run(exec_env=prod_env, job_id="prod_test_job")