Skip to content

Commit 70dd368

Browse files
authored
Merge pull request #10 from VectorInstitute/weighted-dp-fedavg
Weighted dp fedavg
2 parents 5a1cbbb + d4201cf commit 70dd368

File tree

19 files changed

+1152
-34
lines changed

19 files changed

+1152
-34
lines changed

examples/datasets/breast_cancer_data/hospital_0.csv

Lines changed: 193 additions & 0 deletions
Large diffs are not rendered by default.

examples/datasets/breast_cancer_data/hospital_1.csv

Lines changed: 165 additions & 0 deletions
Large diffs are not rendered by default.

examples/datasets/breast_cancer_data/hospital_2.csv

Lines changed: 214 additions & 0 deletions
Large diffs are not rendered by default.

examples/dp_fed_examples/client_level_dp/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ n_server_rounds: 20
66
# much more noise can kill server side convergence.
77
server_noise_multiplier: 0.01
88
n_clients: 3
9-
client_sampling: 0.66
9+
client_sampling_rate: 0.66
1010
server_learning_rate: 1.0
1111
server_momentum: 0.2
1212

examples/dp_fed_examples/client_level_dp/server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,16 @@ def main(config: Dict[str, Any]) -> None:
8686
client_manager = PoissonSamplingClientManager()
8787

8888
# Accountant that computes the privacy through training
89-
accountant = FlClientLevelAccountantPoissonSampling(config["client_sampling"], config["server_noise_multiplier"])
89+
accountant = FlClientLevelAccountantPoissonSampling(
90+
config["client_sampling_rate"], config["server_noise_multiplier"]
91+
)
9092
target_delta = 1.0 / config["n_clients"]
9193
epsilon = accountant.get_epsilon(config["n_server_rounds"], target_delta)
9294
log(INFO, f"Model privacy after full training will be ({epsilon}, {target_delta})")
9395

