|
1 | 1 | import argparse |
2 | | -from collections import OrderedDict |
3 | 2 | from logging import INFO |
4 | 3 | from pathlib import Path |
5 | 4 | from typing import Dict, Tuple |
6 | 5 |
|
7 | 6 | import flwr as fl |
8 | | -import numpy as np |
9 | 7 | import torch |
10 | 8 | import torch.nn as nn |
11 | 9 | import torchvision.transforms as transforms |
|
16 | 14 |
|
17 | 15 | from examples.models.cnn_model import Net |
18 | 16 | from fl4health.clients.clipping_client import NumpyClippingClient |
| 17 | +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger |
19 | 18 |
|
20 | 19 |
|
21 | 20 | def load_data(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: |
@@ -98,57 +97,25 @@ def validate( |
98 | 97 |
|
99 | 98 |
|
100 | 99 | class CifarClient(NumpyClippingClient): |
101 | | - def __init__( |
102 | | - self, |
103 | | - data_path: Path, |
104 | | - device: torch.device, |
105 | | - ) -> None: |
106 | | - super().__init__() |
107 | | - self.device = device |
108 | | - self.data_path = data_path |
109 | | - self.initialized = False |
110 | | - self.train_loader: DataLoader |
111 | | - |
112 | | - def get_parameters(self, config: Config) -> NDArrays: |
113 | | - # Determines which weights are sent back to the server for aggregation. |
114 | | - # Currently sending all of them ordered by state_dict keys |
115 | | - # NOTE: Order matters, because it is relied upon by set_parameters below |
116 | | - model_weights = [val.cpu().numpy() for _, val in self.model.state_dict().items()] |
117 | | - # Clipped the weights and store clipping information in parameters |
118 | | - clipped_weight_update, clipping_bit = self.compute_weight_update_and_clip(model_weights) |
119 | | - return clipped_weight_update + [np.array([clipping_bit])] |
120 | | - |
121 | | - def set_parameters(self, parameters: NDArrays, config: Config) -> None: |
122 | | - # Sets the local model parameters transfered from the server. The state_dict is |
123 | | - # reconstituted because parameters is simply a list of bytes |
124 | | - # The last entry in the parameters list is assumed to be a clipping bound (even if we're evaluating) |
125 | | - server_model_parameters = parameters[:-1] |
126 | | - params_dict = zip(self.model.state_dict().keys(), server_model_parameters) |
127 | | - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) |
128 | | - self.model.load_state_dict(state_dict, strict=True) |
129 | | - |
130 | | - # Store the starting parameters without clipping bound before client optimization steps |
131 | | - self.current_weights = server_model_parameters |
132 | | - |
133 | | - # Expectation is that the last entry in the parameters NDArrays is a clipping bound |
134 | | - clipping_bound = parameters[-1] |
135 | | - self.clipping_bound = float(clipping_bound) |
| 100 | + def __init__(self, data_path: Path, device: torch.device) -> None: |
| 101 | + super().__init__(data_path, device) |
| 102 | + self.model = Net().to(self.device) |
| 103 | + self.parameter_exchanger = FullParameterExchanger() |
136 | 104 |
|
137 | 105 | def setup_client(self, config: Config) -> None: |
138 | | - self.batch_size = config["batch_size"] |
139 | | - self.local_epochs = config["local_epochs"] |
140 | | - self.adaptive_clipping = config["adaptive_clipping"] |
| 106 | + super().setup_client(config) |
| 107 | + self.batch_size = self.narrow_config_type(config, "batch_size", int) |
| 108 | + self.local_epochs = self.narrow_config_type(config, "local_epochs", int) |
| 109 | + # Server explicitly sets the clipping strategy |
| 110 | + self.adaptive_clipping = self.narrow_config_type(config, "adaptive_clipping", bool) |
141 | 111 |
|
142 | 112 | train_loader, validation_loader, num_examples = load_data(self.data_path, self.batch_size) |
143 | 113 |
|
144 | 114 | self.train_loader = train_loader |
145 | 115 | self.validation_loader = validation_loader |
146 | 116 | self.num_examples = num_examples |
147 | | - self.model = Net().to(self.device) |
148 | | - self.initialized = True |
149 | 117 |
|
150 | 118 | def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: |
151 | | - |
152 | 119 | self.set_parameters(parameters, config) |
153 | 120 | accuracy = train( |
154 | 121 | self.model, |
|
0 commit comments