From fffa5802544b8e977198a451c08691937c45d072 Mon Sep 17 00:00:00 2001
From: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
Date: Tue, 24 Feb 2026 11:46:10 -0800
Subject: [PATCH] [2.7] Fix recipe API bug list and harden recipe behavior
(#4228)
## Issues addressed
This PR addresses the reported recipe bug list:
- `BUG-8` (Critical): `CyclicRecipe` `model=None + initial_ckpt` created
phantom persistor reference and runtime failure.
- `BUG-1` (Critical): `CyclicRecipe` dict model config caused
`TypeError`.
- `BUG-7` (Critical): base `CyclicRecipe` silently ignored
`initial_ckpt`.
- `BUG-9` (High): `FedAvgRecipe` `per_site_config[\"server\"]` caused
invalid target handling/crash.
- `BUG-4` (High): `Recipe.execute/export` mutated job state and made
recipe reuse unsafe.
- `BUG-2` (Medium): `FedAvgRecipe` per-site `or` fallback ignored falsy
overrides.
- `BUG-3` (Medium): `SimEnv` with `num_clients=0` and client list
resolved `num_threads=0`.
- `BUG-5` (Low): base `FedAvgRecipe` accepted `initial_ckpt` +
`model=None` but produced incomplete model source.
- `BUG-6` (Low): `add_cross_site_evaluation` idempotency guard depended
only on transient `_cse_added` attribute.
## Changes
- Hardened base `CyclicRecipe` model/persistor setup:
- fail-fast when a valid persistor cannot be configured;
- support dict model config for supported frameworks;
- apply/validate `initial_ckpt` for wrapper and framework-specific
paths.
- Hardened base `FedAvgRecipe`:
- validate and reject reserved `per_site_config` targets (`server`,
`@ALL`);
- preserve falsy per-site override values via explicit `is not None`
fallback;
- fail-fast when no persistor and no model params are available.
- Added shared framework persistor setup utility in recipe utils and
wired both Cyclic/FedAvg paths to it.
- Made `Recipe.execute()` and `Recipe.export()` reusable-safe by
snapshot/restore of temporary execution params.
- Fixed `SimEnv` default resolution for `num_clients`/`num_threads` when
explicit client list is provided.
- Hardened CSE idempotency with workflow-based detection in addition to
`_cse_added` fast-path.
- Updated impacted examples and exports:
- `hello-cyclic` now passes `min_clients` explicitly;
- `hello-numpy` uses transfer-type API;
- recipe module exports include `add_cross_site_evaluation`.
- Expanded recipe unit tests with targeted regressions for all above
behaviors.
## Reason for these changes
These updates enforce clear model-source contracts and remove
silent/implicit fallback behavior that could generate invalid jobs or
runtime crashes. The goal is deterministic recipe behavior, safe object
reuse, and predictable per-site/CSE configuration semantics in 2.7.
## Affected files
Core implementation:
- `nvflare/recipe/cyclic.py`
- `nvflare/recipe/fedavg.py`
- `nvflare/recipe/spec.py`
- `nvflare/recipe/sim_env.py`
- `nvflare/recipe/utils.py`
- `nvflare/recipe/poc_env.py`
- `nvflare/recipe/__init__.py`
Framework wrappers / examples:
- `nvflare/app_opt/pt/recipes/cyclic.py`
- `nvflare/app_opt/tf/recipes/cyclic.py`
- `examples/hello-world/hello-cyclic/job.py`
- `examples/hello-world/hello-numpy/job.py`
- `examples/tutorials/job_recipe.ipynb`
Tests:
- `tests/unit_test/recipe/cyclic_recipe_test.py`
- `tests/unit_test/recipe/fedavg_recipe_test.py`
- `tests/unit_test/recipe/spec_test.py`
- `tests/unit_test/recipe/sim_env_test.py`
- `tests/unit_test/recipe/utils_test.py`
## Test strategy
- Targeted regression suites:
- `pytest -q tests/unit_test/recipe/cyclic_recipe_test.py
tests/unit_test/recipe/fedavg_recipe_test.py
tests/unit_test/recipe/spec_test.py`
- `pytest -q tests/unit_test/recipe/sim_env_test.py
tests/unit_test/recipe/utils_test.py`
- Full recipe suite:
- `pytest -q tests/unit_test/recipe`
- Result: recipe suite passed (`183 passed, 13 skipped`).
Made with [Cursor](https://cursor.com)
---
examples/hello-world/hello-cyclic/job.py | 1 +
examples/hello-world/hello-numpy/job.py | 4 +-
examples/tutorials/job_recipe.ipynb | 62 +--------
nvflare/app_opt/pt/recipes/cyclic.py | 17 ++-
nvflare/app_opt/pt/recipes/fedavg.py | 52 +++++---
nvflare/app_opt/tf/recipes/cyclic.py | 9 +-
nvflare/app_opt/tf/recipes/fedavg.py | 26 ++--
nvflare/recipe/__init__.py | 4 +-
nvflare/recipe/cyclic.py | 48 ++++---
nvflare/recipe/fedavg.py | 53 ++++++--
nvflare/recipe/poc_env.py | 6 +-
nvflare/recipe/prod_env.py | 4 +-
nvflare/recipe/sim_env.py | 9 +-
nvflare/recipe/spec.py | 75 +++++++----
nvflare/recipe/utils.py | 50 ++++++-
tests/unit_test/recipe/cyclic_recipe_test.py | 95 +++++++++++++-
tests/unit_test/recipe/fedavg_recipe_test.py | 101 +++++++++++++-
tests/unit_test/recipe/poc_env_test.py | 6 +
tests/unit_test/recipe/sim_env_test.py | 6 +
tests/unit_test/recipe/spec_test.py | 130 ++++++++++++++-----
tests/unit_test/recipe/utils_test.py | 121 ++++++++++++++++-
21 files changed, 677 insertions(+), 202 deletions(-)
diff --git a/examples/hello-world/hello-cyclic/job.py b/examples/hello-world/hello-cyclic/job.py
index 9aa4807dde..a7019c2297 100644
--- a/examples/hello-world/hello-cyclic/job.py
+++ b/examples/hello-world/hello-cyclic/job.py
@@ -24,6 +24,7 @@
recipe = CyclicRecipe(
num_rounds=num_rounds,
+ min_clients=n_clients,
# Model can be specified as class instance or dict config:
model=Net(),
# Alternative: model={"class_path": "model.Net", "args": {}},
diff --git a/examples/hello-world/hello-numpy/job.py b/examples/hello-world/hello-numpy/job.py
index 6ebf99a127..975455adb3 100644
--- a/examples/hello-world/hello-numpy/job.py
+++ b/examples/hello-world/hello-numpy/job.py
@@ -17,8 +17,8 @@
"""
import argparse
-from nvflare.apis.dxo import DataKind
from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe
+from nvflare.client.config import TransferType
from nvflare.recipe import SimEnv, add_experiment_tracking
@@ -57,7 +57,7 @@ def main():
train_script="client.py",
train_args=train_args,
launch_external_process=launch_process,
- aggregator_data_kind=DataKind.WEIGHTS if args.update_type == "full" else DataKind.WEIGHT_DIFF,
+ params_transfer_type=TransferType.FULL if args.update_type == "full" else TransferType.DIFF,
)
add_experiment_tracking(recipe, tracking_type="tensorboard")
if args.export_config:
diff --git a/examples/tutorials/job_recipe.ipynb b/examples/tutorials/job_recipe.ipynb
index c44c50e92f..585ea13d11 100644
--- a/examples/tutorials/job_recipe.ipynb
+++ b/examples/tutorials/job_recipe.ipynb
@@ -141,29 +141,7 @@
"cell_type": "markdown",
"id": "b94e258c",
"metadata": {},
- "source": [
- "## Pre-trained Checkpoint Path\n",
- "\n",
- "Use `initial_ckpt` to specify a path to pre-trained model weights:\n",
- "\n",
- "```python\n",
- "recipe = FedAvgRecipe(\n",
- " model=SimpleNetwork(),\n",
- " initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path\n",
- " ...\n",
- ")\n",
- "```\n",
- "\n",
- "
\n",
- "
Checkpoint Path Requirements:\n",
- "
\n",
- "- Absolute path required: The path must be an absolute path (e.g., `/data/models/model.pt`), not relative.
\n",
- "- May not exist locally: The checkpoint file does not need to exist on the machine where you create the recipe. It only needs to exist on the server when the model is actually loaded during job execution.
\n",
- "- PyTorch requires model architecture: For PyTorch, you must provide `model` (class instance or dict config) along with `initial_ckpt`, because PyTorch checkpoints contain only weights, not architecture.
\n",
- "- TensorFlow/Keras can use checkpoint alone: Keras `.h5` or SavedModel formats contain both architecture and weights, so `initial_ckpt` can be used without `model`.
\n",
- "
\n",
- "
"
- ]
+ "source": "## Pre-trained Checkpoint Path\n\nUse `initial_ckpt` to specify a path to pre-trained model weights:\n\n```python\nrecipe = FedAvgRecipe(\n model=SimpleNetwork(),\n initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path (server-side)\n ...\n)\n```\n\n\n
Checkpoint Path Options:\n
\n- Relative path: The file is bundled into the job automatically and deployed to the server. The file must exist locally when the recipe is created.
\n- Absolute path: Treated as a server-side path and used as-is at runtime. The file does not need to exist locally — it only needs to exist on the server when the job runs.
\n- PyTorch requires model architecture: For PyTorch, you must provide `model` (class instance or dict config) along with `initial_ckpt`, because PyTorch checkpoints contain only weights, not architecture.
\n- TensorFlow/Keras can use checkpoint alone: Keras `.h5` or SavedModel formats contain both architecture and weights, so `initial_ckpt` can be used without `model`.
\n
\n
"
},
{
"cell_type": "code",
@@ -256,25 +234,7 @@
"cell_type": "markdown",
"id": "pocenv-section",
"metadata": {},
- "source": [
- "### PocEnv – Proof-of-Concept Environment\n",
- "\n",
- "Runs server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n",
- "\n",
- "Best suited for:\n",
- "* Demonstrations\n",
- "* Small-scale validation before production deployment\n",
- "* Debugging orchestration logic\n",
- "\n",
- "**Arguments:**\n",
- "* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n",
- "* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n",
- "* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n",
- "* `auto_stop` (bool, optional): Whether to automatically stop POC services after job completion.\n",
- "* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n",
- "* `docker_image` (str, optional): Docker image to use for POC.\n",
- "* `project_conf_path` (str, optional): Path to the project configuration file."
- ]
+ "source": "### PocEnv – Proof-of-Concept Environment\n\nRuns server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n\nBest suited for:\n* Demonstrations\n* Small-scale validation before production deployment\n* Debugging orchestration logic\n\n**Arguments:**\n* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n* `docker_image` (str, optional): Docker image to use for POC.\n* `project_conf_path` (str, optional): Path to the project configuration file."
},
{
"cell_type": "markdown",
@@ -331,21 +291,7 @@
"cell_type": "markdown",
"id": "prodenv-section",
"metadata": {},
- "source": [
- "### ProdEnv – Production Environment\n",
- "\n",
- "We assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n",
- "\n",
- "Best suited for:\n",
- "* Enterprise federated learning deployments\n",
- "* Multi-institution collaborations\n",
- "* Production-scale workloads\n",
- "\n",
- "**Arguments:**\n",
- "* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n",
- "* `login_timeout` (float): timeout value for the admin to login to the system\n",
- "* `monitor_job_duration` (int): duration to monitor the job execution, None means no monitoring at all"
- ]
+ "source": "### ProdEnv – Production Environment\n\nWe assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n\nBest suited for:\n* Enterprise federated learning deployments\n* Multi-institution collaborations\n* Production-scale workloads\n\n**Arguments:**\n* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n* `login_timeout` (float): timeout value for the admin to login to the system"
},
{
"cell_type": "markdown",
@@ -522,4 +468,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
+}
\ No newline at end of file
diff --git a/nvflare/app_opt/pt/recipes/cyclic.py b/nvflare/app_opt/pt/recipes/cyclic.py
index 6145da312f..2f335743da 100644
--- a/nvflare/app_opt/pt/recipes/cyclic.py
+++ b/nvflare/app_opt/pt/recipes/cyclic.py
@@ -18,6 +18,7 @@
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe
+from nvflare.recipe.utils import extract_persistor_id
class CyclicRecipe(BaseCyclicRecipe):
@@ -103,11 +104,17 @@ def _setup_model_and_persistor(self, job) -> str:
# If model is already a PTModel wrapper (user passed PTModel directly), use as-is
if hasattr(self.model, "add_to_fed_job"):
result = job.to_server(self.model, id="persistor")
- return result["persistor_id"]
+ return extract_persistor_id(result)
- from nvflare.recipe.utils import prepare_initial_ckpt
+ from nvflare.recipe.utils import resolve_initial_ckpt
- ckpt_path = prepare_initial_ckpt(self._pt_initial_ckpt, job)
- pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path)
+ ckpt_path = resolve_initial_ckpt(self._pt_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
+ if self.model is None and ckpt_path:
+ raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.")
+ if self.model is None:
+ return ""
+
+ allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH
+ pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path, allow_numpy_conversion=allow_numpy_conversion)
result = job.to_server(pt_model, id="persistor")
- return result["persistor_id"]
+ return extract_persistor_id(result)
diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py
index dca486cab8..d81cecbd82 100644
--- a/nvflare/app_opt/pt/recipes/fedavg.py
+++ b/nvflare/app_opt/pt/recipes/fedavg.py
@@ -162,21 +162,37 @@ def __init__(
def _setup_model_and_persistor(self, job) -> str:
"""Override to handle PyTorch-specific model setup."""
- if self.model is not None or self.initial_ckpt is not None:
- from nvflare.app_opt.pt.job_config.model import PTModel
- from nvflare.recipe.utils import prepare_initial_ckpt
-
- # Disable numpy conversion when using tensor format to keep PyTorch tensors
- allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH
-
- ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
- pt_model = PTModel(
- model=self.model,
- initial_ckpt=ckpt_path,
- persistor=self.model_persistor,
- locator=self._pt_model_locator,
- allow_numpy_conversion=allow_numpy_conversion,
- )
- job.comp_ids.update(job.to_server(pt_model))
- return job.comp_ids.get("persistor_id", "")
- return ""
+ from nvflare.app_opt.pt.job_config.model import PTModel
+ from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor
+
+ persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor)
+ if persistor_id:
+ if hasattr(job, "comp_ids"):
+ job.comp_ids["persistor_id"] = persistor_id
+ if self._pt_model_locator is not None:
+ locator_id = job.to_server(self._pt_model_locator, id="locator")
+ if isinstance(locator_id, str) and locator_id:
+ job.comp_ids["locator_id"] = locator_id
+ return persistor_id
+
+ ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
+ if self.model is None and ckpt_path:
+ raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.")
+ if self.model is None:
+ return ""
+
+ # Disable numpy conversion when using tensor format to keep PyTorch tensors.
+ allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH
+ pt_model = PTModel(
+ model=self.model,
+ initial_ckpt=ckpt_path,
+ locator=self._pt_model_locator,
+ allow_numpy_conversion=allow_numpy_conversion,
+ )
+ result = job.to_server(pt_model, id="persistor")
+ if isinstance(result, dict) and hasattr(job, "comp_ids"):
+ job.comp_ids.update(result)
+ persistor_id = extract_persistor_id(result)
+ if persistor_id and hasattr(job, "comp_ids"):
+ job.comp_ids.setdefault("persistor_id", persistor_id)
+ return persistor_id
diff --git a/nvflare/app_opt/tf/recipes/cyclic.py b/nvflare/app_opt/tf/recipes/cyclic.py
index bc531295fa..49f8ce3505 100644
--- a/nvflare/app_opt/tf/recipes/cyclic.py
+++ b/nvflare/app_opt/tf/recipes/cyclic.py
@@ -18,6 +18,7 @@
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe
+from nvflare.recipe.utils import extract_persistor_id
class CyclicRecipe(BaseCyclicRecipe):
@@ -103,11 +104,11 @@ def _setup_model_and_persistor(self, job) -> str:
# If model is already a TFModel wrapper (user passed TFModel directly), use as-is
if hasattr(self.model, "add_to_fed_job"):
result = job.to_server(self.model, id="persistor")
- return result["persistor_id"]
+ return extract_persistor_id(result)
- from nvflare.recipe.utils import prepare_initial_ckpt
+ from nvflare.recipe.utils import resolve_initial_ckpt
- ckpt_path = prepare_initial_ckpt(self._tf_initial_ckpt, job)
+ ckpt_path = resolve_initial_ckpt(self._tf_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path)
result = job.to_server(tf_model, id="persistor")
- return result["persistor_id"]
+ return extract_persistor_id(result)
diff --git a/nvflare/app_opt/tf/recipes/fedavg.py b/nvflare/app_opt/tf/recipes/fedavg.py
index 4b632d7f72..674236c204 100644
--- a/nvflare/app_opt/tf/recipes/fedavg.py
+++ b/nvflare/app_opt/tf/recipes/fedavg.py
@@ -145,16 +145,16 @@ def __init__(
def _setup_model_and_persistor(self, job) -> str:
"""Override to handle TensorFlow-specific model setup."""
- if self.model is not None or self.initial_ckpt is not None:
- from nvflare.app_opt.tf.job_config.model import TFModel
- from nvflare.recipe.utils import prepare_initial_ckpt
-
- ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
- tf_model = TFModel(
- model=self.model,
- initial_ckpt=ckpt_path,
- persistor=self.model_persistor,
- )
- job.comp_ids["persistor_id"] = job.to_server(tf_model)
- return job.comp_ids.get("persistor_id", "")
- return ""
+ from nvflare.app_opt.tf.job_config.model import TFModel
+ from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor
+
+ persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor)
+ if persistor_id:
+ return persistor_id
+
+ ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
+ if self.model is None and not ckpt_path:
+ return ""
+
+ tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path)
+ return extract_persistor_id(job.to_server(tf_model, id="persistor"))
diff --git a/nvflare/recipe/__init__.py b/nvflare/recipe/__init__.py
index ecbe7ad1b1..d3590becdb 100644
--- a/nvflare/recipe/__init__.py
+++ b/nvflare/recipe/__init__.py
@@ -17,6 +17,6 @@
from .prod_env import ProdEnv
from .run import Run
from .sim_env import SimEnv
-from .utils import add_experiment_tracking
+from .utils import add_cross_site_evaluation, add_experiment_tracking
-__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "FedAvgRecipe"]
+__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "add_cross_site_evaluation", "FedAvgRecipe"]
diff --git a/nvflare/recipe/cyclic.py b/nvflare/recipe/cyclic.py
index e2406c38fb..e93ea97e0b 100644
--- a/nvflare/recipe/cyclic.py
+++ b/nvflare/recipe/cyclic.py
@@ -141,6 +141,7 @@ def __init__(
if isinstance(self.model, dict):
self.model = recipe_model_to_job_model(self.model)
+ self.min_clients = v.min_clients
self.num_rounds = v.num_rounds
self.train_script = v.train_script
self.train_args = v.train_args
@@ -162,13 +163,16 @@ def __init__(
# Setup model persistor first - subclasses override for framework-specific handling
persistor_id = self._setup_model_and_persistor(job)
- # Use returned persistor_id or default to "persistor"
if not persistor_id:
- persistor_id = "persistor"
+ raise ValueError(
+ "Unable to configure a model persistor for CyclicRecipe. "
+ "Provide a supported model/framework combination (PyTorch or TensorFlow), "
+ "or pass a framework-specific model wrapper with add_to_fed_job()."
+ )
# Define the controller workflow and send to server
controller = CyclicController(
- num_rounds=num_rounds,
+ num_rounds=self.num_rounds,
task_assignment_timeout=10,
persistor_id=persistor_id,
shareable_generator_id="shareable_generator",
@@ -194,25 +198,39 @@ def __init__(
super().__init__(job)
def _setup_model_and_persistor(self, job) -> str:
- """Setup framework-specific model components and persistor.
+ """Setup model wrapper persistor.
- Handles PTModel/TFModel wrappers passed by framework-specific subclasses.
+ Handles the following model inputs:
+ - framework-specific wrappers (objects with ``add_to_fed_job``)
Returns:
str: The persistor_id to be used by the controller.
"""
- if self.model is None:
- return ""
+ from nvflare.recipe.utils import extract_persistor_id, prepare_initial_ckpt
- # Check if model is a model wrapper (PTModel, TFModel)
+ ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
+
+ # Model wrapper path (PTModel/TFModel or custom wrapper)
if hasattr(self.model, "add_to_fed_job"):
- # It's a model wrapper - use its add_to_fed_job method
+ if ckpt_path:
+ if not hasattr(self.model, "initial_ckpt"):
+ raise ValueError(
+ f"initial_ckpt is provided, but model wrapper {type(self.model).__name__} "
+ "does not support 'initial_ckpt'."
+ )
+ existing_ckpt = getattr(self.model, "initial_ckpt", None)
+ if existing_ckpt and existing_ckpt != ckpt_path:
+ raise ValueError(
+ f"Conflicting checkpoint values: model wrapper has initial_ckpt={existing_ckpt}, "
+ f"but recipe initial_ckpt={ckpt_path}."
+ )
+ setattr(self.model, "initial_ckpt", ckpt_path)
+
result = job.to_server(self.model, id="persistor")
- return result["persistor_id"]
+ return extract_persistor_id(result)
- # Unknown model type
- raise TypeError(
- f"Unsupported model type: {type(self.model).__name__}. "
- f"Use a framework-specific recipe (PTCyclicRecipe, TFCyclicRecipe, etc.) "
- f"or wrap your model in PTModel/TFModel."
+ raise ValueError(
+ f"Unsupported framework '{self.framework}' for base CyclicRecipe model persistence. "
+ "Use a framework-specific CyclicRecipe subclass, or pass a framework-specific "
+ "model wrapper with add_to_fed_job()."
)
diff --git a/nvflare/recipe/fedavg.py b/nvflare/recipe/fedavg.py
index fd59424413..ad1d372af6 100644
--- a/nvflare/recipe/fedavg.py
+++ b/nvflare/recipe/fedavg.py
@@ -17,6 +17,7 @@
from pydantic import BaseModel
from nvflare.apis.dxo import DataKind
+from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.workflows.fedavg import FedAvg
@@ -236,6 +237,7 @@ def __init__(
self.params_transfer_type = v.params_transfer_type
self.model_persistor = v.model_persistor
self.per_site_config = v.per_site_config
+ self._validate_per_site_config(self.per_site_config)
self.launch_once = v.launch_once
self.shutdown_timeout = v.shutdown_timeout
self.key_metric = v.key_metric
@@ -274,6 +276,12 @@ def __init__(
has_persistor = persistor_id != ""
model_params = None if has_persistor else self._get_model_params()
+ if not has_persistor and model_params is None:
+ raise ValueError(
+ "Unable to configure a model source for FedAvgRecipe: no persistor and no model parameters. "
+ "Use a framework-specific recipe for checkpoint-only initialization, or provide model/model_persistor."
+ )
+
# Prepare aggregator for controller - must be ModelAggregator for FLModel-based aggregation
model_aggregator = self._get_model_aggregator()
@@ -311,10 +319,18 @@ def __init__(
if site_config.get("launch_external_process") is not None
else self.launch_external_process
)
- command = site_config.get("command") or self.command
- framework = site_config.get("framework") or self.framework
- expected_format = site_config.get("server_expected_format") or self.server_expected_format
- transfer_type = site_config.get("params_transfer_type") or self.params_transfer_type
+ command = site_config.get("command") if site_config.get("command") is not None else self.command
+ framework = site_config.get("framework") if site_config.get("framework") is not None else self.framework
+ expected_format = (
+ site_config.get("server_expected_format")
+ if site_config.get("server_expected_format") is not None
+ else self.server_expected_format
+ )
+ transfer_type = (
+ site_config.get("params_transfer_type")
+ if site_config.get("params_transfer_type") is not None
+ else self.params_transfer_type
+ )
launch_once = (
site_config.get("launch_once") if site_config.get("launch_once") is not None else self.launch_once
)
@@ -352,6 +368,23 @@ def __init__(
Recipe.__init__(self, job)
+ @staticmethod
+ def _validate_per_site_config(per_site_config: Optional[Dict[str, Dict]]) -> None:
+ if per_site_config is None:
+ return
+
+ reserved_targets = {SERVER_SITE_NAME, ALL_SITES}
+ for site_name, site_config in per_site_config.items():
+ if not isinstance(site_name, str):
+ raise ValueError(f"per_site_config key must be str, got {type(site_name).__name__}")
+ if site_name in reserved_targets:
+ raise ValueError(
+ f"'{site_name}' is a reserved target name and cannot be used in per_site_config. "
+ f"Reserved names: {sorted(reserved_targets)}"
+ )
+ if not isinstance(site_config, dict):
+ raise ValueError(f"per_site_config['{site_name}'] must be a dict, got {type(site_config).__name__}")
+
def _get_model_params(self) -> Optional[Dict]:
"""Convert model to dict of params.
@@ -406,14 +439,14 @@ def _get_model_aggregator(self):
return None
def _setup_model_and_persistor(self, job: BaseFedJob) -> str:
- """Setup framework-specific model components and persistor.
+ """Setup generic custom persistor only.
- Base implementation handles custom persistor. Framework-specific subclasses
- should override this to use PTModel/TFModel for their model types.
+ Framework-specific recipes (PT/TF/NumPy) override this method to build and
+ register their model wrappers and default persistors.
Returns:
str: The persistor_id to be used by the controller.
"""
- if self.model_persistor is not None:
- return job.to_server(self.model_persistor, id="persistor")
- return ""
+ from nvflare.recipe.utils import setup_custom_persistor
+
+ return setup_custom_persistor(job=job, model_persistor=self.model_persistor)
diff --git a/nvflare/recipe/poc_env.py b/nvflare/recipe/poc_env.py
index ac6a67fadf..5b7c12a42b 100644
--- a/nvflare/recipe/poc_env.py
+++ b/nvflare/recipe/poc_env.py
@@ -21,7 +21,7 @@
from nvflare.job_config.api import FedJob
from nvflare.recipe.spec import ExecEnv
-from nvflare.recipe.utils import _collect_non_local_scripts
+from nvflare.recipe.utils import collect_non_local_scripts
from nvflare.tool.poc.poc_commands import (
_clean_poc,
_start_poc,
@@ -64,7 +64,7 @@ def check_client_configuration(self):
)
# Check if num_clients is valid when clients is None
- if self.clients is None and self.num_clients <= 0:
+ if self.clients is None and (self.num_clients is None or self.num_clients <= 0):
raise ValueError("num_clients must be greater than 0")
return self
@@ -135,7 +135,7 @@ def deploy(self, job: FedJob):
str: Job ID or deployment result.
"""
# Validate scripts exist locally for POC
- non_local_scripts = _collect_non_local_scripts(job)
+ non_local_scripts = collect_non_local_scripts(job)
if non_local_scripts:
raise ValueError(
f"The following scripts do not exist locally: {non_local_scripts}. "
diff --git a/nvflare/recipe/prod_env.py b/nvflare/recipe/prod_env.py
index de7dc6f9f9..485604e719 100644
--- a/nvflare/recipe/prod_env.py
+++ b/nvflare/recipe/prod_env.py
@@ -20,7 +20,7 @@
from nvflare.job_config.api import FedJob
from nvflare.recipe.spec import ExecEnv
-from nvflare.recipe.utils import _collect_non_local_scripts
+from nvflare.recipe.utils import collect_non_local_scripts
from .session_mgr import SessionManager
@@ -85,7 +85,7 @@ def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
def deploy(self, job: FedJob):
"""Deploy a job using SessionManager."""
# Log warnings for non-local scripts (assumed pre-installed on production)
- non_local_scripts = _collect_non_local_scripts(job)
+ non_local_scripts = collect_non_local_scripts(job)
for script in non_local_scripts:
logger.warning(
f"Script '{script}' not found locally. " f"Assuming it is pre-installed on the production system."
diff --git a/nvflare/recipe/sim_env.py b/nvflare/recipe/sim_env.py
index 1fe5c18be7..71e82ea6b3 100644
--- a/nvflare/recipe/sim_env.py
+++ b/nvflare/recipe/sim_env.py
@@ -20,7 +20,7 @@
from nvflare.job_config.api import FedJob
from .spec import ExecEnv
-from .utils import _collect_non_local_scripts
+from .utils import collect_non_local_scripts
WORKSPACE_ROOT = "/tmp/nvflare/simulation"
@@ -87,8 +87,9 @@ def __init__(
workspace_root=workspace_root,
)
- self.num_clients = v.num_clients
- self.num_threads = v.num_threads if v.num_threads is not None else v.num_clients
+ resolved_num_clients = v.num_clients if v.num_clients > 0 else len(v.clients or [])
+ self.num_clients = resolved_num_clients
+ self.num_threads = v.num_threads if v.num_threads is not None else resolved_num_clients
self.gpu_config = v.gpu_config
self.log_config = v.log_config
self.clients = v.clients
@@ -96,7 +97,7 @@ def __init__(
def deploy(self, job: FedJob):
# Validate scripts exist locally for simulation
- non_local_scripts = _collect_non_local_scripts(job)
+ non_local_scripts = collect_non_local_scripts(job)
if non_local_scripts:
raise ValueError(
f"The following scripts do not exist locally: {non_local_scripts}. "
diff --git a/nvflare/recipe/spec.py b/nvflare/recipe/spec.py
index 52b290a56f..27f5e023ef 100644
--- a/nvflare/recipe/spec.py
+++ b/nvflare/recipe/spec.py
@@ -13,6 +13,7 @@
# limitations under the License.
from abc import ABC, abstractmethod
+from contextlib import contextmanager
from typing import Dict, List, Optional, Union
from nvflare.apis.filter import Filter
@@ -30,7 +31,7 @@ def __init__(self, extra: Optional[dict] = None):
Args:
extra: a dict of extra properties
"""
- if not extra:
+ if extra is None:
extra = {}
if not isinstance(extra, dict):
raise ValueError(f"extra must be dict but got {type(extra)}")
@@ -114,6 +115,47 @@ def process_env(self, env: ExecEnv):
"""
pass
+ def _snapshot_additional_params(self) -> Dict[str, Dict]:
+ snapshot = {}
+ deploy_map = getattr(self.job, "_deploy_map", {})
+ for target, app in deploy_map.items():
+ app_config = getattr(app, "app_config", None)
+ if app_config is None:
+ continue
+ params = getattr(app_config, "additional_params", None)
+ if isinstance(params, dict):
+ snapshot[target] = dict(params)
+ return snapshot
+
+ def _restore_additional_params(self, snapshot: Dict[str, Dict]) -> None:
+ deploy_map = getattr(self.job, "_deploy_map", {})
+ for target, app in deploy_map.items():
+ app_config = getattr(app, "app_config", None)
+ if app_config is None:
+ continue
+ params = getattr(app_config, "additional_params", None)
+ if isinstance(params, dict):
+ original = snapshot.get(target, {})
+ params.clear()
+ params.update(original)
+
+ @contextmanager
+ def _temporary_exec_params(self, server_exec_params: dict = None, client_exec_params: dict = None):
+ params_snapshot = None
+ if server_exec_params or client_exec_params:
+ params_snapshot = self._snapshot_additional_params()
+
+ try:
+ if server_exec_params:
+ self.job.to_server(server_exec_params)
+
+ if client_exec_params:
+ self._add_to_client_apps(client_exec_params)
+ yield
+ finally:
+ if params_snapshot is not None:
+ self._restore_additional_params(params_snapshot)
+
def _add_to_client_apps(self, obj, clients: Optional[List[str]] = None, **kwargs):
"""Add an object to client apps, preserving existing per-site structure.
@@ -328,16 +370,10 @@ def export(
Returns: None
"""
- if server_exec_params:
- self.job.to_server(server_exec_params)
-
- if client_exec_params:
- self._add_to_client_apps(client_exec_params)
-
- if env:
- self.process_env(env)
-
- self.job.export_job(job_dir)
+ with self._temporary_exec_params(server_exec_params=server_exec_params, client_exec_params=client_exec_params):
+ if env:
+ self.process_env(env)
+ self.job.export_job(job_dir)
def execute(
self, env: ExecEnv, server_exec_params: Optional[dict] = None, client_exec_params: Optional[dict] = None
@@ -352,15 +388,10 @@ def execute(
Returns: Run to get job ID and execution results
"""
- if server_exec_params:
- self.job.to_server(server_exec_params)
-
- if client_exec_params:
- self._add_to_client_apps(client_exec_params)
-
- self.process_env(env)
- job_id = env.deploy(self.job)
- from nvflare.recipe.run import Run
+ with self._temporary_exec_params(server_exec_params=server_exec_params, client_exec_params=client_exec_params):
+ self.process_env(env)
+ job_id = env.deploy(self.job)
+ from nvflare.recipe.run import Run
- run = Run(env, job_id)
- return run
+ run = Run(env, job_id)
+ return run
diff --git a/nvflare/recipe/utils.py b/nvflare/recipe/utils.py
index 5bdbccf59f..ab041c833e 100644
--- a/nvflare/recipe/utils.py
+++ b/nvflare/recipe/utils.py
@@ -59,6 +59,27 @@
}
+def _has_cross_site_eval_workflow(job: FedJob) -> bool:
+ """Check if CrossSiteModelEval workflow is already configured on server."""
+ from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
+
+ deploy_map = getattr(job, "_deploy_map", {})
+ server_app = deploy_map.get("server")
+ if not server_app or not hasattr(server_app, "app_config"):
+ return False
+
+ workflows = getattr(server_app.app_config, "workflows", [])
+ for w in workflows:
+ # Server stores workflow definitions as wrapper objects (e.g. WorkFlow)
+ # with the actual controller on `controller`.
+ if isinstance(w, CrossSiteModelEval):
+ return True
+ controller = getattr(w, "controller", None)
+ if controller is not None and isinstance(controller, CrossSiteModelEval):
+ return True
+ return False
+
+
def add_experiment_tracking(
recipe: Recipe,
tracking_type: str,
@@ -250,8 +271,10 @@ def add_cross_site_evaluation(
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
from nvflare.job_config.script_runner import FrameworkType
- # Idempotency check: prevent multiple calls on the same recipe
- if hasattr(recipe, "_cse_added") and recipe._cse_added:
+ # Idempotency check: prevent multiple calls on the same recipe.
+ # Keep the explicit flag fast-path, but also verify server workflow state so
+ # protection remains effective even if dynamic attributes are lost.
+ if getattr(recipe, "_cse_added", False) or _has_cross_site_eval_workflow(recipe.job):
name = recipe.name if hasattr(recipe, "name") else "cross-site-evaluation job"
raise RuntimeError(
f"Cross-site evaluation has already been added to recipe '{name}'. "
@@ -399,7 +422,7 @@ def _has_task_executor(job, task_name: str) -> bool:
return False
-def _collect_non_local_scripts(job: FedJob) -> List[str]:
+def collect_non_local_scripts(job: FedJob) -> List[str]:
"""Collect scripts that don't exist locally.
This utility function is used by ExecEnv subclasses to validate script resources
@@ -498,6 +521,27 @@ def prepare_initial_ckpt(initial_ckpt: Optional[str], job) -> Optional[str]:
return os.path.basename(initial_ckpt)
+def extract_persistor_id(result: Any) -> str:
+ if isinstance(result, dict):
+ persistor_id = result.get("persistor_id", "")
+ return persistor_id if isinstance(persistor_id, str) else ""
+ if isinstance(result, str):
+ return result
+ return ""
+
+
+def resolve_initial_ckpt(initial_ckpt: Optional[str], prepared_initial_ckpt: Optional[str], job) -> Optional[str]:
+ if prepared_initial_ckpt is not None:
+ return prepared_initial_ckpt
+ return prepare_initial_ckpt(initial_ckpt, job)
+
+
+def setup_custom_persistor(*, job, model_persistor=None) -> str:
+ if model_persistor is None:
+ return ""
+ return extract_persistor_id(job.to_server(model_persistor, id="persistor"))
+
+
def validate_dict_model_config(model: Any) -> None:
"""Validate recipe dict model config structure.
diff --git a/tests/unit_test/recipe/cyclic_recipe_test.py b/tests/unit_test/recipe/cyclic_recipe_test.py
index 8374d1c69e..894355737a 100644
--- a/tests/unit_test/recipe/cyclic_recipe_test.py
+++ b/tests/unit_test/recipe/cyclic_recipe_test.py
@@ -19,6 +19,8 @@
import pytest
import torch.nn as nn
+from nvflare.app_opt.pt.job_config.model import PTModel
+from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe
@@ -58,15 +60,10 @@ def base_recipe_params():
class TestBaseCyclicRecipe:
- """Test cases for base CyclicRecipe class.
-
- Note: Base CyclicRecipe doesn't directly support nn.Module or dict config.
- Use framework-specific recipes (PTCyclicRecipe, TFCyclicRecipe) for those.
- """
+ """Test cases for base CyclicRecipe class."""
def test_initial_ckpt_must_exist_for_relative_path(self):
"""Test that non-existent relative paths are rejected (no mock - validation must run)."""
- # Don't use mock_file_system fixture - we need real validation
with pytest.raises(ValueError, match="does not exist locally"):
BaseCyclicRecipe(
name="test_relative",
@@ -87,6 +84,76 @@ def test_requires_model_or_checkpoint(self, base_recipe_params):
**base_recipe_params,
)
+ def test_rejects_non_wrapper_model_for_base_recipe(self, mock_file_system, base_recipe_params):
+ """Base CyclicRecipe no longer owns PT/TF model persistence for raw model inputs."""
+ with pytest.raises(ValueError, match="Use a framework-specific CyclicRecipe subclass"):
+ BaseCyclicRecipe(
+ name="test_base_pt_dict",
+ model={"class_path": "torch.nn.Linear", "args": {"in_features": 10, "out_features": 2}},
+ framework=FrameworkType.PYTORCH,
+ **base_recipe_params,
+ )
+
+ def test_rejects_pytorch_checkpoint_without_model(self, mock_file_system, base_recipe_params):
+ """Base recipe requires framework-specific subclass for PT checkpoint-only setup."""
+ with pytest.raises(ValueError, match="Use a framework-specific CyclicRecipe subclass"):
+ BaseCyclicRecipe(
+ name="test_base_pt_ckpt_no_model",
+ model=None,
+ initial_ckpt="/abs/path/to/model.pt",
+ framework=FrameworkType.PYTORCH,
+ **base_recipe_params,
+ )
+
+ def test_rejects_ckpt_only_for_default_framework(self, mock_file_system, base_recipe_params):
+ """Fail fast for ckpt-only usage when no supported framework/wrapper is provided."""
+ with pytest.raises(ValueError, match="Unsupported framework"):
+ BaseCyclicRecipe(
+ name="test_ckpt_only_default_framework",
+ model=None,
+ initial_ckpt="/abs/path/to/model.pt",
+ **base_recipe_params,
+ )
+
+ def test_applies_initial_ckpt_to_wrapper_model(self, mock_file_system, base_recipe_params, simple_model):
+ """When wrapper model is used, recipe-level initial_ckpt should be applied to persistor."""
+ recipe = BaseCyclicRecipe(
+ name="test_wrapper_ckpt",
+ model=PTModel(model=simple_model),
+ initial_ckpt="/abs/path/to/model.pt",
+ framework=FrameworkType.PYTORCH,
+ **base_recipe_params,
+ )
+
+ server_app = recipe.job._deploy_map.get("server")
+ persistor = server_app.app_config.components.get("persistor")
+ assert persistor is not None
+ assert getattr(persistor, "source_ckpt_file_full_name", None) == "/abs/path/to/model.pt"
+
+ def test_rejects_unsupported_framework_without_wrapper(self, mock_file_system, base_recipe_params):
+ """Fail fast for unsupported base framework/model persistence combinations."""
+ with pytest.raises(ValueError, match="Unsupported framework"):
+ BaseCyclicRecipe(
+ name="test_unsupported_framework",
+ model={"class_path": "torch.nn.Linear", "args": {"in_features": 10, "out_features": 2}},
+ framework=FrameworkType.NUMPY,
+ **base_recipe_params,
+ )
+
+
+class TestBaseCyclicRecipeAttributes:
+ """Test that CyclicRecipe stores validated attributes correctly."""
+
+ def test_min_clients_attribute(self, mock_file_system, base_recipe_params, simple_model):
+ """min_clients must be accessible as an instance attribute after construction."""
+ recipe = BaseCyclicRecipe(
+ name="test_min_clients",
+ model=PTModel(model=simple_model),
+ framework=FrameworkType.PYTORCH,
+ **base_recipe_params,
+ )
+ assert recipe.min_clients == base_recipe_params["min_clients"]
+
class TestPTCyclicRecipe:
"""Test cases for PyTorch CyclicRecipe."""
@@ -105,6 +172,22 @@ def test_pt_cyclic_initial_ckpt(self, mock_file_system, base_recipe_params, simp
assert recipe.name == "test_pt_cyclic"
assert recipe.job is not None
+ def test_pt_cyclic_with_ptmodel_wrapper_returns_persistor_id(
+ self, mock_file_system, base_recipe_params, simple_model
+ ):
+ """PTModel wrapper path must correctly extract persistor_id from dict return."""
+ from nvflare.app_opt.pt.recipes.cyclic import CyclicRecipe as PTCyclicRecipe
+
+ recipe = PTCyclicRecipe(
+ name="test_pt_wrapper",
+ model=PTModel(model=simple_model),
+ **base_recipe_params,
+ )
+
+ server_app = recipe.job._deploy_map.get("server")
+ assert server_app is not None
+ assert "persistor" in server_app.app_config.components
+
class TestTFCyclicRecipe:
"""Test cases for TensorFlow CyclicRecipe."""
diff --git a/tests/unit_test/recipe/fedavg_recipe_test.py b/tests/unit_test/recipe/fedavg_recipe_test.py
index 6900d189a2..930bc97ab5 100644
--- a/tests/unit_test/recipe/fedavg_recipe_test.py
+++ b/tests/unit_test/recipe/fedavg_recipe_test.py
@@ -17,8 +17,10 @@
import pytest
import torch.nn as nn
-from nvflare.apis.job_def import SERVER_SITE_NAME
+from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.app_common.abstract.fl_model import FLModel
+from nvflare.app_common.abstract.model_locator import ModelLocator
+from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.aggregators.model_aggregator import ModelAggregator
from nvflare.app_common.np.recipes import NumpyFedAvgRecipe
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
@@ -79,6 +81,26 @@ def __init__(self):
pass
+class DummyPersistor(ModelPersistor):
+ """Minimal ModelPersistor used to test custom persistor wiring."""
+
+ def load_model(self, fl_ctx):
+ return {}
+
+ def save_model(self, model, fl_ctx):
+ return None
+
+
+class DummyLocator(ModelLocator):
+ """Minimal ModelLocator used to test explicit locator registration."""
+
+ def get_model_names(self, fl_ctx):
+ return []
+
+ def locate_model(self, model_name, fl_ctx):
+ return None
+
+
@pytest.fixture
def mock_file_system():
"""Mock file system operations for all tests."""
@@ -414,6 +436,81 @@ def test_invalid_aggregator_type_raises_validation_error(self, mock_file_system,
**base_recipe_params,
)
+ def test_per_site_config_rejects_reserved_server_target(self, mock_file_system, base_recipe_params, simple_model):
+ """Reserved target 'server' must not be allowed in per_site_config."""
+ with pytest.raises(ValueError, match="reserved target name"):
+ FedAvgRecipe(
+ name="test_reserved_server_target",
+ model=simple_model,
+ per_site_config={"server": {}},
+ **base_recipe_params,
+ )
+
+ def test_per_site_config_rejects_reserved_all_sites_target(
+ self, mock_file_system, base_recipe_params, simple_model
+ ):
+ """Reserved target '@ALL' must not be allowed in per_site_config."""
+ with pytest.raises(ValueError, match="reserved target name"):
+ FedAvgRecipe(
+ name="test_reserved_all_sites_target",
+ model=simple_model,
+ per_site_config={ALL_SITES: {}},
+ **base_recipe_params,
+ )
+
+ def test_per_site_empty_command_override_is_preserved(self, mock_file_system, base_recipe_params, simple_model):
+ """Falsy per-site override values (e.g. command='') must not be replaced by defaults."""
+ recipe = FedAvgRecipe(
+ name="test_empty_command_override",
+ model=simple_model,
+ launch_external_process=True,
+ per_site_config={"site-1": {"command": ""}},
+ **base_recipe_params,
+ )
+
+ site_app = recipe.job._deploy_map.get("site-1")
+ assert site_app is not None
+ launcher = site_app.app_config.components.get("launcher")
+ assert launcher is not None
+ assert "python3 -u" not in launcher._script
+ assert launcher._script.startswith(" custom/")
+
+ def test_custom_model_persistor_tracks_persistor_id(self, mock_file_system, base_recipe_params, simple_model):
+ """Custom PT persistor path should persist comp_ids['persistor_id'] for later workflows."""
+ recipe = FedAvgRecipe(
+ name="test_custom_persistor_comp_id",
+ model=simple_model,
+ model_persistor=DummyPersistor(),
+ **base_recipe_params,
+ )
+
+ persistor_id = recipe.job.comp_ids.get("persistor_id", "")
+ assert persistor_id
+ assert "locator_id" not in recipe.job.comp_ids
+ server_app = recipe.job._deploy_map.get(SERVER_SITE_NAME)
+ assert server_app is not None
+ assert persistor_id in server_app.app_config.components
+
+ def test_custom_model_persistor_with_locator_registers_locator(
+ self, mock_file_system, base_recipe_params, simple_model
+ ):
+ """If custom model_locator is provided, it should be registered even on custom persistor path."""
+ locator = DummyLocator()
+ recipe = FedAvgRecipe(
+ name="test_custom_persistor_with_locator",
+ model=simple_model,
+ model_persistor=DummyPersistor(),
+ model_locator=locator,
+ **base_recipe_params,
+ )
+
+ assert recipe.job.comp_ids.get("persistor_id", "")
+ locator_id = recipe.job.comp_ids.get("locator_id", "")
+ assert locator_id
+ server_app = recipe.job._deploy_map.get(SERVER_SITE_NAME)
+ assert server_app is not None
+ assert server_app.app_config.components.get(locator_id) is locator
+
def test_dict_config_missing_path_raises_error(self, mock_file_system, base_recipe_params):
"""Test that dict config without 'class_path' key raises error."""
with pytest.raises(ValueError, match="must have 'class_path' key"):
@@ -452,7 +549,7 @@ def test_initial_ckpt_with_none_model_not_allowed_for_pt(self, mock_file_system,
"""Test that PT FedAvg rejects initial_ckpt with None model (PT needs architecture)."""
# PyTorch requires model architecture even when loading from checkpoint
# TensorFlow can load full models, but PT cannot
- with pytest.raises(ValueError, match="Unable to add None to job"):
+ with pytest.raises(ValueError, match="FrameworkType.PYTORCH requires 'model' when using initial_ckpt"):
FedAvgRecipe(
name="test_ckpt_no_model",
model=None,
diff --git a/tests/unit_test/recipe/poc_env_test.py b/tests/unit_test/recipe/poc_env_test.py
index b710a7e1e4..ee0656cb8a 100644
--- a/tests/unit_test/recipe/poc_env_test.py
+++ b/tests/unit_test/recipe/poc_env_test.py
@@ -62,6 +62,12 @@ def test_poc_env_validation():
PocEnv(num_clients=3, clients=["site1", "site2"])
+def test_poc_env_none_num_clients_raises():
+ """Test that PocEnv(num_clients=None) raises ValueError instead of crashing with TypeError."""
+ with pytest.raises(ValueError, match="num_clients must be greater than 0"):
+ PocEnv(num_clients=None, clients=None)
+
+
def test_poc_env_client_names():
"""Test PocEnv client name generation and validation."""
# Test auto-generated client names (delegated to prepare_poc_provision)
diff --git a/tests/unit_test/recipe/sim_env_test.py b/tests/unit_test/recipe/sim_env_test.py
index d1758b3819..276a0cfcb0 100644
--- a/tests/unit_test/recipe/sim_env_test.py
+++ b/tests/unit_test/recipe/sim_env_test.py
@@ -34,3 +34,9 @@ def test_sim_env_validation():
# Test with empty clients list and zero num_clients (invalid)
with pytest.raises(ValueError, match="Either 'num_clients' must be > 0 or 'clients' list must be provided"):
SimEnv(num_clients=0, clients=[])
+
+ # BUG-3 regression: when clients list is provided and num_clients=0,
+ # env should derive client/thread counts from the list.
+ env = SimEnv(num_clients=0, clients=["client1", "client2", "client3"])
+ assert env.num_clients == 3
+ assert env.num_threads == 3
diff --git a/tests/unit_test/recipe/spec_test.py b/tests/unit_test/recipe/spec_test.py
index 4904b9d246..13dbb4c15b 100644
--- a/tests/unit_test/recipe/spec_test.py
+++ b/tests/unit_test/recipe/spec_test.py
@@ -23,11 +23,11 @@
from nvflare.job_config.api import FedApp, FedJob
from nvflare.job_config.fed_app_config import ClientAppConfig
-from nvflare.recipe.utils import _collect_non_local_scripts
+from nvflare.recipe.utils import collect_non_local_scripts
class TestCollectNonLocalScriptsUtility:
- """Test the _collect_non_local_scripts utility function."""
+ """Test the collect_non_local_scripts utility function."""
def setup_method(self):
self.job = FedJob(name="test_job", min_clients=1)
@@ -36,7 +36,7 @@ def setup_method(self):
def test_no_scripts_returns_empty_list(self):
"""Test with no scripts added."""
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
assert result == []
def test_local_script_not_included(self):
@@ -47,7 +47,7 @@ def test_local_script_not_included(self):
try:
self.client_app.add_external_script(temp_path)
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
assert result == []
finally:
os.unlink(temp_path)
@@ -57,7 +57,7 @@ def test_non_local_absolute_path_included(self):
non_local_script = "/preinstalled/remote_script.py"
self.client_app.add_external_script(non_local_script)
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
assert non_local_script in result
def test_multiple_non_local_scripts(self):
@@ -70,7 +70,7 @@ def test_multiple_non_local_scripts(self):
for script in scripts:
self.client_app.add_external_script(script)
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
for script in scripts:
assert script in result
@@ -86,7 +86,7 @@ def test_mixed_local_and_non_local_scripts(self):
non_local_script = "/preinstalled/remote_script.py"
self.client_app.add_external_script(non_local_script)
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
assert non_local_script in result
assert local_script not in result
finally:
@@ -102,7 +102,7 @@ def test_multiple_apps(self):
self.client_app.add_external_script(script1)
client_app2.add_external_script(script2)
- result = _collect_non_local_scripts(self.job)
+ result = collect_non_local_scripts(self.job)
assert script1 in result
assert script2 in result
@@ -171,7 +171,6 @@ def temp_script(self):
def test_add_server_config(self, temp_script):
"""Test add_server_config adds params to server app."""
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -179,8 +178,7 @@ def test_add_server_config(self, temp_script):
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY, # NUMPY can load from ckpt without model
+ model={"class_path": "model.DummyModel", "args": {}},
)
config = {"np_download_chunk_size": 2097152}
@@ -193,7 +191,6 @@ def test_add_server_config(self, temp_script):
def test_add_client_config(self, temp_script):
"""Test add_client_config applies to all clients and specific clients."""
from nvflare.apis.job_def import ALL_SITES
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
# Test all clients
@@ -202,8 +199,7 @@ def test_add_client_config(self, temp_script):
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
)
config = {"timeout": 600}
recipe.add_client_config(config)
@@ -215,7 +211,6 @@ def test_add_client_config(self, temp_script):
def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script):
"""Test add_client_file stores file paths in ext_scripts and dirs in ext_dirs."""
from nvflare.apis.job_def import ALL_SITES
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -223,8 +218,7 @@ def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script):
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
)
with tempfile.TemporaryDirectory() as temp_dir:
@@ -239,7 +233,6 @@ def test_add_client_file_adds_to_ext_scripts_and_ext_dirs(self, temp_script):
def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp_script):
"""Test add_client_file keeps per-site topology and does not create ALL_SITES app."""
from nvflare.apis.job_def import ALL_SITES
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -247,8 +240,7 @@ def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
per_site_config={"site-1": {}, "site-2": {}},
)
@@ -264,7 +256,6 @@ def test_add_client_file_preserves_per_site_clients_without_all_sites(self, temp
def test_add_client_file_with_specific_clients_only_updates_selected_sites(self, temp_script):
"""Test add_client_file(..., clients=[...]) only adds file to specified sites."""
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -272,8 +263,7 @@ def test_add_client_file_with_specific_clients_only_updates_selected_sites(self,
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
per_site_config={"site-1": {}, "site-2": {}, "site-3": {}},
)
@@ -298,7 +288,6 @@ def test_add_client_file_with_specific_clients_only_updates_selected_sites(self,
def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_script):
"""Test add_server_file stores file paths in ext_scripts and dirs in ext_dirs."""
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -306,8 +295,7 @@ def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_scri
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
)
with tempfile.TemporaryDirectory() as temp_dir:
@@ -321,7 +309,6 @@ def test_add_server_file_adds_to_server_ext_scripts_and_ext_dirs(self, temp_scri
def test_config_in_generated_json(self, temp_script):
"""Test that configs appear in generated JSON files."""
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -329,8 +316,7 @@ def test_config_in_generated_json(self, temp_script):
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
)
recipe.add_server_config({"server_param": 123})
@@ -352,7 +338,6 @@ def test_config_in_generated_json(self, temp_script):
def test_config_type_error(self, temp_script):
"""Test TypeError is raised for non-dict arguments."""
- from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe
recipe = FedAvgRecipe(
@@ -360,8 +345,7 @@ def test_config_type_error(self, temp_script):
num_rounds=2,
min_clients=2,
train_script=temp_script,
- initial_ckpt="/abs/path/to/model.npy",
- framework=FrameworkType.NUMPY,
+ model={"class_path": "model.DummyModel", "args": {}},
)
with pytest.raises(TypeError, match="config must be a dict"):
@@ -369,3 +353,85 @@ def test_config_type_error(self, temp_script):
with pytest.raises(TypeError, match="config must be a dict"):
recipe.add_client_config(123) # type: ignore[arg-type]
+
+
+class _DummyExecEnv:
+ def __init__(self):
+ self.extra = {}
+
+ def get_extra_prop(self, prop_name, default=None):
+ return self.extra.get(prop_name, default)
+
+ def deploy(self, job):
+ return "dummy-job-id"
+
+ def get_job_status(self, job_id):
+ return None
+
+ def abort_job(self, job_id):
+ return None
+
+ def get_job_result(self, job_id, timeout: float = 0.0):
+ return None
+
+
+class TestRecipeExecuteExportParamIsolation:
+ """Test that execute/export do not permanently mutate recipe additional_params."""
+
+ @pytest.fixture
+ def temp_script(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
+ f.write("# Test training script\n")
+ temp_path = f.name
+ yield temp_path
+ os.unlink(temp_path)
+
+ def test_execute_server_params_do_not_accumulate(self, temp_script):
+ from nvflare.recipe.fedavg import FedAvgRecipe
+
+ recipe = FedAvgRecipe(
+ name="test_execute_param_isolation",
+ num_rounds=2,
+ min_clients=2,
+ train_script=temp_script,
+ model={"class_path": "model.DummyModel", "args": {}},
+ )
+
+ env = _DummyExecEnv()
+ server_app = recipe.job._deploy_map.get("server")
+ assert server_app is not None
+ assert server_app.app_config.additional_params == {}
+ recipe.execute(env, server_exec_params={"param_a": 1})
+ assert server_app.app_config.additional_params == {}
+
+ recipe.execute(env, server_exec_params={"param_b": 2})
+ assert server_app.app_config.additional_params == {}
+
+ def test_execute_then_export_no_cross_contamination(self, temp_script):
+ from nvflare.recipe.fedavg import FedAvgRecipe
+
+ recipe = FedAvgRecipe(
+ name="test_execute_export_isolation",
+ num_rounds=2,
+ min_clients=2,
+ train_script=temp_script,
+ model={"class_path": "model.DummyModel", "args": {}},
+ )
+
+ env = _DummyExecEnv()
+ recipe.execute(env, server_exec_params={"from_execute": 1})
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ recipe.export(job_dir=tmpdir, server_exec_params={"from_export": 2})
+ server_cfg_path = os.path.join(
+ tmpdir, "test_execute_export_isolation", "app", "config", "config_fed_server.json"
+ )
+ with open(server_cfg_path) as f:
+ server_cfg = json.load(f)
+
+ assert "from_execute" not in server_cfg
+ assert server_cfg.get("from_export") == 2
+
+ server_app = recipe.job._deploy_map.get("server")
+ assert server_app is not None
+ assert server_app.app_config.additional_params == {}
diff --git a/tests/unit_test/recipe/utils_test.py b/tests/unit_test/recipe/utils_test.py
index e1b69a9f66..19f525de15 100644
--- a/tests/unit_test/recipe/utils_test.py
+++ b/tests/unit_test/recipe/utils_test.py
@@ -18,7 +18,13 @@
import pytest
-from nvflare.recipe.utils import prepare_initial_ckpt, validate_ckpt
+from nvflare.recipe.utils import (
+ extract_persistor_id,
+ prepare_initial_ckpt,
+ resolve_initial_ckpt,
+ setup_custom_persistor,
+ validate_ckpt,
+)
@pytest.fixture
@@ -149,3 +155,116 @@ def test_multiple_calls_different_files(self, temp_workdir):
assert result2 == "ckpt2.pt"
assert job.add_file_to_server.call_count == 2
+
+
+class TestPersistorUtils:
+ """Tests for persistor utility helpers."""
+
+ def test_extract_persistor_id(self):
+ assert extract_persistor_id({"persistor_id": "persistor_a"}) == "persistor_a"
+ assert extract_persistor_id({"persistor_id": 123}) == ""
+ assert extract_persistor_id("persistor_b") == "persistor_b"
+ assert extract_persistor_id(None) == ""
+
+ def test_setup_custom_persistor_returns_empty_when_not_provided(self):
+ job = MagicMock()
+
+ result = setup_custom_persistor(job=job, model_persistor=None)
+
+ assert result == ""
+ job.to_server.assert_not_called()
+
+ def test_setup_custom_persistor_registers_component(self):
+ job = MagicMock()
+ custom_persistor = object()
+ job.to_server.return_value = "custom_persistor"
+
+ result = setup_custom_persistor(job=job, model_persistor=custom_persistor)
+
+ assert result == "custom_persistor"
+ job.to_server.assert_called_once_with(custom_persistor, id="persistor")
+
+ def test_setup_custom_persistor_extracts_dict_result(self):
+ job = MagicMock()
+ custom_persistor = object()
+ job.to_server.return_value = {"persistor_id": "custom_from_dict"}
+
+ result = setup_custom_persistor(job=job, model_persistor=custom_persistor)
+
+ assert result == "custom_from_dict"
+
+ def test_resolve_initial_ckpt_prefers_prepared_value(self):
+ job = MagicMock()
+
+ result = resolve_initial_ckpt(
+ initial_ckpt="relative/path/model.pt",
+ prepared_initial_ckpt="already_prepared.pt",
+ job=job,
+ )
+
+ assert result == "already_prepared.pt"
+ job.add_file_to_server.assert_not_called()
+
+ def test_resolve_initial_ckpt_uses_prepare_when_prepared_missing(self, monkeypatch):
+ calls = {}
+
+ def fake_prepare(initial_ckpt, job):
+ calls["initial_ckpt"] = initial_ckpt
+ calls["job"] = job
+ return "prepared_by_helper.pt"
+
+ monkeypatch.setattr("nvflare.recipe.utils.prepare_initial_ckpt", fake_prepare)
+ job = MagicMock()
+
+ result = resolve_initial_ckpt(
+ initial_ckpt="relative/path/model.pt",
+ prepared_initial_ckpt=None,
+ job=job,
+ )
+
+ assert result == "prepared_by_helper.pt"
+ assert calls["initial_ckpt"] == "relative/path/model.pt"
+ assert calls["job"] is job
+
+
+class TestRecipePackageExports:
+ """Tests for public API exports from nvflare.recipe."""
+
+ def test_add_cross_site_evaluation_importable_from_recipe(self):
+ """add_cross_site_evaluation must be importable from the top-level nvflare.recipe package."""
+ from nvflare.recipe import add_cross_site_evaluation
+
+ assert callable(add_cross_site_evaluation)
+
+
+class TestCrossSiteEvalIdempotency:
+ """Tests for resilient idempotency in add_cross_site_evaluation."""
+
+ def test_idempotency_survives_missing_flag(self):
+ from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe
+ from nvflare.recipe.utils import add_cross_site_evaluation
+
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
+ f.write("# dummy train script\n")
+ train_script = f.name
+
+ try:
+ recipe = NumpyFedAvgRecipe(
+ name="test_cse_idempotency",
+ model=[1.0, 2.0],
+ min_clients=2,
+ num_rounds=2,
+ train_script=train_script,
+ )
+
+ add_cross_site_evaluation(recipe)
+ assert getattr(recipe, "_cse_added", False) is True
+
+ # Simulate transient attribute loss (e.g. serialization boundary).
+ del recipe._cse_added
+ assert not hasattr(recipe, "_cse_added")
+
+ with pytest.raises(RuntimeError, match="already been added"):
+ add_cross_site_evaluation(recipe)
+ finally:
+ os.unlink(train_script)