Skip to content

Commit 8c1331e

Browse files
committed
adaptive protocols
1 parent 2450258 commit 8c1331e

File tree

2 files changed

+45
-59
lines changed

2 files changed

+45
-59
lines changed

fl4health/mixins/adaptive_drift_contrained.py

Lines changed: 45 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import cast
66

77
import torch
8+
import warnings
89
from flwr.common.logger import log
910
from flwr.common.typing import Config, NDArrays
1011

@@ -16,61 +17,47 @@
1617
from fl4health.utils.losses import TrainingLosses
1718
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType
1819

20+
from fl4health.clients.basic_client import BasicClientProtocol
21+
22+
23+
class AdaptiveProtocol(BasicClientProtocol):
24+
loss_for_adaptation: float | None
25+
drift_penalty_tensors: list[torch.Tensor] | None
26+
drift_penalty_weight: float | None
27+
1928

2029
class AdaptiveDriftConstrainedMixin:
30+
def __init_subclass__(cls, **kwargs):
31+
"""This method is called when a class inherits from AdaptiveMixin"""
32+
super().__init_subclass__(**kwargs)
33+
34+
# Check at class definition time if the parent class satisfies BasicClientProtocol
35+
for base in cls.__bases__:
36+
if base is not AdaptiveDriftConstrainedMixin and isinstance(base, BasicClientProtocol):
37+
return
38+
39+
# If we get here, no compatible base was found
40+
warnings.warn(
41+
f"Class {cls.__name__} inherits from AdaptiveMixin but none of its other "
42+
f"base classes implement BasicClientProtocol. This may cause runtime errors.",
43+
RuntimeWarning,
44+
)
45+
46+
def __init__(self, *args, **kwargs):
47+
# Verify at instance creation time
48+
if not isinstance(self, BasicClientProtocol):
49+
raise TypeError(
50+
f"Class {self.__class__.__name__} uses AdaptiveMixin but does not "
51+
f"implement BasicClientProtocol. Make sure a parent class implements "
52+
f"the required methods and attributes."
53+
)
54+
super().__init__(*args, **kwargs)
2155

22-
_drift_penalty_tensors = None
23-
_parameter_exchanger = None
24-
_drift_penalty_weight = None
25-
_loss_for_adaptation = 0.1
26-
27-
@property
28-
def drift_penalty_tensors(self) -> list[torch.Tensor] | None:
29-
"""These are the tensors that will be used to compute the penalty loss"""
30-
return self._drift_penalty_tensors
31-
32-
@drift_penalty_tensors.setter
33-
def drift_penalty_tensors(self, value: list[torch.Tensor]) -> None:
34-
self._drift_penalty_tensors = value
35-
36-
@property
37-
def parameter_exchanger(self) -> FullParameterExchangerWithPacking[float] | None:
38-
"""Exchanger with packing to be able to exchange the weights and auxiliary information with the server for adaptation"""
39-
return self._parameter_exchanger
40-
41-
@parameter_exchanger.setter
42-
def parameter_exchanger(self, value: FullParameterExchangerWithPacking[float]) -> None:
43-
self._parameter_exchanger = value
44-
45-
@property
46-
def drift_penalty_weight(self) -> float:
47-
"""Weight on the penalty loss to be used in backprop. This is what might be adapted via server calculations."""
48-
return self._drift_penalty_weight
49-
50-
@drift_penalty_weight.setter
51-
def drift_penalty_weight(self, value: float) -> None:
52-
self._drift_penalty_weight = value
53-
54-
@property
55-
def loss_for_adaptation(self) -> float | None:
56-
"""This is the loss value to be sent back to the server on which adaptation decisions will be made."""
57-
return self._loss_for_adaptation
58-
59-
@loss_for_adaptation.setter
60-
def loss_for_adaptation(self, value: float) -> None:
61-
self._loss_for_adaptation = value
62-
63-
@property
64-
def penalty_loss_function(self) -> WeightDriftLoss:
56+
def penalty_loss_function(self: AdaptiveProtocol) -> WeightDriftLoss:
6557
"""Function to compute the penalty loss."""
66-
try:
67-
device = self.device
68-
device = cast(torch.device, device)
69-
return WeightDriftLoss(self.device)
70-
except AttributeError as err:
71-
raise ValueError("Parent Client is missing a `device` attribute.") from err
72-
73-
def get_parameters(self, config: Config) -> NDArrays:
58+
return WeightDriftLoss(self.device)
59+
60+
def get_parameters(self: AdaptiveProtocol, config: Config) -> NDArrays:
7461
"""
7562
Packs the parameters and training loss into a single ``NDArrays`` to be sent to the server for aggregation. If
7663
the client has not been initialized, this means the server is requesting parameters for initialization and
@@ -109,7 +96,7 @@ def get_parameters(self, config: Config) -> NDArrays:
10996
packed_params = self.parameter_exchanger.pack_parameters(model_weights, self.loss_for_adaptation)
11097
return packed_params
11198

112-
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
99+
def set_parameters(self: AdaptiveProtocol, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
113100
"""
114101
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are
115102
unpacked for the clients to use in training. In the first fitting round, we assume the full model is being
@@ -132,7 +119,7 @@ def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bo
132119
super().set_parameters(server_model_state, config, fitting_round)
133120

134121
def compute_training_loss(
135-
self,
122+
self: AdaptiveProtocol,
136123
preds: TorchPredType,
137124
features: TorchFeatureType,
138125
target: TorchTargetType,
@@ -165,7 +152,7 @@ def compute_training_loss(
165152

166153
return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses)
167154

168-
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
155+
def get_parameter_exchanger(self: AdaptiveProtocol, config: Config) -> ParameterExchanger:
169156
"""
170157
Setting up the parameter exchanger to include the appropriate packing functionality.
171158
By default we assume that we're exchanging all parameters. Can be overridden for other behavior
@@ -179,7 +166,9 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
179166

180167
return FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint())
181168

182-
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None:
169+
def update_after_train(
170+
self: AdaptiveProtocol, local_steps: int, loss_dict: dict[str, float], config: Config
171+
) -> None:
183172
"""
184173
Called after training with the number of ``local_steps`` performed over the FL round and the corresponding loss
185174
dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter
@@ -195,7 +184,7 @@ def update_after_train(self, local_steps: int, loss_dict: dict[str, float], conf
195184
self.loss_for_adaptation = loss_dict["loss_for_adaptation"]
196185
super().update_after_train(local_steps, loss_dict, config)
197186

198-
def compute_penalty_loss(self) -> torch.Tensor:
187+
def compute_penalty_loss(self: AdaptiveProtocol) -> torch.Tensor:
199188
"""
200189
Computes the drift loss for the client model and drift tensors
201190

fl4health/mixins/personalized/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
"""Base Personalized Mixins"""
22

3-
from abc import ABC, abstractmethod
43
from enum import Enum
54

6-
from fl4health.clients.basic_client import BasicClient
7-
85

96
class PersonalizedMethod(str, Enum):
107
DITTO = "ditto"

0 commit comments

Comments
 (0)