-
Notifications
You must be signed in to change notification settings - Fork 242
Cherry-pick [2.7] Fix recipe API bug list and harden recipe behavior (#4228) #4293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,21 +162,37 @@ def __init__( | |
|
|
||
| def _setup_model_and_persistor(self, job) -> str: | ||
| """Override to handle PyTorch-specific model setup.""" | ||
| if self.model is not None or self.initial_ckpt is not None: | ||
| from nvflare.app_opt.pt.job_config.model import PTModel | ||
| from nvflare.recipe.utils import prepare_initial_ckpt | ||
|
|
||
| # Disable numpy conversion when using tensor format to keep PyTorch tensors | ||
| allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH | ||
|
|
||
| ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) | ||
| pt_model = PTModel( | ||
| model=self.model, | ||
| initial_ckpt=ckpt_path, | ||
| persistor=self.model_persistor, | ||
| locator=self._pt_model_locator, | ||
| allow_numpy_conversion=allow_numpy_conversion, | ||
| ) | ||
| job.comp_ids.update(job.to_server(pt_model)) | ||
| return job.comp_ids.get("persistor_id", "") | ||
| return "" | ||
| from nvflare.app_opt.pt.job_config.model import PTModel | ||
| from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor | ||
|
|
||
| persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor) | ||
| if persistor_id: | ||
| if hasattr(job, "comp_ids"): | ||
| job.comp_ids["persistor_id"] = persistor_id | ||
| if self._pt_model_locator is not None: | ||
| locator_id = job.to_server(self._pt_model_locator, id="locator") | ||
| if isinstance(locator_id, str) and locator_id: | ||
| job.comp_ids["locator_id"] = locator_id | ||
|
Comment on lines
+172
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Applying if self._pt_model_locator is not None:
raw = job.to_server(self._pt_model_locator, id="locator")
locator_id = raw if isinstance(raw, str) and raw else (
raw.get("locator_id", "") if isinstance(raw, dict) else ""
)
if locator_id:
job.comp_ids["locator_id"] = locator_id |
||
| return persistor_id | ||
|
|
||
| ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) | ||
| if self.model is None and ckpt_path: | ||
| raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.") | ||
| if self.model is None: | ||
| return "" | ||
|
|
||
| # Disable numpy conversion when using tensor format to keep PyTorch tensors. | ||
| allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH | ||
| pt_model = PTModel( | ||
| model=self.model, | ||
| initial_ckpt=ckpt_path, | ||
| locator=self._pt_model_locator, | ||
| allow_numpy_conversion=allow_numpy_conversion, | ||
| ) | ||
| result = job.to_server(pt_model, id="persistor") | ||
| if isinstance(result, dict) and hasattr(job, "comp_ids"): | ||
| job.comp_ids.update(result) | ||
| persistor_id = extract_persistor_id(result) | ||
| if persistor_id and hasattr(job, "comp_ids"): | ||
| job.comp_ids.setdefault("persistor_id", persistor_id) | ||
| return persistor_id | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,6 +141,7 @@ def __init__( | |
| if isinstance(self.model, dict): | ||
| self.model = recipe_model_to_job_model(self.model) | ||
|
|
||
| self.min_clients = v.min_clients | ||
| self.num_rounds = v.num_rounds | ||
| self.train_script = v.train_script | ||
| self.train_args = v.train_args | ||
|
|
@@ -162,13 +163,16 @@ def __init__( | |
| # Setup model persistor first - subclasses override for framework-specific handling | ||
| persistor_id = self._setup_model_and_persistor(job) | ||
|
|
||
| # Use returned persistor_id or default to "persistor" | ||
| if not persistor_id: | ||
| persistor_id = "persistor" | ||
| raise ValueError( | ||
| "Unable to configure a model persistor for CyclicRecipe. " | ||
| "Provide a supported model/framework combination (PyTorch or TensorFlow), " | ||
| "or pass a framework-specific model wrapper with add_to_fed_job()." | ||
| ) | ||
|
|
||
| # Define the controller workflow and send to server | ||
| controller = CyclicController( | ||
| num_rounds=num_rounds, | ||
| num_rounds=self.num_rounds, | ||
| task_assignment_timeout=10, | ||
| persistor_id=persistor_id, | ||
| shareable_generator_id="shareable_generator", | ||
|
|
@@ -194,25 +198,39 @@ def __init__( | |
| super().__init__(job) | ||
|
|
||
| def _setup_model_and_persistor(self, job) -> str: | ||
| """Setup framework-specific model components and persistor. | ||
| """Setup model wrapper persistor. | ||
|
|
||
| Handles PTModel/TFModel wrappers passed by framework-specific subclasses. | ||
| Handles the following model inputs: | ||
| - framework-specific wrappers (objects with ``add_to_fed_job``) | ||
|
|
||
| Returns: | ||
| str: The persistor_id to be used by the controller. | ||
| """ | ||
| if self.model is None: | ||
| return "" | ||
| from nvflare.recipe.utils import extract_persistor_id, prepare_initial_ckpt | ||
|
|
||
| # Check if model is a model wrapper (PTModel, TFModel) | ||
| ckpt_path = prepare_initial_ckpt(self.initial_ckpt, job) | ||
|
|
||
| # Model wrapper path (PTModel/TFModel or custom wrapper) | ||
| if hasattr(self.model, "add_to_fed_job"): | ||
| # It's a model wrapper - use its add_to_fed_job method | ||
| if ckpt_path: | ||
| if not hasattr(self.model, "initial_ckpt"): | ||
| raise ValueError( | ||
| f"initial_ckpt is provided, but model wrapper {type(self.model).__name__} " | ||
| "does not support 'initial_ckpt'." | ||
| ) | ||
| existing_ckpt = getattr(self.model, "initial_ckpt", None) | ||
| if existing_ckpt and existing_ckpt != ckpt_path: | ||
| raise ValueError( | ||
| f"Conflicting checkpoint values: model wrapper has initial_ckpt={existing_ckpt}, " | ||
| f"but recipe initial_ckpt={ckpt_path}." | ||
| ) | ||
| setattr(self.model, "initial_ckpt", ckpt_path) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In-place mutation of a caller-provided model wrapper
Consider either copying the wrapper before mutating it, or documenting clearly that the recipe takes ownership of the wrapper and callers must not reuse the same instance: import copy
model_wrapper = copy.copy(self.model)
setattr(model_wrapper, "initial_ckpt", ckpt_path)
result = job.to_server(model_wrapper, id="persistor")
return extract_persistor_id(result)A shallow copy is sufficient because only the |
||
|
|
||
| result = job.to_server(self.model, id="persistor") | ||
| return result["persistor_id"] | ||
| return extract_persistor_id(result) | ||
|
|
||
| # Unknown model type | ||
| raise TypeError( | ||
| f"Unsupported model type: {type(self.model).__name__}. " | ||
| f"Use a framework-specific recipe (PTCyclicRecipe, TFCyclicRecipe, etc.) " | ||
| f"or wrap your model in PTModel/TFModel." | ||
| raise ValueError( | ||
| f"Unsupported framework '{self.framework}' for base CyclicRecipe model persistence. " | ||
| "Use a framework-specific CyclicRecipe subclass, or pass a framework-specific " | ||
| "model wrapper with add_to_fed_job()." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.