Skip to content

[Feature] Add DittoPersonalizedMixin #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0765492
add ditto mixin
nerdai May 8, 2025
c811fbc
reimports and public api for mixins.personalized
nerdai May 8, 2025
2d22847
don't cover protocols
nerdai May 15, 2025
9e82518
tests for ensure_protocol_compliance and add unit tests
nerdai May 15, 2025
0894136
unit test for make_it_personal factory method
nerdai May 15, 2025
bfe1dda
wip
nerdai May 15, 2025
1f00d95
don't cover protocols
nerdai May 15, 2025
789cb1b
test raise no warnings
nerdai May 15, 2025
5e64570
test get params
nerdai May 15, 2025
26553fa
test get params uninitialized
nerdai May 15, 2025
066ae4a
test set params
nerdai May 15, 2025
f367007
test get optimizers
nerdai May 15, 2025
097d82f
predict unit test
nerdai May 15, 2025
19b9557
test extract pred
nerdai May 15, 2025
7975cc5
test update before train
nerdai May 15, 2025
30a895c
test predict private delegation
nerdai May 15, 2025
85ee409
wip
nerdai May 15, 2025
6c02563
unit test train step
nerdai May 16, 2025
68c9da7
update tornado
nerdai May 16, 2025
2b578f4
cr
nerdai May 16, 2025
6556076
cr
nerdai May 16, 2025
acb279c
cr
nerdai May 16, 2025
7df26eb
cr
nerdai May 16, 2025
e3a86d6
cr
nerdai May 16, 2025
b49e4e0
cr
nerdai May 16, 2025
bb63f1a
cr
nerdai May 16, 2025
2032e14
cr
nerdai May 16, 2025
71c41bf
cr
nerdai May 16, 2025
cefdd82
cr
nerdai May 16, 2025
8a6ac3e
cr
nerdai May 16, 2025
496d14b
cr
nerdai May 16, 2025
eeb9096
name change to PersonalizedMode
nerdai May 16, 2025
58c9244
cr
nerdai May 16, 2025
e124499
cr
nerdai May 16, 2025
137c9c6
cr
nerdai May 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion fl4health/mixins/core_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType
from fl4health.utils.losses import EvaluationLosses, TrainingLosses
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType


@runtime_checkable
Expand Down Expand Up @@ -68,3 +69,35 @@ class BasicClientProtocol(BasicClientProtocolPreSetup, Protocol):
train_loader: DataLoader
val_loader: DataLoader
test_loader: DataLoader | None
criterion: _Loss

def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> None:
pass # pragma: no cover

def update_before_train(self, current_server_round: int) -> None:
pass # pragma: no cover

def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]:
pass # pragma: no cover

def transform_target(self, target: TorchTargetType) -> TorchTargetType:
pass # pragma: no cover

def compute_training_loss(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> TrainingLosses:
pass # pragma: no cover

def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]:
pass # pragma: no cover

def compute_evaluation_loss(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> EvaluationLosses:
pass # pragma: no cover
39 changes: 39 additions & 0 deletions fl4health/mixins/personalized/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from enum import Enum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal to put such things in the __init__.py? I'm not sure of the co, but this feels kind of weird.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, why put this in here instead of something like a utility file or something else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have seen this adopted in practice, by engineers more experienced than me; and in other libraries. That is, I am treating this as a factory/registry, which I have seen be put in these module-level init files. I think so long as we're not baking in the actual logic here, then we're okay... but i'm also not married to this. If you really prefer to outsource this to some other module, say factory or registry, then I'm not picky. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I think that these factories/registries are created here is because often all of the classes you need are already being re-imported (i.e., and being exposed in the public api "all")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. I honestly have no idea what best practices apply to __init__.py files. Seems kind of wild west ha.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea me too... I just go with things I have seen before in libraries i've worked in, or have read source code in, and how past teammates have used __init__.py.


from fl4health.clients.basic_client import BasicClient
from fl4health.mixins.personalized.ditto import DittoPersonalizedMixin, DittoPersonalizedProtocol


class PersonalizedMode(Enum):
DITTO = "ditto"


PersonalizedMixinRegistry = {PersonalizedMode.DITTO: DittoPersonalizedMixin}


def make_it_personal(client_base_type: type[BasicClient], mode: PersonalizedMode) -> type[BasicClient]:
"""A mixed class factory for converting basic clients to personalized versions."""
if mode == PersonalizedMode.DITTO:

return type(
f"Ditto{client_base_type.__name__}",
(
PersonalizedMixinRegistry[mode],
client_base_type,
),
{
# Special flag to bypass validation
"_dynamically_created": True
},
)
else:
raise ValueError("Unrecognized personalized mode.")


__all__ = [
"DittoPersonalizedMixin",
"DittoPersonalizedProtocol",
"PersonalizedMode",
"PersonalizedMixinRegistry",
"make_it_personal",
]
Loading
Loading