Skip to content

NNunet with Ditto Example (CU-868d7w56m) #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
10 changes: 6 additions & 4 deletions examples/ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from torch.utils.data import DataLoader

from examples.models.cnn_model import MnistNet
from fl4health.clients.ditto_client import DittoClient
from fl4health.clients.basic_client import BasicClient
from fl4health.mixins.personalized import PersonalizedModes, make_it_personal
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
Expand All @@ -21,7 +22,7 @@
from fl4health.utils.sampler import DirichletLabelBasedSampler


class MnistDittoClient(DittoClient):
class MnistClient(BasicClient):
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
sample_percentage = narrow_dict_type(config, "downsampling_ratio", float)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=sample_percentage, beta=1)
Expand All @@ -34,14 +35,15 @@ def get_model(self, config: Config) -> nn.Module:

def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
# Note that the global optimizer operates on self.global_model.parameters()
global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=0.01)
local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01)
return {"global": global_optimizer, "local": local_optimizer}
return {"local": local_optimizer}

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


MnistDittoClient = make_it_personal(MnistClient, PersonalizedModes.DITTO)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
Expand Down
2 changes: 1 addition & 1 deletion examples/ditto_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
n_server_rounds: 5 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
n_clients: 1 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training

Expand Down
133 changes: 78 additions & 55 deletions examples/nnunet_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from logging import DEBUG, INFO
from os.path import exists, join
from pathlib import Path
from typing import Any, Literal

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
Expand All @@ -21,11 +22,14 @@
from torchmetrics.segmentation import GeneralizedDiceScore

from fl4health.clients.nnunet_client import NnunetClient
from fl4health.mixins.personalized import PersonalizedModes, make_it_personal
from fl4health.utils.load_data import load_msd_dataset
from fl4health.utils.metrics import TorchMetric, TransformsMetric
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels
from fl4health.utils.nnunet_utils import get_segs_from_probs, set_nnunet_env

personalized_client_classes = {"ditto": make_it_personal(NnunetClient, PersonalizedModes.DITTO)}


def main(
dataset_path: Path,
Expand All @@ -37,66 +41,76 @@ def main(
compile: bool = True,
intermediate_client_state_dir: str | None = None,
client_name: str | None = None,
personalized_strategy: Literal["ditto"] | None = None,
) -> None:
# Log device and server address
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Using device: {device}")
log(INFO, f"Using server address: {server_address}")

# Load the dataset if necessary
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
nnUNet_raw = join(dataset_path, "nnunet_raw")
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
load_msd_dataset(nnUNet_raw, msd_dataset_name)

# The dataset ID will be the same as the MSD Task number
dataset_id = int(msd_dataset_enum.value[4:6])
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"

# Convert the msd dataset if necessary
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))

# Create a metric
dice = TransformsMetric(
metric=TorchMetric(
name="Pseudo DICE",
metric=GeneralizedDiceScore(
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
).to(device),
),
pred_transforms=[torch.sigmoid, get_segs_from_probs],
)

if intermediate_client_state_dir is not None:
checkpoint_and_state_module = ClientCheckpointAndStateModule(
state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir))
with torch.autograd.set_detect_anomaly(True):
# Log device and server address
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Using device: {device}")
log(INFO, f"Using server address: {server_address}")

# Load the dataset if necessary
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
nnUNet_raw = join(dataset_path, "nnunet_raw")
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
load_msd_dataset(nnUNet_raw, msd_dataset_name)

# The dataset ID will be the same as the MSD Task number
dataset_id = int(msd_dataset_enum.value[4:6])
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"

# Convert the msd dataset if necessary
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))

# Create a metric
dice = TransformsMetric(
metric=TorchMetric(
name="Pseudo DICE",
metric=GeneralizedDiceScore(
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
).to(device),
),
pred_transforms=[torch.sigmoid, get_segs_from_probs],
)
else:
checkpoint_and_state_module = None

# Create client
client = NnunetClient(
# Args specific to nnUNetClient
dataset_id=dataset_id,
fold=fold,
always_preprocess=always_preprocess,
verbose=verbose,
compile=compile,
# BaseClient Args
device=device,
metrics=[dice],
progress_bar=verbose,
checkpoint_and_state_module=checkpoint_and_state_module,
client_name=client_name,
)
if intermediate_client_state_dir is not None:
checkpoint_and_state_module = ClientCheckpointAndStateModule(
state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir))
)
else:
checkpoint_and_state_module = None

# Create client
client_kwargs: dict[str, Any] = {}
client_kwargs.update(
# Args specific to nnUNetClient
dataset_id=dataset_id,
fold=fold,
always_preprocess=always_preprocess,
verbose=verbose,
compile=compile,
# BaseClient Args
device=device,
metrics=[dice],
progress_bar=verbose,
checkpoint_and_state_module=checkpoint_and_state_module,
client_name=client_name,
)
if personalized_strategy:
log(INFO, f"Setting up client for personalized strategy: {personalized_strategy}")
client = personalized_client_classes[personalized_strategy](**client_kwargs)
else:
log(INFO, "Setting up client without personalization")
client = NnunetClient(**client_kwargs)
log(INFO, f"Using client: {type(client).__name__}")

start_client(server_address=server_address, client=client.to_client())
start_client(server_address=server_address, client=client.to_client())

# Shutdown the client
client.shutdown()
# Shutdown the client
client.shutdown()


if __name__ == "__main__":
Expand Down Expand Up @@ -189,6 +203,14 @@ def main(
help="[OPTIONAL] Name of the client used to name client state checkpoint. \
Defaults to None, in which case a random name is generated for the client",
)
parser.add_argument(
"--personalized-strategy",
type=str,
required=False,
default=None,
help="[OPTIONAL] Personalized strategy to use. Can be 'ditto' or 'mr-mtl' \
Defaults to None, in which no personalized strategy is applied.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -218,4 +240,5 @@ def main(
compile=not args.skip_compile,
intermediate_client_state_dir=args.intermediate_client_state_dir,
client_name=args.client_name,
personalized_strategy=args.personalized_strategy,
)
29 changes: 26 additions & 3 deletions examples/nnunet_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
from fl4health.parameter_exchange.parameter_packer import (
ParameterPackerAdaptiveConstraint,
)
from fl4health.servers.nnunet_server import NnunetServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import make_dict_with_epochs_or_steps
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn

Expand Down Expand Up @@ -73,15 +77,30 @@ def main(
else:
params = None

strategy = FedAvg(
# strategy = FedAvg(
# min_fit_clients=config["n_clients"],
# min_evaluate_clients=config["n_clients"],
# min_available_clients=config["n_clients"],
# on_fit_config_fn=fit_config_fn,
# on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval
# fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
# evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
# initial_parameters=params,
# )

strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=params,
initial_loss_weight=0.1,
adapt_loss_weight=False,
)

state_checkpointer = (
Expand All @@ -93,6 +112,10 @@ def main(
model=None, parameter_exchanger=FullParameterExchanger(), state_checkpointer=state_checkpointer
)

checkpoint_and_state_module.parameter_exchanger = FullParameterExchangerWithPacking( # type:ignore [assignment]
ParameterPackerAdaptiveConstraint()
)

server = NnunetServer(
client_manager=SimpleClientManager(),
fl_config=config,
Expand Down
Loading
Loading