Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions examples/advanced/edge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Remember to enable allow self signed certs from the device SDK side.

After the startup of NVFlare system, we can start the job with:
```
python jobs/et_job.py --total_num_of_devices 4 --num_of_simulated_devices_on_each_leaf 1
python jobs/et_cifar10/job.py --total_num_of_devices 4 --num_of_simulated_devices_on_each_leaf 1
```

This is going to run 1 simulated device on each leaf client, so a total of 4 devices.
Expand All @@ -115,7 +115,7 @@ This is going to run 1 simulated device on each leaf client, so a total of 4 dev
We could submit the job to NVFlare system with:

```
python jobs/et_job.py --total_num_of_devices 5 --num_of_simulated_devices_on_each_leaf 1
python jobs/et_cifar10/job.py --total_num_of_devices 5 --num_of_simulated_devices_on_each_leaf 1
```

Note that we are using 1 simulated_devices_on_each_leaf and we have 4 leaf clients, so total of 4 simulated devices but we set total_num_of_devices to be 5, the additional 1 is the real device.
Expand Down Expand Up @@ -180,10 +180,10 @@ This will generate two job configurations and submit them to run on the FL syste
- async assumes server updates the global model and immediately dispatch it after receiving ONE device's update.

```commandline
python3 jobs/pt_job.py --fl_mode sync
python3 jobs/pt_job.py --fl_mode async
python3 jobs/pt_job.py --fl_mode sync --no_delay
python3 jobs/pt_job.py --fl_mode async --no_delay
python3 jobs/pt_cifar10_sync/job.py
python3 jobs/pt_cifar10_async/job.py
python3 jobs/pt_cifar10_sync/job.py --no_delay
python3 jobs/pt_cifar10_async/job.py --no_delay
```

You will then see the simulated devices start receiving the model from the server and complete local trainings.
Expand Down Expand Up @@ -264,7 +264,7 @@ These settings will simulate a realistic cross-device federated learning scenari
In admin console, submit the job:

```commandline
python3 jobs/pt_job_adv.py
python3 jobs/pt_cifar10_adv/job.py
```

Upon finishing, we can visualize the results in TensorBoard as before:
Expand All @@ -275,3 +275,21 @@ You will see the following results:
<img src="./figs/cifar10_adv_acc.png" alt="Cifar10 Advanced Results" width="800" >

As shown, due to the large number of devices and the limited number of samples for each device, the training process can be much slower than the previous experiments, and the accuracy converges to a lower level.

### General Hierarchical FL System
Note that the above FL system is not limited to cross-edge FL, but can also be used for general hierarchical FL scenarios, where the hierarchy can be deployed to any types of servers.

To illustrate this, we can use the same system to run a cross-silo FL example with 4 silos, each with 1 client running language model training, which is much heavier than the CIFAR10 training:

We first prepare the data, same as our [HF_SFT example](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/llm_hf/README.md):
```commandline
mkdir /tmp/nvflare/dataset
git clone https://huggingface.co/datasets/databricks/databricks-dolly-15k /tmp/nvflare/dataset
mkdir /tmp/nvflare/dataset/dolly
python ./utils/preprocess_dolly.py --training_file /tmp/nvflare/dataset/databricks-dolly-15k.jsonl --output_dir /tmp/nvflare/dataset/dolly
```

Assuming we have the same provisioning and setting steps following the same procedure as above, then we can submit the job to run cross-silo FL with 4 silos, each with 1 client, and each client will use a subset of 3000 samples for local training.
```commandline
python3 jobs/pt_hf_sft/job.py --subset_size 3000 --no_delay
```
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import argparse
import os

from client import Cifar10ETTaskProcessor
from model import TrainingNet

