Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/hello-world/hello-cyclic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

recipe = CyclicRecipe(
num_rounds=num_rounds,
min_clients=n_clients,
# Model can be specified as class instance or dict config:
model=Net(),
# Alternative: model={"class_path": "model.Net", "args": {}},
Expand Down
4 changes: 2 additions & 2 deletions examples/hello-world/hello-numpy/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"""
import argparse

from nvflare.apis.dxo import DataKind
from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe
from nvflare.client.config import TransferType
from nvflare.recipe import SimEnv, add_experiment_tracking


Expand Down Expand Up @@ -57,7 +57,7 @@ def main():
train_script="client.py",
train_args=train_args,
launch_external_process=launch_process,
aggregator_data_kind=DataKind.WEIGHTS if args.update_type == "full" else DataKind.WEIGHT_DIFF,
params_transfer_type=TransferType.FULL if args.update_type == "full" else TransferType.DIFF,
)
add_experiment_tracking(recipe, tracking_type="tensorboard")
if args.export_config:
Expand Down
62 changes: 4 additions & 58 deletions examples/tutorials/job_recipe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,29 +141,7 @@
"cell_type": "markdown",
"id": "b94e258c",
"metadata": {},
"source": [
"## Pre-trained Checkpoint Path\n",
"\n",
"Use `initial_ckpt` to specify a path to pre-trained model weights:\n",
"\n",
"```python\n",
"recipe = FedAvgRecipe(\n",
" model=SimpleNetwork(),\n",
" initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path\n",
" ...\n",
")\n",
"```\n",
"\n",
"<div class=\"alert alert-block alert-warning\">\n",
"<b>Checkpoint Path Requirements:</b>\n",
"<ul>\n",
"<li><b>Absolute path required:</b> The path must be an absolute path (e.g., `/data/models/model.pt`), not relative.</li>\n",
"<li><b>May not exist locally:</b> The checkpoint file does <b>not</b> need to exist on the machine where you create the recipe. It only needs to exist on the <b>server</b> when the model is actually loaded during job execution.</li>\n",
"<li><b>PyTorch requires model architecture:</b> For PyTorch, you must provide `model` (class instance or dict config) along with `initial_ckpt`, because PyTorch checkpoints contain only weights, not architecture.</li>\n",
"<li><b>TensorFlow/Keras can use checkpoint alone:</b> Keras `.h5` or SavedModel formats contain both architecture and weights, so `initial_ckpt` can be used without `model`.</li>\n",
"</ul>\n",
"</div>"
]
"source": "## Pre-trained Checkpoint Path\n\nUse `initial_ckpt` to specify a path to pre-trained model weights:\n\n```python\nrecipe = FedAvgRecipe(\n model=SimpleNetwork(),\n initial_ckpt=\"/data/models/pretrained_model.pt\", # Absolute path (server-side)\n ...\n)\n```\n\n<div class=\"alert alert-block alert-warning\">\n<b>Checkpoint Path Options:</b>\n<ul>\n<li><b>Relative path:</b> The file is bundled into the job automatically and deployed to the server. The file must exist locally when the recipe is created.</li>\n<li><b>Absolute path:</b> Treated as a server-side path and used as-is at runtime. The file does <b>not</b> need to exist locally — it only needs to exist on the <b>server</b> when the job runs.</li>\n<li><b>PyTorch requires model architecture:</b> For PyTorch, you must provide `model` (class instance or dict config) along with `initial_ckpt`, because PyTorch checkpoints contain only weights, not architecture.</li>\n<li><b>TensorFlow/Keras can use checkpoint alone:</b> Keras `.h5` or SavedModel formats contain both architecture and weights, so `initial_ckpt` can be used without `model`.</li>\n</ul>\n</div>"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -256,25 +234,7 @@
"cell_type": "markdown",
"id": "pocenv-section",
"metadata": {},
"source": [
"### PocEnv – Proof-of-Concept Environment\n",
"\n",
"Runs server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n",
"\n",
"Best suited for:\n",
"* Demonstrations\n",
"* Small-scale validation before production deployment\n",
"* Debugging orchestration logic\n",
"\n",
"**Arguments:**\n",
"* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n",
"* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n",
"* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n",
"* `auto_stop` (bool, optional): Whether to automatically stop POC services after job completion.\n",
"* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n",
"* `docker_image` (str, optional): Docker image to use for POC.\n",
"* `project_conf_path` (str, optional): Path to the project configuration file."
]
"source": "### PocEnv – Proof-of-Concept Environment\n\nRuns server and clients as **separate processes** on the same machine. This simulates real-world deployment within a single node, with server and clients are running in different processes. More realistic than `SimEnv`, but still lightweight enough for a single node.\n\nBest suited for:\n* Demonstrations\n* Small-scale validation before production deployment\n* Debugging orchestration logic\n\n**Arguments:**\n* `num_clients` (int, optional): Number of clients to use in POC mode. Defaults to 2.\n* `clients` (List[str], optional): List of client names. If None, will generate site-1, site-2, etc.\n* `gpu_ids` (List[int], optional): List of GPU IDs to assign to clients. If None, uses CPU only.\n* `use_he` (bool, optional): Whether to use HE. Defaults to False.\n* `docker_image` (str, optional): Docker image to use for POC.\n* `project_conf_path` (str, optional): Path to the project configuration file."
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -331,21 +291,7 @@
"cell_type": "markdown",
"id": "prodenv-section",
"metadata": {},
"source": [
"### ProdEnv – Production Environment\n",
"\n",
"We assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n",
"\n",
"Best suited for:\n",
"* Enterprise federated learning deployments\n",
"* Multi-institution collaborations\n",
"* Production-scale workloads\n",
"\n",
"**Arguments:**\n",
"* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n",
"* `login_timeout` (float): timeout value for the admin to login to the system\n",
"* `monitor_job_duration` (int): duration to monitor the job execution, None means no monitoring at all"
]
"source": "### ProdEnv – Production Environment\n\nWe assume the system of a server and clients is up and running across **multiple machines and sites**. Uses secure communication channels and real-world NVFLARE deployment infrastructure. The ProdEnv will utilize the admin's startup package to communicate with an existing NVFlare system to execute and monitor that job execution.\n\nBest suited for:\n* Enterprise federated learning deployments\n* Multi-institution collaborations\n* Production-scale workloads\n\n**Arguments:**\n* `startup_kit_location` (str): the directory that contains the startup kit of the admin (generated by nvflare provisioning)\n* `login_timeout` (float): timeout value for the admin to login to the system"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -522,4 +468,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
17 changes: 12 additions & 5 deletions nvflare/app_opt/pt/recipes/cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe
from nvflare.recipe.utils import extract_persistor_id


class CyclicRecipe(BaseCyclicRecipe):
Expand Down Expand Up @@ -103,11 +104,17 @@ def _setup_model_and_persistor(self, job) -> str:
# If model is already a PTModel wrapper (user passed PTModel directly), use as-is
if hasattr(self.model, "add_to_fed_job"):
result = job.to_server(self.model, id="persistor")
return result["persistor_id"]
return extract_persistor_id(result)

from nvflare.recipe.utils import prepare_initial_ckpt
from nvflare.recipe.utils import resolve_initial_ckpt

ckpt_path = prepare_initial_ckpt(self._pt_initial_ckpt, job)
pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path)
ckpt_path = resolve_initial_ckpt(self._pt_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
if self.model is None and ckpt_path:
raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.")
if self.model is None:
return ""

allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH
pt_model = PTModel(model=self.model, initial_ckpt=ckpt_path, allow_numpy_conversion=allow_numpy_conversion)
result = job.to_server(pt_model, id="persistor")
return result["persistor_id"]
return extract_persistor_id(result)
52 changes: 34 additions & 18 deletions nvflare/app_opt/pt/recipes/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,37 @@ def __init__(

def _setup_model_and_persistor(self, job) -> str:
"""Override to handle PyTorch-specific model setup."""
if self.model is not None or self.initial_ckpt is not None:
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.recipe.utils import prepare_initial_ckpt

# Disable numpy conversion when using tensor format to keep PyTorch tensors
allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH

ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
pt_model = PTModel(
model=self.model,
initial_ckpt=ckpt_path,
persistor=self.model_persistor,
locator=self._pt_model_locator,
allow_numpy_conversion=allow_numpy_conversion,
)
job.comp_ids.update(job.to_server(pt_model))
return job.comp_ids.get("persistor_id", "")
return ""
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor

persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor)
if persistor_id:
if hasattr(job, "comp_ids"):
job.comp_ids["persistor_id"] = persistor_id
if self._pt_model_locator is not None:
locator_id = job.to_server(self._pt_model_locator, id="locator")
if isinstance(locator_id, str) and locator_id:
job.comp_ids["locator_id"] = locator_id
Comment on lines +172 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locator_id silently dropped when job.to_server returns a dict

job.to_server(self._pt_model_locator, id="locator") can return a dict when the locator implements add_to_fed_job (e.g., a future custom PTFileModelLocator subclass or any locator wrapper). The guard isinstance(locator_id, str) and locator_id would then be False, so comp_ids["locator_id"] is never populated.

add_cross_site_evaluation later reads recipe.job.comp_ids.get("locator_id", "") (via locator_config), and a missing key leads to the locator not being wired up for cross-site evaluation without any error being raised — a silent failure.

Applying extract_persistor_id (or an equivalent extract_component_id helper) here would make the handling consistent with how persistor_id is resolved elsewhere in the same method:

if self._pt_model_locator is not None:
    raw = job.to_server(self._pt_model_locator, id="locator")
    locator_id = raw if isinstance(raw, str) and raw else (
        raw.get("locator_id", "") if isinstance(raw, dict) else ""
    )
    if locator_id:
        job.comp_ids["locator_id"] = locator_id

return persistor_id

ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
if self.model is None and ckpt_path:
raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.")
if self.model is None:
return ""

# Disable numpy conversion when using tensor format to keep PyTorch tensors.
allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH
pt_model = PTModel(
model=self.model,
initial_ckpt=ckpt_path,
locator=self._pt_model_locator,
allow_numpy_conversion=allow_numpy_conversion,
)
result = job.to_server(pt_model, id="persistor")
if isinstance(result, dict) and hasattr(job, "comp_ids"):
job.comp_ids.update(result)
persistor_id = extract_persistor_id(result)
if persistor_id and hasattr(job, "comp_ids"):
job.comp_ids.setdefault("persistor_id", persistor_id)
return persistor_id
9 changes: 5 additions & 4 deletions nvflare/app_opt/tf/recipes/cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.cyclic import CyclicRecipe as BaseCyclicRecipe
from nvflare.recipe.utils import extract_persistor_id


class CyclicRecipe(BaseCyclicRecipe):
Expand Down Expand Up @@ -103,11 +104,11 @@ def _setup_model_and_persistor(self, job) -> str:
# If model is already a TFModel wrapper (user passed TFModel directly), use as-is
if hasattr(self.model, "add_to_fed_job"):
result = job.to_server(self.model, id="persistor")
return result["persistor_id"]
return extract_persistor_id(result)

from nvflare.recipe.utils import prepare_initial_ckpt
from nvflare.recipe.utils import resolve_initial_ckpt

ckpt_path = prepare_initial_ckpt(self._tf_initial_ckpt, job)
ckpt_path = resolve_initial_ckpt(self._tf_initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path)
result = job.to_server(tf_model, id="persistor")
return result["persistor_id"]
return extract_persistor_id(result)
26 changes: 13 additions & 13 deletions nvflare/app_opt/tf/recipes/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ def __init__(

def _setup_model_and_persistor(self, job) -> str:
"""Override to handle TensorFlow-specific model setup."""
if self.model is not None or self.initial_ckpt is not None:
from nvflare.app_opt.tf.job_config.model import TFModel
from nvflare.recipe.utils import prepare_initial_ckpt

ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)
tf_model = TFModel(
model=self.model,
initial_ckpt=ckpt_path,
persistor=self.model_persistor,
)
job.comp_ids["persistor_id"] = job.to_server(tf_model)
return job.comp_ids.get("persistor_id", "")
return ""
from nvflare.app_opt.tf.job_config.model import TFModel
from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor

persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor)
if persistor_id:
return persistor_id

ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job)
if self.model is None and not ckpt_path:
return ""

tf_model = TFModel(model=self.model, initial_ckpt=ckpt_path)
return extract_persistor_id(job.to_server(tf_model, id="persistor"))
4 changes: 2 additions & 2 deletions nvflare/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from .prod_env import ProdEnv
from .run import Run
from .sim_env import SimEnv
from .utils import add_experiment_tracking
from .utils import add_cross_site_evaluation, add_experiment_tracking

__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "FedAvgRecipe"]
__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking", "add_cross_site_evaluation", "FedAvgRecipe"]
48 changes: 33 additions & 15 deletions nvflare/recipe/cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
if isinstance(self.model, dict):
self.model = recipe_model_to_job_model(self.model)

self.min_clients = v.min_clients
self.num_rounds = v.num_rounds
self.train_script = v.train_script
self.train_args = v.train_args
Expand All @@ -162,13 +163,16 @@ def __init__(
# Setup model persistor first - subclasses override for framework-specific handling
persistor_id = self._setup_model_and_persistor(job)

# Use returned persistor_id or default to "persistor"
if not persistor_id:
persistor_id = "persistor"
raise ValueError(
"Unable to configure a model persistor for CyclicRecipe. "
"Provide a supported model/framework combination (PyTorch or TensorFlow), "
"or pass a framework-specific model wrapper with add_to_fed_job()."
)

# Define the controller workflow and send to server
controller = CyclicController(
num_rounds=num_rounds,
num_rounds=self.num_rounds,
task_assignment_timeout=10,
persistor_id=persistor_id,
shareable_generator_id="shareable_generator",
Expand All @@ -194,25 +198,39 @@ def __init__(
super().__init__(job)

def _setup_model_and_persistor(self, job) -> str:
"""Setup framework-specific model components and persistor.
"""Setup model wrapper persistor.

Handles PTModel/TFModel wrappers passed by framework-specific subclasses.
Handles the following model inputs:
- framework-specific wrappers (objects with ``add_to_fed_job``)

Returns:
str: The persistor_id to be used by the controller.
"""
if self.model is None:
return ""
from nvflare.recipe.utils import extract_persistor_id, prepare_initial_ckpt

# Check if model is a model wrapper (PTModel, TFModel)
ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job)

