@@ -57,6 +57,8 @@ def __init__(
5757 self .test_output_path = test_output_path
5858 self .optimizer_momentum = optimizer_momentum
5959 self .optimizer_weight_decay = optimizer_weight_decay
60+ self .learning_rates = learning_rates
61+ self .dimensions = dimensions
6062
6163 for param in self .model .parameters ():
6264 param .requires_grad = False
@@ -67,20 +69,6 @@ def __init__(
6769 else :
6870 self .global_pool = nn .AdaptiveAvgPool2d ((1 , 1 ))
6971
70- try :
71- # Asparagus by default expects decoder head here
72- feature_dim = self .model .decoder .fc .in_features
73- except AttributeError as e :
74- # Our MAE model has different layer name
75- logging .warning (f"`self.model.decoder.fc.in_features` raised { e } , falling back to `self.model.head.in_features`." )
76- feature_dim = self .model .head .in_features
77-
78- self .heads = nn .ModuleDict ()
79- for lr in learning_rates :
80- head_name = self ._lr_to_linear_head_name (lr )
81- head = self ._make_head (feature_dim , num_classes )
82- self .heads [head_name ] = head
83-
8472 self .loss_fn = nn .CrossEntropyLoss (weight = torch .Tensor (loss_weight ) if loss_weight else None )
8573
8674 self .train_metrics = self .configure_metrics ("train" )
@@ -89,25 +77,37 @@ def __init__(
8977 # Test metrics (only for best head)
9078 self .test_metrics = self .configure_test_metrics ()
9179
80+ self .heads = None
9281 self .best_head_lr = None
9382 self .ignore_index_in_metrics = - 1
9483
95- def _make_head (self , feature_dim : int , num_classes : int ) -> nn .Module :
84+ def configure_model (self ):
85+ if self .heads is None :
86+ tmp_arr = torch .zeros ((1 , 1 , 32 , 32 , 32 )) if self .dimensions == "3D" else torch .zeros ((1 , 1 , 32 , 32 ))
87+ feature_dim = self .get_features (tmp_arr ).view (- 1 ).size (0 )
88+ self .heads = nn .ModuleDict ()
89+
90+ for lr in self .learning_rates :
91+ head_name = self .lr_to_linear_head_name (lr )
92+ head = self .make_head (feature_dim , self .num_classes )
93+ self .heads [head_name ] = head
94+
95+ def make_head (self , feature_dim : int , num_classes : int ) -> nn .Module :
9696 head = nn .Linear (feature_dim , num_classes )
9797 nn .init .normal_ (head .weight , mean = 0.0 , std = 0.01 )
9898 nn .init .zeros_ (head .bias )
9999 return head
100100
101101 @staticmethod
102- def _lr_to_linear_head_name (lr : float ) -> str :
102+ def lr_to_linear_head_name (lr : float ) -> str :
103103 return f"lr_{ lr :.0e} " .replace ("." , "_" ).replace ("+" , "" ).replace ("-" , "m" )
104104
105105 def train (self , mode = True ):
106106 super ().train (mode )
107107 self .model .eval ()
108108 return self
109109
110- def _get_features (self , x : torch .Tensor ) -> torch .Tensor :
110+ def get_features (self , x : torch .Tensor ) -> torch .Tensor :
111111 with torch .no_grad ():
112112 skips = self .model ._encode (x )
113113
@@ -117,13 +117,12 @@ def _get_features(self, x: torch.Tensor) -> torch.Tensor:
117117 return torch .flatten (features , 1 )
118118
119119 def on_before_batch_transfer (self , batch , dataloader_idx ):
120- batch ["CLSREG_label" ] = batch ["CLSREG_label" ].squeeze (- 1 ).long ()
120+ batch ["CLSREG_label" ] = batch ["CLSREG_label" ].view (- 1 ).long ()
121121 return batch
122122
123123 def training_step (self , batch , batch_idx ):
124124 x , y = batch ["image" ], batch ["CLSREG_label" ]
125- features = self ._get_features (x )
126-
125+ features = self .get_features (x )
127126 total_loss = 0.0
128127 for head_name , head in self .heads .items ():
129128 logits = head (features )
@@ -143,16 +142,15 @@ def training_step(self, batch, batch_idx):
143142 @torch .no_grad ()
144143 def on_train_epoch_end (self ):
145144 for lr in self .learning_rates :
146- head_name = self ._lr_to_linear_head_name (lr )
145+ head_name = self .lr_to_linear_head_name (lr )
147146 metrics = self .train_metrics [head_name ].compute ()
148147 formatted = format_multilabel_metrics (metrics , ignore_index = self .ignore_index_in_metrics )
149148 self .log_dict (formatted , sync_dist = True )
150149 self .train_metrics [head_name ].reset ()
151150
152151 def validation_step (self , batch , batch_idx ):
153152 x , y = batch ["image" ], batch ["CLSREG_label" ]
154- features = self ._get_features (x )
155-
153+ features = self .get_features (x )
156154 for head_name , head in self .heads .items ():
157155 logits = head (features )
158156 loss = self .loss_fn (logits , y )
@@ -170,7 +168,7 @@ def validation_step(self, batch, batch_idx):
170168 def on_validation_epoch_end (self ):
171169 current_aurocs = {}
172170 for lr in self .learning_rates :
173- head_name = self ._lr_to_linear_head_name (lr )
171+ head_name = self .lr_to_linear_head_name (lr )
174172 metrics = self .val_metrics [head_name ].compute ()
175173 formatted = format_multilabel_metrics (metrics , ignore_index = self .ignore_index_in_metrics )
176174 self .log_dict (formatted , sync_dist = True )
@@ -206,7 +204,7 @@ def configure_test_metrics(self):
206204 def configure_metrics (self , prefix : str ):
207205 metrics = nn .ModuleDict ()
208206 for lr in self .learning_rates :
209- head_name = self ._lr_to_linear_head_name (lr )
207+ head_name = self .lr_to_linear_head_name (lr )
210208 metrics [head_name ] = MetricCollection (
211209 {
212210 f"{ prefix } /{ head_name } /auroc_macro" : MulticlassAUROC (num_classes = self .num_classes , average = "macro" ),
@@ -219,15 +217,15 @@ def configure_metrics(self, prefix: str):
219217 return metrics
220218
221219 def on_test_epoch_start (self ):
222- logging .info (f"Testing with head: { self ._lr_to_linear_head_name (self .best_head_lr )} (lr={ self .best_head_lr } )" )
220+ logging .info (f"Testing with head: { self .lr_to_linear_head_name (self .best_head_lr )} (lr={ self .best_head_lr } )" )
223221 self .results = {}
224222 self .logits = []
225223 self .labels = []
226224
227225 def test_step (self , batch , batch_idx ):
228226 x = batch ["image" ]
229- features = self ._get_features (x )
230- logits = self .heads [self ._lr_to_linear_head_name (self .best_head_lr )](features )
227+ features = self .get_features (x )
228+ logits = self .heads [self .lr_to_linear_head_name (self .best_head_lr )](features )
231229
232230 label = batch ["CLSREG_label" ]
233231 self .results [batch ["file_path" ]] = {
@@ -239,15 +237,13 @@ def test_step(self, batch, batch_idx):
239237
240238 def on_test_epoch_end (self ):
241239 logits_tensor = torch .stack (self .logits ).float ()
242- labels_tensor = torch .stack (self .labels )
243-
240+ labels_tensor = torch .stack (self .labels ).view (- 1 )
244241 avg_results = self .test_metrics (logits_tensor , labels_tensor )
245242 avg_results = {key : value .cpu ().numpy ().tolist () for key , value in avg_results .items ()}
246-
247243 self .results ["metrics" ] = avg_results
248- self .results ["best_head" ] = self ._lr_to_linear_head_name (self .best_head_lr )
244+ self .results ["best_head" ] = self .lr_to_linear_head_name (self .best_head_lr )
249245 self .results ["best_head_lr" ] = self .best_head_lr
250246 os .makedirs (os .path .split (self .test_output_path )[0 ], exist_ok = True )
251247 save_json (self .results , self .test_output_path )
252- logging .info (f"Test using best head: { self ._lr_to_linear_head_name (self .best_head_lr )} (lr={ self .best_head_lr } )" )
248+ logging .info (f"Test using best head: { self .lr_to_linear_head_name (self .best_head_lr )} (lr={ self .best_head_lr } )" )
253249 logging .info (f"Aggregated test results for { len (self .results )} files: { avg_results } " )
0 commit comments