1
1
"""Ditto Personalized Mixin"""
2
2
3
- from abc import ABC , abstractmethod
4
3
from logging import DEBUG , INFO
5
- from typing import cast
4
+ from typing import cast , runtime_checkable
6
5
7
6
import torch
8
7
import torch .nn as nn
8
+ import warnings
9
9
from flwr .common .logger import log
10
10
from flwr .common .typing import Config , NDArrays , Scalar
11
11
from torch .optim import Optimizer
12
12
13
- from fl4health .mixins .adaptive_drift_contrained import AdaptiveDriftConstrainedMixin
13
+ from fl4health .mixins .adaptive_drift_contrained import AdaptiveDriftConstrainedMixin , AdaptiveProtocol
14
14
from fl4health .mixins .personalized .base import BasePersonalizedMixin
15
15
from fl4health .parameter_exchange .full_exchanger import FullParameterExchanger
16
- from fl4health .parameter_exchange .parameter_exchanger_base import ParameterExchanger
17
16
from fl4health .utils .config import narrow_dict_type
18
- from fl4health .utils .losses import EvaluationLosses , LossMeterType , TrainingLosses
17
+ from fl4health .utils .losses import EvaluationLosses , TrainingLosses
19
18
from fl4health .utils .typing import TorchFeatureType , TorchInputType , TorchPredType , TorchTargetType
19
+ from fl4health .clients .basic_client import BasicClientProtocol
20
20
21
21
22
- class DittoPersonalizedMixin (AdaptiveDriftConstrainedMixin , BasePersonalizedMixin , ABC ):
22
+ @runtime_checkable
23
+ class DittoProtocol (AdaptiveProtocol ):
24
+ global_model : torch .nn .Module | None
23
25
24
- def __init__ (self , * args , ** kwargs ):
25
- super ().__init__ (* args , ** kwargs )
26
- self ._global_model = None
26
+ def get_global_model (self , config : Config ) -> nn .Module : ...
27
+
28
+ def _copy_optimizer_with_new_params (self , original_optimizer : Optimizer ): ...
29
+
30
+ def set_initial_global_tensors (self ) -> None : ...
31
+
32
+ def _extract_pred (self , kind : str , preds : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]: ...
27
33
28
- @property
29
- def global_model (self ) -> torch .nn .Module | None :
30
- """Gets the global model."""
31
- return self ._global_model
32
34
33
- @global_model .setter
34
- def global_model (self , value : torch .nn .Module ) -> None :
35
- self ._global_model = value
35
+ class DittoPersonalizedMixin (AdaptiveDriftConstrainedMixin , BasePersonalizedMixin ):
36
+
37
+ def __init_subclass__ (cls , ** kwargs ):
38
+ """This method is called when a class inherits from AdaptiveMixin"""
39
+ super ().__init_subclass__ (** kwargs )
40
+
41
+ # Check at class definition time if the parent class satisfies BasicClientProtocol
42
+ for base in cls .__bases__ :
43
+ if base is not DittoPersonalizedMixin and isinstance (base , BasicClientProtocol ):
44
+ return
45
+
46
+ # If we get here, no compatible base was found
47
+ warnings .warn (
48
+ f"Class { cls .__name__ } inherits from DittoPersonalizedMixin but none of its other "
49
+ f"base classes implement BasicClientProtocol. This may cause runtime errors." ,
50
+ RuntimeWarning ,
51
+ )
52
+
53
+ def __init__ (self , * args , ** kwargs ):
54
+ # Verify at instance creation time
55
+ if not isinstance (self , BasicClientProtocol ):
56
+ raise TypeError (
57
+ f"Class { self .__class__ .__name__ } uses AdaptiveMixin but does not "
58
+ f"implement BasicClientProtocol. Make sure a parent class implements "
59
+ f"the required methods and attributes."
60
+ )
61
+ super ().__init__ (* args , ** kwargs )
36
62
37
63
@property
38
- def optimizer_keys (self ) -> list [str ]:
64
+ def optimizer_keys (self : DittoProtocol ) -> list [str ]:
39
65
"""Returns the optimizer keys."""
40
66
return ["local" , "global" ]
41
67
42
- def _copy_optimizer_with_new_params (self , original_optimizer ):
68
+ def _copy_optimizer_with_new_params (self : DittoProtocol , original_optimizer : Optimizer ):
43
69
OptimClass = original_optimizer .__class__
44
70
state_dict = original_optimizer .state_dict ()
45
71
@@ -59,7 +85,23 @@ def _copy_optimizer_with_new_params(self, original_optimizer):
59
85
60
86
return global_optimizer
61
87
62
- def get_optimizer (self , config : Config ) -> dict [str , Optimizer ]:
88
+ def get_global_model (self : DittoProtocol , config : Config ) -> nn .Module :
89
+ """
90
+ Returns the global model to be used during Ditto training and as a constraint for the local model.
91
+
92
+ The global model should be the same architecture as the local model so we reuse the ``get_model`` call. We
93
+ explicitly send the model to the desired device. This is idempotent.
94
+
95
+ Args:
96
+ config (Config): The config from the server.
97
+
98
+ Returns:
99
+ nn.Module: The PyTorch model serving as the global model for Ditto
100
+ """
101
+ config ["for_global" ] = True
102
+ return self .get_model (config ).to (self .device )
103
+
104
+ def get_optimizer (self : DittoProtocol , config : Config ) -> dict [str , Optimizer ]:
63
105
if self .global_model is None :
64
106
# try set it here
65
107
self .global_model = self .get_global_model (config ) # is this the same config?
@@ -80,23 +122,7 @@ def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
80
122
global_optimizer = self ._copy_optimizer_with_new_params (original_optimizer )
81
123
return {"local" : original_optimizer , "global" : global_optimizer }
82
124
83
- def get_global_model (self , config : Config ) -> nn .Module :
84
- """
85
- Returns the global model to be used during Ditto training and as a constraint for the local model.
86
-
87
- The global model should be the same architecture as the local model so we reuse the ``get_model`` call. We
88
- explicitly send the model to the desired device. This is idempotent.
89
-
90
- Args:
91
- config (Config): The config from the server.
92
-
93
- Returns:
94
- nn.Module: The PyTorch model serving as the global model for Ditto
95
- """
96
- config ["for_global" ] = True
97
- return self .get_model (config ).to (self .device )
98
-
99
- def set_optimizer (self , config : Config ) -> None :
125
+ def set_optimizer (self : DittoProtocol , config : Config ) -> None :
100
126
"""
101
127
Ditto requires an optimizer for the global model and one for the local model. This function simply ensures that
102
128
the optimizers setup by the user have the proper keys and that there are two optimizers.
@@ -108,7 +134,7 @@ def set_optimizer(self, config: Config) -> None:
108
134
assert isinstance (optimizers , dict ) and set (self .optimizer_keys ) == set (optimizers .keys ())
109
135
self .optimizers = optimizers
110
136
111
- def setup_client (self , config : Config ) -> None :
137
+ def setup_client (self : DittoProtocol , config : Config ) -> None :
112
138
"""
113
139
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these.
114
140
Then set initialized attribute to True. In this class, this function simply adds the additional step of
@@ -127,7 +153,7 @@ def setup_client(self, config: Config) -> None:
127
153
super ().setup_client (config )
128
154
# Need to setup the global model here as well. It should be the same architecture as the local model.
129
155
130
- def get_parameters (self , config : Config ) -> NDArrays :
156
+ def get_parameters (self : DittoProtocol , config : Config ) -> NDArrays :
131
157
"""
132
158
For Ditto, we transfer the **GLOBAL** model weights to the server to be aggregated. The local model weights
133
159
stay with the client.
@@ -164,7 +190,7 @@ def get_parameters(self, config: Config) -> NDArrays:
164
190
log (INFO , "Successfully packed parameters of global model" )
165
191
return packed_params
166
192
167
- def set_parameters (self , parameters : NDArrays , config : Config , fitting_round : bool ) -> None :
193
+ def set_parameters (self : DittoProtocol , parameters : NDArrays , config : Config , fitting_round : bool ) -> None :
168
194
"""
169
195
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are
170
196
unpacked for the clients to use in training. The parameters being passed are to be routed to the global model.
@@ -197,7 +223,7 @@ def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bo
197
223
log (INFO , "Setting the global model weights" )
198
224
parameter_exchanger .pull_parameters (server_model_state , self .global_model , config )
199
225
200
- def initialize_all_model_weights (self , parameters : NDArrays , config : Config ) -> None :
226
+ def initialize_all_model_weights (self : DittoProtocol , parameters : NDArrays , config : Config ) -> None :
201
227
"""
202
228
If this is the first time we're initializing the model weights, we initialize both the global and the local
203
229
weights together.
@@ -210,16 +236,17 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) ->
210
236
parameter_exchanger .pull_parameters (parameters , self .model , config )
211
237
parameter_exchanger .pull_parameters (parameters , self .global_model , config )
212
238
213
- def set_initial_global_tensors (self ) -> None :
239
+ def set_initial_global_tensors (self : DittoProtocol ) -> None :
214
240
"""
215
241
Saving the initial **GLOBAL MODEL** weights and detaching them so that we don't compute gradients with
216
242
respect to the tensors. These are used to form the Ditto local update penalty term.
217
243
"""
218
- self .drift_penalty_tensors = [
219
- initial_layer_weights .detach ().clone () for initial_layer_weights in self .global_model .parameters ()
220
- ]
244
+ if self .global_model :
245
+ self .drift_penalty_tensors = [
246
+ initial_layer_weights .detach ().clone () for initial_layer_weights in self .global_model .parameters ()
247
+ ]
221
248
222
- def update_before_train (self , current_server_round : int ) -> None :
249
+ def update_before_train (self : DittoProtocol , current_server_round : int ) -> None :
223
250
"""
224
251
Procedures that should occur before proceeding with the training loops for the models. In this case, we
225
252
save the global models parameters to be used in constraining training of the local model.
@@ -234,7 +261,9 @@ def update_before_train(self, current_server_round: int) -> None:
234
261
235
262
super ().update_before_train (current_server_round )
236
263
237
- def train_step (self , input : TorchInputType , target : TorchTargetType ) -> tuple [TrainingLosses , TorchPredType ]:
264
+ def train_step (
265
+ self : DittoProtocol , input : TorchInputType , target : TorchTargetType
266
+ ) -> tuple [TrainingLosses , TorchPredType ]:
238
267
"""
239
268
Mechanics of training loop follow from original Ditto implementation: https://github.com/litian96/ditto
240
269
@@ -325,7 +354,7 @@ def predict(
325
354
else :
326
355
raise ValueError (f"Unsupported pred type: { type (global_preds )} ." )
327
356
328
- def _extract_pred (self , kind : str , preds : dict [str , torch .Tensor ]):
357
+ def _extract_pred (self , kind : str , preds : dict [str , torch .Tensor ]) -> dict [ str , torch . Tensor ] :
329
358
if kind not in ["global" , "local" ]:
330
359
raise ValueError ("Unsupported kind of prediction. Must be 'global' or 'local'." )
331
360
@@ -336,7 +365,7 @@ def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]):
336
365
return retval
337
366
338
367
def compute_loss_and_additional_losses (
339
- self ,
368
+ self : DittoProtocol ,
340
369
preds : TorchPredType ,
341
370
features : TorchFeatureType ,
342
371
target : TorchTargetType ,
@@ -379,7 +408,7 @@ def compute_loss_and_additional_losses(
379
408
return local_loss , additional_losses
380
409
381
410
def compute_training_loss (
382
- self ,
411
+ self : DittoProtocol ,
383
412
preds : TorchPredType ,
384
413
features : TorchFeatureType ,
385
414
target : TorchTargetType ,
@@ -430,7 +459,7 @@ def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict
430
459
return super ().validate (include_losses_in_metrics = include_losses_in_metrics )
431
460
432
461
def compute_evaluation_loss (
433
- self ,
462
+ self : DittoProtocol ,
434
463
preds : TorchPredType ,
435
464
features : TorchFeatureType ,
436
465
target : TorchTargetType ,
@@ -451,5 +480,5 @@ def compute_evaluation_loss(
451
480
indexed by name.
452
481
"""
453
482
# Check that both models are in eval mode
454
- assert not self .global_model .training and not self .model .training
483
+ assert not self .global_model is None and not self . global_model .training and not self .model .training
455
484
return super ().compute_evaluation_loss (preds , features , target )
0 commit comments