Skip to content

Commit bbb9d27

Browse files
committed
wip
1 parent 71ee464 commit bbb9d27

File tree

4 files changed

+118
-103
lines changed

4 files changed

+118
-103
lines changed

examples/nnunet_example/client.py

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from torchmetrics.segmentation import GeneralizedDiceScore
2323

2424
from fl4health.clients.nnunet_client import NnunetClient
25+
from fl4health.mixins.personalized import make_it_personal
2526
from fl4health.utils.load_data import load_msd_dataset
2627
from fl4health.utils.metrics import TorchMetric, TransformsMetric
2728
from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels
2829
from fl4health.utils.nnunet_utils import get_segs_from_probs, set_nnunet_env
29-
from fl4health.mixins.personalized import make_it_personal
30+
3031

3132
personalized_client_classes = {"ditto": make_it_personal(NnunetClient, "ditto")}
3233

@@ -43,74 +44,75 @@ def main(
4344
client_name: str | None = None,
4445
personalized_strategy: Literal["ditto"] | None = None,
4546
) -> None:
46-
# Log device and server address
47-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48-
log(INFO, f"Using device: {device}")
49-
log(INFO, f"Using server address: {server_address}")
50-
51-
# Load the dataset if necessary
52-
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
53-
nnUNet_raw = join(dataset_path, "nnunet_raw")
54-
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
55-
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
56-
load_msd_dataset(nnUNet_raw, msd_dataset_name)
57-
58-
# The dataset ID will be the same as the MSD Task number
59-
dataset_id = int(msd_dataset_enum.value[4:6])
60-
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"
61-
62-
# Convert the msd dataset if necessary
63-
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
64-
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
65-
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))
66-
67-
# Create a metric
68-
dice = TransformsMetric(
69-
metric=TorchMetric(
70-
name="Pseudo DICE",
71-
metric=GeneralizedDiceScore(
72-
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
73-
).to(device),
74-
),
75-
pred_transforms=[torch.sigmoid, get_segs_from_probs],
76-
)
47+
with torch.autograd.set_detect_anomaly(True):
48+
# Log device and server address
49+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50+
log(INFO, f"Using device: {device}")
51+
log(INFO, f"Using server address: {server_address}")
52+
53+
# Load the dataset if necessary
54+
msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name)
55+
nnUNet_raw = join(dataset_path, "nnunet_raw")
56+
if not exists(join(nnUNet_raw, msd_dataset_enum.value)):
57+
log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset")
58+
load_msd_dataset(nnUNet_raw, msd_dataset_name)
59+
60+
# The dataset ID will be the same as the MSD Task number
61+
dataset_id = int(msd_dataset_enum.value[4:6])
62+
nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}"
63+
64+
# Convert the msd dataset if necessary
65+
if not exists(join(nnUNet_raw, nnunet_dataset_name)):
66+
log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset")
67+
convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value))
68+
69+
# Create a metric
70+
dice = TransformsMetric(
71+
metric=TorchMetric(
72+
name="Pseudo DICE",
73+
metric=GeneralizedDiceScore(
74+
num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False
75+
).to(device),
76+
),
77+
pred_transforms=[torch.sigmoid, get_segs_from_probs],
78+
)
7779

