Skip to content

Commit 7edc69b

Browse files
committed
ditto protocols
1 parent dcf507a commit 7edc69b

File tree

2 files changed

+82
-51
lines changed

2 files changed

+82
-51
lines changed

fl4health/mixins/adaptive_drift_contrained.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class AdaptiveProtocol(BasicClientProtocol):
2424
drift_penalty_tensors: list[torch.Tensor] | None
2525
drift_penalty_weight: float | None
2626

27+
def compute_penalty_loss(self) -> torch.Tensor: ...
28+
2729

2830
class AdaptiveDriftConstrainedMixin:
2931
def __init_subclass__(cls, **kwargs):

fl4health/mixins/personalized/ditto.py

Lines changed: 80 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,71 @@
11
"""Ditto Personalized Mixin"""
22

3-
from abc import ABC, abstractmethod
43
from logging import DEBUG, INFO
5-
from typing import cast
4+
from typing import cast, runtime_checkable
65

76
import torch
87
import torch.nn as nn
8+
import warnings
99
from flwr.common.logger import log
1010
from flwr.common.typing import Config, NDArrays, Scalar
1111
from torch.optim import Optimizer
1212

13-
from fl4health.mixins.adaptive_drift_contrained import AdaptiveDriftConstrainedMixin
13+
from fl4health.mixins.adaptive_drift_contrained import AdaptiveDriftConstrainedMixin, AdaptiveProtocol
1414
from fl4health.mixins.personalized.base import BasePersonalizedMixin
1515
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
16-
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
1716
from fl4health.utils.config import narrow_dict_type
18-
from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses
17+
from fl4health.utils.losses import EvaluationLosses, TrainingLosses
1918
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType
19+
from fl4health.clients.basic_client import BasicClientProtocol
2020

2121

22-
class DittoPersonalizedMixin(AdaptiveDriftConstrainedMixin, BasePersonalizedMixin, ABC):
22+
@runtime_checkable
23+
class DittoProtocol(AdaptiveProtocol):
24+
global_model: torch.nn.Module | None
2325

24-
def __init__(self, *args, **kwargs):
25-
super().__init__(*args, **kwargs)
26-
self._global_model = None
26+
def get_global_model(self, config: Config) -> nn.Module: ...
27+
28+
def _copy_optimizer_with_new_params(self, original_optimizer: Optimizer): ...
29+
30+
def set_initial_global_tensors(self) -> None: ...
31+
32+
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: ...
2733

28-
@property
29-
def global_model(self) -> torch.nn.Module | None:
30-
"""Gets the global model."""
31-
return self._global_model
3234

33-
@global_model.setter
34-
def global_model(self, value: torch.nn.Module) -> None:
35-
self._global_model = value
35+
class DittoPersonalizedMixin(AdaptiveDriftConstrainedMixin, BasePersonalizedMixin):
36+
37+
def __init_subclass__(cls, **kwargs):
38+
"""This method is called when a class inherits from AdaptiveMixin"""
39+
super().__init_subclass__(**kwargs)
40+
41+
# Check at class definition time if the parent class satisfies BasicClientProtocol
42+
for base in cls.__bases__:
43+
if base is not DittoPersonalizedMixin and isinstance(base, BasicClientProtocol):
44+
return
45+
46+
# If we get here, no compatible base was found
47+
warnings.warn(
48+
f"Class {cls.__name__} inherits from DittoPersonalizedMixin but none of its other "
49+
f"base classes implement BasicClientProtocol. This may cause runtime errors.",
50+
RuntimeWarning,
51+
)
52+
53+
def __init__(self, *args, **kwargs):
54+
# Verify at instance creation time
55+
if not isinstance(self, BasicClientProtocol):
56+
raise TypeError(
57+
f"Class {self.__class__.__name__} uses AdaptiveMixin but does not "
58+
f"implement BasicClientProtocol. Make sure a parent class implements "
59+
f"the required methods and attributes."
60+
)
61+
super().__init__(*args, **kwargs)
3662

3763
@property
38-
def optimizer_keys(self) -> list[str]:
64+
def optimizer_keys(self: DittoProtocol) -> list[str]:
3965
"""Returns the optimizer keys."""
4066
return ["local", "global"]
4167

42-
def _copy_optimizer_with_new_params(self, original_optimizer):
68+
def _copy_optimizer_with_new_params(self: DittoProtocol, original_optimizer: Optimizer):
4369
OptimClass = original_optimizer.__class__
4470
state_dict = original_optimizer.state_dict()
4571

@@ -59,7 +85,23 @@ def _copy_optimizer_with_new_params(self, original_optimizer):
5985

6086
return global_optimizer
6187

62-
def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
88+
def get_global_model(self: DittoProtocol, config: Config) -> nn.Module:
89+
"""
90+
Returns the global model to be used during Ditto training and as a constraint for the local model.
91+
92+
The global model should be the same architecture as the local model so we reuse the ``get_model`` call. We
93+
explicitly send the model to the desired device. This is idempotent.
94+
95+
Args:
96+
config (Config): The config from the server.
97+
98+
Returns:
99+
nn.Module: The PyTorch model serving as the global model for Ditto
100+
"""
101+
config["for_global"] = True
102+
return self.get_model(config).to(self.device)
103+
104+
def get_optimizer(self: DittoProtocol, config: Config) -> dict[str, Optimizer]:
63105
if self.global_model is None:
64106
# try set it here
65107
self.global_model = self.get_global_model(config) # is this the same config?
@@ -80,23 +122,7 @@ def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
80122
global_optimizer = self._copy_optimizer_with_new_params(original_optimizer)
81123
return {"local": original_optimizer, "global": global_optimizer}
82124

83-
def get_global_model(self, config: Config) -> nn.Module:
84-
"""
85-
Returns the global model to be used during Ditto training and as a constraint for the local model.
86-
87-
The global model should be the same architecture as the local model so we reuse the ``get_model`` call. We
88-
explicitly send the model to the desired device. This is idempotent.
89-
90-
Args:
91-
config (Config): The config from the server.
92-
93-
Returns:
94-
nn.Module: The PyTorch model serving as the global model for Ditto
95-
"""
96-
config["for_global"] = True
97-
return self.get_model(config).to(self.device)
98-
99-
def set_optimizer(self, config: Config) -> None:
125+
def set_optimizer(self: DittoProtocol, config: Config) -> None:
100126
"""
101127
Ditto requires an optimizer for the global model and one for the local model. This function simply ensures that
102128
the optimizers setup by the user have the proper keys and that there are two optimizers.
@@ -108,7 +134,7 @@ def set_optimizer(self, config: Config) -> None:
108134
assert isinstance(optimizers, dict) and set(self.optimizer_keys) == set(optimizers.keys())
109135
self.optimizers = optimizers
110136

111-
def setup_client(self, config: Config) -> None:
137+
def setup_client(self: DittoProtocol, config: Config) -> None:
112138
"""
113139
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these.
114140
Then set initialized attribute to True. In this class, this function simply adds the additional step of
@@ -127,7 +153,7 @@ def setup_client(self, config: Config) -> None:
127153
super().setup_client(config)
128154
# Need to setup the global model here as well. It should be the same architecture as the local model.
129155

130-
def get_parameters(self, config: Config) -> NDArrays:
156+
def get_parameters(self: DittoProtocol, config: Config) -> NDArrays:
131157
"""
132158
For Ditto, we transfer the **GLOBAL** model weights to the server to be aggregated. The local model weights
133159
stay with the client.
@@ -164,7 +190,7 @@ def get_parameters(self, config: Config) -> NDArrays:
164190
log(INFO, "Successfully packed parameters of global model")
165191
return packed_params
166192

167-
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
193+
def set_parameters(self: DittoProtocol, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
168194
"""
169195
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are
170196
unpacked for the clients to use in training. The parameters being passed are to be routed to the global model.
@@ -197,7 +223,7 @@ def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bo
197223
log(INFO, "Setting the global model weights")
198224
parameter_exchanger.pull_parameters(server_model_state, self.global_model, config)
199225

