Skip to content

Commit 4bf8565

Browse files
authored
Merge pull request #384 from VectorInstitute/nerdai/personalized-mixins
[Feature] Add `BasicClientProtocol` and `AdaptiveDriftConstrainedMixin`
2 parents 9338be2 + a175745 commit 4bf8565

File tree

7 files changed

+531
-1
lines changed

7 files changed

+531
-1
lines changed

fl4health/clients/basic_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ def set_optimizer(self, config: Config) -> None:
10941094
assert not isinstance(optimizer, dict)
10951095
self.optimizers = {"global": optimizer}
10961096

1097-
def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]:
1097+
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
10981098
"""
10991099
User defined method that returns a PyTorch Train DataLoader
11001100
and a PyTorch Validation DataLoader

fl4health/mixins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from fl4health.mixins.adaptive_drift_constrained import AdaptiveDriftConstrainedMixin
2+
3+
__all__ = ["AdaptiveDriftConstrainedMixin"]
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""AdaptiveDriftConstrainedMixin"""
2+
3+
import warnings
4+
from logging import INFO, WARN
5+
from typing import Any, Protocol, runtime_checkable
6+
7+
import torch
8+
from flwr.common.logger import log
9+
from flwr.common.typing import Config, NDArrays
10+
11+
from fl4health.clients.basic_client import BasicClient
12+
from fl4health.losses.weight_drift_loss import WeightDriftLoss
13+
from fl4health.mixins.core_protocols import BasicClientProtocol, BasicClientProtocolPreSetup
14+
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
15+
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
16+
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
17+
from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint
18+
from fl4health.utils.losses import TrainingLosses
19+
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType
20+
21+
22+
@runtime_checkable
23+
class AdaptiveDriftConstrainedProtocol(BasicClientProtocol, Protocol):
24+
loss_for_adaptation: float
25+
drift_penalty_tensors: list[torch.Tensor] | None
26+
drift_penalty_weight: float | None
27+
penalty_loss_function: WeightDriftLoss
28+
parameter_exchanger: FullParameterExchangerWithPacking[float]
29+
30+
def compute_penalty_loss(self) -> torch.Tensor: ... # noqa: E704
31+
32+
33+
class AdaptiveDriftConstrainedMixin:
34+
def __init__(self, *args: Any, **kwargs: Any):
35+
"""Adaptive Drift Constrained Mixin
36+
37+
To be used with `~fl4health.BaseClient` in order to add the ability to compute
38+
losses via a constrained adaptive drift.
39+
40+
Raises:
41+
RuntimeError: when the inheriting class does not satisfy `BasicClientProtocolPreSetup`.
42+
"""
43+
# Initialize mixin-specific attributes with default values
44+
self.loss_for_adaptation = 0.1
45+
self.drift_penalty_tensors = None
46+
self.drift_penalty_weight = None
47+
48+
# Call parent's init
49+
try:
50+
super().__init__(*args, **kwargs)
51+
except TypeError:
52+
# if a parent class doesn't take args/kwargs
53+
super().__init__()
54+
55+
# set penalty_loss_function
56+
if not isinstance(self, BasicClientProtocolPreSetup):
57+
raise RuntimeError("This object needs to satisfy `BasicClientProtocolPreSetup`.")
58+
self.penalty_loss_function = WeightDriftLoss(self.device)
59+
60+
def __init_subclass__(cls, **kwargs: Any):
61+
"""This method is called when a class inherits from AdaptiveDriftConstrainedMixin."""
62+
super().__init_subclass__(**kwargs)
63+
64+
# Skip check for other mixins
65+
if cls.__name__.endswith("Mixin"):
66+
return
67+
68+
# Skip validation for dynamically created classes
69+
if hasattr(cls, "_dynamically_created"):
70+
return
71+
72+
# Check at class definition time if the parent class satisfies BasicClientProtocol
73+
for base in cls.__bases__:
74+
if base is not AdaptiveDriftConstrainedMixin and issubclass(base, BasicClient):
75+
return
76+
77+
# If we get here, no compatible base was found
78+
msg = (
79+
f"Class {cls.__name__} inherits from AdaptiveDriftConstrainedMixin but none of its other "
80+
f"base classes is a BasicClient. This may cause runtime errors."
81+
)
82+
log(WARN, msg)
83+
warnings.warn(
84+
msg,
85+
RuntimeWarning,
86+
)
87+
88+
def get_parameters(self: AdaptiveDriftConstrainedProtocol, config: Config) -> NDArrays:
89+
"""
90+
Packs the parameters and training loss into a single ``NDArrays`` to be sent to the server for aggregation. If
91+
the client has not been initialized, this means the server is requesting parameters for initialization and
92+
just the model parameters are sent. When using the ``FedAvgWithAdaptiveConstraint`` strategy, this should not
93+
happen, as that strategy requires server-side initialization parameters. However, other strategies may handle
94+
this case.
95+
96+
Args:
97+
config (Config): Configurations to allow for customization of this functions behavior
98+
99+
Returns:
100+
NDArrays: Parameters and training loss packed together into a list of numpy arrays to be sent to the server
101+
"""
102+
if not self.initialized:
103+
log(INFO, "Setting up client and providing full model parameters to the server for initialization")
104+
105+
# If initialized is False, the server is requesting model parameters from which to initialize all other
106+
# clients. As such get_parameters is being called before fit or evaluate, so we must call
107+
# setup_client first.
108+
self.setup_client(config)
109+
110+
# Need all parameters even if normally exchanging partial
111+
return FullParameterExchanger().push_parameters(self.model, config=config)
112+
else:
113+
114+
# Make sure the proper components are there
115+
assert (
116+
self.model is not None
117+
and self.parameter_exchanger is not None
118+
and self.loss_for_adaptation is not None
119+
)
120+
model_weights = self.parameter_exchanger.push_parameters(self.model, config=config)
121+
122+
# Weights and training loss sent to server for aggregation. Training loss is sent because server will
123+
# decide to increase or decrease the penalty weight, if adaptivity is turned on.
124+
packed_params = self.parameter_exchanger.pack_parameters(model_weights, self.loss_for_adaptation)
125+
return packed_params
126+
127+
def set_parameters(
128+
self: AdaptiveDriftConstrainedProtocol, parameters: NDArrays, config: Config, fitting_round: bool
129+
) -> None:
130+
"""
131+
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are
132+
unpacked for the clients to use in training. In the first fitting round, we assume the full model is being
133+
initialized and use the ``FullParameterExchanger()`` to set all model weights.
134+
135+
Args:
136+
parameters (NDArrays): Parameters have information about model state to be added to the relevant client
137+
model and also the penalty weight to be applied during training.
138+
config (Config): The config is sent by the FL server to allow for customization in the function if desired.
139+
fitting_round (bool): Boolean that indicates whether the current federated learning round is a fitting
140+
round or an evaluation round. This is used to help determine which parameter exchange should be used
141+
for pulling parameters. A full parameter exchanger is always used if the current federated learning
142+
round is the very first fitting round.
143+
"""
144+
assert self.model is not None and self.parameter_exchanger is not None
145+
146+
server_model_state, self.drift_penalty_weight = self.parameter_exchanger.unpack_parameters(parameters)
147+
log(INFO, f"Penalty weight received from the server: {self.drift_penalty_weight}")
148+
149+
super().set_parameters(server_model_state, config, fitting_round) # type: ignore[safe-super]
150+
151+
def compute_training_loss(
152+
self: AdaptiveDriftConstrainedProtocol,
153+
preds: TorchPredType,
154+
features: TorchFeatureType,
155+
target: TorchTargetType,
156+
) -> TrainingLosses:
157+
"""
158+
Computes training loss given predictions of the model and ground truth data. Adds to objective by including
159+
penalty loss.
160+
161+
Args:
162+
preds (TorchPredType): Prediction(s) of the model(s) indexed by name. All predictions included in
163+
dictionary will be used to compute metrics.
164+
features: (TorchFeatureType): Feature(s) of the model(s) indexed by name.
165+
target: (TorchTargetType): Ground truth data to evaluate predictions against.
166+
167+
Returns:
168+
TrainingLosses: An instance of ``TrainingLosses`` containing backward loss and additional losses indexed
169+
by name. Additional losses includes penalty loss.
170+
"""
171+
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target)
172+
if additional_losses is None:
173+
additional_losses = {}
174+
175+
additional_losses["loss"] = loss.clone()
176+
# adding the vanilla loss to the additional losses to be used by update_after_train for potential adaptation
177+
additional_losses["loss_for_adaptation"] = loss.clone()
178+
179+
# Compute the drift penalty loss and store it in the additional losses dictionary.
180+
penalty_loss = self.compute_penalty_loss()
181+
additional_losses["penalty_loss"] = penalty_loss.clone()
182+
183+
return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses)
184+
185+
def get_parameter_exchanger(self: AdaptiveDriftConstrainedProtocol, config: Config) -> ParameterExchanger:
186+
"""
187+
Setting up the parameter exchanger to include the appropriate packing functionality.
188+
By default we assume that we're exchanging all parameters. Can be overridden for other behavior
189+
190+
Args:
191+
config (Config): The config is sent by the FL server to allow for customization in the function if desired.
192+
193+
Returns:
194+
ParameterExchanger: Exchanger that can handle packing/unpacking auxiliary server information.
195+
"""
196+
197+
return FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint())
198+
199+
def update_after_train(
200+
self: AdaptiveDriftConstrainedProtocol, local_steps: int, loss_dict: dict[str, float], config: Config
201+
) -> None:
202+
"""
203+
Called after training with the number of ``local_steps`` performed over the FL round and the corresponding loss
204+
dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter
205+
on the server side.
206+
207+
Args:
208+
local_steps (int): The number of steps so far in the round in the local training.
209+
loss_dict (dict[str, float]): A dictionary of losses from local training.
210+
config (Config): The config from the server
211+
"""
212+
assert "loss_for_adaptation" in loss_dict
213+
# Store current loss which is the vanilla loss without the penalty term added in
214+
self.loss_for_adaptation = loss_dict["loss_for_adaptation"]
215+
super().update_after_train(local_steps, loss_dict, config) # type: ignore[safe-super]
216+
217+
def compute_penalty_loss(self: AdaptiveDriftConstrainedProtocol) -> torch.Tensor:
218+
"""
219+
Computes the drift loss for the client model and drift tensors
220+
221+
Returns:
222+
torch.Tensor: Computed penalty loss tensor
223+
"""
224+
# Penalty tensors must have been set for these clients.
225+
assert self.drift_penalty_tensors is not None
226+
227+
return self.penalty_loss_function(self.model, self.drift_penalty_tensors, self.drift_penalty_weight)
228+
229+
230+
def apply_adaptive_drift_to_client(client_base_type: type[BasicClient]) -> type[BasicClient]:
231+
"""Dynamically create an adapted client class.
232+
233+
Args:
234+
client_base_type (type[BasicClient]): The class to be mixed.
235+
236+
Returns:
237+
type[BasicClient]: A basic client that has been mixed with `AdaptiveDriftConstrainedMixin`.
238+
"""
239+
240+
return type(
241+
f"AdaptiveDrift{client_base_type.__name__}",
242+
(
243+
AdaptiveDriftConstrainedMixin,
244+
client_base_type,
245+
),
246+
{
247+
# Special flag to bypass validation
248+
"_dynamically_created": True
249+
},
250+
)

fl4health/mixins/core_protocols.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Protocol, runtime_checkable
2+
3+
import torch
4+
import torch.nn as nn
5+
from flwr.common.typing import Config, NDArrays, Scalar
6+
from torch.nn.modules.loss import _Loss
7+
from torch.optim import Optimizer
8+
from torch.utils.data import DataLoader
9+
10+
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType
11+
12+
13+
@runtime_checkable
14+
class NumPyClientMinimalProtocol(Protocol):
15+
"""A minimal protocol for NumPyClient with just essential methods."""
16+
17+
def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
18+
pass # pragma: no cover
19+
20+
def fit(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[NDArrays, int, dict[str, Scalar]]:
21+
pass # pragma: no cover
22+
23+
def evaluate(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[float, int, dict[str, Scalar]]:
24+
pass # pragma: no cover
25+
26+
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
27+
pass # pragma: no cover
28+
29+
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None:
30+
pass # pragma: no cover
31+
32+
33+
@runtime_checkable
34+
class BasicClientProtocolPreSetup(NumPyClientMinimalProtocol, Protocol):
35+
"""A minimal protocol for BasicClient focused on methods."""
36+
37+
device: torch.device
38+
initialized: bool
39+
40+
# Include only methods, not attributes that get initialized later
41+
def setup_client(self, config: Config) -> None:
42+
pass # pragma: no cover
43+
44+
def get_model(self, config: Config) -> nn.Module:
45+
pass # pragma: no cover
46+
47+
def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]:
48+
pass # pragma: no cover
49+
50+
def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]:
51+
pass # pragma: no cover
52+
53+
def get_criterion(self, config: Config) -> _Loss:
54+
pass # pragma: no cover
55+
56+
def compute_loss_and_additional_losses(
57+
self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType
58+
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:
59+
pass # pragma: no cover
60+
61+
62+
@runtime_checkable
63+
class BasicClientProtocol(BasicClientProtocolPreSetup, Protocol):
64+
"""A minimal protocol for BasicClient focused on methods."""
65+
66+
model: nn.Module
67+
optimizers: dict[str, torch.optim.Optimizer]
68+
train_loader: DataLoader
69+
val_loader: DataLoader
70+
test_loader: DataLoader | None

tests/mixins/__init__.py

Whitespace-only changes.

tests/mixins/conftest.py

Whitespace-only changes.

0 commit comments

Comments
 (0)