1212from torch .optim .lr_scheduler import ReduceLROnPlateau
1313
1414# local imports
15- from . .utils .logging_helpers import plot_tensors
16- from . .utils .model_descriptions import print_model_summary
17- from . .utils .radiometrics import histogram as histogram_match
18- from . .utils .radiometrics import normalise_10k
19- from .model_blocks import ExponentialMovingAverage
15+ from opensr_srgan .utils .logging_helpers import plot_tensors
16+ from opensr_srgan .utils .model_descriptions import print_model_summary
17+ from opensr_srgan .utils .radiometrics import histogram as histogram_match
18+ from opensr_srgan .utils .radiometrics import normalise_10k
19+ from opensr_srgan . model .model_blocks import ExponentialMovingAverage
2020
2121
2222#############################################################################################################
@@ -44,15 +44,17 @@ def __init__(self, config="config.yaml", mode="train"):
4444 # SECTION: Load Configuration
4545 # Purpose: Load and parse model/training hyperparameters from YAML file.
4646 # ======================================================================
47- if isinstance (config , Path ) or isinstance (config , str ):
48- self . config = OmegaConf .load (config ) # load config file with OmegaConf
47+ if isinstance (config , str ) or isinstance (config , Path ):
48+ config = OmegaConf .load (config )
4949 elif isinstance (config , dict ):
50- self . config = OmegaConf .create (config ) # create config from dict
50+ config = OmegaConf .create (config )
5151 elif OmegaConf .is_config (config ):
52- self . config = config # already an OmegaConf object
52+ pass
5353 else :
54- print ( "Invalid config type; must be file path or OmegaConf/ dict." )
54+ raise TypeError ( "Config must be a filepath (str or Path), dict, or OmegaConf object ." )
5555 assert mode in {"train" , "eval" }, "Mode must be 'train' or 'eval'" # validate mode
56+
57+ self .config = config
5658 self .mode = mode # store mode (train/eval)
5759
5860 # --- Training settings ---
@@ -91,7 +93,7 @@ def __init__(self, config="config.yaml", mode="train"):
9193 # Purpose: Configure generator content loss and discriminator adversarial loss.
9294 # ======================================================================
9395 if self .mode == "train" :
94- from .loss import GeneratorContentLoss
96+ from opensr_srgan . model .loss import GeneratorContentLoss
9597 self .content_loss_criterion = GeneratorContentLoss (self .config ) # perceptual loss (VGG + pixel)
9698 self .adversarial_loss_criterion = torch .nn .BCEWithLogitsLoss () # binary cross-entropy for D/G
9799
@@ -109,7 +111,7 @@ def get_models(self, mode):
109111
110112 if generator_type == 'SRResNet' :
111113 # Standard SRResNet generator
112- from .generators .srresnet import Generator
114+ from opensr_srgan . model .generators .srresnet import Generator
113115 self .generator = Generator (
114116 in_channels = self .config .Model .in_bands , # number of input channels
115117 large_kernel_size = self .config .Generator .large_kernel_size ,
@@ -120,7 +122,7 @@ def get_models(self, mode):
120122 )
121123 elif generator_type in ['res' , 'rcab' , 'rrdb' , 'lka' ]:
122124 # Advanced generator variants (ResNet, RCAB, RRDB, etc.)
123- from .generators .flexible_generator import FlexibleGenerator
125+ from opensr_srgan . model .generators .flexible_generator import FlexibleGenerator
124126 self .generator = FlexibleGenerator (
125127 in_channels = self .config .Model .in_bands ,
126128 n_channels = self .config .Generator .n_channels ,
@@ -131,7 +133,7 @@ def get_models(self, mode):
131133 block_type = self .config .Generator .model_type
132134 )
133135 elif generator_type .lower () in ['conditional_cgan' , 'cgan' ]:
134- from .generators import ConditionalGANGenerator
136+ from opensr_srgan . model .generators import ConditionalGANGenerator
135137
136138 self .generator = ConditionalGANGenerator (
137139 in_channels = self .config .Model .in_bands ,
@@ -156,7 +158,7 @@ def get_models(self, mode):
156158 n_blocks = getattr (self .config .Discriminator , 'n_blocks' , None )
157159
158160 if discriminator_type == 'standard' :
159- from .discriminators .srgan_discriminator import Discriminator
161+ from opensr_srgan . model .discriminators .srgan_discriminator import Discriminator
160162
161163 discriminator_kwargs = {
162164 "in_channels" : self .config .Model .in_bands ,
@@ -166,7 +168,7 @@ def get_models(self, mode):
166168
167169 self .discriminator = Discriminator (** discriminator_kwargs )
168170 elif discriminator_type == 'patchgan' :
169- from .discriminators .patchgan import PatchGANDiscriminator
171+ from opensr_srgan . model .discriminators .patchgan import PatchGANDiscriminator
170172
171173 patchgan_layers = n_blocks if n_blocks is not None else 3
172174 self .discriminator = PatchGANDiscriminator (
@@ -198,9 +200,9 @@ def predict_step(self, lr_imgs):
198200 lr_min , lr_max = lr_imgs .min ().item (), lr_imgs .max ().item () # get value range
199201 if lr_max > 1.5 : # Sentinel-2 style raw reflectance → normalize
200202 lr_imgs = normalise_10k (lr_imgs , stage = "norm" ) # normalize to 0–1 range
201- normalized = True
203+ needs_normalization = True
202204 else :
203- normalized = False # already normalized
205+ needs_normalization = False # already normalized
204206
205207 # --- Perform super-resolution (optionally using EMA weights) ---
206208 context = self .ema .average_parameters (self .generator ) if self .ema is not None else nullcontext ()
@@ -211,7 +213,7 @@ def predict_step(self, lr_imgs):
211213 sr_imgs = histogram_match (lr_imgs , sr_imgs ) # match distributions
212214
213215 # --- Denormalize only if normalization was applied ---
214- if normalized :
216+ if needs_normalization :
215217 sr_imgs = normalise_10k (sr_imgs , stage = "denorm" ) # convert back to original scale
216218
217219 # --- Move to CPU and return ---
@@ -300,6 +302,7 @@ def training_step(self,batch,batch_idx,optimizer_idx):
300302 # run discriminator and get loss between pred labels and true labels
301303 sr_discriminated = self .discriminator (sr_imgs ) # D(SR): logits for generator outputs
302304 adversarial_loss = self .adversarial_loss_criterion (sr_discriminated , torch .ones_like (sr_discriminated )) # keep taargets 1.0 for G loss
305+ self .log ("generator/adversarial_loss" ,adversarial_loss ,sync_dist = True ) # log unweighted adversarial loss
303306
304307 """ 3. Weight the losses"""
305308 adv_weight = self ._adv_loss_weight () # get adversarial weight based on current step
@@ -317,8 +320,9 @@ def optimizer_step(
317320 optimizer ,
318321 optimizer_idx ,
319322 optimizer_closure ,
320- on_tpu = False ,
323+ on_tpu = False , # these arguments are needed in case we're running on PL>2.0
321324 using_lbfgs = False ,
325+
322326 ):
323327 optimizer .step (closure = optimizer_closure )
324328 optimizer .zero_grad ()
@@ -485,13 +489,13 @@ def configure_optimizers(self):
485489 optimizer_g , mode = 'min' ,
486490 factor = self .config .Schedulers .factor_g ,
487491 patience = self .config .Schedulers .patience_g ,
488- verbose = self .config .Schedulers .verbose
492+ # verbose=self.config.Schedulers.verbose
489493 )
490494 scheduler_d = ReduceLROnPlateau (
491495 optimizer_d , mode = 'min' ,
492496 factor = self .config .Schedulers .factor_d ,
493497 patience = self .config .Schedulers .patience_d ,
494- verbose = self .config .Schedulers .verbose
498+ # verbose=self.config.Schedulers.verbose
495499 )
496500
497501 # optional generator warmup scheduler (step-based)
@@ -556,7 +560,7 @@ def on_fit_start(self): # called once at the start of training
556560 # SECTION: Print Model Summary
557561 # Purpose: Output model architecture and parameter counts (only once).
558562 # ======================================================================
559- from . .utils .gpu_rank import _is_global_zero
563+ from opensr_srgan .utils .gpu_rank import _is_global_zero
560564 if _is_global_zero ():
561565 print_model_summary (self ) # print model summary to console
562566
@@ -756,8 +760,7 @@ def load_from_checkpoint(self,ckpt_path):
756760
757761
758762if __name__ == "__main__" :
759- config_path = Path (__file__ ).resolve ().parents [1 ] / "configs" / "config_20m.yaml"
760- model = SRGAN_model (config_file_path = str (config_path ))
761- model .forward (torch .randn (1 ,6 ,32 ,32 ))
762-
763- model .load_from_checkpoint ("logs/SRGAN_6bands/2025-10-11_23-53-20/last.ckpt" )
763+ config_path = "opensr_srgan/configs/config_10m.yaml"
764+ model = SRGAN_model (config = str (config_path ))
765+ model .forward (torch .randn (1 ,4 ,32 ,32 ))
766+
0 commit comments