from nvflare.edge.tools.et_fed_buff_recipe import (
DeviceManagerConfig,
ETFedBuffRecipe,
Expand All @@ -26,7 +29,6 @@

parser = argparse.ArgumentParser()
parser.add_argument("--export_job", action="store_true")
parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--workspace_dir", type=str, default="/tmp/nvflare/workspaces")
parser.add_argument("--project_name", type=str, default="edge_example")
parser.add_argument("--total_num_of_devices", type=int, default=4)
Expand All @@ -38,47 +40,25 @@
total_num_of_devices = args.total_num_of_devices
num_of_simulated_devices_on_each_leaf = args.num_of_simulated_devices_on_each_leaf

if args.dataset == "cifar10":
from processors.cifar10_et_task_processor import Cifar10ETTaskProcessor
from processors.models.cifar10_model import TrainingNet

dataset_root = "/tmp/nvflare/cifar10"
job_name = "cifar10_et"
device_model = TrainingNet()
batch_size = 4
input_shape = (batch_size, 3, 32, 32)
output_shape = (batch_size,)
task_processor = Cifar10ETTaskProcessor(
data_path=dataset_root,
training_config={
"batch_size": batch_size,
"shuffle": True,
"num_workers": 0,
},
subset_size=100,
)
evaluator_config = EvaluatorConfig(
torchvision_dataset={"name": "CIFAR10", "path": dataset_root},
eval_frequency=1,
)
elif args.dataset == "xor":
from processors.models.xor_model import TrainingNet
from processors.xor_et_task_processor import XorETTaskProcessor

job_name = "xor_et"
device_model = TrainingNet()
batch_size = 1
input_shape = (batch_size, 2)
output_shape = (batch_size,)
task_processor = XorETTaskProcessor(
training_config={
"batch_size": batch_size,
"shuffle": True,
"num_workers": 0,
},
)
evaluator_config = None

dataset_root = "/tmp/nvflare/cifar10"
job_name = "cifar10_et"
device_model = TrainingNet()
batch_size = 4
input_shape = (batch_size, 3, 32, 32)
output_shape = (batch_size,)
task_processor = Cifar10ETTaskProcessor(
data_path=dataset_root,
training_config={
"batch_size": batch_size,
"shuffle": True,
"num_workers": 0,
},
subset_size=100,
)
evaluator_config = EvaluatorConfig(
torchvision_dataset={"name": "CIFAR10", "path": dataset_root},
eval_frequency=1,
)

recipe = ETFedBuffRecipe(
job_name=job_name,
Expand Down
93 changes: 93 additions & 0 deletions examples/advanced/edge/jobs/et_xor/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from client import XorETTaskProcessor
from model import TrainingNet

from nvflare.edge.tools.et_fed_buff_recipe import (
DeviceManagerConfig,
ETFedBuffRecipe,
ModelManagerConfig,
SimulationConfig,
)
from nvflare.recipe.prod_env import ProdEnv

parser = argparse.ArgumentParser()
parser.add_argument("--export_job", action="store_true")
parser.add_argument("--workspace_dir", type=str, default="/tmp/nvflare/workspaces")
parser.add_argument("--project_name", type=str, default="edge_example")
parser.add_argument("--total_num_of_devices", type=int, default=4)
parser.add_argument("--num_of_simulated_devices_on_each_leaf", type=int, default=1)
args = parser.parse_args()

prod_dir = os.path.join(args.workspace_dir, args.project_name, "prod_00")
admin_startup_kit_dir = os.path.join(prod_dir, "[email protected]")
total_num_of_devices = args.total_num_of_devices
num_of_simulated_devices_on_each_leaf = args.num_of_simulated_devices_on_each_leaf

job_name = "xor_et"
device_model = TrainingNet()
batch_size = 1
input_shape = (batch_size, 2)
output_shape = (batch_size,)
task_processor = XorETTaskProcessor(
training_config={
"batch_size": batch_size,
"shuffle": True,
"num_workers": 0,
},
)
evaluator_config = None

recipe = ETFedBuffRecipe(
job_name=job_name,
device_model=device_model,
input_shape=input_shape,
output_shape=output_shape,
model_manager_config=ModelManagerConfig(
# max_num_active_model_versions=1,
max_model_version=3,
update_timeout=1000,
num_updates_for_model=total_num_of_devices,
# max_model_history=1,
),
device_manager_config=DeviceManagerConfig(
device_selection_size=total_num_of_devices,
min_hole_to_fill=total_num_of_devices,
),
evaluator_config=evaluator_config,
simulation_config=(
SimulationConfig(
task_processor=task_processor,
num_devices=num_of_simulated_devices_on_each_leaf,
)
if num_of_simulated_devices_on_each_leaf > 0
else None
),
device_training_params={"epoch": 3, "lr": 0.0001, "batch_size": batch_size},
)
if args.export_job:
output_dir = os.path.join(admin_startup_kit_dir, "transfer")
print(f"Exporting recipe to {output_dir}")
recipe.export(job_dir=output_dir)
else:
env = ProdEnv(startup_kit_location=admin_startup_kit_dir, username="[email protected]")
run = recipe.execute(env)
print()
print("Result can be found in :", run.get_result())
print("Job Status is:", run.get_status())
print()
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import filelock
import torch
from model import Cifar10ConvNet
from torch.utils.data import Subset
from torchvision import datasets, transforms

Expand All @@ -26,8 +27,6 @@
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.task_response import TaskResponse

from .models.cifar10_model import Cifar10ConvNet

log = logging.getLogger(__name__)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -59,24 +58,21 @@ def __init__(
self.local_epochs = local_epochs
self.local_lr = local_lr
self.local_momentum = local_momentum
# Training
self.train_loader = None
self.net = None
self.optimizer = None
self.criterion = None

def setup(self, job: JobResponse) -> None:
pass

def shutdown(self) -> None:
pass

def _pytorch_training(self, global_model):
# Data loading code
device_id = self.device.device_id if self.device else "unknown"
self.logger.info(f"Device {device_id}: setup...")
transform = transforms.Compose([transforms.ToTensor()])
batch_size = self.local_batch_size

# CIFAR10 dataset
# Add file lock to prevent multiple simultaneous downloads
lock_file = os.path.join(self.data_root, "cifar10.lock")
with filelock.FileLock(lock_file):
train_set = datasets.CIFAR10(root=self.data_root, train=True, download=True, transform=transform)

# Generate seed according to device_id
# Randomly select a subset of the training set
# generate a random indices list
indices = list(range(len(train_set)))
Expand All @@ -85,35 +81,48 @@ def _pytorch_training(self, global_model):
indices = indices[: self.subset_size]
# create a new train_set from the selected indices
train_subset = Subset(train_set, indices)
# create a dataloader for the train_subset
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
# create dataloader for the train_subset
self.train_loader = torch.utils.data.DataLoader(
train_subset, batch_size=self.local_batch_size, shuffle=True, num_workers=2
)

# Network loading
net = Cifar10ConvNet()
net.load_state_dict(global_model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=self.local_lr, momentum=self.local_momentum)
net.to(DEVICE)
# Training-related components
self.net = Cifar10ConvNet()
self.net.to(DEVICE)
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.net.parameters(), lr=self.local_lr, momentum=self.local_momentum)

def shutdown(self) -> None:
if self.train_loader:
del self.train_loader
if self.net:
del self.net
if self.optimizer:
del self.optimizer
if self.criterion:
del self.criterion

def _pytorch_training(self, global_model):
# Load global model params
self.net.load_state_dict(global_model)

# Training loop
# Let's do 4 local epochs
for epoch in range(self.local_epochs):
for i, data in enumerate(train_loader, 0):
for i, data in enumerate(self.train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

# zero the parameter gradients
optimizer.zero_grad()

self.optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
outputs = self.net(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
optimizer.step()
self.optimizer.step()

# Calculate the model param diff
diff_dict = {}
for key, param in net.state_dict().items():
for key, param in self.net.state_dict().items():
numpy_param = param.cpu().numpy() - global_model[key].numpy()
# Convert numpy array to list for serialization
diff_dict[key] = numpy_param.tolist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from processors.cifar10_pt_task_processor import Cifar10PTTaskProcessor
from processors.models.cifar10_model import Cifar10ConvNet
from client import Cifar10PTTaskProcessor
from model import Cifar10ConvNet

from nvflare.edge.tools.edge_fed_buff_recipe import (
DeviceManagerConfig,
Expand Down
Loading
Loading