From fffa5802544b8e977198a451c08691937c45d072 Mon Sep 17 00:00:00 2001 From: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Date: Tue, 24 Feb 2026 11:46:10 -0800 Subject: [PATCH] [2.7] Fix recipe API bug list and harden recipe behavior (#4228) ## Issues addressed This PR addresses the reported recipe bug list: - `BUG-8` (Critical): `CyclicRecipe` `model=None + initial_ckpt` created phantom persistor reference and runtime failure. - `BUG-1` (Critical): `CyclicRecipe` dict model config caused `TypeError`. - `BUG-7` (Critical): base `CyclicRecipe` silently ignored `initial_ckpt`. - `BUG-9` (High): `FedAvgRecipe` `per_site_config[\"server\"]` caused invalid target handling/crash. - `BUG-4` (High): `Recipe.execute/export` mutated job state and made recipe reuse unsafe. - `BUG-2` (Medium): `FedAvgRecipe` per-site `or` fallback ignored falsy overrides. - `BUG-3` (Medium): `SimEnv` with `num_clients=0` and client list resolved `num_threads=0`. - `BUG-5` (Low): base `FedAvgRecipe` accepted `initial_ckpt` + `model=None` but produced incomplete model source. - `BUG-6` (Low): `add_cross_site_evaluation` idempotency guard depended only on transient `_cse_added` attribute. ## Changes - Hardened base `CyclicRecipe` model/persistor setup: - fail-fast when a valid persistor cannot be configured; - support dict model config for supported frameworks; - apply/validate `initial_ckpt` for wrapper and framework-specific paths. - Hardened base `FedAvgRecipe`: - validate and reject reserved `per_site_config` targets (`server`, `@ALL`); - preserve falsy per-site override values via explicit `is not None` fallback; - fail-fast when no persistor and no model params are available. - Added shared framework persistor setup utility in recipe utils and wired both Cyclic/FedAvg paths to it. - Made `Recipe.execute()` and `Recipe.export()` reusable-safe by snapshot/restore of temporary execution params. - Fixed `SimEnv` default resolution for `num_clients`/`num_threads` when explicit client list is provided. - Hardened CSE idempotency with workflow-based detection in addition to `_cse_added` fast-path. - Updated impacted examples and exports: - `hello-cyclic` now passes `min_clients` explicitly; - `hello-numpy` uses transfer-type API; - recipe module exports include `add_cross_site_evaluation`. - Expanded recipe unit tests with targeted regressions for all above behaviors. ## Reason for these changes These updates enforce clear model-source contracts and remove silent/implicit fallback behavior that could generate invalid jobs or runtime crashes. The goal is deterministic recipe behavior, safe object reuse, and predictable per-site/CSE configuration semantics in 2.7. ## Affected files Core implementation: - `nvflare/recipe/cyclic.py` - `nvflare/recipe/fedavg.py` - `nvflare/recipe/spec.py` - `nvflare/recipe/sim_env.py` - `nvflare/recipe/utils.py` - `nvflare/recipe/poc_env.py` - `nvflare/recipe/__init__.py` Framework wrappers / examples: - `nvflare/app_opt/pt/recipes/cyclic.py` - `nvflare/app_opt/tf/recipes/cyclic.py` - `examples/hello-world/hello-cyclic/job.py` - `examples/hello-world/hello-numpy/job.py` - `examples/tutorials/job_recipe.ipynb` Tests: - `tests/unit_test/recipe/cyclic_recipe_test.py` - `tests/unit_test/recipe/fedavg_recipe_test.py` - `tests/unit_test/recipe/spec_test.py` - `tests/unit_test/recipe/sim_env_test.py` - `tests/unit_test/recipe/utils_test.py` ## Test strategy - Targeted regression suites: - `pytest -q tests/unit_test/recipe/cyclic_recipe_test.py tests/unit_test/recipe/fedavg_recipe_test.py tests/unit_test/recipe/spec_test.py` - `pytest -q tests/unit_test/recipe/sim_env_test.py tests/unit_test/recipe/utils_test.py` - Full recipe suite: - `pytest -q tests/unit_test/recipe` - Result: recipe suite passed (`183 passed, 13 skipped`). Made with [Cursor](https://cursor.com) --- examples/hello-world/hello-cyclic/job.py | 1 + examples/hello-world/hello-numpy/job.py | 4 +- examples/tutorials/job_recipe.ipynb | 62 +-------- nvflare/app_opt/pt/recipes/cyclic.py | 17 ++- nvflare/app_opt/pt/recipes/fedavg.py | 52 +++++--- nvflare/app_opt/tf/recipes/cyclic.py | 9 +- nvflare/app_opt/tf/recipes/fedavg.py | 26 ++-- nvflare/recipe/__init__.py | 4 +- nvflare/recipe/cyclic.py | 48 ++++--- nvflare/recipe/fedavg.py | 53 ++++++-- nvflare/recipe/poc_env.py | 6 +- nvflare/recipe/prod_env.py | 4 +- nvflare/recipe/sim_env.py | 9 +- nvflare/recipe/spec.py | 75 +++++++---- nvflare/recipe/utils.py | 50 ++++++- tests/unit_test/recipe/cyclic_recipe_test.py | 95 +++++++++++++- tests/unit_test/recipe/fedavg_recipe_test.py | 101 +++++++++++++- tests/unit_test/recipe/poc_env_test.py | 6 + tests/unit_test/recipe/sim_env_test.py | 6 + tests/unit_test/recipe/spec_test.py | 130 ++++++++++++++----- tests/unit_test/recipe/utils_test.py | 121 ++++++++++++++++- 21 files changed, 677 insertions(+), 202 deletions(-) diff --git a/examples/hello-world/hello-cyclic/job.py b/examples/hello-world/hello-cyclic/job.py index 9aa4807dde..a7019c2297 100644 --- a/examples/hello-world/hello-cyclic/job.py +++ b/examples/hello-world/hello-cyclic/job.py @@ -24,6 +24,7 @@ recipe = CyclicRecipe( num_rounds=num_rounds, + min_clients=n_clients, # Model can be specified as class instance or dict config: model=Net(), # Alternative: model={"class_path": "model.Net", "args": {}}, diff --git a/examples/hello-world/hello-numpy/job.py b/examples/hello-world/hello-numpy/job.py index 6ebf99a127..975455adb3 100644 --- a/examples/hello-world/hello-numpy/job.py +++ b/examples/hello-world/hello-numpy/job.py @@ -17,8 +17,8 @@ """ import argparse -from nvflare.apis.dxo import DataKind from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe +from nvflare.client.config import TransferType from nvflare.recipe import SimEnv, add_experiment_tracking @@ -57,7 +57,7 @@ def main(): train_script="client.py", train_args=train_args, launch_external_process=launch_process, - aggregator_data_kind=DataKind.WEIGHTS if args.update_type == "full" else DataKind.WEIGHT_DIFF, + params_transfer_type=TransferType.FULL if args.update_type == "full" else TransferType.DIFF, ) add_experiment_tracking(recipe, tracking_type="tensorboard") if args.export_config: diff --git a/examples/tutorials/job_recipe.ipynb b/examples/tutorials/job_recipe.ipynb index c44c50e92f..585ea13d11 100644 --- a/examples/tutorials/job_recipe.ipynb +++ b/examples/tutorials/job_recipe.ipynb @@ -141,29 +141,7 @@ "cell_type": "markdown", "id": "b94e258c", "metadata": {}, - "source": [ - "## Pre-trained Checkpoint Path\n", - "\n", - "Use `initial_ckpt` to specify a path to pre-trained model weights:\n", - "\n", - "```python\n", - "recipe = FedAvgRecipe(\n", - " model=SimpleNetwork(),\n", - " initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path\n", - " ...\n", - ")\n", - "```\n", - "\n", - "
\n", - "Checkpoint Path Requirements:\n", - "\n", - "
" - ] + "source": "## Pre-trained Checkpoint Path\n\nUse `initial_ckpt` to specify a path to pre-trained model weights:\n\n```python\nrecipe = FedAvgRecipe(\n model=SimpleNetwork(),\n initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path (server-side)\n ...\n)\n```\n\n
\nCheckpoint Path Options:\n\n
" }, { "cell_type": "code", @@ -256,25 +234,7 @@ "cell_type": "markdown", "id": "pocenv-section", "metadata": {}, - "source": [ - "### PocEnv – Proof-of-Concept Environment\n", - "\n", - "Runs server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n", - "\n", - "Best suited for:\n", - "* Demonstrations\n", - "* Small-scale validation before production deployment\n", - "* Debugging orchestration logic\n", - "\n", - "**Arguments:**\n", - "* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n", - "* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n", - "* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n", - "* `auto_stop` (bool, optional): Whether to automatically stop POC services after job completion.\n", - "* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n", - "* `docker_image` (str, optional): Docker image to use for POC.\n", - "* `project_conf_path` (str, optional): Path to the project configuration file." - ] + "source": "### PocEnv – Proof-of-Concept Environment\n\nRuns server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n\nBest suited for:\n* Demonstrations\n* Small-scale validation before production deployment\n* Debugging orchestration logic\n\n**Arguments:**\n* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n* `docker_image` (str, optional): Docker image to use for POC.\n* `project_conf_path` (str, optional): Path to the project configuration file." }, { "cell_type": "markdown", @@ -331,21 +291,7 @@ "cell_type": "markdown", "id": "prodenv-section", "metadata": {}, - "source": [ - "### ProdEnv – Production Environment\n", - "\n", - "We assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n", - "\n", - "Best suited for:\n", - "* Enterprise federated learning deployments\n", - "* Multi-institution collaborations\n", - "* Production-scale workloads\n", - "\n", - "**Arguments:**\n", - "* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n", - "* `login_timeout` (float): timeout value for the admin to login to the system\n", - "* `monitor_job_duration` (int): duration to monitor the job execution, None means no monitoring at all" - ] + "source": "### ProdEnv – Production Environment\n\nWe assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n\nBest suited for:\n* Enterprise federated learning deployments\n* Multi-institution collaborations\n* Production-scale workloads\n\n**Arguments:**\n* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n* `login_timeout` (float): timeout value for the admin to login to the system" }, { "cell_type": "markdown", @@ -522,4 +468,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/nvflare/app_opt/pt/recipes/cyclic.py b/nvflare/app_opt/pt/recipes/cyclic.py index 6145da312f..2f335743da 100644 --- a/nvflare/app_opt/pt/recipes/cyclic.py +++ b/nvflare/app_opt/pt/recipes/cyclic.py @@ -18,6 +18,7 @@ from nvflare.client.config import ExchangeFormat, TransferType from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe +from nvflare.recipe.utils import extract_persistor_id class CyclicRecipe(BaseCyclicRecipe): @@ -103,11 +104,17 @@ def _setup_model_and_persistor(self, job) -> str: # If model is already a PTModel wrapper (user passed PTModel directly), use as-is if hasattr(self.model, "add_to_fed_job"): result = job.to_server(self.model, id="persistor") - return result["persistor_id"] + return extract_persistor_id(result) - from nvflare.recipe.utils import prepare_initial_ckpt + from nvflare.recipe.utils import resolve_initial_ckpt - ckpt_path = prepare_initial_ckpt(self._pt_initial_ckpt, job) - pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path) + ckpt_path = resolve_initial_ckpt(self._pt_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) + if self.model is None and ckpt_path: + raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.") + if self.model is None: + return "" + + allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH + pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path, allow_numpy_conversion=allow_numpy_conversion) result = job.to_server(pt_model, id="persistor") - return result["persistor_id"] + return extract_persistor_id(result) diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py index dca486cab8..d81cecbd82 100644 --- a/nvflare/app_opt/pt/recipes/fedavg.py +++ b/nvflare/app_opt/pt/recipes/fedavg.py @@ -162,21 +162,37 @@ def __init__( def _setup_model_and_persistor(self, job) -> str: """Override to handle PyTorch-specific model setup.""" - if self.model is not None or self.initial_ckpt is not None: - from nvflare.app_opt.pt.job_config.model import PTModel - from nvflare.recipe.utils import prepare_initial_ckpt - - # Disable numpy conversion when using tensor format to keep PyTorch tensors - allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH - - ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) - pt_model = PTModel( - model=self.model, - initial_ckpt=ckpt_path, - persistor=self.model_persistor, - locator=self._pt_model_locator, - allow_numpy_conversion=allow_numpy_conversion, - ) - job.comp_ids.update(job.to_server(pt_model)) - return job.comp_ids.get("persistor_id", "") - return "" + from nvflare.app_opt.pt.job_config.model import PTModel + from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor + + persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor) + if persistor_id: + if hasattr(job, "comp_ids"): + job.comp_ids["persistor_id"] = persistor_id + if self._pt_model_locator is not None: + locator_id = job.to_server(self._pt_model_locator, id="locator") + if isinstance(locator_id, str) and locator_id: + job.comp_ids["locator_id"] = locator_id + return persistor_id + + ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) + if self.model is None and ckpt_path: + raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.") + if self.model is None: + return "" + + # Disable numpy conversion when using tensor format to keep PyTorch tensors. + allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH + pt_model = PTModel( + model=self.model, + initial_ckpt=ckpt_path, + locator=self._pt_model_locator, + allow_numpy_conversion=allow_numpy_conversion, + ) + result = job.to_server(pt_model, id="persistor") + if isinstance(result, dict) and hasattr(job, "comp_ids"): + job.comp_ids.update(result) + persistor_id = extract_persistor_id(result) + if persistor_id and hasattr(job, "comp_ids"): + job.comp_ids.setdefault("persistor_id", persistor_id) + return persistor_id diff --git a/nvflare/app_opt/tf/recipes/cyclic.py b/nvflare/app_opt/tf/recipes/cyclic.py index bc531295fa..49f8ce3505 100644 --- a/nvflare/app_opt/tf/recipes/cyclic.py +++ b/nvflare/app_opt/tf/recipes/cyclic.py @@ -18,6 +18,7 @@ from nvflare.client.config import ExchangeFormat, TransferType from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe +from nvflare.recipe.utils import extract_persistor_id class CyclicRecipe(BaseCyclicRecipe): @@ -103,11 +104,11 @@ def _setup_model_and_persistor(self, job) -> str: # If model is already a TFModel wrapper (user passed TFModel directly), use as-is if hasattr(self.model, "add_to_fed_job"): result = job.to_server(self.model, id="persistor") - return result["persistor_id"] + return extract_persistor_id(result) - from nvflare.recipe.utils import prepare_initial_ckpt + from nvflare.recipe.utils import resolve_initial_ckpt - ckpt_path = prepare_initial_ckpt(self._tf_initial_ckpt, job) + ckpt_path = resolve_initial_ckpt(self._tf_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path) result = job.to_server(tf_model, id="persistor") - return result["persistor_id"] + return extract_persistor_id(result) diff --git a/nvflare/app_opt/tf/recipes/fedavg.py b/nvflare/app_opt/tf/recipes/fedavg.py index 4b632d7f72..674236c204 100644 --- a/nvflare/app_opt/tf/recipes/fedavg.py +++ b/nvflare/app_opt/tf/recipes/fedavg.py @@ -145,16 +145,16 @@ def __init__( def _setup_model_and_persistor(self, job) -> str: """Override to handle TensorFlow-specific model setup.""" - if self.model is not None or self.initial_ckpt is not None: - from nvflare.app_opt.tf.job_config.model import TFModel - from nvflare.recipe.utils import prepare_initial_ckpt - - ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) - tf_model = TFModel( - model=self.model, - initial_ckpt=ckpt_path, - persistor=self.model_persistor, - ) - job.comp_ids["persistor_id"] = job.to_server(tf_model) - return job.comp_ids.get("persistor_id", "") - return "" + from nvflare.app_opt.tf.job_config.model import TFModel + from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor + + persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor) + if persistor_id: + return persistor_id + + ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) + if self.model is None and not ckpt_path: + return "" + + tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path) + return extract_persistor_id(job.to_server(tf_model, id="persistor")) diff --git a/nvflare/recipe/__init__.py b/nvflare/recipe/__init__.py index ecbe7ad1b1..d3590becdb 100644 --- a/nvflare/recipe/__init__.py +++ b/nvflare/recipe/__init__.py @@ -17,6 +17,6 @@ from .prod_env import ProdEnv from .run import Run from .sim_env import SimEnv -from .utils import add_experiment_tracking +from .utils import add_cross_site_evaluation, add_experiment_tracking -__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "FedAvgRecipe"] +__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "add_cross_site_evaluation", "FedAvgRecipe"] diff --git a/nvflare/recipe/cyclic.py b/nvflare/recipe/cyclic.py index e2406c38fb..e93ea97e0b 100644 --- a/nvflare/recipe/cyclic.py +++ b/nvflare/recipe/cyclic.py @@ -141,6 +141,7 @@ def __init__( if isinstance(self.model, dict): self.model = recipe_model_to_job_model(self.model) + self.min_clients = v.min_clients self.num_rounds = v.num_rounds self.train_script = v.train_script self.train_args = v.train_args @@ -162,13 +163,16 @@ def __init__( # Setup model persistor first - subclasses override for framework-specific handling persistor_id = self._setup_model_and_persistor(job) - # Use returned persistor_id or default to "persistor" if not persistor_id: - persistor_id = "persistor" + raise ValueError( + "Unable to configure a model persistor for CyclicRecipe. " + "Provide a supported model/framework combination (PyTorch or TensorFlow), " + "or pass a framework-specific model wrapper with add_to_fed_job()." + ) # Define the controller workflow and send to server controller = CyclicController( - num_rounds=num_rounds, + num_rounds=self.num_rounds, task_assignment_timeout=10, persistor_id=persistor_id, shareable_generator_id="shareable_generator", @@ -194,25 +198,39 @@ def __init__( super().__init__(job) def _setup_model_and_persistor(self, job) -> str: - """Setup framework-specific model components and persistor. + """Setup model wrapper persistor. - Handles PTModel/TFModel wrappers passed by framework-specific subclasses. + Handles the following model inputs: + - framework-specific wrappers (objects with ``add_to_fed_job``) Returns: str: The persistor_id to be used by the controller. """ - if self.model is None: - return "" + from nvflare.recipe.utils import extract_persistor_id, prepare_initial_ckpt - # Check if model is a model wrapper (PTModel, TFModel) + ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) + + # Model wrapper path (PTModel/TFModel or custom wrapper) if hasattr(self.model, "add_to_fed_job"): - # It's a model wrapper - use its add_to_fed_job method + if ckpt_path: + if not hasattr(self.model, "initial_ckpt"): + raise ValueError( + f"initial_ckpt is provided, but model wrapper {type(self.model).__name__} " + "does not support 'initial_ckpt'." + ) + existing_ckpt = getattr(self.model, "initial_ckpt", None) + if existing_ckpt and existing_ckpt != ckpt_path: + raise ValueError( + f"Conflicting checkpoint values: model wrapper has initial_ckpt={existing_ckpt}, " + f"but recipe initial_ckpt={ckpt_path}." + ) + setattr(self.model, "initial_ckpt", ckpt_path) + result = job.to_server(self.model, id="persistor") - return result["persistor_id"] + return extract_persistor_id(result) - # Unknown model type - raise TypeError( - f"Unsupported model type: {type(self.model).__name__}. " - f"Use a framework-specific recipe (PTCyclicRecipe, TFCyclicRecipe, etc.) " - f"or wrap your model in PTModel/TFModel." + raise ValueError( + f"Unsupported framework '{self.framework}' for base CyclicRecipe model persistence. " + "Use a framework-specific CyclicRecipe subclass, or pass a framework-specific " + "model wrapper with add_to_fed_job()." ) diff --git a/nvflare/recipe/fedavg.py b/nvflare/recipe/fedavg.py index fd59424413..ad1d372af6 100644 --- a/nvflare/recipe/fedavg.py +++ b/nvflare/recipe/fedavg.py @@ -17,6 +17,7 @@ from pydantic import BaseModel from nvflare.apis.dxo import DataKind +from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME from nvflare.app_common.abstract.aggregator import Aggregator from nvflare.app_common.abstract.model_persistor import ModelPersistor from nvflare.app_common.workflows.fedavg import FedAvg @@ -236,6 +237,7 @@ def __init__( self.params_transfer_type = v.params_transfer_type self.model_persistor = v.model_persistor self.per_site_config = v.per_site_config + self._validate_per_site_config(self.per_site_config) self.launch_once = v.launch_once self.shutdown_timeout = v.shutdown_timeout self.key_metric = v.key_metric @@ -274,6 +276,12 @@ def __init__( has_persistor = persistor_id != "" model_params = None if has_persistor else self._get_model_params() + if not has_persistor and model_params is None: + raise ValueError( + "Unable to configure a model source for FedAvgRecipe: no persistor and no model parameters. " + "Use a framework-specific recipe for checkpoint-only initialization, or provide model/model_persistor." + ) + # Prepare aggregator for controller - must be ModelAggregator for FLModel-based aggregation model_aggregator = self._get_model_aggregator() @@ -311,10 +319,18 @@ def __init__( if site_config.get("launch_external_process") is not None else self.launch_external_process ) - command = site_config.get("command") or self.command - framework = site_config.get("framework") or self.framework - expected_format = site_config.get("server_expected_format") or self.server_expected_format - transfer_type = site_config.get("params_transfer_type") or self.params_transfer_type + command = site_config.get("command") if site_config.get("command") is not None else self.command + framework = site_config.get("framework") if site_config.get("framework") is not None else self.framework + expected_format = ( + site_config.get("server_expected_format") + if site_config.get("server_expected_format") is not None + else self.server_expected_format + ) + transfer_type = ( + site_config.get("params_transfer_type") + if site_config.get("params_transfer_type") is not None + else self.params_transfer_type + ) launch_once = ( site_config.get("launch_once") if site_config.get("launch_once") is not None else self.launch_once ) @@ -352,6 +368,23 @@ def __init__( Recipe.__init__(self, job) + @staticmethod + def _validate_per_site_config(per_site_config: Optional[Dict[str, Dict]]) -> None: + if per_site_config is None: + return + + reserved_targets = {SERVER_SITE_NAME, ALL_SITES} + for site_name, site_config in per_site_config.items(): + if not isinstance(site_name, str): + raise ValueError(f"per_site_config key must be str, got {type(site_name).__name__}") + if site_name in reserved_targets: + raise ValueError( + f"'{site_name}' is a reserved target name and cannot be used in per_site_config. " + f"Reserved names: {sorted(reserved_targets)}" + ) + if not isinstance(site_config, dict): + raise ValueError(f"per_site_config['{site_name}'] must be a dict, got {type(site_config).__name__}") + def _get_model_params(self) -> Optional[Dict]: """Convert model to dict of params. @@ -406,14 +439,14 @@ def _get_model_aggregator(self): return None def _setup_model_and_persistor(self, job: BaseFedJob) -> str: - """Setup framework-specific model components and persistor. + """Setup generic custom persistor only. - Base implementation handles custom persistor. Framework-specific subclasses - should override this to use PTModel/TFModel for their model types. + Framework-specific recipes (PT/TF/NumPy) override this method to build and + register their model wrappers and default persistors. Returns: str: The persistor_id to be used by the controller. """ - if self.model_persistor is not None: - return job.to_server(self.model_persistor, id="persistor") - return "" + from nvflare.recipe.utils import setup_custom_persistor + + return setup_custom_persistor(job=job, model_persistor=self.model_persistor) diff --git a/nvflare/recipe/poc_env.py b/nvflare/recipe/poc_env.py index ac6a67fadf..5b7c12a42b 100644 --- a/nvflare/recipe/poc_env.py +++ b/nvflare/recipe/poc_env.py @@ -21,7 +21,7 @@ from nvflare.job_config.api import FedJob from nvflare.recipe.spec import ExecEnv -from nvflare.recipe.utils import _collect_non_local_scripts +from nvflare.recipe.utils import collect_non_local_scripts from nvflare.tool.poc.poc_commands import ( _clean_poc, _start_poc, @@ -64,7 +64,7 @@ def check_client_configuration(self): ) # Check if num_clients is valid when clients is None - if self.clients is None and self.num_clients <= 0: + if self.clients is None and (self.num_clients is None or self.num_clients <= 0): raise ValueError("num_clients must be greater than 0") return self @@ -135,7 +135,7 @@ def deploy(self, job: FedJob): str: Job ID or deployment result. """ # Validate scripts exist locally for POC - non_local_scripts = _collect_non_local_scripts(job) + non_local_scripts = collect_non_local_scripts(job) if non_local_scripts: raise ValueError( f"The following scripts do not exist locally: {non_local_scripts}. " diff --git a/nvflare/recipe/prod_env.py b/nvflare/recipe/prod_env.py index de7dc6f9f9..485604e719 100644 --- a/nvflare/recipe/prod_env.py +++ b/nvflare/recipe/prod_env.py @@ -20,7 +20,7 @@ from nvflare.job_config.api import FedJob from nvflare.recipe.spec import ExecEnv -from nvflare.recipe.utils import _collect_non_local_scripts +from nvflare.recipe.utils import collect_non_local_scripts from .session_mgr import SessionManager @@ -85,7 +85,7 @@ def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]: def deploy(self, job: FedJob): """Deploy a job using SessionManager.""" # Log warnings for non-local scripts (assumed pre-installed on production) - non_local_scripts = _collect_non_local_scripts(job) + non_local_scripts = collect_non_local_scripts(job) for script in non_local_scripts: logger.warning( f"Script '{script}' not found locally. " f"Assuming it is pre-installed on the production system." diff --git a/nvflare/recipe/sim_env.py b/nvflare/recipe/sim_env.py index 1fe5c18be7..71e82ea6b3 100644 --- a/nvflare/recipe/sim_env.py +++ b/nvflare/recipe/sim_env.py @@ -20,7 +20,7 @@ from nvflare.job_config.api import FedJob from .spec import ExecEnv -from .utils import _collect_non_local_scripts +from .utils import collect_non_local_scripts WORKSPACE_ROOT = "/tmp/nvflare/simulation" @@ -87,8 +87,9 @@ def __init__( workspace_root=workspace_root, ) - self.num_clients = v.num_clients - self.num_threads = v.num_threads if v.num_threads is not None else v.num_clients + resolved_num_clients = v.num_clients if v.num_clients > 0 else len(v.clients or []) + self.num_clients = resolved_num_clients + self.num_threads = v.num_threads if v.num_threads is not None else resolved_num_clients self.gpu_config = v.gpu_config self.log_config = v.log_config self.clients = v.clients @@ -96,7 +97,7 @@ def __init__( def deploy(self, job: FedJob): # Validate scripts exist locally for simulation - non_local_scripts = _collect_non_local_scripts(job) + non_local_scripts = collect_non_local_scripts(job) if non_local_scripts: raise ValueError( f"The following scripts do not exist locally: {non_local_scripts}. " diff --git a/nvflare/recipe/spec.py b/nvflare/recipe/spec.py index 52b290a56f..27f5e023ef 100644 --- a/nvflare/recipe/spec.py +++ b/nvflare/recipe/spec.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import Dict, List, Optional, Union from nvflare.apis.filter import Filter @@ -30,7 +31,7 @@ def __init__(self, extra: Optional[dict] = None): Args: extra: a dict of extra properties """ - if not extra: + if extra is None: extra = {} if not isinstance(extra, dict): raise ValueError(f"extra must be dict but got {type(extra)}") @@ -114,6 +115,47 @@ def process_env(self, env: ExecEnv): """ pass + def _snapshot_additional_params(self) -> Dict[str, Dict]: + snapshot = {} + deploy_map = getattr(self.job, "_deploy_map", {}) + for target, app in deploy_map.items(): + app_config = getattr(app, "app_config", None) + if app_config is None: + continue + params = getattr(app_config, "additional_params", None) + if isinstance(params, dict): + snapshot[target] = dict(params) + return snapshot + + def _restore_additional_params(self, snapshot: Dict[str, Dict]) -> None: + deploy_map = getattr(self.job, "_deploy_map", {}) + for target, app in deploy_map.items(): + app_config = getattr(app, "app_config", None) + if app_config is None: + continue + params = getattr(app_config, "additional_params", None) + if isinstance(params, dict): + original = snapshot.get(target, {}) + params.clear() + params.update(original) + + @contextmanager + def _temporary_exec_params(self, server_exec_params: dict = None, client_exec_params: dict = None): + params_snapshot = None + if server_exec_params or client_exec_params: + params_snapshot = self._snapshot_additional_params() + + try: + if server_exec_params: + self.job.to_server(server_exec_params) + + if client_exec_params: + self._add_to_client_apps(client_exec_params) + yield + finally: + if params_snapshot is not None: + self._restore_additional_params(params_snapshot) + def _add_to_client_apps(self, obj, clients: Optional[List[str]] = None, **kwargs): """Add an object to client apps, preserving existing per-site structure. @@ -328,16 +370,10 @@ def export( Returns: None """ - if server_exec_params: - self.job.to_server(server_exec_params) - - if client_exec_params: - self._add_to_client_apps(client_exec_params) - - if env: - self.process_env(env) - - self.job.export_job(job_dir) + with self._temporary_exec_params(server_exec_params=server_exec_params, client_exec_params=client_exec_params): + if env: + self.process_env(env) + self.job.export_job(job_dir) def execute( self, env: ExecEnv, server_exec_params: Optional[dict] = None, client_exec_params: Optional[dict] = None @@ -352,15 +388,10 @@ def execute( Returns: Run to get job ID and execution results """ - if server_exec_params: - self.job.to_server(server_exec_params) - - if client_exec_params: - self._add_to_client_apps(client_exec_params) - - self.process_env(env) - job_id = env.deploy(self.job) - from nvflare.recipe.run import Run + with self._temporary_exec_params(server_exec_params=server_exec_params, client_exec_params=client_exec_params): + self.process_env(env) + job_id = env.deploy(self.job) + from nvflare.recipe.run import Run - run = Run(env, job_id) - return run + run = Run(env, job_id) + return run diff --git a/nvflare/recipe/utils.py b/nvflare/recipe/utils.py index 5bdbccf59f..ab041c833e 100644 --- a/nvflare/recipe/utils.py +++ b/nvflare/recipe/utils.py @@ -59,6 +59,27 @@ } +def _has_cross_site_eval_workflow(job: FedJob) -> bool: + """Check if CrossSiteModelEval workflow is already configured on server.""" + from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval + + deploy_map = getattr(job, "_deploy_map", {}) + server_app = deploy_map.get("server") + if not server_app or not hasattr(server_app, "app_config"): + return False + + workflows = getattr(server_app.app_config, "workflows", []) + for w in workflows: + # Server stores workflow definitions as wrapper objects (e.g. WorkFlow) + # with the actual controller on `controller`. + if isinstance(w, CrossSiteModelEval): + return True + controller = getattr(w, "controller", None) + if controller is not None and isinstance(controller, CrossSiteModelEval): + return True + return False + + def add_experiment_tracking( recipe: Recipe, tracking_type: str, @@ -250,8 +271,10 @@ def add_cross_site_evaluation( from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval from nvflare.job_config.script_runner import FrameworkType - # Idempotency check: prevent multiple calls on the same recipe - if hasattr(recipe, "_cse_added") and recipe._cse_added: + # Idempotency check: prevent multiple calls on the same recipe. + # Keep the explicit flag fast-path, but also verify server workflow state so + # protection remains effective even if dynamic attributes are lost. + if getattr(recipe, "_cse_added", False) or _has_cross_site_eval_workflow(recipe.job): name = recipe.name if hasattr(recipe, "name") else "cross-site-evaluation job" raise RuntimeError( f"Cross-site evaluation has already been added to recipe '{name}'. " @@ -399,7 +422,7 @@ def _has_task_executor(job, task_name: str) -> bool: return False -def _collect_non_local_scripts(job: FedJob) -> List[str]: +def collect_non_local_scripts(job: FedJob) -> List[str]: """Collect scripts that don't exist locally. This utility function is used by ExecEnv subclasses to validate script resources @@ -498,6 +521,27 @@ def prepare_initial_ckpt(initial_ckpt: Optional[str], job) -> Optional[str]: return os.path.basename(initial_ckpt) +def extract_persistor_id(result: Any) -> str: + if isinstance(result, dict): + persistor_id = result.get("persistor_id", "") + return persistor_id if isinstance(persistor_id, str) else "" + if isinstance(result, str): + return result + return "" + + +def resolve_initial_ckpt(initial_ckpt: Optional[str], prepared_initial_ckpt: Optional[str], job) -> Optional[str]: + if prepared_initial_ckpt is not None: + return prepared_initial_ckpt + return prepare_initial_ckpt(initial_ckpt, job) + + +def setup_custom_persistor(*, job, model_persistor=None) -> str: + if model_persistor is None: + return "" + return extract_persistor_id(job.to_server(model_persistor, id="persistor")) + + def validate_dict_model_config(model: Any) -> None: """Validate recipe dict model config structure. diff --git a/tests/unit_test/recipe/cyclic_recipe_test.py b/tests/unit_test/recipe/cyclic_recipe_test.py index 8374d1c69e..894355737a 100644 --- a/tests/unit_test/recipe/cyclic_recipe_test.py +++ b/tests/unit_test/recipe/cyclic_recipe_test.py @@ -19,6 +19,8 @@ import pytest import torch.nn as nn +from nvflare.app_opt.pt.job_config.model import PTModel +from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe @@ -58,15 +60,10 @@ def base_recipe_params(): class TestBaseCyclicRecipe: - """Test cases for base CyclicRecipe class. - - Note: Base CyclicRecipe doesn't directly support nn.Module or dict config. - Use framework-specific recipes (PTCyclicRecipe, TFCyclicRecipe) for those. - """ + """Test cases for base CyclicRecipe class.""" def test_initial_ckpt_must_exist_for_relative_path(self): """Test that non-existent relative paths are rejected (no mock - validation must run).""" - # Don't use mock_file_system fixture - we need real validation with pytest.raises(ValueError, match="does not exist locally"): BaseCyclicRecipe( name="test_relative", @@ -87,6 +84,76 @@ def test_requires_model_or_checkpoint(self, base_recipe_params): **base_recipe_params, ) + def test_rejects_non_wrapper_model_for_base_recipe(self, mock_file_system, base_recipe_params): + """Base CyclicRecipe no longer owns PT/TF model persistence for raw model inputs.""" + with pytest.raises(ValueError, match="Use a framework-specific CyclicRecipe subclass"): + BaseCyclicRecipe( + name="test_base_pt_dict", + model={"class_path": "torch.nn.Linear", "args": {"in_features": 10, "out_features": 2}}, + framework=FrameworkType.PYTORCH, + **base_recipe_params, + ) + + def test_rejects_pytorch_checkpoint_without_model(self, mock_file_system, base_recipe_params): + """Base recipe requires framework-specific subclass for PT checkpoint-only setup.""" + with pytest.raises(ValueError, match="Use a framework-specific CyclicRecipe subclass"): + BaseCyclicRecipe( + name="test_base_pt_ckpt_no_model", + model=None, + initial_ckpt="/abs/path/to/model.pt", + framework=FrameworkType.PYTORCH, + **base_recipe_params, + ) + + def test_rejects_ckpt_only_for_default_framework(self, mock_file_system, base_recipe_params): + """Fail fast for ckpt-only usage when no supported framework/wrapper is provided.""" + with pytest.raises(ValueError, match="Unsupported framework"): + BaseCyclicRecipe( + name="test_ckpt_only_default_framework", + model=None, + initial_ckpt="/abs/path/to/model.pt", + **base_recipe_params, + ) + + def test_applies_initial_ckpt_to_wrapper_model(self, mock_file_system, base_recipe_params, simple_model): + """When wrapper model is used, recipe-level initial_ckpt should be applied to persistor.""" + recipe = BaseCyclicRecipe( + name="test_wrapper_ckpt", + model=PTModel(model=simple_model), + initial_ckpt="/abs/path/to/model.pt", + framework=FrameworkType.PYTORCH, + **base_recipe_params, + ) + + server_app = recipe.job._deploy_map.get("server") + persistor = server_app.app_config.components.get("persistor") + assert persistor is not None + assert getattr(persistor, "source_ckpt_file_full_name", None) == "/abs/path/to/model.pt" + + def test_rejects_unsupported_framework_without_wrapper(self, mock_file_system, base_recipe_params): + """Fail fast for unsupported base framework/model persistence combinations.""" + with pytest.raises(ValueError, match="Unsupported framework"): + BaseCyclicRecipe( + name="test_unsupported_framework", + model={"class_path": "torch.nn.Linear", "args": {"in_features": 10, "out_features": 2}}, + framework=FrameworkType.NUMPY, + **base_recipe_params, + ) + + +class TestBaseCyclicRecipeAttributes: + """Test that CyclicRecipe stores validated attributes correctly.""" + + def test_min_clients_attribute(self, mock_file_system, base_recipe_params, simple_model): + """min_clients must be accessible as an instance attribute after construction.""" + recipe = BaseCyclicRecipe( + name="test_min_clients", + model=PTModel(model=simple_model), + framework=FrameworkType.PYTORCH, + **base_recipe_params, + ) + assert recipe.min_clients == base_recipe_params["min_clients"] + class TestPTCyclicRecipe: """Test cases for PyTorch CyclicRecipe.""" @@ -105,6 +172,22 @@ def test_pt_cyclic_initial_ckpt(self, mock_file_system, base_recipe_params, simp assert recipe.name == "test_pt_cyclic" assert recipe.job is not None + def test_pt_cyclic_with_ptmodel_wrapper_returns_persistor_id( + self, mock_file_system, base_recipe_params, simple_model + ): + """PTModel wrapper path must correctly extract persistor_id from dict return.""" + from nvflare.app_opt.pt.recipes.cyclic import CyclicRecipe as PTCyclicRecipe + + recipe = PTCyclicRecipe( + name="test_pt_wrapper", + model=PTModel(model=simple_model), + **base_recipe_params, + ) + + server_app = recipe.job._deploy_map.get("server") + assert server_app is not None + assert "persistor" in server_app.app_config.components + class TestTFCyclicRecipe: """Test cases for TensorFlow CyclicRecipe.""" diff --git a/tests/unit_test/recipe/fedavg_recipe_test.py b/tests/unit_test/recipe/fedavg_recipe_test.py index 6900d189a2..930bc97ab5 100644 --- a/tests/unit_test/recipe/fedavg_recipe_test.py +++ b/tests/unit_test/recipe/fedavg_recipe_test.py @@ -17,8 +17,10 @@ import pytest import torch.nn as nn -from nvflare.apis.job_def import SERVER_SITE_NAME +from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.abstract.model_locator import ModelLocator +from nvflare.app_common.abstract.model_persistor import ModelPersistor from nvflare.app_common.aggregators.model_aggregator import ModelAggregator from nvflare.app_common.np.recipes import NumpyFedAvgRecipe from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector @@ -79,6 +81,26 @@ def __init__(self): pass +class DummyPersistor(ModelPersistor): + """Minimal ModelPersistor used to test custom persistor wiring.""" + + def load_model(self, fl_ctx): + return {} + + def save_model(self, model, fl_ctx): + return None + + +class DummyLocator(ModelLocator): + """Minimal ModelLocator used to test explicit locator registration.""" + + def get_model_names(self, fl_ctx): + return [] + + def locate_model(self, model_name, fl_ctx): + return None + + @pytest.fixture def mock_file_system(): """Mock file system operations for all tests.""" @@ -414,6 +436,81 @@ def test_invalid_aggregator_type_raises_validation_error(self, mock_file_system, **base_recipe_params, ) + def test_per_site_config_rejects_reserved_server_target(self, mock_file_system, base_recipe_params, simple_model): + """Reserved target 'server' must not be allowed in per_site_config.""" + with pytest.raises(ValueError, match="reserved target name"): + FedAvgRecipe( + name="test_reserved_server_target", + model=simple_model, + per_site_config={"server": {}}, + **base_recipe_params, + ) + + def test_per_site_config_rejects_reserved_all_sites_target( + self, mock_file_system, base_recipe_params, simple_model + ): + """Reserved target '@ALL' must not be allowed in per_site_config.""" + with pytest.raises(ValueError, match="reserved target name"): + FedAvgRecipe( + name="test_reserved_all_sites_target", + model=simple_model, + per_site_config={ALL_SITES: {}}, + **base_recipe_params, + ) + + def test_per_site_empty_command_override_is_preserved(self, mock_file_system, base_recipe_params, simple_model): + """Falsy per-site override values (e.g. command='') must not be replaced by defaults.""" + recipe = FedAvgRecipe( + name="test_empty_command_override", + model=simple_model, + launch_external_process=True, + per_site_config={"site-1": {"command": ""}}, + **base_recipe_params, + ) + + site_app = recipe.job._deploy_map.get("site-1") + assert site_app is not None + launcher = site_app.app_config.components.get("launcher") + assert launcher is not None + assert "python3 -u" not in launcher._script + assert launcher._script.startswith(" custom/") + + def test_custom_model_persistor_tracks_persistor_id(self, mock_file_system, base_recipe_params, simple_model): + """Custom PT persistor path should persist comp_ids['persistor_id'] for later workflows.""" + recipe = FedAvgRecipe( + name="test_custom_persistor_comp_id", + model=simple_model, + model_persistor=DummyPersistor(), + **base_recipe_params, + ) + + persistor_id = recipe.job.comp_ids.get("persistor_id", "") + assert persistor_id + assert "locator_id" not in recipe.job.comp_ids + server_app = recipe.job._deploy_map.get(SERVER_SITE_NAME) + assert server_app is not None + assert persistor_id in server_app.app_config.components + + def test_custom_model_persistor_with_locator_registers_locator( + self, mock_file_system, base_recipe_params, simple_model + ): + """If custom model_locator is provided, it should be registered even on custom persistor path.""" + locator = DummyLocator() + recipe = FedAvgRecipe( + name="test_custom_persistor_with_locator", + model=simple_model, + model_persistor=DummyPersistor(), + model_locator=locator, + **base_recipe_params, + ) + + assert recipe.job.comp_ids.get("persistor_id", "") + locator_id = recipe.job.comp_ids.get("locator_id", "") + assert locator_id + server_app = recipe.job._deploy_map.get(SERVER_SITE_NAME) + assert server_app is not None + assert server_app.app_config.components.get(locator_id) is locator + def test_dict_config_missing_path_raises_error(self, mock_file_system, base_recipe_params): """Test that dict config without 'class_path' key raises error.""" with pytest.raises(ValueError, match="must have 'class_path' key"): @@ -452,7 +549,7 @@ def test_initial_ckpt_with_none_model_not_allowed_for_pt(self, mock_file_system, """Test that PT FedAvg rejects initial_ckpt with None model (PT needs architecture).""" # PyTorch requires model architecture even when loading from checkpoint # TensorFlow can load full models, but PT cannot - with pytest.raises(ValueError, match="Unable to add None to job"): + with pytest.raises(ValueError, match="FrameworkType.PYTORCH requires 'model' when using initial_ckpt"): FedAvgRecipe( name="test_ckpt_no_model", model=None, diff --git a/tests/unit_test/recipe/poc_env_test.py b/tests/unit_test/recipe/poc_env_test.py index b710a7e1e4..ee0656cb8a 100644 --- a/tests/unit_test/recipe/poc_env_test.py +++ b/tests/unit_test/recipe/poc_env_test.py @@ -62,6 +62,12 @@ def test_poc_env_validation(): PocEnv(num_clients=3, clients=["site1", "site2"]) +def test_poc_env_none_num_clients_raises(): + """Test that PocEnv(num_clients=None) raises ValueError instead of crashing with TypeError.""" + with pytest.raises(ValueError, match="num_clients must be greater than 0"): + PocEnv(num_clients=None, clients=None) + + def test_poc_env_client_names(): """Test PocEnv client name generation and validation.""" # Test auto-generated client names (delegated to prepare_poc_provision) diff --git a/tests/unit_test/recipe/sim_env_test.py b/tests/unit_test/recipe/sim_env_test.py index d1758b3819..276a0cfcb0 100644 --- a/tests/unit_test/recipe/sim_env_test.py +++ b/tests/unit_test/recipe/sim_env_test.py @@ -34,3 +34,9 @@ def test_sim_env_validation(): # Test with empty clients list and zero num_clients (invalid) with pytest.raises(ValueError, match="Either 'num_clients' must be > 0 or 'clients' list must be provided"): SimEnv(num_clients=0, clients=[]) + + # BUG-3 regression: when clients list is provided and num_clients=0, + # env should derive client/thread counts from the list. + env = SimEnv(num_clients=0, clients=["client1", "client2", "client3"]) + assert env.num_clients == 3 + assert env.num_threads == 3 diff --git a/tests/unit_test/recipe/spec_test.py b/tests/unit_test/recipe/spec_test.py index 4904b9d246..13dbb4c15b 100644 --- a/tests/unit_test/recipe/spec_test.py +++ b/tests/unit_test/recipe/spec_test.py @@ -23,11 +23,11 @@ from nvflare.job_config.api import FedApp, FedJob from nvflare.job_config.fed_app_config import ClientAppConfig -from nvflare.recipe.utils import _collect_non_local_scripts +from nvflare.recipe.utils import collect_non_local_scripts class TestCollectNonLocalScriptsUtility: - """Test the _collect_non_local_scripts utility function.""" + """Test the collect_non_local_scripts utility function.""" def setup_method(self): self.job = FedJob(name="test_job", min_clients=1) @@ -36,7 +36,7 @@ def setup_method(self): def test_no_scripts_returns_empty_list(self): """Test with no scripts added.""" - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) assert result == [] def test_local_script_not_included(self): @@ -47,7 +47,7 @@ def test_local_script_not_included(self): try: self.client_app.add_external_script(temp_path) - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) assert result == [] finally: os.unlink(temp_path) @@ -57,7 +57,7 @@ def test_non_local_absolute_path_included(self): non_local_script = "/preinstalled/remote_script.py" self.client_app.add_external_script(non_local_script) - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) assert non_local_script in result def test_multiple_non_local_scripts(self): @@ -70,7 +70,7 @@ def test_multiple_non_local_scripts(self): for script in scripts: self.client_app.add_external_script(script) - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) for script in scripts: assert script in result @@ -86,7 +86,7 @@ def test_mixed_local_and_non_local_scripts(self): non_local_script = "/preinstalled/remote_script.py" self.client_app.add_external_script(non_local_script) - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) assert non_local_script in result assert local_script not in result finally: @@ -102,7 +102,7 @@ def test_multiple_apps(self): self.client_app.add_external_script(script1) client_app2.add_external_script(script2) - result = _collect_non_local_scripts(self.job) + result = collect_non_local_scripts(self.job) assert script1 in result assert script2 in result @@ -171,7 +171,6 @@ def temp_script(self): def test_add_server_config(self, temp_script): """Test add_server_config adds params to server app.""" - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -179,8 +178,7 @@ def test_add_server_config(self, temp_script): num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, # NUMPY can load from ckpt without model + model={"class_path": "model.DummyModel", "args": {}}, ) config = {"np_download_chunk_size": 2097152} @@ -193,7 +191,6 @@ def test_add_server_config(self, temp_script): def test_add_client_config(self, temp_script): """Test add_client_config applies to all clients and specific clients.""" from nvflare.apis.job_def import ALL_SITES - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe # Test all clients @@ -202,8 +199,7 @@ def test_add_client_config(self, temp_script): num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, ) config = {"timeout": 600} recipe.add_client_config(config) @@ -215,7 +211,6 @@ def test_add_client_config(self, temp_script): def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script): """Test add_client_file stores file paths in ext_scripts and dirs in ext_dirs.""" from nvflare.apis.job_def import ALL_SITES - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -223,8 +218,7 @@ def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script): num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, ) with tempfile.TemporaryDirectory() as temp_dir: @@ -239,7 +233,6 @@ def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script): def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp_script): """Test add_client_file keeps per-site topology and does not create ALL_SITES app.""" from nvflare.apis.job_def import ALL_SITES - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -247,8 +240,7 @@ def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, per_site_config={"site-1": {}, "site-2": {}}, ) @@ -264,7 +256,6 @@ def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp def test_add_client_file_with_specific_clients_only_updates_selected_sites(self, temp_script): """Test add_client_file(..., clients=[...]) only adds file to specified sites.""" - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -272,8 +263,7 @@ def test_add_client_file_with_specific_clients_only_updates_selected_sites(self, num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, per_site_config={"site-1": {}, "site-2": {}, "site-3": {}}, ) @@ -298,7 +288,6 @@ def test_add_client_file_with_specific_clients_only_updates_selected_sites(self, def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_script): """Test add_server_file stores file paths in ext_scripts and dirs in ext_dirs.""" - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -306,8 +295,7 @@ def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_scri num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, ) with tempfile.TemporaryDirectory() as temp_dir: @@ -321,7 +309,6 @@ def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_scri def test_config_in_generated_json(self, temp_script): """Test that configs appear in generated JSON files.""" - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -329,8 +316,7 @@ def test_config_in_generated_json(self, temp_script): num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, ) recipe.add_server_config({"server_param": 123}) @@ -352,7 +338,6 @@ def test_config_in_generated_json(self, temp_script): def test_config_type_error(self, temp_script): """Test TypeError is raised for non-dict arguments.""" - from nvflare.fuel.utils.constants import FrameworkType from nvflare.recipe.fedavg import FedAvgRecipe recipe = FedAvgRecipe( @@ -360,8 +345,7 @@ def test_config_type_error(self, temp_script): num_rounds=2, min_clients=2, train_script=temp_script, - initial_ckpt="/abs/path/to/model.npy", - framework=FrameworkType.NUMPY, + model={"class_path": "model.DummyModel", "args": {}}, ) with pytest.raises(TypeError, match="config must be a dict"): @@ -369,3 +353,85 @@ def test_config_type_error(self, temp_script): with pytest.raises(TypeError, match="config must be a dict"): recipe.add_client_config(123) # type: ignore[arg-type] + + +class _DummyExecEnv: + def __init__(self): + self.extra = {} + + def get_extra_prop(self, prop_name, default=None): + return self.extra.get(prop_name, default) + + def deploy(self, job): + return "dummy-job-id" + + def get_job_status(self, job_id): + return None + + def abort_job(self, job_id): + return None + + def get_job_result(self, job_id, timeout: float = 0.0): + return None + + +class TestRecipeExecuteExportParamIsolation: + """Test that execute/export do not permanently mutate recipe additional_params.""" + + @pytest.fixture + def temp_script(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write("# Test training script\n") + temp_path = f.name + yield temp_path + os.unlink(temp_path) + + def test_execute_server_params_do_not_accumulate(self, temp_script): + from nvflare.recipe.fedavg import FedAvgRecipe + + recipe = FedAvgRecipe( + name="test_execute_param_isolation", + num_rounds=2, + min_clients=2, + train_script=temp_script, + model={"class_path": "model.DummyModel", "args": {}}, + ) + + env = _DummyExecEnv() + server_app = recipe.job._deploy_map.get("server") + assert server_app is not None + assert server_app.app_config.additional_params == {} + recipe.execute(env, server_exec_params={"param_a": 1}) + assert server_app.app_config.additional_params == {} + + recipe.execute(env, server_exec_params={"param_b": 2}) + assert server_app.app_config.additional_params == {} + + def test_execute_then_export_no_cross_contamination(self, temp_script): + from nvflare.recipe.fedavg import FedAvgRecipe + + recipe = FedAvgRecipe( + name="test_execute_export_isolation", + num_rounds=2, + min_clients=2, + train_script=temp_script, + model={"class_path": "model.DummyModel", "args": {}}, + ) + + env = _DummyExecEnv() + recipe.execute(env, server_exec_params={"from_execute": 1}) + + with tempfile.TemporaryDirectory() as tmpdir: + recipe.export(job_dir=tmpdir, server_exec_params={"from_export": 2}) + server_cfg_path = os.path.join( + tmpdir, "test_execute_export_isolation", "app", "config", "config_fed_server.json" + ) + with open(server_cfg_path) as f: + server_cfg = json.load(f) + + assert "from_execute" not in server_cfg + assert server_cfg.get("from_export") == 2 + + server_app = recipe.job._deploy_map.get("server") + assert server_app is not None + assert server_app.app_config.additional_params == {} diff --git a/tests/unit_test/recipe/utils_test.py b/tests/unit_test/recipe/utils_test.py index e1b69a9f66..19f525de15 100644 --- a/tests/unit_test/recipe/utils_test.py +++ b/tests/unit_test/recipe/utils_test.py @@ -18,7 +18,13 @@ import pytest -from nvflare.recipe.utils import prepare_initial_ckpt, validate_ckpt +from nvflare.recipe.utils import ( + extract_persistor_id, + prepare_initial_ckpt, + resolve_initial_ckpt, + setup_custom_persistor, + validate_ckpt, +) @pytest.fixture @@ -149,3 +155,116 @@ def test_multiple_calls_different_files(self, temp_workdir): assert result2 == "ckpt2.pt" assert job.add_file_to_server.call_count == 2 + + +class TestPersistorUtils: + """Tests for persistor utility helpers.""" + + def test_extract_persistor_id(self): + assert extract_persistor_id({"persistor_id": "persistor_a"}) == "persistor_a" + assert extract_persistor_id({"persistor_id": 123}) == "" + assert extract_persistor_id("persistor_b") == "persistor_b" + assert extract_persistor_id(None) == "" + + def test_setup_custom_persistor_returns_empty_when_not_provided(self): + job = MagicMock() + + result = setup_custom_persistor(job=job, model_persistor=None) + + assert result == "" + job.to_server.assert_not_called() + + def test_setup_custom_persistor_registers_component(self): + job = MagicMock() + custom_persistor = object() + job.to_server.return_value = "custom_persistor" + + result = setup_custom_persistor(job=job, model_persistor=custom_persistor) + + assert result == "custom_persistor" + job.to_server.assert_called_once_with(custom_persistor, id="persistor") + + def test_setup_custom_persistor_extracts_dict_result(self): + job = MagicMock() + custom_persistor = object() + job.to_server.return_value = {"persistor_id": "custom_from_dict"} + + result = setup_custom_persistor(job=job, model_persistor=custom_persistor) + + assert result == "custom_from_dict" + + def test_resolve_initial_ckpt_prefers_prepared_value(self): + job = MagicMock() + + result = resolve_initial_ckpt( + initial_ckpt="relative/path/model.pt", + prepared_initial_ckpt="already_prepared.pt", + job=job, + ) + + assert result == "already_prepared.pt" + job.add_file_to_server.assert_not_called() + + def test_resolve_initial_ckpt_uses_prepare_when_prepared_missing(self, monkeypatch): + calls = {} + + def fake_prepare(initial_ckpt, job): + calls["initial_ckpt"] = initial_ckpt + calls["job"] = job + return "prepared_by_helper.pt" + + monkeypatch.setattr("nvflare.recipe.utils.prepare_initial_ckpt", fake_prepare) + job = MagicMock() + + result = resolve_initial_ckpt( + initial_ckpt="relative/path/model.pt", + prepared_initial_ckpt=None, + job=job, + ) + + assert result == "prepared_by_helper.pt" + assert calls["initial_ckpt"] == "relative/path/model.pt" + assert calls["job"] is job + + +class TestRecipePackageExports: + """Tests for public API exports from nvflare.recipe.""" + + def test_add_cross_site_evaluation_importable_from_recipe(self): + """add_cross_site_evaluation must be importable from the top-level nvflare.recipe package.""" + from nvflare.recipe import add_cross_site_evaluation + + assert callable(add_cross_site_evaluation) + + +class TestCrossSiteEvalIdempotency: + """Tests for resilient idempotency in add_cross_site_evaluation.""" + + def test_idempotency_survives_missing_flag(self): + from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe + from nvflare.recipe.utils import add_cross_site_evaluation + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write("# dummy train script\n") + train_script = f.name + + try: + recipe = NumpyFedAvgRecipe( + name="test_cse_idempotency", + model=[1.0, 2.0], + min_clients=2, + num_rounds=2, + train_script=train_script, + ) + + add_cross_site_evaluation(recipe) + assert getattr(recipe, "_cse_added", False) is True + + # Simulate transient attribute loss (e.g. serialization boundary). + del recipe._cse_added + assert not hasattr(recipe, "_cse_added") + + with pytest.raises(RuntimeError, match="already been added"): + add_cross_site_evaluation(recipe) + finally: + os.unlink(train_script)