Skip to content

Commit 9ad36b8

Browse files
committed
wip
1 parent 7edc69b commit 9ad36b8

File tree

3 files changed

+17
-26
lines changed

3 files changed

+17
-26
lines changed

examples/nnunet_example/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels
2929
from fl4health.utils.nnunet_utils import get_segs_from_probs, set_nnunet_env
3030

31-
3231
personalized_client_classes = {"ditto": make_it_personal(NnunetClient, "ditto")}
3332

3433

@@ -107,7 +106,7 @@ def main(
107106
log(INFO, f"Setting up client without personalization")
108107
client = NnunetClient(**client_kwargs)
109108
log(INFO, f"Using client: {type(client).__name__}")
110-
log(INFO, f"Parameter exchanger: {type(client.parameter_exchanger).__name__}")
109+
# log(INFO, f"Parameter exchanger: {type(client.parameter_exchanger).__name__}")
111110

112111
start_client(server_address=server_address, client=client.to_client())
113112

fl4health/mixins/adaptive_drift_contrained.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections.abc import Sequence
55
from logging import INFO
6-
from typing import cast
6+
from typing import Protocol, runtime_checkable
77

88
import torch
99
from flwr.common.logger import log
@@ -19,13 +19,16 @@
1919
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType
2020

2121

22-
class AdaptiveProtocol(BasicClientProtocol):
22+
@runtime_checkable
23+
class AdaptiveProtocol(BasicClientProtocol, Protocol):
2324
loss_for_adaptation: float | None
2425
drift_penalty_tensors: list[torch.Tensor] | None
2526
drift_penalty_weight: float | None
2627

2728
def compute_penalty_loss(self) -> torch.Tensor: ...
2829

30+
def ensure_protocol_compliance(self) -> None: ...
31+
2932

3033
class AdaptiveDriftConstrainedMixin:
3134
def __init_subclass__(cls, **kwargs):
@@ -44,15 +47,10 @@ def __init_subclass__(cls, **kwargs):
4447
RuntimeWarning,
4548
)
4649

47-
def __init__(self, *args, **kwargs):
48-
# Verify at instance creation time
50+
def ensure_protocol_compliance(self) -> None:
51+
"""Call this after the object is fully initialized"""
4952
if not isinstance(self, BasicClientProtocol):
50-
raise TypeError(
51-
f"Class {self.__class__.__name__} uses AdaptiveMixin but does not "
52-
f"implement BasicClientProtocol. Make sure a parent class implements "
53-
f"the required methods and attributes."
54-
)
55-
super().__init__(*args, **kwargs)
53+
raise TypeError(f"Protocol requirements not met.")
5654

5755
def penalty_loss_function(self: AdaptiveProtocol) -> WeightDriftLoss:
5856
"""Function to compute the penalty loss."""

fl4health/mixins/personalized/ditto.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""Ditto Personalized Mixin"""
22

3+
import warnings
34
from logging import DEBUG, INFO
4-
from typing import cast, runtime_checkable
5+
from typing import Protocol, cast, runtime_checkable
56

67
import torch
78
import torch.nn as nn
8-
import warnings
99
from flwr.common.logger import log
1010
from flwr.common.typing import Config, NDArrays, Scalar
1111
from torch.optim import Optimizer
1212

13+
from fl4health.clients.basic_client import BasicClientProtocol
1314
from fl4health.mixins.adaptive_drift_contrained import AdaptiveDriftConstrainedMixin, AdaptiveProtocol
1415
from fl4health.mixins.personalized.base import BasePersonalizedMixin
1516
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
1617
from fl4health.utils.config import narrow_dict_type
1718
from fl4health.utils.losses import EvaluationLosses, TrainingLosses
1819
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType
19-
from fl4health.clients.basic_client import BasicClientProtocol
2020

2121

2222
@runtime_checkable
23-
class DittoProtocol(AdaptiveProtocol):
23+
class DittoProtocol(AdaptiveProtocol, Protocol):
2424
global_model: torch.nn.Module | None
2525

2626
def get_global_model(self, config: Config) -> nn.Module: ...
@@ -50,16 +50,6 @@ def __init_subclass__(cls, **kwargs):
5050
RuntimeWarning,
5151
)
5252

53-
def __init__(self, *args, **kwargs):
54-
# Verify at instance creation time
55-
if not isinstance(self, BasicClientProtocol):
56-
raise TypeError(
57-
f"Class {self.__class__.__name__} uses AdaptiveMixin but does not "
58-
f"implement BasicClientProtocol. Make sure a parent class implements "
59-
f"the required methods and attributes."
60-
)
61-
super().__init__(*args, **kwargs)
62-
6353
@property
6454
def optimizer_keys(self: DittoProtocol) -> list[str]:
6555
"""Returns the optimizer keys."""
@@ -143,6 +133,8 @@ def setup_client(self: DittoProtocol, config: Config) -> None:
143133
Args:
144134
config (Config): The config from the server.
145135
"""
136+
self.ensure_protocol_compliance()
137+
146138
try:
147139
self.global_model = self.get_global_model(config)
148140
log(INFO, f"global model set: {type(self.global_model).__name__}")
@@ -254,6 +246,7 @@ def update_before_train(self: DittoProtocol, current_server_round: int) -> None:
254246
Args:
255247
current_server_round (int): Indicates which server round we are currently executing.
256248
"""
249+
self.ensure_protocol_compliance()
257250
self.set_initial_global_tensors()
258251

259252
# Need to also set the global model to train mode before any training begins.
@@ -479,6 +472,7 @@ def compute_evaluation_loss(
479472
EvaluationLosses: An instance of ``EvaluationLosses`` containing checkpoint loss and additional losses
480473
indexed by name.
481474
"""
475+
self.ensure_protocol_compliance()
482476
# Check that both models are in eval mode
483477
assert not self.global_model is None and not self.global_model.training and not self.model.training
484478
return super().compute_evaluation_loss(preds, features, target)

0 commit comments

Comments
 (0)