78-
if intermediate_client_state_dir is not None:
79-
checkpoint_and_state_module = ClientCheckpointAndStateModule(
80-
state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir))
80+
if intermediate_client_state_dir is not None:
81+
checkpoint_and_state_module = ClientCheckpointAndStateModule(
82+
state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir))
83+
)
84+
else:
85+
checkpoint_and_state_module = None
86+
87+
# Create client
88+
client_kwargs = {}
89+
client_kwargs.update(
90+
# Args specific to nnUNetClient
91+
dataset_id=dataset_id,
92+
fold=fold,
93+
always_preprocess=always_preprocess,
94+
verbose=verbose,
95+
compile=compile,
96+
# BaseClient Args
97+
device=device,
98+
metrics=[dice],
99+
progress_bar=verbose,
100+
checkpoint_and_state_module=checkpoint_and_state_module,
101+
client_name=client_name,
81102
)
82-
else:
83-
checkpoint_and_state_module = None
84-
85-
# Create client
86-
client_kwargs = {}
87-
client_kwargs.update(
88-
# Args specific to nnUNetClient
89-
dataset_id=dataset_id,
90-
fold=fold,
91-
always_preprocess=always_preprocess,
92-
verbose=verbose,
93-
compile=compile,
94-
# BaseClient Args
95-
device=device,
96-
metrics=[dice],
97-
progress_bar=verbose,
98-
checkpoint_and_state_module=checkpoint_and_state_module,
99-
client_name=client_name,
100-
)
101-
if personalized_strategy:
102-
log(INFO, f"Setting up client for personalized strategy: {personalized_strategy}")
103-
client = personalized_client_classes[personalized_strategy](**client_kwargs)
104-
else:
105-
log(INFO, f"Setting up client without personalization")
106-
client = NnunetClient(**client_kwargs)
107-
log(INFO, f"Using client: {type(client).__name__}")
108-
log(INFO, f"Parameter exchanger: {type(client.parameter_exchanger).__name__}")
109-
110-
start_client(server_address=server_address, client=client.to_client())
111-
112-
# Shutdown the client
113-
client.shutdown()
103+
if personalized_strategy:
104+
log(INFO, f"Setting up client for personalized strategy: {personalized_strategy}")
105+
client = personalized_client_classes[personalized_strategy](**client_kwargs)
106+
else:
107+
log(INFO, f"Setting up client without personalization")
108+
client = NnunetClient(**client_kwargs)
109+
log(INFO, f"Using client: {type(client).__name__}")
110+
log(INFO, f"Parameter exchanger: {type(client.parameter_exchanger).__name__}")
111+
112+
start_client(server_address=server_address, client=client.to_client())
113+
114+
# Shutdown the client
115+
client.shutdown()
114116

115117

116118
if __name__ == "__main__":

examples/nnunet_example/server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
from flwr.common.typing import Config
1212
from flwr.server.client_manager import SimpleClientManager
1313
from flwr.server.strategy import FedAvg
14-
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
1514

1615
from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
1716
from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule
18-
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
1917
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
18+
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
19+
from fl4health.parameter_exchange.parameter_packer import (
20+
ParameterPackerAdaptiveConstraint,
21+
)
2022
from fl4health.servers.nnunet_server import NnunetServer
23+
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
2124
from fl4health.utils.config import make_dict_with_epochs_or_steps
2225
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
2326
from fl4health.utils.parameter_extraction import get_all_model_parameters
24-
from fl4health.parameter_exchange.parameter_packer import (
25-
ParameterPackerAdaptiveConstraint,
26-
)
2727

2828

