Skip to content

NNunet with Ditto Example (CU-868d7w56m) #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft

Conversation

nerdai
Copy link
Collaborator

@nerdai nerdai commented Mar 31, 2025

PR Type

Feature + Example

Short Description

In working on a new example for Ditto and NNunet, I saw an opportunity to reduce code due to inheritance and promote composition patterns, which in my opinion makes the library more modular and compact.

This PR promotes the usage of a DittoMixin rather than a DittoClient:

  • The mixin approach makes it composable with clients, resulting in the transformation of non-personalized clients into personalized ones much easier. Validations would need to be performed, but the PoC of this appears to be working on our ditto_example

BEFORE

from examples.models.cnn_model import MnistNet
from fl4health.clients.ditto_client import DittoClient
...

class MnistDittoClient(DittoClient):
    def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
        sample_percentage = narrow_dict_type(config, "downsampling_ratio", float)
        sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=sample_percentage, beta=1)
        batch_size = narrow_dict_type(config, "batch_size", int)
        train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
        return train_loader, val_loader

    def get_model(self, config: Config) -> nn.Module:
        return MnistNet().to(self.device)

    def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
        # Note that the global optimizer operates on self.global_model.parameters()
        global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=0.01)
        local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01)
        return {"global": global_optimizer, "local": local_optimizer}

    def get_criterion(self, config: Config) -> _Loss:
        return torch.nn.CrossEntropyLoss()

AFTER

from examples.models.cnn_model import MnistNet
from fl4health.clients.basic_client import BasicClient
from fl4health.mixins.personalized import make_it_personal
...

class MnistClient(BasicClient):
    def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
        sample_percentage = narrow_dict_type(config, "downsampling_ratio", float)
        sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=sample_percentage, beta=1)
        batch_size = narrow_dict_type(config, "batch_size", int)
        train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
        return train_loader, val_loader

    def get_model(self, config: Config) -> nn.Module:
        return MnistNet().to(self.device)

    def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
        # Note that the global optimizer operates on self.global_model.parameters()
        local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01)
        return {"local": local_optimizer}

    def get_criterion(self, config: Config) -> _Loss:
        return torch.nn.CrossEntropyLoss()


MnistDittoClient = make_it_personal(MnistClient, "ditto")

This should mean that taking any of our non-personalized clients, including NnunetClient can be ditto-ify'ed in one line:

NnunetDittoClient = make_it_personal(NnunetClient, "ditto")

OLDER

Example on how to stand up a NNunet with Ditto (pFL technique) FL system. It combines two existing examples:

Ditto Example
NNunet Example

Clickup Ticket(s): Link(s) if applicable.

CU-868d7w56m

Add a short description of what is in this PR.

Tests Added

Describe the tests that have been added to ensure the codes correctness, if applicable.

@vector-admin
Copy link

Copy link

codecov bot commented Mar 31, 2025

Codecov Report

Attention: Patch coverage is 18.48341% with 344 lines in your changes missing coverage. Please review.

Project coverage is 58.17%. Comparing base (ab6c725) to head (223a787).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
fl4health/mixins/personalized/ditto.py 0.00% 187 Missing ⚠️
fl4health/mixins/adaptive_drift_contrained.py 0.00% 77 Missing ⚠️
fl4health/clients/basic_client.py 65.78% 39 Missing ⚠️
fl4health/clients/nnunet_client.py 0.00% 25 Missing ⚠️
fl4health/mixins/personalized/__init__.py 0.00% 10 Missing ⚠️
fl4health/mixins/personalized/base.py 0.00% 5 Missing ⚠️
fl4health/servers/base_server.py 0.00% 1 Missing ⚠️

❌ Your project check has failed because the head coverage (58.17%) is below the target coverage (65.00%). You can increase the head coverage or adjust the target coverage.

❗ There is a different number of reports uploaded between BASE (ab6c725) and HEAD (223a787). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (ab6c725) HEAD (223a787)
2 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #364       +/-   ##
===========================================
- Coverage   75.79%   58.17%   -17.63%     
===========================================
  Files         139      143        +4     
  Lines        8742     9149      +407     
===========================================
- Hits         6626     5322     -1304     
- Misses       2116     3827     +1711     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@nerdai nerdai force-pushed the nerdai/pfl-nnunet branch from 59be90a to bbb9d27 Compare April 11, 2025 14:55
@nerdai nerdai force-pushed the nerdai/pfl-nnunet branch from 01a3e43 to 94e734a Compare April 23, 2025 16:00
@nerdai nerdai force-pushed the nerdai/pfl-nnunet branch from 1ed3aec to 8c1331e Compare April 25, 2025 17:30
@nerdai nerdai force-pushed the nerdai/pfl-nnunet branch from a92a73f to 7edc69b Compare April 25, 2025 17:55
@nerdai nerdai force-pushed the nerdai/pfl-nnunet branch from 1467c1c to 9ad36b8 Compare April 25, 2025 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants