22
22
from torchmetrics .segmentation import GeneralizedDiceScore
23
23
24
24
from fl4health .clients .nnunet_client import NnunetClient
25
+ from fl4health .mixins .personalized import make_it_personal
25
26
from fl4health .utils .load_data import load_msd_dataset
26
27
from fl4health .utils .metrics import TorchMetric , TransformsMetric
27
28
from fl4health .utils .msd_dataset_sources import get_msd_dataset_enum , msd_num_labels
28
29
from fl4health .utils .nnunet_utils import get_segs_from_probs , set_nnunet_env
29
- from fl4health . mixins . personalized import make_it_personal
30
+
30
31
31
32
personalized_client_classes = {"ditto" : make_it_personal (NnunetClient , "ditto" )}
32
33
@@ -43,74 +44,75 @@ def main(
43
44
client_name : str | None = None ,
44
45
personalized_strategy : Literal ["ditto" ] | None = None ,
45
46
) -> None :
46
- # Log device and server address
47
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
48
- log (INFO , f"Using device: { device } " )
49
- log (INFO , f"Using server address: { server_address } " )
50
-
51
- # Load the dataset if necessary
52
- msd_dataset_enum = get_msd_dataset_enum (msd_dataset_name )
53
- nnUNet_raw = join (dataset_path , "nnunet_raw" )
54
- if not exists (join (nnUNet_raw , msd_dataset_enum .value )):
55
- log (INFO , f"Downloading and extracting { msd_dataset_enum .value } dataset" )
56
- load_msd_dataset (nnUNet_raw , msd_dataset_name )
57
-
58
- # The dataset ID will be the same as the MSD Task number
59
- dataset_id = int (msd_dataset_enum .value [4 :6 ])
60
- nnunet_dataset_name = f"Dataset{ dataset_id :03d} _{ msd_dataset_enum .value .split ('_' )[1 ]} "
61
-
62
- # Convert the msd dataset if necessary
63
- if not exists (join (nnUNet_raw , nnunet_dataset_name )):
64
- log (INFO , f"Converting { msd_dataset_enum .value } into nnunet dataset" )
65
- convert_msd_dataset (source_folder = join (nnUNet_raw , msd_dataset_enum .value ))
66
-
67
- # Create a metric
68
- dice = TransformsMetric (
69
- metric = TorchMetric (
70
- name = "Pseudo DICE" ,
71
- metric = GeneralizedDiceScore (
72
- num_classes = msd_num_labels [msd_dataset_enum ], weight_type = "square" , include_background = False
73
- ).to (device ),
74
- ),
75
- pred_transforms = [torch .sigmoid , get_segs_from_probs ],
76
- )
47
+ with torch .autograd .set_detect_anomaly (True ):
48
+ # Log device and server address
49
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
50
+ log (INFO , f"Using device: { device } " )
51
+ log (INFO , f"Using server address: { server_address } " )
52
+
53
+ # Load the dataset if necessary
54
+ msd_dataset_enum = get_msd_dataset_enum (msd_dataset_name )
55
+ nnUNet_raw = join (dataset_path , "nnunet_raw" )
56
+ if not exists (join (nnUNet_raw , msd_dataset_enum .value )):
57
+ log (INFO , f"Downloading and extracting { msd_dataset_enum .value } dataset" )
58
+ load_msd_dataset (nnUNet_raw , msd_dataset_name )
59
+
60
+ # The dataset ID will be the same as the MSD Task number
61
+ dataset_id = int (msd_dataset_enum .value [4 :6 ])
62
+ nnunet_dataset_name = f"Dataset{ dataset_id :03d} _{ msd_dataset_enum .value .split ('_' )[1 ]} "
63
+
64
+ # Convert the msd dataset if necessary
65
+ if not exists (join (nnUNet_raw , nnunet_dataset_name )):
66
+ log (INFO , f"Converting { msd_dataset_enum .value } into nnunet dataset" )
67
+ convert_msd_dataset (source_folder = join (nnUNet_raw , msd_dataset_enum .value ))
68
+
69
+ # Create a metric
70
+ dice = TransformsMetric (
71
+ metric = TorchMetric (
72
+ name = "Pseudo DICE" ,
73
+ metric = GeneralizedDiceScore (
74
+ num_classes = msd_num_labels [msd_dataset_enum ], weight_type = "square" , include_background = False
75
+ ).to (device ),
76
+ ),
77
+ pred_transforms = [torch .sigmoid , get_segs_from_probs ],
78
+ )
77
79
78
- if intermediate_client_state_dir is not None :
79
- checkpoint_and_state_module = ClientCheckpointAndStateModule (
80
- state_checkpointer = PerRoundStateCheckpointer (Path (intermediate_client_state_dir ))
80
+ if intermediate_client_state_dir is not None :
81
+ checkpoint_and_state_module = ClientCheckpointAndStateModule (
82
+ state_checkpointer = PerRoundStateCheckpointer (Path (intermediate_client_state_dir ))
83
+ )
84
+ else :
85
+ checkpoint_and_state_module = None
86
+
87
+ # Create client
88
+ client_kwargs = {}
89
+ client_kwargs .update (
90
+ # Args specific to nnUNetClient
91
+ dataset_id = dataset_id ,
92
+ fold = fold ,
93
+ always_preprocess = always_preprocess ,
94
+ verbose = verbose ,
95
+ compile = compile ,
96
+ # BaseClient Args
97
+ device = device ,
98
+ metrics = [dice ],
99
+ progress_bar = verbose ,
100
+ checkpoint_and_state_module = checkpoint_and_state_module ,
101
+ client_name = client_name ,
81
102
)
82
- else :
83
- checkpoint_and_state_module = None
84
-
85
- # Create client
86
- client_kwargs = {}
87
- client_kwargs .update (
88
- # Args specific to nnUNetClient
89
- dataset_id = dataset_id ,
90
- fold = fold ,
91
- always_preprocess = always_preprocess ,
92
- verbose = verbose ,
93
- compile = compile ,
94
- # BaseClient Args
95
- device = device ,
96
- metrics = [dice ],
97
- progress_bar = verbose ,
98
- checkpoint_and_state_module = checkpoint_and_state_module ,
99
- client_name = client_name ,
100
- )
101
- if personalized_strategy :
102
- log (INFO , f"Setting up client for personalized strategy: { personalized_strategy } " )
103
- client = personalized_client_classes [personalized_strategy ](** client_kwargs )
104
- else :
105
- log (INFO , f"Setting up client without personalization" )
106
- client = NnunetClient (** client_kwargs )
107
- log (INFO , f"Using client: { type (client ).__name__ } " )
108
- log (INFO , f"Parameter exchanger: { type (client .parameter_exchanger ).__name__ } " )
109
-
110
- start_client (server_address = server_address , client = client .to_client ())
111
-
112
- # Shutdown the client
113
- client .shutdown ()
103
+ if personalized_strategy :
104
+ log (INFO , f"Setting up client for personalized strategy: { personalized_strategy } " )
105
+ client = personalized_client_classes [personalized_strategy ](** client_kwargs )
106
+ else :
107
+ log (INFO , f"Setting up client without personalization" )
108
+ client = NnunetClient (** client_kwargs )
109
+ log (INFO , f"Using client: { type (client ).__name__ } " )
110
+ log (INFO , f"Parameter exchanger: { type (client .parameter_exchanger ).__name__ } " )
111
+
112
+ start_client (server_address = server_address , client = client .to_client ())
113
+
114
+ # Shutdown the client
115
+ client .shutdown ()
114
116
115
117
116
118
if __name__ == "__main__" :
0 commit comments