1313from kornia .augmentation import AugmentationSequential
1414from lightning .pytorch import LightningModule , Trainer
1515from lightning .pytorch .cli import LRSchedulerCallable , OptimizerCallable
16- from tools .utils import denormalization , load_weights_from_checkpoint
17- from tools .visualization import visualize_prediction
1816from torch import Tensor
17+ from torch .optim .lr_scheduler import _LRScheduler
1918from torchmetrics .segmentation import MeanIoU
2019from torchmetrics .wrappers import ClasswiseWrapper
2120
21+ from geo_deep_learning .tools .utils import denormalization , load_weights_from_checkpoint
22+ from geo_deep_learning .tools .visualization import visualize_prediction
23+
2224# Ignore warning about default grid_sample and affine_grid behavior triggered by kornia
2325warnings .filterwarnings (
2426 "ignore" ,
@@ -140,17 +142,27 @@ def configure_model(self) -> None:
140142 map_location = map_location ,
141143 )
142144
143- def configure_optimizers (self ) -> list [ list [ dict [ str , Any ]]] :
144- """Configure optimizers."""
145+ def configure_optimizers (self ) -> list :
146+ """Configure optimizers and schedulers ."""
145147 optimizer = self .optimizer (self .parameters ())
146- if (
147- self .hparams ["scheduler" ]["class_path" ]
148- == "torch.optim.lr_scheduler.OneCycleLR"
149- ):
150- max_lr = (
151- self .hparams .get ("scheduler" , {}).get ("init_args" , {}).get ("max_lr" )
152- )
148+ scheduler_cfg = self .hparams .get ("scheduler" , None )
149+
150+ # Initialize scheduler variable (either an LR scheduler or None)
151+ scheduler : _LRScheduler | None = None
152+
153+ # Handle non-CLI case
154+ if not scheduler_cfg or not isinstance (scheduler_cfg , dict ):
155+ scheduler = self .scheduler (optimizer ) if callable (self .scheduler ) else None
156+ if scheduler :
157+ return [optimizer ], [{"scheduler" : scheduler , ** self .scheduler_config }]
158+ return [optimizer ]
159+
160+ # CLI-compatible config logic
161+ scheduler_class_path = scheduler_cfg .get ("class_path" , "" )
162+ if scheduler_class_path == "torch.optim.lr_scheduler.OneCycleLR" :
163+ max_lr = scheduler_cfg .get ("init_args" , {}).get ("max_lr" )
153164 stepping_batches = self .trainer .estimated_stepping_batches
165+
154166 if stepping_batches > - 1 :
155167 scheduler = torch .optim .lr_scheduler .OneCycleLR (
156168 optimizer ,
@@ -165,31 +177,31 @@ def configure_optimizers(self) -> list[list[dict[str, Any]]]:
165177 epoch_size = self .trainer .datamodule .epoch_size
166178 accumulate_grad_batches = self .trainer .accumulate_grad_batches
167179 max_epochs = self .trainer .max_epochs
180+
168181 steps_per_epoch = math .ceil (
169182 epoch_size / (batch_size * accumulate_grad_batches ),
170183 )
171184 buffer_steps = int (steps_per_epoch * accumulate_grad_batches )
185+
172186 scheduler = torch .optim .lr_scheduler .OneCycleLR (
173187 optimizer ,
174188 max_lr = max_lr ,
175189 steps_per_epoch = steps_per_epoch + buffer_steps ,
176190 epochs = max_epochs ,
177191 )
178192 else :
179- stepping_batches = (
180- self .hparams .get ("scheduler" , {})
181- .get ("init_args" , {})
182- .get ("total_steps" )
183- )
193+ total_steps = scheduler_cfg .get ("init_args" , {}).get ("total_steps" )
184194 scheduler = torch .optim .lr_scheduler .OneCycleLR (
185195 optimizer ,
186196 max_lr = max_lr ,
187- total_steps = stepping_batches ,
197+ total_steps = total_steps ,
188198 )
189199 else :
190200 scheduler = self .scheduler (optimizer )
191201
192- return [optimizer ], [{"scheduler" : scheduler , ** self .scheduler_config }]
202+ return [optimizer ], [
203+ {"scheduler" : scheduler , ** self .scheduler_config },
204+ ] if scheduler else [optimizer ]
193205
194206 def forward (self , image : Tensor ) -> Tensor :
195207 """Forward pass."""
@@ -275,6 +287,7 @@ def test_step(
275287 y_hat = self (x )
276288 loss = self .loss (y_hat , y )
277289 y = y .squeeze (1 ).long ()
290+
278291 if self .num_classes == 1 :
279292 y_hat = (y_hat .sigmoid ().squeeze (1 ) > self .threshold ).long ()
280293 else :
0 commit comments