-
Notifications
You must be signed in to change notification settings - Fork 12
[Feature] Add DittoPersonalizedMixin
#385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@emersodb: as discussed here is the WIP PR for adding the ditto mixin. Will add unit tests to get coverage checks, but in the meantime you can take a look around the PR and add your review. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #385 +/- ##
==========================================
+ Coverage 76.78% 76.86% +0.07%
==========================================
Files 152 155 +3
Lines 9236 9462 +226
==========================================
+ Hits 7092 7273 +181
- Misses 2144 2189 +45 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments throughout. I think some of them might be related to this PR setting up some scaffolding for follow ups, which might have spurred a few of the comments.
DITTO = "ditto" | ||
|
||
|
||
PersonalizedMixinRegistry = {"ditto": DittoPersonalizedMixin} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the keys of this registry be typed as the enum above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, because we inherit from str
as well, this allows us to use PersonalizedMode as a str as well. But, I can change this to the enum, no problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True...I think if we want to limit to the enum types here, it's a little more "rigid" if we use the enum explicitly. If you disagree that's alright.
|
||
|
||
def make_it_personal(client_base_type: type[BasicClient], mode: PersonalizedModes) -> type[BasicClient]: | ||
if mode == "ditto": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outside of it using fewer characters, is there an advantage to using the string here instead of PersonalizedModes.DITTO
? The advantage to the latter, that I can think of, is that, if we were to do a bulk rename of DITTO
to DITTO_PER
(for some reason, who knows), this wouldn't get left behind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main advantage that I have used it for is code reduction, but you can also apply str operations to the enum directly if needed. For our users, I think code reduction is pretty nice, as most are lazy to import an enum when they pass in a str instead. In our lib, we can create validations to ensure the passed in str is a correct member of the enum, which is what we do in FedRAG. And, if I'm not mistaken, if using Pydantic, then I think they've got some nice validations for this as well.
UPDATE:
I just saw mypy errors for this. I have been working so long with Pydantic that I have been spoiled by it! Pydantic's validation allows for passing in str
as the argument for mode without mypy complaining. Without pydantic, we would have to add some extra logic e.g., Union to be able to pass in the str.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, if we handle renames like that, then we should keep in mind to always add some backwards compatibility functionality so that we don't break things like that so easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's fair. I get the argument for code reduction, but I also think the advantage of enums is "forced" adherence to some structure. So I tend to not "bail people out" with the string escape 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhhh i still think we have the forced adherence, but we just provide convenience at the same time. With proper validation we provide convenience to the user (and developer) without losing any of the benefits of using enums in Python?
I'm not married to any sort of way tbh for this library. So, I think I'll just use enum everywhere if that's the convention...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the mode a strict enum!
@@ -0,0 +1,32 @@ | |||
from enum import Enum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it normal to put such things in the __init__.py
? I'm not sure of the co, but this feels kind of weird.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is, why put this in here instead of something like a utility file or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I have seen this adopted in practice, by engineers more experienced than me; and in other libraries. That is, I am treating this as a factory/registry, which I have seen be put in these module-level init files. I think so long as we're not baking in the actual logic here, then we're okay... but i'm also not married to this. If you really prefer to outsource this to some other module, say factory
or registry
, then I'm not picky. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason I think that these factories/registries are created here is because often all of the classes you need are already being re-imported (i.e., and being exposed in the public api "all")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair. I honestly have no idea what best practices apply to __init__.py
files. Seems kind of wild west ha.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea me too... I just go with things I have seen before in libraries i've worked in, or have read source code in, and how past teammates have used __init__.py
.
@@ -0,0 +1,527 @@ | |||
"""Ditto Personalized Mixin""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't been putting stuff like this at the top of files. Is there a reason you like doing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah two reasons:
- I used to have static checkers check for this, so I just go used to doing them.
- Some automatic docs tools, will use this as the header/title in the api reference for this module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see...I guess the later would be a reason to do it. Just kind of seems like "redundant" comments is all 🙂
|
||
@property | ||
def optimizer_keys(self: DittoProtocol) -> list[str]: | ||
"""Returns the optimizer keys.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that this method is so small, maybe this is overly pedantic, but this doesn't quite follow the docs structure we typically use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries. I'm more than happy to conform to our docs structure. Google, and for all methods? Or what's the standard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah google style docs is what we've been using. There's a nice vs code integration that you can setup if you look at the ## Code Documentation
section of the contributing.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I got autodoc string already installed :) and its also set to Google already on my code. Do we doc all methods in this way or just ones that we want to have reference in the API docs, or some other criteria.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been documenting all of them in this way unless they are trivial. However, I'm open to establishing a different norm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trivial in the sense that the method is self-documenting i.e., perhaps by its obvious naming suggesting what its doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. Everyone has their own definition of trivial I suppose, but that's what I meant 😂. I find "self-documenting" is a problematic adjective. Everyone thinks their code is self-documenting ha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure. Okay, I think what I'm hearing is that the standard is just doc string everything. No problem and just good to know.
retval.update(**{f"local-{k}": v for k, v in local_preds.items()}) | ||
return retval, {} | ||
else: | ||
raise ValueError(f"Unsupported pred type: {type(global_preds)}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only reports the global_preds
type. I know the models are essentially copies, but on the off chance the return values are different should we return both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
else: | ||
raise ValueError(f"Unsupported pred type: {type(global_preds)}.") | ||
|
||
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps some docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
|
||
# Compute global model vanilla loss | ||
|
||
if hasattr(self, "_special_compute_loss_and_additional_losses"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the defined anywhere? I'm not 100% sure of its envisioned use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be added in the nnunet_client
. I am thinking to make this into a protocol, but needed because theirs some special business logic in nnunet_client
that required delegation of compute_loss_and_additional_losses
to the underlying class. Though, I do think this is a sensible thing to do in general. The form factor might need polishing, but this shouldn't affect the user api. In the next PR I'll take care of modifying the Nnunet client as such. This PR is already too big.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha
log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss") | ||
global_loss, _ = self._special_compute_loss_and_additional_losses(global_preds, features, target) | ||
|
||
# Compute local model loss + ditto constraint term |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you took this from my docs and its also out of date there too, but we don't actually compute the ditto constraint here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same on line 443 below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay, should be tho? Is this an issue with missing logic or just a mismatch in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a mismatch in the comment. The ditto constraint is added elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. Adjusted the comment by removing the bit on the "+ dittto constraint term"
|
||
# local loss is stored in loss, global model loss is stored in additional losses. | ||
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target) | ||
additional_losses = additional_losses or {} # make mypy happy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the issue for mypy here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think nnunet_client
or the other way around broke the interface for compute_loss_and_additional_losses
? Specifically:
# from nnunet client
def compute_loss_and_additional_losses(
self,
preds: TorchPredType,
features: dict[str, torch.Tensor],
target: TorchTargetType,
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:
# from ditto client
def compute_loss_and_additional_losses(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
I went with nnunet_client
here cause that's what I was trying to make work during development. And so additional_losses
can be None
but we treat it as a dict. So, in the event it is None
I fallback to dict
so that mypy recognizes this as a dict and doesn't complain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm. I see.
6b8eec3
to
9e82518
Compare
|
||
optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")} | ||
assert self.global_model is not None | ||
global_optimizer = OptimClass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, how do we handle this in DittoClient
? My intetnion was to base this mixin as a mirror of that client. Maybe I missed it there, if we're doing something different?
optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")} | ||
assert self.global_model is not None | ||
global_optimizer = OptimClass( | ||
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, the approach I've taken here in this mixin is to basically just copy what the client has, which I thought would be a sensible default/starting place. If we believe the user will really want to customize this, then I think we'll need to consider a good design for this. Question is, do we need this out the gate, or is this approach that I'm taking useful in its current stage.
I think this can be a quick and easy way (despite being restrictive) to ditto-ify an existing client. If a user needs more flexibility, then we should think of optimal designs for this to build off this mixin approach, or introduce some "builder" pattern, or if very esoteric, then we just write in the docs how to create a subclass of DittoClient (i.e lower level approach)
|
||
optimizer_kwargs = {k: v for k, v in param_group.items() if k not in ("params", "lr")} | ||
assert self.global_model is not None | ||
global_optimizer = OptimClass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, in ditto client this is an abstract method. Please see my response to your other comment. The approach this mixin takes leans on quick/simple personalization versus offering full flexibility out the gate.
assert self.global_model is not None | ||
global_optimizer = OptimClass( | ||
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs | ||
) # type:ignore [call-arg] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fl4health/mixins/personalized/ditto.py:137: error: Unexpected keyword argument
"lr" for "Optimizer" [call-arg]
global_optimizer = OptimClass(self.global_model.parameters(), ...
I guess its not included in the base optim class, but some optimizers require it. Example:
torch.optim.AdamW(self.global_model.parameters(), lr=0.01)
assert self.global_model is not None | ||
global_optimizer = OptimClass( | ||
self.global_model.parameters(), lr=param_group["lr"], **optimizer_kwargs | ||
) # type:ignore [call-arg] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could maybe just simplify and keep it in the optimizer_kwargs
.
retval.update(**{f"local-{k}": v for k, v in local_preds.items()}) | ||
return retval, {} | ||
else: | ||
raise ValueError(f"Unsupported pred type: {type(global_preds)}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
else: | ||
raise ValueError(f"Unsupported pred type: {type(global_preds)}.") | ||
|
||
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
|
||
# Compute global model vanilla loss | ||
|
||
if hasattr(self, "_special_compute_loss_and_additional_losses"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be added in the nnunet_client
. I am thinking to make this into a protocol, but needed because theirs some special business logic in nnunet_client
that required delegation of compute_loss_and_additional_losses
to the underlying class. Though, I do think this is a sensible thing to do in general. The form factor might need polishing, but this shouldn't affect the user api. In the next PR I'll take care of modifying the Nnunet client as such. This PR is already too big.
log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss") | ||
global_loss, _ = self._special_compute_loss_and_additional_losses(global_preds, features, target) | ||
|
||
# Compute local model loss + ditto constraint term |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay, should be tho? Is this an issue with missing logic or just a mismatch in the comment?
|
||
# local loss is stored in loss, global model loss is stored in additional losses. | ||
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target) | ||
additional_losses = additional_losses or {} # make mypy happy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think nnunet_client
or the other way around broke the interface for compute_loss_and_additional_losses
? Specifically:
# from nnunet client
def compute_loss_and_additional_losses(
self,
preds: TorchPredType,
features: dict[str, torch.Tensor],
target: TorchTargetType,
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:
# from ditto client
def compute_loss_and_additional_losses(
self,
preds: TorchPredType,
features: TorchFeatureType,
target: TorchTargetType,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
I went with nnunet_client
here cause that's what I was trying to make work during development. And so additional_losses
can be None
but we treat it as a dict. So, in the event it is None
I fallback to dict
so that mypy recognizes this as a dict and doesn't complain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
wooohoo! thanks for the review -- sorry it's a bit nuanced. |
PR Type
Feature
Short Description
Continuation of using mixins to implement pFL methods. This PR adds the
DittoPersonalizedMixin
which subclasses the previously mergedAdaptiveDriftConstrainedMixin
.Tests Added
Describe the tests that have been added to ensure the codes correctness, if applicable.