# Model wrapper path (PTModel/TFModel or custom wrapper)
if hasattr(self.model, "add_to_fed_job"):
# It's a model wrapper - use its add_to_fed_job method
if ckpt_path:
if not hasattr(self.model, "initial_ckpt"):
raise ValueError(
f"initial_ckpt is provided, but model wrapper {type(self.model).__name__} "
"does not support 'initial_ckpt'."
)
existing_ckpt = getattr(self.model, "initial_ckpt", None)
if existing_ckpt and existing_ckpt != ckpt_path:
raise ValueError(
f"Conflicting checkpoint values: model wrapper has initial_ckpt={existing_ckpt}, "
f"but recipe initial_ckpt={ckpt_path}."
)
setattr(self.model, "initial_ckpt", ckpt_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-place mutation of a caller-provided model wrapper

setattr(self.model, "initial_ckpt", ckpt_path) modifies the wrapper object that the caller passed in. If the same PTModel / TFModel instance is passed to two different recipes (or to execute + export in succession), the second call will observe the initial_ckpt that was injected by the first, which may not be what the caller intended.

Consider either copying the wrapper before mutating it, or documenting clearly that the recipe takes ownership of the wrapper and callers must not reuse the same instance:

import copy
model_wrapper = copy.copy(self.model)
setattr(model_wrapper, "initial_ckpt", ckpt_path)
result = job.to_server(model_wrapper, id="persistor")
return extract_persistor_id(result)

A shallow copy is sufficient because only the initial_ckpt attribute is changed.


result = job.to_server(self.model, id="persistor")
return result["persistor_id"]
return extract_persistor_id(result)

# Unknown model type
raise TypeError(
f"Unsupported model type: {type(self.model).__name__}. "
f"Use a framework-specific recipe (PTCyclicRecipe, TFCyclicRecipe, etc.) "
f"or wrap your model in PTModel/TFModel."
raise ValueError(
f"Unsupported framework '{self.framework}' for base CyclicRecipe model persistence. "
"Use a framework-specific CyclicRecipe subclass, or pass a framework-specific "
"model wrapper with add_to_fed_job()."
)
Loading
Loading