Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/advanced/edge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,21 @@ You will see the following results:
<img src="./figs/cifar10_adv_acc.png" alt="Cifar10 Advanced Results" width="800" >

As shown, due to the large number of devices and the limited number of samples for each device, the training process can be much slower than the previous experiments, and the accuracy converges to a lower level.

### General Hierarchical FL System
Note that the above FL system is not limited to cross-edge FL, but can also be used for general hierarchical FL scenarios, where the hierarchy can be deployed to any types of servers.

To illustrate this, we can use the same system to run a cross-silo FL example with 4 silos, each with 1 client running language model training, which is much heavier than the CIFAR10 training:

We first prepare the data, same as our [HF_SFT example](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/llm_hf/README.md):
```commandline
mkdir /tmp/nvflare/dataset
git clone https://huggingface.co/datasets/databricks/databricks-dolly-15k /tmp/nvflare/dataset
mkdir /tmp/nvflare/dataset/dolly
python ./utils/preprocess_dolly.py --training_file /tmp/nvflare/dataset/databricks-dolly-15k.jsonl --output_dir /tmp/nvflare/dataset/dolly
```

Assuming we have the same provisioning and setting steps following the same procedure as above, then we can submit the job to run cross-silo FL with 4 silos, each with 1 client, and each client will use a subset of 3000 samples for local training.
```commandline
python3 jobs/hf_sft_job.py --subset_size 3000 --no_delay
```
227 changes: 227 additions & 0 deletions examples/advanced/edge/jobs/hf_sft_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from processors.hf_sft_task_processor import HFSFTTaskProcessor
from processors.models.hf_sft_model import CausalLMModel

from nvflare.edge.tools.edge_fed_buff_recipe import (
DeviceManagerConfig,
EdgeFedBuffRecipe,
ModelManagerConfig,
SimulationConfig,
)
from nvflare.recipe.prod_env import ProdEnv


def create_hf_sft_recipe(
model_name_or_path,
data_path_train,
data_path_valid,
output_path,
devices_per_leaf,
num_leaf_nodes,
global_rounds,
subset_size=None,
local_epochs=1,
local_batch_size=4,
local_lr=5e-4,
lr_scheduler="constant",
no_delay=False,
):
"""
Create an HuggingFace SFT edge recipe for federated learning.
Uses synchronous federated learning mode only.

Args:
model_name_or_path (str): HuggingFace model name or path
data_path_train (str): Path to training data
data_path_valid (str): Path to validation data
output_path (str): Output directory for model checkpoints
devices_per_leaf (int): Number of devices at each leaf node
num_leaf_nodes (int): Number of leaf nodes in the hierarchy
global_rounds (int): Number of global federated rounds
subset_size (int): Size of data subset for each device (None for full dataset)
local_epochs (int): Number of local training epochs per round
batch_size (int): Training batch size
gradient_accumulation_steps (int): Gradient accumulation steps
learning_rate (float): Learning rate for training
lr_scheduler (str): Learning rate scheduler type
no_delay (bool): If True, set communication delay and device speed to 0.0
"""
total_devices = devices_per_leaf * num_leaf_nodes

# Set communication delay and device speed based on no_delay flag
if no_delay:
communication_delay = {"mean": 0.0, "std": 0.0}
device_speed = {"mean": [0.0], "std": [0.0]}
suffix = "_no_delay"
else:
# Adjust delays for longer HF training times
communication_delay = {"mean": 10.0, "std": 2.0}
device_speed = {"mean": [300.0, 600.0, 1200.0], "std": [30.0, 60.0, 120.0]}
suffix = ""

# Create the HF SFT task processor
task_processor = HFSFTTaskProcessor(
model_name_or_path=model_name_or_path,
data_path_train=data_path_train,
data_path_valid=data_path_valid,
output_path=output_path,
communication_delay=communication_delay,
device_speed=device_speed,
subset_size=subset_size,
total_epochs=local_epochs * global_rounds,
local_epochs=local_epochs,
local_batch_size=local_batch_size,
local_lr=local_lr,
lr_scheduler=lr_scheduler,
)

# Configure model manager for synchronous FL
model_manager_config = ModelManagerConfig(
global_lr=1.0, # Use simple averaging for SFT
# Need all devices to train for one global model version
num_updates_for_model=total_devices,
max_model_version=global_rounds,
update_timeout=1800, # Longer timeout for HF training (30 minutes)
)

# Configure device manager for synchronous FL
device_manager_config = DeviceManagerConfig(
# Each leaf node has devices_per_leaf devices
device_selection_size=total_devices,
# Wait for all devices to finish training before starting
# dispatching the next global model version (synchronous)
min_hole_to_fill=total_devices,
# Always reuse the same devices for federated learning
device_reuse=True,
)

# Create the recipe
recipe = EdgeFedBuffRecipe(
job_name=f"hf_sft_job_sync{suffix}",
model=CausalLMModel(model_name_or_path=model_name_or_path),
model_manager_config=model_manager_config,
device_manager_config=device_manager_config,
evaluator_config=None, # No built-in evaluator for HF models
simulation_config=SimulationConfig(
task_processor=task_processor,
job_timeout=7200.0, # 2 hour timeout for HF training (increased)
num_workers=2, # Reduced workers to avoid resource conflicts
# Simulation config is for each leaf node
num_devices=devices_per_leaf,
),
custom_source_root=None,
)

return recipe


def main():
parser = argparse.ArgumentParser(description="Create HuggingFace SFT edge recipe for federated learning")
parser.add_argument(
"--model_name_or_path", type=str, default="facebook/opt-125m", help="HuggingFace model name or path"
)
parser.add_argument(
"--data_path_train", type=str, default="/tmp/nvflare/dataset/dolly/training.jsonl", help="Path to training data"
)
parser.add_argument(
"--data_path_valid",
type=str,
default="/tmp/nvflare/dataset/dolly/validation.jsonl",
help="Path to validation data",
)
parser.add_argument(
"--output_path",
type=str,
default="./workspace_federated/opt-125m-dolly-sft",
help="Output directory for model checkpoints",
)
parser.add_argument(
"--subset_size", type=int, default=None, help="Size of data subset for each device (None for full dataset)"
)
parser.add_argument("--devices_per_leaf", type=int, default=1, help="Number of devices on each leaf node")
parser.add_argument("--num_leaf_nodes", type=int, default=4, help="Number of leaf nodes in the hierarchy")
parser.add_argument("--global_rounds", type=int, default=3, help="Number of global federated rounds")
parser.add_argument("--local_epochs", type=int, default=1, help="Number of local training epochs per round")
parser.add_argument("--batch_size", type=int, default=4, help="Training batch size")
parser.add_argument("--gradient_accumulation_steps", type=int, default=10, help="Gradient accumulation steps")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate for training")
parser.add_argument("--lr_scheduler", type=str, default="constant", help="Learning rate scheduler type")
parser.add_argument("--workspace_dir", type=str, default="/tmp/nvflare/workspaces", help="Workspace directory")
parser.add_argument(
"--no_delay",
action="store_true",
help="If set, disable communication delay and device speed variations (set to 0.0)",
)
parser.add_argument(
"--export_job", action="store_true", help="If set, export the recipe to the admin's transfer directory"
)
parser.add_argument("--project_name", type=str, default="edge_example", help="Project name")

args = parser.parse_args()

prod_dir = os.path.join(args.workspace_dir, args.project_name, "prod_00")
admin_startup_kit_dir = os.path.join(prod_dir, "[email protected]")

try:
print("Creating HuggingFace SFT federated learning recipe...")

# If subset_size is not specified, calculate a reasonable default
# to ensure each device gets a portion of the dataset
if args.subset_size is None:
total_devices = args.devices_per_leaf * args.num_leaf_nodes
print(
f"No subset size specified. Consider setting --subset_size to distribute data across {total_devices} devices"
)
print("Example: For a 15000-sample dataset with 4 devices, use --subset_size 3750")

recipe = create_hf_sft_recipe(
model_name_or_path=args.model_name_or_path,
data_path_train=args.data_path_train,
data_path_valid=args.data_path_valid,
output_path=args.output_path,
devices_per_leaf=args.devices_per_leaf,
num_leaf_nodes=args.num_leaf_nodes,
global_rounds=args.global_rounds,
subset_size=args.subset_size,
local_epochs=args.local_epochs,
local_batch_size=args.batch_size,
local_lr=args.learning_rate,
lr_scheduler=args.lr_scheduler,
no_delay=args.no_delay,
)

except Exception as e:
print(f"Error creating recipe: {e}")
return 1

if args.export_job:
output_dir = os.path.join(admin_startup_kit_dir, "transfer")
print(f"Exporting recipe to {output_dir}")
recipe.export(job_dir=output_dir)
else:
env = ProdEnv(startup_kit_location=admin_startup_kit_dir, username="[email protected]")
run = recipe.execute(env)
print()
print("Result can be found in:", run.get_result())
print("Job Status is:", run.get_status())
print()


if __name__ == "__main__":
exit(main())
64 changes: 37 additions & 27 deletions examples/advanced/edge/jobs/processors/cifar10_pt_task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,21 @@ def __init__(
self.local_epochs = local_epochs
self.local_lr = local_lr
self.local_momentum = local_momentum
# Training
self.train_loader = None
self.net = None
self.optimizer = None
self.criterion = None

def setup(self, job: JobResponse) -> None:
pass

def shutdown(self) -> None:
pass

def _pytorch_training(self, global_model):
# Data loading code
device_id = self.device.device_id if self.device else "unknown"
self.logger.info(f"Device {device_id}: setup...")
transform = transforms.Compose([transforms.ToTensor()])
batch_size = self.local_batch_size

# CIFAR10 dataset
# Add file lock to prevent multiple simultaneous downloads
lock_file = os.path.join(self.data_root, "cifar10.lock")
with filelock.FileLock(lock_file):
train_set = datasets.CIFAR10(root=self.data_root, train=True, download=True, transform=transform)

# Generate seed according to device_id
# Randomly select a subset of the training set
# generate a random indices list
indices = list(range(len(train_set)))
Expand All @@ -85,35 +82,48 @@ def _pytorch_training(self, global_model):
indices = indices[: self.subset_size]
# create a new train_set from the selected indices
train_subset = Subset(train_set, indices)
# create a dataloader for the train_subset
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
# create dataloader for the train_subset
self.train_loader = torch.utils.data.DataLoader(
train_subset, batch_size=self.local_batch_size, shuffle=True, num_workers=2
)

# Network loading
net = Cifar10ConvNet()
net.load_state_dict(global_model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=self.local_lr, momentum=self.local_momentum)
net.to(DEVICE)
# Training-related components
self.net = Cifar10ConvNet()
self.net.to(DEVICE)
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.net.parameters(), lr=self.local_lr, momentum=self.local_momentum)

def shutdown(self) -> None:
if self.train_loader:
del self.train_loader
if self.net:
del self.net
if self.optimizer:
del self.optimizer
if self.criterion:
del self.criterion

def _pytorch_training(self, global_model):
# Load global model params
self.net.load_state_dict(global_model)

# Training loop
# Let's do 4 local epochs
for epoch in range(self.local_epochs):
for i, data in enumerate(train_loader, 0):
for i, data in enumerate(self.train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

# zero the parameter gradients
optimizer.zero_grad()

self.optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
outputs = self.net(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
optimizer.step()
self.optimizer.step()

# Calculate the model param diff
diff_dict = {}
for key, param in net.state_dict().items():
for key, param in self.net.state_dict().items():
numpy_param = param.cpu().numpy() - global_model[key].numpy()
# Convert numpy array to list for serialization
diff_dict[key] = numpy_param.tolist()
Expand Down
Loading
Loading