Skip to content

Commit f12305f

Browse files
authored
Merge pull request #12 from VectorInstitute/dbe/fenda-implementation-other-fixes
FENDA model implementation and other fixes
2 parents 2f60e5b + 0982d96 commit f12305f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+911
-386
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,6 @@ settings.json
139139
.DS_Store
140140

141141
#datasets
142-
**/cifar_data/**
143-
**/news_classification/partitioned_datasets/**
142+
**/datasets/cifar_data/**
143+
**/datasets/news_classification/partitioned_datasets/**
144+
**/datasets/mnist_data/**

.pre-commit-config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ repos:
3131
- id: isort
3232

3333
- repo: https://github.com/pre-commit/mirrors-mypy
34-
rev: v0.942
34+
rev: v0.991
3535
hooks:
3636
- id: mypy
37+
name: mypy
38+
entry: ./run_mypy.sh
39+
language: system
3740

3841
- repo: https://github.com/nbQA-dev/nbQA
3942
rev: 1.3.1

examples/basic_example/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,8 @@ clients. This is done by simply running (remembering to activate your environmen
3232
```
3333
python -m examples.basic_example.client --dataset_path /path/to/data
3434
```
35+
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
36+
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
37+
automatically downloaded to the path specified and used in the run.
38+
3539
After both clients have been started federated learning should commence.

examples/basic_example/client.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
from collections import OrderedDict
32
from logging import INFO
43
from pathlib import Path
54
from typing import Dict, Tuple
@@ -14,6 +13,8 @@
1413
from torchvision.datasets import CIFAR10
1514

1615
from examples.models.cnn_model import Net
16+
from fl4health.clients.numpy_fl_client import NumpyFlClient
17+
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
1718

1819

1920
def load_data(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader, Dict[str, int]]:
@@ -96,35 +97,19 @@ def validate(
9697
return loss / n_batches, accuracy
9798

9899

99-
class CifarClient(fl.client.NumPyClient):
100-
def __init__(
101-
self,
102-
data_path: Path,
103-
device: torch.device,
104-
) -> None:
105-
self.data_path = data_path
106-
self.device = device
107-
self.initialized = False
108-
109-
def get_parameters(self, config: Config) -> NDArrays:
110-
# Determines which weights are sent back to the server for aggregation.
111-
# Currently sending all of them ordered by state_dict keys
112-
# NOTE: Order matters, because it is relied upon by set_parameters below
113-
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
114-
115-
def set_parameters(self, parameters: NDArrays, config: Config) -> None:
116-
# Sets the local model parameters transfered from the server. The state_dict is
117-
# reconstituted because parameters is simply a list of bytes
118-
params_dict = zip(self.model.state_dict().keys(), parameters)
119-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
120-
self.model.load_state_dict(state_dict, strict=True)
100+
class CifarClient(NumpyFlClient):
101+
def __init__(self, data_path: Path, device: torch.device) -> None:
102+
super().__init__(data_path, device)
103+
self.model = Net().to(self.device)
104+
self.parameter_exchanger = FullParameterExchanger()
121105

122106
def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
123107
if not self.initialized:
124108
self.setup_client(config)
125-
126109
self.set_parameters(parameters, config)
127-
accuracy = train(self.model, self.train_loader, epochs=config["local_epochs"], device=self.device)
110+
111+
local_epochs = self.narrow_config_type(config, "local_epochs", int)
112+
accuracy = train(self.model, self.train_loader, epochs=local_epochs, device=self.device)
128113
# FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
129114
# calculation results.
130115
return (
@@ -145,16 +130,15 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di
145130
)
146131

147132
def setup_client(self, config: Config) -> None:
133+
super().setup_client(config)
134+
batch_size = self.narrow_config_type(config, "batch_size", int)
148135

149-
train_loader, validation_loader, num_examples = load_data(self.data_path, config["batch_size"])
136+
train_loader, validation_loader, num_examples = load_data(self.data_path, batch_size)
150137

151138
self.train_loader = train_loader
152139
self.validation_loader = validation_loader
153140
self.num_examples = num_examples
154141

155-
model = Net().to(self.device)
156-
self.model = model
157-
158142

159143
if __name__ == "__main__":
160144
parser = argparse.ArgumentParser(description="FL Client Main")

examples/basic_example/server.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,10 @@
88
from flwr.server.strategy import FedAvg
99

1010
from examples.models.cnn_model import Net
11+
from examples.simple_metric_aggregation import metric_aggregation, normalize_metrics
1112
from fl4health.utils.config import load_config
1213

1314

14-
def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Tuple[int, Metrics]:
15-
aggregated_metrics: Metrics = {}
16-
total_examples = 0
17-
# Run through all of the metrics
18-
for num_examples_on_client, client_metrics in all_client_metrics:
19-
total_examples += num_examples_on_client
20-
for metric_name, metric_value in client_metrics.items():
21-
# Here we assume each metric is normalized by the number of examples on the client. So we scale up to
22-
# get the "raw" value
23-
if metric_name in aggregated_metrics:
24-
aggregated_metrics[metric_name] += num_examples_on_client * metric_value
25-
else:
26-
aggregated_metrics[metric_name] = num_examples_on_client * metric_value
27-
return total_examples, aggregated_metrics
28-
29-
30-
def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics:
31-
# Normalize all metric values by the total count of examples seen.
32-
return {metric_name: metric_value / total_examples for metric_name, metric_value in aggregated_metrics.items()}
33-
34-
3515
def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics:
3616
# This function is run by the server to aggregate metrics returned by each clients fit function
3717
# NOTE: The first value of the tuple is number of examples for FedAvg
@@ -56,9 +36,9 @@ def get_initial_model_parameters() -> Parameters:
5636
def fit_config(
5737
local_epochs: int,
5838
batch_size: int,
59-
n_server_rounds: int,
39+
current_round: int,
6040
) -> Config:
61-
return {"local_epochs": local_epochs, "batch_size": batch_size, "n_server_rounds": n_server_rounds}
41+
return {"local_epochs": local_epochs, "batch_size": batch_size, "current_round": current_round}
6242

6343

6444
def main(config: Dict[str, Any]) -> None:

examples/docker_basic_example/fl_client/client.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
from collections import OrderedDict
32
from logging import INFO
43
from pathlib import Path
54
from typing import Dict, Tuple
@@ -14,6 +13,8 @@
1413
from torchvision.datasets import CIFAR10
1514

1615
from examples.docker_basic_example.model import Net
16+
from fl4health.clients.numpy_fl_client import NumpyFlClient
17+
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
1718

1819

1920
def load_data(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader, Dict[str, int]]:
@@ -96,36 +97,19 @@ def validate(
9697
return loss / n_batches, accuracy
9798

9899

99-
class CifarClient(fl.client.NumPyClient):
100-
def __init__(
101-
self,
102-
data_path: Path,
103-
device: torch.device,
104-
) -> None:
105-
106-
self.data_path = data_path
107-
self.device = device
108-
self.initialized = False
109-
110-
def get_parameters(self, config: Config) -> NDArrays:
111-
# Determines which weights are sent back to the server for aggregation.
112-
# Currently sending all of them ordered by state_dict keys
113-
# NOTE: Order matters, because it is relied upon by set_parameters below
114-
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
115-
116-
def set_parameters(self, parameters: NDArrays, config: Config) -> None:
117-
# Sets the local model parameters transfered from the server. The state_dict is
118-
# reconstituted because parameters is simply a list of bytes
119-
params_dict = zip(self.model.state_dict().keys(), parameters)
120-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
121-
self.model.load_state_dict(state_dict, strict=True)
100+
class CifarClient(NumpyFlClient):
101+
def __init__(self, data_path: Path, device: torch.device) -> None:
102+
super().__init__(data_path, device)
103+
self.model = Net()
104+
self.parameter_exchanger = FullParameterExchanger()
122105

123106
def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
124107
if not self.initialized:
125108
self.setup_client(config)
126-
127109
self.set_parameters(parameters, config)
128-
accuracy = train(self.model, self.train_loader, epochs=config["local_epochs"], device=self.device)
110+
111+
local_epochs = self.narrow_config_type(config, "local_epochs", int)
112+
accuracy = train(self.model, self.train_loader, epochs=local_epochs, device=self.device)
129113
# FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
130114
# calculation results.
131115
return (
@@ -146,16 +130,14 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di
146130
)
147131

148132
def setup_client(self, config: Config) -> None:
149-
150-
train_loader, validation_loader, num_examples = load_data(self.data_path, config["batch_size"])
133+
super().setup_client(config)
134+
batch_size = self.narrow_config_type(config, "batch_size", int)
135+
train_loader, validation_loader, num_examples = load_data(self.data_path, batch_size)
151136

152137
self.train_loader = train_loader
153138
self.validation_loader = validation_loader
154139
self.num_examples = num_examples
155140

156-
model = Net()
157-
self.model = model
158-
159141

160142
if __name__ == "__main__":
161143
parser = argparse.ArgumentParser(description="FL Client Main")

examples/docker_basic_example/fl_server/server.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,10 @@
88
from flwr.server.strategy import FedAvg
99

1010
from examples.docker_basic_example.model import Net
11+
from examples.simple_metric_aggregation import metric_aggregation, normalize_metrics
1112
from fl4health.utils.config import load_config
1213

1314

14-
def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Tuple[int, Metrics]:
15-
aggregated_metrics: Metrics = {}
16-
total_examples = 0
17-
# Run through all of the metrics
18-
for num_examples_on_client, client_metrics in all_client_metrics:
19-
total_examples += num_examples_on_client
20-
for metric_name, metric_value in client_metrics.items():
21-
# Here we assume each metric is normalized by the number of examples on the client. So we scale up to
22-
# get the "raw" value
23-
if metric_name in aggregated_metrics:
24-
aggregated_metrics[metric_name] += num_examples_on_client * metric_value
25-
else:
26-
aggregated_metrics[metric_name] = num_examples_on_client * metric_value
27-
return total_examples, aggregated_metrics
28-
29-
30-
def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics:
31-
# Normalize all metric values by the total count of examples seen.
32-
return {metric_name: metric_value / total_examples for metric_name, metric_value in aggregated_metrics.items()}
33-
34-
3515
def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics:
3616
# This function is run by the server to aggregate metrics returned by each clients fit function
3717
# NOTE: The first value of the tuple is number of examples for FedAvg
@@ -49,9 +29,9 @@ def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]
4929
def fit_config(
5030
local_epochs: int,
5131
batch_size: int,
52-
n_server_rounds: int,
32+
current_round: int,
5333
) -> Config:
54-
return {"local_epochs": local_epochs, "batch_size": batch_size, "n_server_rounds": n_server_rounds}
34+
return {"local_epochs": local_epochs, "batch_size": batch_size, "current_round": current_round}
5535

5636

5737
def get_initial_model_parameters() -> Parameters:

examples/docker_basic_example/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class Net(nn.Module):
77
def __init__(self) -> None:
8-
super(Net, self).__init__()
8+
super().__init__()
99
self.conv1 = nn.Conv2d(3, 6, 5)
1010
self.pool = nn.MaxPool2d(2, 2)
1111
self.conv2 = nn.Conv2d(6, 16, 5)

examples/dp_fed_examples/client_level_dp/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,8 @@ Once the server has started and logged "FL starting," the next step, in separate
2323
```
2424
python -m examples.dp_fed_examples.client_level_dp.client --dataset_path /path/to/data
2525
```
26+
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
27+
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
28+
automatically downloaded to the path specified and used in the run.
29+
2630
After the minimum number of clients have been started federated learning should commence.

examples/dp_fed_examples/client_level_dp/client.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import argparse
2-
from collections import OrderedDict
32
from logging import INFO
43
from pathlib import Path
54
from typing import Dict, Tuple
65

76
import flwr as fl
8-
import numpy as np
97
import torch
108
import torch.nn as nn
119
import torchvision.transforms as transforms
@@ -16,6 +14,7 @@
1614

1715
from examples.models.cnn_model import Net
1816
from fl4health.clients.clipping_client import NumpyClippingClient
17+
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
1918

2019

2120
def load_data(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader, Dict[str, int]]:
@@ -98,57 +97,25 @@ def validate(
9897

9998

10099
class CifarClient(NumpyClippingClient):
101-
def __init__(
102-
self,
103-
data_path: Path,
104-
device: torch.device,
105-
) -> None:
106-
super().__init__()
107-
self.device = device
108-
self.data_path = data_path
109-
self.initialized = False
110-
self.train_loader: DataLoader
111-
112-
def get_parameters(self, config: Config) -> NDArrays:
113-
# Determines which weights are sent back to the server for aggregation.
114-
# Currently sending all of them ordered by state_dict keys
115-
# NOTE: Order matters, because it is relied upon by set_parameters below
116-
model_weights = [val.cpu().numpy() for _, val in self.model.state_dict().items()]
117-
# Clipped the weights and store clipping information in parameters
118-
clipped_weight_update, clipping_bit = self.compute_weight_update_and_clip(model_weights)
119-
return clipped_weight_update + [np.array([clipping_bit])]
120-
121-
def set_parameters(self, parameters: NDArrays, config: Config) -> None:
122-
# Sets the local model parameters transfered from the server. The state_dict is
123-
# reconstituted because parameters is simply a list of bytes
124-
# The last entry in the parameters list is assumed to be a clipping bound (even if we're evaluating)
125-
server_model_parameters = parameters[:-1]
126-
params_dict = zip(self.model.state_dict().keys(), server_model_parameters)
127-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
128-
self.model.load_state_dict(state_dict, strict=True)
129-
130-
# Store the starting parameters without clipping bound before client optimization steps
131-
self.current_weights = server_model_parameters
132-
133-
# Expectation is that the last entry in the parameters NDArrays is a clipping bound
134-
clipping_bound = parameters[-1]
135-
self.clipping_bound = float(clipping_bound)
100+
def __init__(self, data_path: Path, device: torch.device) -> None:
101+
super().__init__(data_path, device)
102+
self.model = Net().to(self.device)
103+
self.parameter_exchanger = FullParameterExchanger()
136104

137105
def setup_client(self, config: Config) -> None:
138-
self.batch_size = config["batch_size"]
139-
self.local_epochs = config["local_epochs"]
140-
self.adaptive_clipping = config["adaptive_clipping"]
106+
super().setup_client(config)
107+
self.batch_size = self.narrow_config_type(config, "batch_size", int)
108+
self.local_epochs = self.narrow_config_type(config, "local_epochs", int)
109+
# Server explicitly sets the clipping strategy
110+
self.adaptive_clipping = self.narrow_config_type(config, "adaptive_clipping", bool)
141111

142112
train_loader, validation_loader, num_examples = load_data(self.data_path, self.batch_size)
143113

144114
self.train_loader = train_loader
145115
self.validation_loader = validation_loader
146116
self.num_examples = num_examples
147-
self.model = Net().to(self.device)
148-
self.initialized = True
149117

150118
def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
151-
152119
self.set_parameters(parameters, config)
153120
accuracy = train(
154121
self.model,

0 commit comments

Comments
 (0)