|
1 | 1 | """Ditto Personalized Mixin"""
|
2 | 2 |
|
| 3 | +import warnings |
3 | 4 | from logging import DEBUG, INFO
|
4 |
| -from typing import cast, runtime_checkable |
| 5 | +from typing import Protocol, cast, runtime_checkable |
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 | import torch.nn as nn
|
8 |
| -import warnings |
9 | 9 | from flwr.common.logger import log
|
10 | 10 | from flwr.common.typing import Config, NDArrays, Scalar
|
11 | 11 | from torch.optim import Optimizer
|
12 | 12 |
|
| 13 | +from fl4health.clients.basic_client import BasicClientProtocol |
13 | 14 | from fl4health.mixins.adaptive_drift_contrained import AdaptiveDriftConstrainedMixin, AdaptiveProtocol
|
14 | 15 | from fl4health.mixins.personalized.base import BasePersonalizedMixin
|
15 | 16 | from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
|
16 | 17 | from fl4health.utils.config import narrow_dict_type
|
17 | 18 | from fl4health.utils.losses import EvaluationLosses, TrainingLosses
|
18 | 19 | from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType
|
19 |
| -from fl4health.clients.basic_client import BasicClientProtocol |
20 | 20 |
|
21 | 21 |
|
22 | 22 | @runtime_checkable
|
23 |
| -class DittoProtocol(AdaptiveProtocol): |
| 23 | +class DittoProtocol(AdaptiveProtocol, Protocol): |
24 | 24 | global_model: torch.nn.Module | None
|
25 | 25 |
|
26 | 26 | def get_global_model(self, config: Config) -> nn.Module: ...
|
@@ -50,16 +50,6 @@ def __init_subclass__(cls, **kwargs):
|
50 | 50 | RuntimeWarning,
|
51 | 51 | )
|
52 | 52 |
|
53 |
| - def __init__(self, *args, **kwargs): |
54 |
| - # Verify at instance creation time |
55 |
| - if not isinstance(self, BasicClientProtocol): |
56 |
| - raise TypeError( |
57 |
| - f"Class {self.__class__.__name__} uses AdaptiveMixin but does not " |
58 |
| - f"implement BasicClientProtocol. Make sure a parent class implements " |
59 |
| - f"the required methods and attributes." |
60 |
| - ) |
61 |
| - super().__init__(*args, **kwargs) |
62 |
| - |
63 | 53 | @property
|
64 | 54 | def optimizer_keys(self: DittoProtocol) -> list[str]:
|
65 | 55 | """Returns the optimizer keys."""
|
@@ -143,6 +133,8 @@ def setup_client(self: DittoProtocol, config: Config) -> None:
|
143 | 133 | Args:
|
144 | 134 | config (Config): The config from the server.
|
145 | 135 | """
|
| 136 | + self.ensure_protocol_compliance() |
| 137 | + |
146 | 138 | try:
|
147 | 139 | self.global_model = self.get_global_model(config)
|
148 | 140 | log(INFO, f"global model set: {type(self.global_model).__name__}")
|
@@ -254,6 +246,7 @@ def update_before_train(self: DittoProtocol, current_server_round: int) -> None:
|
254 | 246 | Args:
|
255 | 247 | current_server_round (int): Indicates which server round we are currently executing.
|
256 | 248 | """
|
| 249 | + self.ensure_protocol_compliance() |
257 | 250 | self.set_initial_global_tensors()
|
258 | 251 |
|
259 | 252 | # Need to also set the global model to train mode before any training begins.
|
@@ -479,6 +472,7 @@ def compute_evaluation_loss(
|
479 | 472 | EvaluationLosses: An instance of ``EvaluationLosses`` containing checkpoint loss and additional losses
|
480 | 473 | indexed by name.
|
481 | 474 | """
|
| 475 | + self.ensure_protocol_compliance() |
482 | 476 | # Check that both models are in eval mode
|
483 | 477 | assert not self.global_model is None and not self.global_model.training and not self.model.training
|
484 | 478 | return super().compute_evaluation_loss(preds, features, target)
|
0 commit comments