1
1
import gc
2
+ import copy
2
3
import logging
3
4
import os
4
5
import pickle
@@ -186,6 +187,7 @@ def __init__(
186
187
self .nnunet_trainer_class = nnunet_trainer_class
187
188
self .nnunet_trainer_class_kwargs = nnunet_trainer_class_kwargs
188
189
self .nnunet_trainer : nnUNetTrainer
190
+
189
191
self .nnunet_config : NnunetConfig
190
192
self .plans : dict [str , Any ] | None = None
191
193
self .steps_per_round : int # N steps per server round
@@ -227,7 +229,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[Tr
227
229
228
230
# As in the nnUNetTrainer, we implement mixed precision using torch.autocast and torch.GradScaler
229
231
# Clear gradients from optimizer if they exist
230
- self .optimizers ["global " ].zero_grad ()
232
+ self .optimizers ["local " ].zero_grad ()
231
233
232
234
# Call user defined methods to get predictions and compute loss
233
235
preds , features = self .predict (input )
@@ -239,11 +241,11 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[Tr
239
241
scaled_backward_loss .backward ()
240
242
241
243
# Rescale gradients then clip based on specified norm
242
- self .grad_scaler .unscale_ (self .optimizers ["global " ])
244
+ self .grad_scaler .unscale_ (self .optimizers ["local " ])
243
245
self .transform_gradients (losses )
244
246
245
247
# Update parameters and scaler
246
- self .grad_scaler .step (self .optimizers ["global " ])
248
+ self .grad_scaler .step (self .optimizers ["local " ])
247
249
self .grad_scaler .update ()
248
250
249
251
return losses , preds
@@ -314,7 +316,11 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
314
316
return train_loader , val_loader
315
317
316
318
def get_model (self , config : Config ) -> nn .Module :
317
- return self .nnunet_trainer .network
319
+ for_global = config .get ("for_global" , False )
320
+ if for_global :
321
+ return copy .deepcopy (self .nnunet_trainer .network )
322
+ else :
323
+ return self .nnunet_trainer .network
318
324
319
325
def get_criterion (self , config : Config ) -> _Loss :
320
326
if isinstance (self .nnunet_trainer .loss , DeepSupervisionWrapper ):
@@ -608,7 +614,6 @@ def setup_client(self, config: Config) -> None:
608
614
def _special_predict (
609
615
self , model : torch .nn .Module , input : torch .Tensor
610
616
) -> tuple [TorchPredType , dict [str , torch .Tensor ]]:
611
- model .train ()
612
617
if isinstance (input , torch .Tensor ):
613
618
# If device type is cuda, nnUNet defaults to mixed precision forward pass
614
619
if self .device .type == "cuda" :
@@ -770,8 +775,14 @@ def update_metric_manager(
770
775
target (TorchTargetType): the targets generated by the dataloader to evaluate the preds with
771
776
metric_manager (MetricManager): the metric manager to update
772
777
"""
778
+ preds = {k : v for k , v in preds .items () if "local" in k }
779
+ # remove prefix
780
+ preds = {k .replace (f"local-" , "" ): v for k , v in preds .items ()}
781
+
773
782
if len (preds ) > 1 :
774
783
# for nnunet the first pred in the output list is the main one
784
+ log (DEBUG , f"preds keys: { preds .keys ()} " )
785
+
775
786
m_pred = convert_deep_supervision_dict_to_list (preds )[0 ]
776
787
777
788
if isinstance (target , torch .Tensor ):
@@ -828,7 +839,7 @@ def get_client_specific_logs(
828
839
logging_mode : LoggingMode ,
829
840
) -> tuple [str , list [tuple [LogLevel , str ]]]:
830
841
if logging_mode == LoggingMode .TRAIN :
831
- lr = float (self .optimizers ["global " ].param_groups [0 ]["lr" ])
842
+ lr = float (self .optimizers ["local " ].param_groups [0 ]["lr" ])
832
843
if current_epoch is None :
833
844
# Assume training by steps
834
845
return f"Initial LR { lr } " , []
@@ -838,7 +849,7 @@ def get_client_specific_logs(
838
849
return "" , []
839
850
840
851
def get_client_specific_reports (self ) -> dict [str , Any ]:
841
- return {"learning_rate" : float (self .optimizers ["global " ].param_groups [0 ]["lr" ])}
852
+ return {"learning_rate" : float (self .optimizers ["local " ].param_groups [0 ]["lr" ])}
842
853
843
854
@use_default_signal_handlers # Experiment planner spawns a process I think
844
855
def get_properties (self , config : Config ) -> dict [str , Scalar ]:
@@ -942,12 +953,13 @@ def update_before_train(self, current_server_round: int) -> None:
942
953
# freeze before the first pass, gc.collect has to check all those variables
943
954
gc .freeze ()
944
955
945
- def transform_gradients (self , losses : TrainingLosses ) -> None :
956
+ def transform_gradients (self , losses : TrainingLosses , model : nn . Module | None = None ) -> None :
946
957
"""
947
958
Apply the gradient clipping performed by the default nnunet trainer. This is the default behavior for
948
959
nnunet 2.5.1
949
960
950
961
Args:
951
962
losses (TrainingLosses): Not used for this transformation.
952
963
"""
953
- nn .utils .clip_grad_norm_ (self .model .parameters (), self .max_grad_norm )
964
+ model = model if model else self .model
965
+ nn .utils .clip_grad_norm_ (model .parameters (), self .max_grad_norm )
0 commit comments