2929
def get_config(

fl4health/clients/nnunet_client.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,10 @@ def setup_client(self, config: Config) -> None:
605605
# We have to call parent method after setting up nnunet trainer
606606
super().setup_client(config)
607607

608-
def _special_predict(self, model, input) -> tuple[TorchPredType, dict[str, torch.Tensor]]:
608+
def _special_predict(
609+
self, model: torch.nn.Module, input: torch.Tensor
610+
) -> tuple[TorchPredType, dict[str, torch.Tensor]]:
611+
model.train()
609612
if isinstance(input, torch.Tensor):
610613
# If device type is cuda, nnUNet defaults to mixed precision forward pass
611614
if self.device.type == "cuda":
@@ -643,29 +646,8 @@ def predict(self, input: TorchInputType) -> tuple[TorchPredType, dict[str, torch
643646
name. The second element is unused by this subclass and therefore is always an empty dict
644647
"""
645648
return self._special_predict(self.model, input)
646-
# if isinstance(input, torch.Tensor):
647-
# # If device type is cuda, nnUNet defaults to mixed precision forward pass
648-
# if self.device.type == "cuda":
649-
# with torch.autocast(self.device.type, enabled=True):
650-
# output = self.model(input)
651-
# else:
652-
# output = self.model(input)
653-
# else:
654-
# raise TypeError('"input" must be of type torch.Tensor for nnUNetClient')
655-
656-
# if isinstance(output, torch.Tensor):
657-
# return {"prediction": output}, {}
658-
# # If output is a list or tuple then deep supervision is on and we need to convert preds into a dict
659-
# elif isinstance(output, (list, tuple)):
660-
# num_spatial_dims = NNUNET_N_SPATIAL_DIMS[self.nnunet_config]
661-
# preds = convert_deep_supervision_list_to_dict(output, num_spatial_dims)
662-
# return preds, {}
663-
# else:
664-
# raise TypeError(
665-
# "Was expecting nnunet model output to be either a torch.Tensor or a list/tuple of torch.Tensors"
666-
# )
667649

668-
def compute_loss_and_additional_losses(
650+
def _special_compute_loss_and_additional_losses(
669651
self,
670652
preds: TorchPredType,
671653
features: dict[str, torch.Tensor],
@@ -688,6 +670,7 @@ def compute_loss_and_additional_losses(
688670
# If deep supervision is turned on we must convert loss and target dicts into lists
689671
loss_preds = prepare_loss_arg(preds)
690672
loss_targets = prepare_loss_arg(target)
673+
log(DEBUG, f"Prepared loss_preds: {type(loss_preds)}, loss_targets: {type(loss_targets)}")
691674

692675
# Ensure we have the same number of predictions and targets
693676
assert isinstance(
@@ -709,6 +692,28 @@ def compute_loss_and_additional_losses(
709692

710693
return loss
711694

695+
def compute_loss_and_additional_losses(
696+
self,
697+
preds: TorchPredType,
698+
features: dict[str, torch.Tensor],
699+
target: TorchTargetType,
700+
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:
701+
"""
702+
Checks the pred and target types and computes the loss. If device type is cuda, loss computed in mixed
703+
precision.
704+
705+
Args:
706+
preds (TorchPredType): Dictionary of model output tensors indexed by name
707+
features (dict[str, torch.Tensor]): Not used by this subclass
708+
target (TorchTargetType): The targets to evaluate the predictions with. If multiple prediction tensors
709+
are given, target must be a dictionary with the same number of tensors
710+
711+
Returns:
712+
tuple[torch.Tensor, dict[str, torch.Tensor] | None]: A tuple where the first element is the loss and the
713+
second element is an optional additional loss
714+
"""
715+
return self._special_compute_loss_and_additional_losses(preds, features, target)
716+
712717
def mask_data(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
713718
"""
714719
Masks the pred and target tensors according to nnunet ``ignore_label``. The number of classes in the input

fl4health/mixins/personalized/ditto.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Ditto Personalized Mixin"""
22

33
from abc import ABC, abstractmethod
4-
from logging import INFO
4+
from logging import INFO, DEBUG
55
from typing import cast
66

77
import torch
@@ -297,8 +297,9 @@ def predict(
297297

298298
if hasattr(self, "_special_predict"):
299299
log(INFO, "Using '_special_predict' to make predictions")
300-
global_preds = self._special_predict(self.global_model, input)
301-
local_preds = self._special_predict(self.model, input)
300+
global_preds, _ = self._special_predict(self.global_model, input)
301+
local_preds, _ = self._special_predict(self.model, input)
302+
log(INFO, f"Successfully predicted for global and local models")
302303
else:
303304
if isinstance(input, torch.Tensor):
304305
global_preds = self.global_model(input)
@@ -338,12 +339,19 @@ def compute_loss_and_additional_losses(
338339
"""
339340

340341
# Compute global model vanilla loss
341-
assert "global" in preds
342-
global_loss = self.criterion(preds["global"], target)
343342

344-
# Compute local model loss + ditto constraint term
345-
assert "local" in preds
346-
local_loss = self.criterion(preds["local"], target)
343+
if hasattr(self, "_special_compute_loss_and_additional_losses"):
344+
log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss")
345+
global_loss, _ = self._special_compute_loss_and_additional_losses(preds["global"], features, target)
346+
347+
# Compute local model loss + ditto constraint term
348+
local_loss, _ = self._special_compute_loss_and_additional_losses(preds["local"], features, target)
349+
350+
else:
351+
global_loss = self.criterion(preds["global"], target)
352+
353+
# Compute local model loss + ditto constraint term
354+
local_loss = self.criterion(preds["local"], target)
347355

348356
additional_losses = {"local_loss": local_loss.clone(), "global_loss": global_loss}
349357

0 commit comments

Comments
 (0)