5
5
from typing import cast
6
6
7
7
import torch
8
+ import warnings
8
9
from flwr .common .logger import log
9
10
from flwr .common .typing import Config , NDArrays
10
11
16
17
from fl4health .utils .losses import TrainingLosses
17
18
from fl4health .utils .typing import TorchFeatureType , TorchPredType , TorchTargetType
18
19
20
+ from fl4health .clients .basic_client import BasicClientProtocol
21
+
22
+
23
+ class AdaptiveProtocol (BasicClientProtocol ):
24
+ loss_for_adaptation : float | None
25
+ drift_penalty_tensors : list [torch .Tensor ] | None
26
+ drift_penalty_weight : float | None
27
+
19
28
20
29
class AdaptiveDriftConstrainedMixin :
30
+ def __init_subclass__ (cls , ** kwargs ):
31
+ """This method is called when a class inherits from AdaptiveMixin"""
32
+ super ().__init_subclass__ (** kwargs )
33
+
34
+ # Check at class definition time if the parent class satisfies BasicClientProtocol
35
+ for base in cls .__bases__ :
36
+ if base is not AdaptiveDriftConstrainedMixin and isinstance (base , BasicClientProtocol ):
37
+ return
38
+
39
+ # If we get here, no compatible base was found
40
+ warnings .warn (
41
+ f"Class { cls .__name__ } inherits from AdaptiveMixin but none of its other "
42
+ f"base classes implement BasicClientProtocol. This may cause runtime errors." ,
43
+ RuntimeWarning ,
44
+ )
45
+
46
+ def __init__ (self , * args , ** kwargs ):
47
+ # Verify at instance creation time
48
+ if not isinstance (self , BasicClientProtocol ):
49
+ raise TypeError (
50
+ f"Class { self .__class__ .__name__ } uses AdaptiveMixin but does not "
51
+ f"implement BasicClientProtocol. Make sure a parent class implements "
52
+ f"the required methods and attributes."
53
+ )
54
+ super ().__init__ (* args , ** kwargs )
21
55
22
- _drift_penalty_tensors = None
23
- _parameter_exchanger = None
24
- _drift_penalty_weight = None
25
- _loss_for_adaptation = 0.1
26
-
27
- @property
28
- def drift_penalty_tensors (self ) -> list [torch .Tensor ] | None :
29
- """These are the tensors that will be used to compute the penalty loss"""
30
- return self ._drift_penalty_tensors
31
-
32
- @drift_penalty_tensors .setter
33
- def drift_penalty_tensors (self , value : list [torch .Tensor ]) -> None :
34
- self ._drift_penalty_tensors = value
35
-
36
- @property
37
- def parameter_exchanger (self ) -> FullParameterExchangerWithPacking [float ] | None :
38
- """Exchanger with packing to be able to exchange the weights and auxiliary information with the server for adaptation"""
39
- return self ._parameter_exchanger
40
-
41
- @parameter_exchanger .setter
42
- def parameter_exchanger (self , value : FullParameterExchangerWithPacking [float ]) -> None :
43
- self ._parameter_exchanger = value
44
-
45
- @property
46
- def drift_penalty_weight (self ) -> float :
47
- """Weight on the penalty loss to be used in backprop. This is what might be adapted via server calculations."""
48
- return self ._drift_penalty_weight
49
-
50
- @drift_penalty_weight .setter
51
- def drift_penalty_weight (self , value : float ) -> None :
52
- self ._drift_penalty_weight = value
53
-
54
- @property
55
- def loss_for_adaptation (self ) -> float | None :
56
- """This is the loss value to be sent back to the server on which adaptation decisions will be made."""
57
- return self ._loss_for_adaptation
58
-
59
- @loss_for_adaptation .setter
60
- def loss_for_adaptation (self , value : float ) -> None :
61
- self ._loss_for_adaptation = value
62
-
63
- @property
64
- def penalty_loss_function (self ) -> WeightDriftLoss :
56
+ def penalty_loss_function (self : AdaptiveProtocol ) -> WeightDriftLoss :
65
57
"""Function to compute the penalty loss."""
66
- try :
67
- device = self .device
68
- device = cast (torch .device , device )
69
- return WeightDriftLoss (self .device )
70
- except AttributeError as err :
71
- raise ValueError ("Parent Client is missing a `device` attribute." ) from err
72
-
73
- def get_parameters (self , config : Config ) -> NDArrays :
58
+ return WeightDriftLoss (self .device )
59
+
60
+ def get_parameters (self : AdaptiveProtocol , config : Config ) -> NDArrays :
74
61
"""
75
62
Packs the parameters and training loss into a single ``NDArrays`` to be sent to the server for aggregation. If
76
63
the client has not been initialized, this means the server is requesting parameters for initialization and
@@ -109,7 +96,7 @@ def get_parameters(self, config: Config) -> NDArrays:
109
96
packed_params = self .parameter_exchanger .pack_parameters (model_weights , self .loss_for_adaptation )
110
97
return packed_params
111
98
112
- def set_parameters (self , parameters : NDArrays , config : Config , fitting_round : bool ) -> None :
99
+ def set_parameters (self : AdaptiveProtocol , parameters : NDArrays , config : Config , fitting_round : bool ) -> None :
113
100
"""
114
101
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are
115
102
unpacked for the clients to use in training. In the first fitting round, we assume the full model is being
@@ -132,7 +119,7 @@ def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bo
132
119
super ().set_parameters (server_model_state , config , fitting_round )
133
120
134
121
def compute_training_loss (
135
- self ,
122
+ self : AdaptiveProtocol ,
136
123
preds : TorchPredType ,
137
124
features : TorchFeatureType ,
138
125
target : TorchTargetType ,
@@ -165,7 +152,7 @@ def compute_training_loss(
165
152
166
153
return TrainingLosses (backward = loss + penalty_loss , additional_losses = additional_losses )
167
154
168
- def get_parameter_exchanger (self , config : Config ) -> ParameterExchanger :
155
+ def get_parameter_exchanger (self : AdaptiveProtocol , config : Config ) -> ParameterExchanger :
169
156
"""
170
157
Setting up the parameter exchanger to include the appropriate packing functionality.
171
158
By default we assume that we're exchanging all parameters. Can be overridden for other behavior
@@ -179,7 +166,9 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
179
166
180
167
return FullParameterExchangerWithPacking (ParameterPackerAdaptiveConstraint ())
181
168
182
- def update_after_train (self , local_steps : int , loss_dict : dict [str , float ], config : Config ) -> None :
169
+ def update_after_train (
170
+ self : AdaptiveProtocol , local_steps : int , loss_dict : dict [str , float ], config : Config
171
+ ) -> None :
183
172
"""
184
173
Called after training with the number of ``local_steps`` performed over the FL round and the corresponding loss
185
174
dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter
@@ -195,7 +184,7 @@ def update_after_train(self, local_steps: int, loss_dict: dict[str, float], conf
195
184
self .loss_for_adaptation = loss_dict ["loss_for_adaptation" ]
196
185
super ().update_after_train (local_steps , loss_dict , config )
197
186
198
- def compute_penalty_loss (self ) -> torch .Tensor :
187
+ def compute_penalty_loss (self : AdaptiveProtocol ) -> torch .Tensor :
199
188
"""
200
189
Computes the drift loss for the client model and drift tensors
201
190
0 commit comments