9496
# Server performs simple FedAveraging as it's server-side optimization strategy
9597
strategy = ClientLevelDPFedAvgM(
96-
fraction_fit=config["client_sampling"],
98+
fraction_fit=config["client_sampling_rate"],
9799
# Server waits for min_available_clients before starting FL rounds
98100
min_available_clients=config["n_clients"],
99101
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Client Level Differential Privacy Federated Learning Example
2+
3+
This example shows how to implement Differential Privacy into the Federated Learning framework. In this case we focus on *client level* privacy which is a more substantial version of instance level DP, where the participation of an entire client's set of data is protected from training dataset membership inference. This example uses the FedAvgM implementation with weighted averaging suggested in Learning Differentially Private Recurrent Language Models along with the adaptive clipping scheme proposed in Differentially Private Learning with Adaptive Clipping. The example uses an accountant specifically tailered to this approach. The clients are Poisson sampled by default.
4+
5+
The example involves collaboratively learning a logistic regression model across multiple hospitals to classify breast cancer given 31 features. The dataset is sourced from [kaggle](https://www.kaggle.com/competitions/breast-cancer-classification/overview). A processed federated version of the dataset is available in the repository.
6+
7+
# Running the Example
8+
In order to run the example, first ensure you have the virtual env of your choice activated and run
9+
```
10+
pip install --upgrade pip
11+
pip install -r requirements.txt
12+
```
13+
to install all of the dependencies for this project.
14+
15+
## Starting Server
16+
17+
The next step is to start the server by running
18+
```
19+
python -m examples.dp_fed_examples.client_level_dp_weighted.server --config_path examples/dp_fed_examples/client_level_dp_weighted/config.yaml
20+
```
21+
22+
## Starting Clients
23+
24+
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the clients expected by the server. This is done by simply running (remembering to activate your environment)
25+
```
26+
python -m examples.dp_fed_examples.client_level_dp_weighted.client --dataset_path examples/datasets/breast_cancer_data/hospital_#.csv
27+
```
28+
After the minimum number of clients have been started federated learning should commence.

examples/dp_fed_examples/client_level_dp_weighted/__init__.py

Whitespace-only changes.
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import argparse
2+
from collections import OrderedDict
3+
from logging import INFO
4+
from pathlib import Path
5+
from typing import Dict, Tuple
6+
7+
import flwr as fl
8+
import numpy as np
9+
import torch
10+
import torch.nn as nn
11+
from flwr.common.logger import log
12+
from flwr.common.typing import Config, NDArrays, Scalar
13+
from torch.utils.data import DataLoader
14+
15+
from examples.dp_fed_examples.client_level_dp_weighted.data import load_data
16+
from examples.models.logistic_regression import LogisticRegression
17+
from fl4health.clients.clipping_client import NumpyClippingClient
18+
19+
20+
def train(net: nn.Module, train_loader: DataLoader, epochs: int, device: torch.device = torch.device("cpu")) -> float:
21+
22+
criterion = torch.nn.BCELoss()
23+
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
24+
25+
for epoch in range(epochs):
26+
correct, total, running_loss = 0, 0, 0.0
27+
n_batches = len(train_loader)
28+
for features, labels in train_loader:
29+
features, labels = features.to(device), labels.to(device)
30+
optimizer.zero_grad()
31+
preds = net(features)
32+
loss = criterion(preds, labels)
33+
loss.backward()
34+
optimizer.step()
35+
36+
running_loss += loss.item()
37+
predicted = preds.data >= 0.5
38+
39+
total += labels.size(0)
40+
correct += (predicted.int() == labels.int()).sum().item()
41+
42+
accuracy = correct / total
43+
# Local client logging.
44+
log(
45+
INFO,
46+
f"Epoch: {epoch}, Client Training Loss: {running_loss/n_batches},"
47+
f" Client Training Accuracy: {accuracy}",
48+
)
49+
return accuracy
50+
51+
52+
def validate(
53+
net: nn.Module,
54+
validation_loader: DataLoader,
55+
device: torch.device = torch.device("cpu"),
56+
) -> Tuple[float, float]:
57+
"""Validate the network on the entire validation set."""
58+
criterion = torch.nn.BCELoss()
59+
correct, total, loss = 0, 0, 0.0
60+
with torch.no_grad():
61+
n_batches = len(validation_loader)
62+
for features, labels in validation_loader:
63+
features, labels = features.to(device), labels.to(device)
64+
preds = net(features)
65+
loss += criterion(preds, labels).item()
66+
predicted = preds.data >= 0.5
67+
total += labels.size(0)
68+
correct += (predicted.int() == labels.int()).sum().item()
69+
accuracy = correct / total
70+
# Local client logging.
71+
log(INFO, f"Client Validation Loss: {loss/n_batches} Client Validation Accuracy: {accuracy}")
72+
return loss / n_batches, accuracy
73+
74+
75+
class HospitalClient(NumpyClippingClient):
76+
def __init__(
77+
self,
78+
data_path: Path,
79+
device: torch.device,
80+
) -> None:
81+
super().__init__()
82+
self.device = device
83+
self.data_path = data_path
84+
self.initialized = False
85+
self.train_loader: DataLoader
86+
87+
def get_parameters(self, config: Config) -> NDArrays:
88+
# Determines which weights are sent back to the server for aggregation.
89+
# Currently sending all of them ordered by state_dict keys
90+
# NOTE: Order matters, because it is relied upon by set_parameters below
91+
model_weights = [val.cpu().numpy() for _, val in self.model.state_dict().items()]
92+
# Clipped the weights and store clipping information in parameters
93+
clipped_weight_update, clipping_bit = self.compute_weight_update_and_clip(model_weights)
94+
return clipped_weight_update + [np.array([clipping_bit])]
95+
96+
def set_parameters(self, parameters: NDArrays, config: Config) -> None:
97+
# Sets the local model parameters transfered from the server. The state_dict is
98+
# reconstituted because parameters is simply a list of bytes
99+
# The last entry in the parameters list is assumed to be a clipping bound (even if we're evaluating)
100+
server_model_parameters = parameters[:-1]
101+
params_dict = zip(self.model.state_dict().keys(), server_model_parameters)
102+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
103+
self.model.load_state_dict(state_dict, strict=True)
104+
105+
# Store the starting parameters without clipping bound before client optimization steps
106+
self.current_weights = server_model_parameters
107+
108+
clipping_bound = parameters[-1]
109+
self.clipping_bound = float(clipping_bound)
110+
111+
def setup_client(self, config: Config) -> None:
112+
self.batch_size = config["batch_size"]
113+
self.local_epochs = config["local_epochs"]
114+
self.adaptive_clipping = config["adaptive_clipping"]
115+
self.scaler_bytes = config["scaler"]
116+
117+
train_loader, validation_loader, num_examples = load_data(self.data_path, self.batch_size, self.scaler_bytes)
118+
119+
self.train_loader = train_loader
120+
self.validation_loader = validation_loader
121+
self.num_examples = num_examples
122+
self.model = LogisticRegression(input_dim=31, output_dim=1).to(self.device)
123+
self.initialized = True
124+
125+
def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
126+
# Expectation is that the last entry in the parameters NDArrays is a clipping bound
127+
if not self.initialized:
128+
self.setup_client(config)
129+
self.set_parameters(parameters, config)
130+
accuracy = train(
131+
self.model,
132+
self.train_loader,
133+
self.local_epochs,
134+
self.device,
135+
)
136+
# FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
137+
# calculation results.
138+
return (
139+
self.get_parameters(config),
140+
self.num_examples["train_set"],
141+
{"accuracy": accuracy},
142+
)
143+
144+
def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]:
145+
# Expectation is that the last entry in the parameters NDArrays is a clipping bound (even if it isn't used
146+
# for evaluation)
147+
if not self.initialized:
148+
self.setup_client(config)
149+
self.set_parameters(parameters, config)
150+
loss, accuracy = validate(self.model, self.validation_loader, device=self.device)
151+
# EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics
152+
# calculation results.
153+
return (
154+
loss,
155+
self.num_examples["validation_set"],
156+
{"accuracy": accuracy},
157+
)
158+
159+
160+
if __name__ == "__main__":
161+
parser = argparse.ArgumentParser(description="FL Client Main")
162+
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
163+
args = parser.parse_args()
164+
165+
# Load model and data
166+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
167+
data_path = Path(args.dataset_path)
168+
client = HospitalClient(data_path, DEVICE)
169+
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Server parameters
2+
n_server_rounds: 25
3+
4+
# NOTE: This multiplier is small, yielding a vacuous epsilon for privacy. It is set to this small value for this
5+
# example due to the small number of clients (3, see below), which, when combined with the clipping implies that
6+
# much more noise can kill server side convergence.
7+
server_noise_multiplier: 0.01
8+
n_clients: 3
9+
client_sampling_rate: 0.667
10+
server_learning_rate: 1.0
11+
server_momentum: 1.0
12+
weighted_averaging: True
13+
14+
# Client training parameters
15+
local_epochs: 1
16+
batch_size: 32
17+
total_samples: 569
18+
19+
# Clipping settings for update and optionally
20+
# adaptive clipping
21+
adaptive_clipping: True
22+
clipping_bound: 0.1
23+
clipping_learning_rate: 0.5
24+
# NOTE: The noise multiplier here is just picked for convenience. The recommended heuristic is
25+
# expected clients per round/20
26+
clipping_bit_noise_multiplier: 0.5
27+
clipping_quantile: 0.5
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pickle
2+
from pathlib import Path
3+
from typing import Dict, Tuple
4+
5+
import numpy as np
6+
import pandas as pd
7+
import torch
8+
from sklearn.preprocessing import MinMaxScaler
9+
from torch.utils.data import DataLoader, TensorDataset
10+
11+
12+
class Scaler:
13+
def __init__(self) -> None:
14+
self.scaler = MinMaxScaler()
15+
16+
def __call__(self, train_x: np.ndarray, val_x: np.ndarray) -> np.ndarray:
17+
scaled_train_x = self.scaler.fit_transform(train_x)
18+
scaled_val_x = self.scaler.transform(val_x)
19+
return scaled_train_x, scaled_val_x
20+
21+
22+
def load_data(data_dir: Path, batch_size: int, scaler_bytes: bytes) -> Tuple[DataLoader, DataLoader, Dict[str, int]]:
23+
data = pd.read_csv(data_dir, index_col=False)
24+
features = data.loc[:, data.columns != "label"].values
25+
labels = data["label"].values
26+
n_samples = data.shape[0]
27+
28+
scaler = pickle.loads(scaler_bytes)
29+
train_samples = int(n_samples * 0.8)
30+
train_features, train_labels = features[:train_samples, :], labels[:train_samples]
31+
val_features, val_labels = features[train_samples:, :], labels[train_samples:]
32+
train_features, val_features = scaler(train_features, val_features)
33+
train_X, train_Y = torch.from_numpy(train_features).float(), torch.from_numpy(train_labels).float()
34+
val_X, val_Y = torch.from_numpy(val_features).float(), torch.from_numpy(val_labels).float()
35+
train_ds, val_ds = TensorDataset(train_X, train_Y), TensorDataset(val_X, val_Y)
36+
train_loader = DataLoader(train_ds, batch_size=batch_size, drop_last=True)
37+
val_loader = DataLoader(val_ds, batch_size=batch_size, drop_last=True)
38+
39+
num_examples = {"train_set": train_samples, "validation_set": n_samples - train_samples}
40+
41+
return train_loader, val_loader, num_examples

0 commit comments

Comments
 (0)