Skip to content

Commit 94e734a

Browse files
committed
janky but working
1 parent bbb9d27 commit 94e734a

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

fl4health/clients/nnunet_client.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
import copy
23
import logging
34
import os
45
import pickle
@@ -186,6 +187,7 @@ def __init__(
186187
self.nnunet_trainer_class = nnunet_trainer_class
187188
self.nnunet_trainer_class_kwargs = nnunet_trainer_class_kwargs
188189
self.nnunet_trainer: nnUNetTrainer
190+
189191
self.nnunet_config: NnunetConfig
190192
self.plans: dict[str, Any] | None = None
191193
self.steps_per_round: int # N steps per server round
@@ -227,7 +229,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[Tr
227229

228230
# As in the nnUNetTrainer, we implement mixed precision using torch.autocast and torch.GradScaler
229231
# Clear gradients from optimizer if they exist
230-
self.optimizers["global"].zero_grad()
232+
self.optimizers["local"].zero_grad()
231233

232234
# Call user defined methods to get predictions and compute loss
233235
preds, features = self.predict(input)
@@ -239,11 +241,11 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[Tr
239241
scaled_backward_loss.backward()
240242

241243
# Rescale gradients then clip based on specified norm
242-
self.grad_scaler.unscale_(self.optimizers["global"])
244+
self.grad_scaler.unscale_(self.optimizers["local"])
243245
self.transform_gradients(losses)
244246

245247
# Update parameters and scaler
246-
self.grad_scaler.step(self.optimizers["global"])
248+
self.grad_scaler.step(self.optimizers["local"])
247249
self.grad_scaler.update()
248250

249251
return losses, preds
@@ -314,7 +316,11 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
314316
return train_loader, val_loader
315317

316318
def get_model(self, config: Config) -> nn.Module:
317-
return self.nnunet_trainer.network
319+
for_global = config.get("for_global", False)
320+
if for_global:
321+
return copy.deepcopy(self.nnunet_trainer.network)
322+
else:
323+
return self.nnunet_trainer.network
318324

319325
def get_criterion(self, config: Config) -> _Loss:
320326
if isinstance(self.nnunet_trainer.loss, DeepSupervisionWrapper):
@@ -608,7 +614,6 @@ def setup_client(self, config: Config) -> None:
608614
def _special_predict(
609615
self, model: torch.nn.Module, input: torch.Tensor
610616
) -> tuple[TorchPredType, dict[str, torch.Tensor]]:
611-
model.train()
612617
if isinstance(input, torch.Tensor):
613618
# If device type is cuda, nnUNet defaults to mixed precision forward pass
614619
if self.device.type == "cuda":
@@ -770,8 +775,14 @@ def update_metric_manager(
770775
target (TorchTargetType): the targets generated by the dataloader to evaluate the preds with
771776
metric_manager (MetricManager): the metric manager to update
772777
"""
778+
preds = {k: v for k, v in preds.items() if "local" in k}
779+
# remove prefix
780+
preds = {k.replace(f"local-", ""): v for k, v in preds.items()}
781+
773782
if len(preds) > 1:
774783
# for nnunet the first pred in the output list is the main one
784+
log(DEBUG, f"preds keys: {preds.keys()}")
785+
775786
m_pred = convert_deep_supervision_dict_to_list(preds)[0]
776787

777788
if isinstance(target, torch.Tensor):
@@ -828,7 +839,7 @@ def get_client_specific_logs(
828839
logging_mode: LoggingMode,
829840
) -> tuple[str, list[tuple[LogLevel, str]]]:
830841
if logging_mode == LoggingMode.TRAIN:
831-
lr = float(self.optimizers["global"].param_groups[0]["lr"])
842+
lr = float(self.optimizers["local"].param_groups[0]["lr"])
832843
if current_epoch is None:
833844
# Assume training by steps
834845
return f"Initial LR {lr}", []
@@ -838,7 +849,7 @@ def get_client_specific_logs(
838849
return "", []
839850

840851
def get_client_specific_reports(self) -> dict[str, Any]:
841-
return {"learning_rate": float(self.optimizers["global"].param_groups[0]["lr"])}
852+
return {"learning_rate": float(self.optimizers["local"].param_groups[0]["lr"])}
842853

843854
@use_default_signal_handlers # Experiment planner spawns a process I think
844855
def get_properties(self, config: Config) -> dict[str, Scalar]:
@@ -942,12 +953,13 @@ def update_before_train(self, current_server_round: int) -> None:
942953
# freeze before the first pass, gc.collect has to check all those variables
943954
gc.freeze()
944955

945-
def transform_gradients(self, losses: TrainingLosses) -> None:
956+
def transform_gradients(self, losses: TrainingLosses, model: nn.Module | None = None) -> None:
946957
"""
947958
Apply the gradient clipping performed by the default nnunet trainer. This is the default behavior for
948959
nnunet 2.5.1
949960
950961
Args:
951962
losses (TrainingLosses): Not used for this transformation.
952963
"""
953-
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
964+
model = model if model else self.model
965+
nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)

fl4health/mixins/personalized/ditto.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_global_model(self, config: Config) -> nn.Module:
9393
Returns:
9494
nn.Module: The PyTorch model serving as the global model for Ditto
9595
"""
96+
config["for_global"] = True
9697
return self.get_model(config).to(self.device)
9798

9899
def set_optimizer(self, config: Config) -> None:
@@ -315,7 +316,24 @@ def predict(
315316
# TODO: Perhaps loosen this at a later date.
316317
# assert isinstance(global_preds, torch.Tensor)
317318
# assert isinstance(local_preds, torch.Tensor)
318-
return {"global": global_preds, "local": local_preds}, {}
319+
if isinstance(global_preds, torch.Tensor) and isinstance(local_preds, torch.Tensor):
320+
return {"global": global_preds, "local": local_preds}, {}
321+
elif isinstance(global_preds, dict) and isinstance(local_preds, dict):
322+
retval = {f"global-{k}": v for k, v in global_preds.items()}
323+
retval.update(**{f"local-{k}": v for k, v in local_preds.items()})
324+
return retval, {}
325+
else:
326+
raise ValueError(f"Unsupported pred type: {type(global_preds)}.")
327+
328+
def _extract_pred(self, kind: str, preds: dict[str, torch.Tensor]):
329+
if kind not in ["global", "local"]:
330+
raise ValueError("Unsupported kind of prediction. Must be 'global' or 'local'.")
331+
332+
# filter
333+
retval = {k: v for k, v in preds.items() if kind in k}
334+
# remove prefix
335+
retval = {k.replace(f"{kind}-", ""): v for k, v in retval.items()}
336+
return retval
319337

320338
def compute_loss_and_additional_losses(
321339
self,
@@ -338,20 +356,23 @@ def compute_loss_and_additional_losses(
338356
- A dictionary with ``local_loss``, ``global_loss`` as additionally reported loss values.
339357
"""
340358

359+
global_preds = self._extract_pred(kind="global", preds=preds)
360+
local_preds = self._extract_pred(kind="local", preds=preds)
361+
341362
# Compute global model vanilla loss
342363

343364
if hasattr(self, "_special_compute_loss_and_additional_losses"):
344365
log(INFO, "Using '_special_compute_loss_and_additional_losses' to compute loss")
345-
global_loss, _ = self._special_compute_loss_and_additional_losses(preds["global"], features, target)
366+
global_loss, _ = self._special_compute_loss_and_additional_losses(global_preds, features, target)
346367

347368
# Compute local model loss + ditto constraint term
348-
local_loss, _ = self._special_compute_loss_and_additional_losses(preds["local"], features, target)
369+
local_loss, _ = self._special_compute_loss_and_additional_losses(local_preds, features, target)
349370

350371
else:
351-
global_loss = self.criterion(preds["global"], target)
372+
global_loss = self.criterion(global_preds, target)
352373

353374
# Compute local model loss + ditto constraint term
354-
local_loss = self.criterion(preds["local"], target)
375+
local_loss = self.criterion(local_preds, target)
355376

356377
additional_losses = {"local_loss": local_loss.clone(), "global_loss": global_loss}
357378

0 commit comments

Comments
 (0)