Skip to content

[Feature] Add DittoPersonalizedMixin #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 16, 2025
Merged

[Feature] Add DittoPersonalizedMixin #385

merged 35 commits into from
May 16, 2025

Conversation

nerdai
Copy link
Collaborator

@nerdai nerdai commented May 8, 2025

PR Type

Feature

Short Description

Continuation of using mixins to implement pFL methods. This PR adds the DittoPersonalizedMixin which subclasses the previously merged AdaptiveDriftConstrainedMixin.

Tests Added

Describe the tests that have been added to ensure the codes correctness, if applicable.

@nerdai nerdai requested a review from emersodb May 8, 2025 14:12
@nerdai
Copy link
Collaborator Author

nerdai commented May 8, 2025

@emersodb: as discussed here is the WIP PR for adding the ditto mixin. Will add unit tests to get coverage checks, but in the meantime you can take a look around the PR and add your review.

Copy link

codecov bot commented May 8, 2025

Codecov Report

Attention: Patch coverage is 80.17621% with 45 lines in your changes missing coverage. Please review.

Project coverage is 76.86%. Comparing base (87d8638) to head (137c9c6).
Report is 36 commits behind head on main.

Files with missing lines Patch % Lines
fl4health/mixins/personalized/ditto.py 77.15% 45 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #385      +/-   ##
==========================================
+ Coverage   76.78%   76.86%   +0.07%     
==========================================
  Files         152      155       +3     
  Lines        9236     9462     +226     
==========================================
+ Hits         7092     7273     +181     
- Misses       2144     2189      +45     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments throughout. I think some of them might be related to this PR setting up some scaffolding for follow ups, which might have spurred a few of the comments.

DITTO = "ditto"


PersonalizedMixinRegistry = {"ditto": DittoPersonalizedMixin}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the keys of this registry be typed as the enum above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, because we inherit from str as well, this allows us to use PersonalizedMode as a str as well. But, I can change this to the enum, no problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True...I think if we want to limit to the enum types here, it's a little more "rigid" if we use the enum explicitly. If you disagree that's alright.



def make_it_personal(client_base_type: type[BasicClient], mode: PersonalizedModes) -> type[BasicClient]:
if mode == "ditto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of it using fewer characters, is there an advantage to using the string here instead of PersonalizedModes.DITTO? The advantage to the latter, that I can think of, is that, if we were to do a bulk rename of DITTO to DITTO_PER (for some reason, who knows), this wouldn't get left behind.

Copy link
Collaborator Author

@nerdai nerdai May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main advantage that I have used it for is code reduction, but you can also apply str operations to the enum directly if needed. For our users, I think code reduction is pretty nice, as most are lazy to import an enum when they pass in a str instead. In our lib, we can create validations to ensure the passed in str is a correct member of the enum, which is what we do in FedRAG. And, if I'm not mistaken, if using Pydantic, then I think they've got some nice validations for this as well.

UPDATE:

I just saw mypy errors for this. I have been working so long with Pydantic that I have been spoiled by it! Pydantic's validation allows for passing in str as the argument for mode without mypy complaining. Without pydantic, we would have to add some extra logic e.g., Union to be able to pass in the str.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, if we handle renames like that, then we should keep in mind to always add some backwards compatibility functionality so that we don't break things like that so easily.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair. I get the argument for code reduction, but I also think the advantage of enums is "forced" adherence to some structure. So I tend to not "bail people out" with the string escape 😂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhhh i still think we have the forced adherence, but we just provide convenience at the same time. With proper validation we provide convenience to the user (and developer) without losing any of the benefits of using enums in Python?

I'm not married to any sort of way tbh for this library. So, I think I'll just use enum everywhere if that's the convention...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the mode a strict enum!

@@ -0,0 +1,32 @@
from enum import Enum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal to put such things in the __init__.py? I'm not sure of the co, but this feels kind of weird.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, why put this in here instead of something like a utility file or something else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have seen this adopted in practice, by engineers more experienced than me; and in other libraries. That is, I am treating this as a factory/registry, which I have seen be put in these module-level init files. I think so long as we're not baking in the actual logic here, then we're okay... but i'm also not married to this. If you really prefer to outsource this to some other module, say factory or registry, then I'm not picky. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I think that these factories/registries are created here is because often all of the classes you need are already being re-imported (i.e., and being exposed in the public api "all")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. I honestly have no idea what best practices apply to __init__.py files. Seems kind of wild west ha.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea me too... I just go with things I have seen before in libraries i've worked in, or have read source code in, and how past teammates have used __init__.py.

@@ -0,0 +1,527 @@
"""Ditto Personalized Mixin"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't been putting stuff like this at the top of files. Is there a reason you like doing it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah two reasons:

  1. I used to have static checkers check for this, so I just go used to doing them.
  2. Some automatic docs tools, will use this as the header/title in the api reference for this module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see...I guess the later would be a reason to do it. Just kind of seems like "redundant" comments is all 🙂


@property
def optimizer_keys(self: DittoProtocol) -> list[str]:
"""Returns the optimizer keys."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this method is so small, maybe this is overly pedantic, but this doesn't quite follow the docs structure we typically use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. I'm more than happy to conform to our docs structure. Google, and for all methods? Or what's the standard?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah google style docs is what we've been using. There's a nice vs code integration that you can setup if you look at the ## Code Documentation section of the contributing.md

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I got autodoc string already installed :) and its also set to Google already on my code. Do we doc all methods in this way or just ones that we want to have reference in the API docs, or some other criteria.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been documenting all of them in this way unless they are trivial. However, I'm open to establishing a different norm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trivial in the sense that the method is self-documenting i.e., perhaps by its obvious naming suggesting what its doing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Everyone has their own definition of trivial I suppose, but that's what I meant 😂. I find "self-documenting" is a problematic adjective. Everyone thinks their code is self-documenting ha

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure. Okay, I think what I'm hearing is that the standard is just doc string everything. No problem and just good to know.

retval.update(**{f"local-{k}": v for k, v in local_preds.items()})
return retval, {}
else:
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only reports the global_preds type. I know the models are essentially copies, but on the off chance the return values are different should we return both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

else:
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")

def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps some docs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


# Compute global model vanilla loss

if hasattr(self, "_special_compute_loss_and_additional_losses"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the defined anywhere? I'm not 100% sure of its envisioned use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be added in the nnunet_client. I am thinking to make this into a protocol, but needed because theirs some special business logic in nnunet_client that required delegation of compute_loss_and_additional_losses to the underlying class. Though, I do think this is a sensible thing to do in general. The form factor might need polishing, but this shouldn't affect the user api. In the next PR I'll take care of modifying the Nnunet client as such. This PR is already too big.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha

log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss")
global_loss, _ = self._special_compute_loss_and_additional_losses(global_preds, features, target)

# Compute local model loss + ditto constraint term
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you took this from my docs and its also out of date there too, but we don't actually compute the ditto constraint here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same on line 443 below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, should be tho? Is this an issue with missing logic or just a mismatch in the comment?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a mismatch in the comment. The ditto constraint is added elsewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Adjusted the comment by removing the bit on the "+ dittto constraint term"


# local loss is stored in loss, global model loss is stored in additional losses.
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target)
additional_losses = additional_losses or {} # make mypy happy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue for mypy here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nnunet_client or the other way around broke the interface for compute_loss_and_additional_losses? Specifically:

