Skip to content

Commit 04eed03

Browse files
authored
Merge pull request #16 from VectorInstitute/apfl
Apfl
2 parents 6e30be2 + 48ab246 commit 04eed03

File tree

18 files changed

+831
-66
lines changed

18 files changed

+831
-66
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
- id: black
2727

2828
- repo: https://github.com/PyCQA/isort
29-
rev: 5.5.2
29+
rev: 5.11.5
3030
hooks:
3131
- id: isort
3232

examples/apfl_example/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# APFL Federated Learning Example
2+
This is an example of [Adaptive Personalized Federated Learning](https://arxiv.org/pdf/2003.13461.pdf) (APFL). APFL is a popular method for achieving personalization in the federated learning setting which is important in real world application where client distributions are heterogenous. APFL extends FedAVG by having each client train a local model distinct from the global model that is jointly learned across clients. A personalized model is generated as a convex combination of the local and global models with a per client mixing parameter that is learned. The local model is updated to minmize the loss of the personalized model on the local data.
3+
4+
In this demo, APFL is applied to an augmented version of the MNIST dataset that is non--IID. The FL server expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties of the clients training/validation data. In theory, the models should be able to perform well on their local data while learning from other clients data that has different statistical properties. The proportion of labels at each client is determined by dirichlet distribtuion across the classes. The lower the beta parameter is for each class, the higher the degree of the label heterogeneity.
5+
6+
The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification.
7+
8+
# Running the Example
9+
In order to run the example, first ensure you have the virtual env of your choice activated and run
10+
```
11+
pip install --upgrade pip
12+
pip install -r requirements.txt
13+
```
14+
to install all of the dependencies for this project.
15+
16+
## Starting Server
17+
18+
The next step is to start the server by running
19+
```
20+
python -m examples.apfl_example.server --config_path /path/to/config.yaml
21+
```
22+
from the FL4Health directory. The following arguments must be present in the specified config file:
23+
* `n_clients`: number of clients the server waits for in order to run the FL training
24+
* `local_epochs`: number of epochs each client will train for locally
25+
* `batch_size`: size of the batches each client will train on
26+
* `n_server_rounds`: The number of rounds to run FL
27+
28+
## Starting Clients
29+
30+
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the two
31+
clients. This is done by simply running (remembering to activate your environment)
32+
```
33+
python -m examples.apfl_example.client --dataset_path /path/to/data
34+
```
35+
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
36+
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
37+
automatically downloaded to the path specified and used in the run.
38+
39+
After all three clients have been started, federated learning should commence.

examples/apfl_example/client.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
from pathlib import Path
3+
from typing import List
4+
5+
import flwr as fl
6+
import torch
7+
from flwr.common.typing import Config
8+
9+
from examples.models.cnn_model import MnistNet
10+
from fl4health.clients.apfl_client import ApflClient
11+
from fl4health.model_bases.apfl_base import APFLModule
12+
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
13+
from fl4health.utils.load_data import load_mnist_data
14+
from fl4health.utils.metrics import Accuracy, Metric
15+
from fl4health.utils.sampler import DirichletLabelBasedSampler
16+
17+
18+
class MnistApflClient(ApflClient):
19+
def __init__(
20+
self,
21+
data_path: Path,
22+
metrics: List[Metric],
23+
device: torch.device,
24+
) -> None:
25+
super().__init__(data_path=data_path, metrics=metrics, device=device)
26+
27+
def setup_client(self, config: Config) -> None:
28+
batch_size = self.narrow_config_type(config, "batch_size", int)
29+
self.model: APFLModule = APFLModule(MnistNet()).to(self.device)
30+
self.criterion = torch.nn.CrossEntropyLoss()
31+
self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
32+
self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
33+
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
34+
35+
self.train_loader, self.val_loader, self.num_examples = load_mnist_data(self.data_path, batch_size, sampler)
36+
self.parameter_exchanger = FixedLayerExchanger(self.model.layers_to_exchange())
37+
38+
super().setup_client(config)
39+
40+
41+
if __name__ == "__main__":
42+
parser = argparse.ArgumentParser(description="FL Client Main")
43+
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
44+
45+
args = parser.parse_args()
46+
47+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48+
data_path = Path(args.dataset_path)
49+
50+
client = MnistApflClient(data_path, [Accuracy()], DEVICE)
51+
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client)

examples/apfl_example/config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Parameters that describe server
2+
n_server_rounds: 25 # The number of rounds to run FL
3+
4+
# Parameters that describe clients
5+
n_clients: 3 # The number of clients in the FL experiment
6+
local_epochs: 1 # The number of epochs to complete for client
7+
batch_size: 128 # The batch size for client training

examples/apfl_example/server.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import argparse
2+
from functools import partial
3+
from typing import Any, Dict, List, Tuple
4+
5+
import flwr as fl
6+
from flwr.common.parameter import ndarrays_to_parameters
7+
from flwr.common.typing import Config, Metrics, Parameters
8+
from flwr.server.strategy import FedAvg
9+
10+
from examples.models.cnn_model import MnistNet
11+
from examples.simple_metric_aggregation import metric_aggregation, normalize_metrics
12+
from fl4health.model_bases.apfl_base import APFLModule
13+
from fl4health.utils.config import load_config
14+
15+
16+
def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics:
17+
# This function is run by the server to aggregate metrics returned by each clients fit function
18+
# NOTE: The first value of the tuple is number of examples for FedAvg
19+
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
20+
return normalize_metrics(total_examples, aggregated_metrics)
21+
22+
23+
def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics:
24+
# This function is run by the server to aggregate metrics returned by each clients evaluate function
25+
# NOTE: The first value of the tuple is number of examples for FedAvg
26+
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
27+
return normalize_metrics(total_examples, aggregated_metrics)
28+
29+
30+
def get_initial_model_parameters() -> Parameters:
31+
# Initializing the model parameters on the server side.
32+
# Currently uses the Pytorch default initialization for the model parameters.
33+
initial_model = APFLModule(MnistNet())
34+
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])
35+
36+
37+
def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config:
38+
return {
39+
"local_epochs": local_epochs,
40+
"batch_size": batch_size,
41+
"n_server_rounds": n_server_rounds,
42+
}
43+
44+
45+
def main(config: Dict[str, Any]) -> None:
46+
# This function will be used to produce a config that is sent to each client to initialize their own environment
47+
fit_config_fn = partial(
48+
fit_config,
49+
config["local_epochs"],
50+
config["batch_size"],
51+
config["n_server_rounds"],
52+
)
53+
54+
# Server performs simple FedAveraging as its server-side optimization strategy
55+
strategy = FedAvg(
56+
min_fit_clients=config["n_clients"],
57+
min_evaluate_clients=config["n_clients"],
58+
# Server waits for min_available_clients before starting FL rounds
59+
min_available_clients=config["n_clients"],
60+
on_fit_config_fn=fit_config_fn,
61+
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
62+
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
63+
initial_parameters=get_initial_model_parameters(),
64+
)
65+
66+
fl.server.start_server(
67+
server_address="0.0.0.0:8080",
68+
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
69+
strategy=strategy,
70+
)
71+
72+
73+
if __name__ == "__main__":
74+
parser = argparse.ArgumentParser(description="FL Server Main")
75+
parser.add_argument(
76+
"--config_path",
77+
action="store",
78+
type=str,
79+
help="Path to configuration file.",
80+
default="config.yaml",
81+
)
82+
args = parser.parse_args()
83+
84+
config = load_config(args.config_path)
85+
86+
main(config)

examples/fenda_example/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# FENDA Federated Learning Example
22
This example provides an example of training a FENDA type model on a non-IID subset of the MNIST data. The FL server
33
expects two clients to be spun up (i.e. it will wait until two clients report in before starting training). Each client
4-
has a modified version of the MNIST dataset. This modification essentially subsamples certain number from the original
4+
has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original
55
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties
66
of the clients training/validation data. In theory, the models should be able to perform well on their local data
77
while learning from other clients data that has different statistical properties. The subsampling is specified by
88
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`.
99

10-
The server has some custom metrics aggregation and using Federated Averaging as its server-side optimization. The
11-
implementation uses a special type of weight exchange based on named-layer identification.
10+
The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification.
1211

1312
# Running the Example
1413
In order to run the example, first ensure you have the virtual env of your choice activated and run

examples/fenda_example/client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from flwr.common.typing import Config, NDArrays, Scalar
1111
from torch.utils.data import DataLoader
1212

13-
from examples.fenda_example.client_data import load_data
1413
from examples.models.fenda_cnn import FendaClassifier, GlobalCnn, LocalCnn
1514
from fl4health.clients.numpy_fl_client import NumpyFlClient
1615
from fl4health.model_bases.fenda_base import FendaJoinMode, FendaModel
1716
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
17+
from fl4health.utils.load_data import load_mnist_data
18+
from fl4health.utils.sampler import MinorityLabelBasedSampler
1819

1920

2021
def train(
@@ -120,9 +121,9 @@ def setup_client(self, config: Config) -> None:
120121
batch_size = self.narrow_config_type(config, "batch_size", int)
121122
downsample_percentage = self.narrow_config_type(config, "downsampling_ratio", float)
122123

123-
train_loader, validation_loader, num_examples = load_data(
124-
self.data_path, batch_size, downsample_percentage, self.minority_numbers
125-
)
124+
sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers)
125+
126+
train_loader, validation_loader, num_examples = load_mnist_data(self.data_path, batch_size, sampler)
126127

127128
self.train_loader = train_loader
128129
self.validation_loader = validation_loader

examples/fenda_example/client_data.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

examples/models/cnn_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
2121
x = F.relu(self.fc2(x))
2222
x = self.fc3(x)
2323
return x
24+
25+
26+
class MnistNet(nn.Module):
27+
def __init__(self) -> None:
28+
super().__init__()
29+
self.conv1 = nn.Conv2d(1, 8, 5)
30+
self.pool = nn.MaxPool2d(2, 2)
31+
self.conv2 = nn.Conv2d(8, 16, 5)
32+
self.fc1 = nn.Linear(16 * 4 * 4, 120)
33+
self.fc2 = nn.Linear(120, 10)
34+
35+
def forward(self, x: torch.Tensor) -> torch.Tensor:
36+
x = self.pool(F.relu(self.conv1(x)))
37+
x = self.pool(F.relu(self.conv2(x)))
38+
x = x.view(-1, 16 * 4 * 4)
39+
x = F.relu(self.fc1(x))
40+
x = F.relu(self.fc2(x))
41+
return x

0 commit comments

Comments
 (0)