200-
def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> None:
226+
def initialize_all_model_weights(self: DittoProtocol, parameters: NDArrays, config: Config) -> None:
201227
"""
202228
If this is the first time we're initializing the model weights, we initialize both the global and the local
203229
weights together.
@@ -210,16 +236,17 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) ->
210236
parameter_exchanger.pull_parameters(parameters, self.model, config)
211237
parameter_exchanger.pull_parameters(parameters, self.global_model, config)
212238

213-
def set_initial_global_tensors(self) -> None:
239+
def set_initial_global_tensors(self: DittoProtocol) -> None:
214240
"""
215241
Saving the initial **GLOBAL MODEL** weights and detaching them so that we don't compute gradients with
216242
respect to the tensors. These are used to form the Ditto local update penalty term.
217243
"""
218-
self.drift_penalty_tensors = [
219-
initial_layer_weights.detach().clone() for initial_layer_weights in self.global_model.parameters()
220-
]
244+
if self.global_model:
245+
self.drift_penalty_tensors = [
246+
initial_layer_weights.detach().clone() for initial_layer_weights in self.global_model.parameters()
247+
]
221248

222-
def update_before_train(self, current_server_round: int) -> None:
249+
def update_before_train(self: DittoProtocol, current_server_round: int) -> None:
223250
"""
224251
Procedures that should occur before proceeding with the training loops for the models. In this case, we
225252
save the global models parameters to be used in constraining training of the local model.
@@ -234,7 +261,9 @@ def update_before_train(self, current_server_round: int) -> None:
234261

235262
super().update_before_train(current_server_round)
236263

237-
def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]:
264+
def train_step(
265+
self: DittoProtocol, input: TorchInputType, target: TorchTargetType
266+
) -> tuple[TrainingLosses, TorchPredType]:
238267
"""
239268
Mechanics of training loop follow from original Ditto implementation: https://github.com/litian96/ditto
240269
@@ -325,7 +354,7 @@ def predict(
325354
else:
326355
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")
327356

328-
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]):
357+
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
329358
if kind not in ["global", "local"]:
330359
raise ValueError("Unsupported kind of prediction. Must be 'global' or 'local'.")
331360

@@ -336,7 +365,7 @@ def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]):
336365
return retval
337366

338367
def compute_loss_and_additional_losses(
339-
self,
368+
self: DittoProtocol,
340369
preds: TorchPredType,
341370
features: TorchFeatureType,
342371
target: TorchTargetType,
@@ -379,7 +408,7 @@ def compute_loss_and_additional_losses(
379408
return local_loss, additional_losses
380409

381410
def compute_training_loss(
382-
self,
411+
self: DittoProtocol,
383412
preds: TorchPredType,
384413
features: TorchFeatureType,
385414
target: TorchTargetType,
@@ -430,7 +459,7 @@ def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict
430459
return super().validate(include_losses_in_metrics=include_losses_in_metrics)
431460

432461
def compute_evaluation_loss(
433-
self,
462+
self: DittoProtocol,
434463
preds: TorchPredType,
435464
features: TorchFeatureType,
436465
target: TorchTargetType,
@@ -451,5 +480,5 @@ def compute_evaluation_loss(
451480
indexed by name.
452481
"""
453482
# Check that both models are in eval mode
454-
assert not self.global_model.training and not self.model.training
483+
assert not self.global_model is None and not self.global_model.training and not self.model.training
455484
return super().compute_evaluation_loss(preds, features, target)

0 commit comments

Comments
 (0)