# from nnunet client
    def compute_loss_and_additional_losses(
        self,
        preds: TorchPredType,
        features: dict[str, torch.Tensor],
        target: TorchTargetType,
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:

# from ditto client
    def compute_loss_and_additional_losses(
        self,
        preds: TorchPredType,
        features: TorchFeatureType,
        target: TorchTargetType,
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:

I went with nnunet_client here cause that's what I was trying to make work during development. And so additional_losses can be None but we treat it as a dict. So, in the event it is None I fallback to dict so that mypy recognizes this as a dict and doesn't complain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm. I see.

@nerdai nerdai force-pushed the nerdai/ditto-mixin branch from 6b8eec3 to 9e82518 Compare May 15, 2025 03:20
@nerdai nerdai marked this pull request as ready for review May 16, 2025 17:58
@VectorInstitute VectorInstitute deleted a comment from emersodb May 16, 2025

optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")}
assert self.global_model is not None
global_optimizer = OptimClass(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how do we handle this in DittoClient? My intetnion was to base this mixin as a mirror of that client. Maybe I missed it there, if we're doing something different?

optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")}
assert self.global_model is not None
global_optimizer = OptimClass(
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, the approach I've taken here in this mixin is to basically just copy what the client has, which I thought would be a sensible default/starting place. If we believe the user will really want to customize this, then I think we'll need to consider a good design for this. Question is, do we need this out the gate, or is this approach that I'm taking useful in its current stage.

I think this can be a quick and easy way (despite being restrictive) to ditto-ify an existing client. If a user needs more flexibility, then we should think of optimal designs for this to build off this mixin approach, or introduce some "builder" pattern, or if very esoteric, then we just write in the docs how to create a subclass of DittoClient (i.e lower level approach)


optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")}
assert self.global_model is not None
global_optimizer = OptimClass(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in ditto client this is an abstract method. Please see my response to your other comment. The approach this mixin takes leans on quick/simple personalization versus offering full flexibility out the gate.

assert self.global_model is not None
global_optimizer = OptimClass(
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs
) # type:ignore [call-arg]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fl4health/mixins/personalized/ditto.py:137: error: Unexpected keyword argument
"lr" for "Optimizer"  [call-arg]
            global_optimizer = OptimClass(self.global_model.parameters(), ...

I guess its not included in the base optim class, but some optimizers require it. Example:

torch.optim.AdamW(self.global_model.parameters(), lr=0.01)

assert self.global_model is not None
global_optimizer = OptimClass(
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs
) # type:ignore [call-arg]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe just simplify and keep it in the optimizer_kwargs.

retval.update(**{f"local-{k}": v for k, v in local_preds.items()})
return retval, {}
else:
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

else:
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")

def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


# Compute global model vanilla loss

if hasattr(self, "_special_compute_loss_and_additional_losses"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be added in the nnunet_client. I am thinking to make this into a protocol, but needed because theirs some special business logic in nnunet_client that required delegation of compute_loss_and_additional_losses to the underlying class. Though, I do think this is a sensible thing to do in general. The form factor might need polishing, but this shouldn't affect the user api. In the next PR I'll take care of modifying the Nnunet client as such. This PR is already too big.

log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss")
global_loss, _ = self._special_compute_loss_and_additional_losses(global_preds, features, target)

# Compute local model loss + ditto constraint term
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, should be tho? Is this an issue with missing logic or just a mismatch in the comment?


# local loss is stored in loss, global model loss is stored in additional losses.
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target)
additional_losses = additional_losses or {} # make mypy happy
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nnunet_client or the other way around broke the interface for compute_loss_and_additional_losses? Specifically:

# from nnunet client
    def compute_loss_and_additional_losses(
        self,
        preds: TorchPredType,
        features: dict[str, torch.Tensor],
        target: TorchTargetType,
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:

# from ditto client
    def compute_loss_and_additional_losses(
        self,
        preds: TorchPredType,
        features: TorchFeatureType,
        target: TorchTargetType,
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:

I went with nnunet_client here cause that's what I was trying to make work during development. And so additional_losses can be None but we treat it as a dict. So, in the event it is None I fallback to dict so that mypy recognizes this as a dict and doesn't complain.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@nerdai
Copy link
Collaborator Author

nerdai commented May 16, 2025

LGTM!

wooohoo! thanks for the review -- sorry it's a bit nuanced.

@nerdai nerdai merged commit d91edb3 into main May 16, 2025
10 checks passed
@nerdai nerdai deleted the nerdai/ditto-mixin branch May 16, 2025 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants