Skip to content

Commit 21a0687

Browse files
authored
Merge pull request #71 from simon-donike/dev
Merge Dev: PL 2 support (exc training)
2 parents b83566d + da0c2b0 commit 21a0687

10 files changed

Lines changed: 368 additions & 80 deletions

File tree

opensr_srgan/configs/config_10m.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ Model:
3333
# ============================================================================ #
3434
# 🏋️ TRAINING CONFIGURATION
3535
# ---------------------------------------------------------------------------- #
36-
Training:
36+
Training:
3737
# --- Hardware Setup
38-
gpus: [0,1,2,3] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
38+
device: "cuda" # Runtime device backend: ['cuda', 'cpu']
39+
gpus: [2,3] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
3940
# --- General Training Setup
4041
max_epochs: 9999 # Maximum number of training epochs
4142
val_check_interval: 1.0 # Validate at x percent of epoch (float) or every N steps (int)
@@ -114,6 +115,6 @@ Schedulers:
114115
Logging:
115116
num_val_images: 5 # Number of validation images logged per epoch
116117
wandb:
117-
enabled: False # Toggle Weights & Biases logging on/off
118+
enabled: True # Toggle Weights & Biases logging on/off
118119
entity: "opensr" # W&B entity or team name
119120
project: "SRGAN_10m" # W&B project name

opensr_srgan/configs/config_20m.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Model:
3535
# ---------------------------------------------------------------------------- #
3636
Training:
3737
# --- Hardware Setup
38+
device: "cuda" # Runtime device backend: ['cuda', 'cpu']
3839
gpus: [2,3] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
3940
# --- General Training Setup
4041
max_epochs: 9999 # Maximum number of training epochs
@@ -113,6 +114,6 @@ Schedulers:
113114
Logging:
114115
num_val_images: 5 # Number of validation images logged per epoch
115116
wandb:
116-
enabled: False # Toggle Weights & Biases logging on/off
117+
enabled: True # Toggle Weights & Biases logging on/off
117118
entity: "opensr" # W&B entity or team name
118119
project: "SRGAN_20m" # W&B project name

opensr_srgan/configs/config_training_example.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ Model:
3333
# ============================================================================ #
3434
# 🏋️ TRAINING CONFIGURATION
3535
# ---------------------------------------------------------------------------- #
36-
Training:
36+
Training:
3737
# --- Hardware Setup
38+
device: "cpu" # Runtime device backend: ['cuda', 'cpu']
3839
gpus: [0] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
3940
# --- General Training Setup
4041
max_epochs: 5 # Maximum number of training epochs

opensr_srgan/data/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

opensr_srgan/data/dataset_selector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from pathlib import Path
2-
31
def select_dataset(config):
42
"""
53
Build train/val datasets from `config` and wrap them into a LightningDataModule.
@@ -23,6 +21,10 @@ def select_dataset(config):
2321
A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.
2422
"""
2523
dataset_selection = config.Data.dataset_type
24+
25+
# Please Note: The "S2_6b","S2_4b","SISR_WW" settings are leftover from previous versions
26+
# I dont want to delete them in case they are needed again.
27+
# Only the "ExampleDataset" is actively used in the current version.
2628

2729
if dataset_selection == "S2_6b":
2830
# Import here to avoid import costs when other datasets are used elsewhere.
@@ -104,10 +106,10 @@ def select_dataset(config):
104106
path = "example_dataset/"
105107
ds_train = ExampleDataset(folder=path, phase="train")
106108
ds_val = ExampleDataset(folder=path, phase="val")
107-
108109
else:
109110
# Centralized error so unsupported keys fail loudly & clearly.
110-
raise NotImplementedError(f"Dataset {dataset_selection} not implemented")
111+
raise NotImplementedError(f"Dataset {dataset_selection} not implemented!"
112+
f"Add your dataset in data/dataset_selector.py to train on that.")
111113

112114
# Wrap the two datasets into a LightningDataModule with config-driven loader knobs.
113115
pl_datamodule = datamodule_from_datasets(config, ds_train, ds_val)

opensr_srgan/data/example_data/download_example_dataset.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
11
from huggingface_hub import hf_hub_download
2-
import zipfile
3-
import os
2+
import zipfile, os
43

54
def get_example_dataset(out_dir: str = "example_dataset/"):
65
"""Download and extract the example dataset for SRGAN training."""
7-
# make sure the target dir exists
86
os.makedirs(out_dir, exist_ok=True)
97

10-
# download the file from your repo
118
repo_id = "simon-donike/SR-GAN"
129
filename = "example_dataset.zip"
1310

14-
print("Downloading from Hugging Face Hub...")
11+
print("📦 Downloading from Hugging Face Hub...")
1512
zip_path = hf_hub_download(repo_id=repo_id, filename=filename)
1613

17-
# unzip after download
1814
with zipfile.ZipFile(zip_path, "r") as z:
19-
z.extractall(out_dir)
15+
members = z.namelist()
2016

17+
# detect common top-level folder (e.g. "example_data/")
18+
prefix = os.path.commonprefix(members)
19+
if prefix and prefix.endswith("/"):
20+
for member in members:
21+
# strip the prefix
22+
target = member[len(prefix):]
23+
if not target: # skip folder itself
24+
continue
25+
target_path = os.path.join(out_dir, target)
26+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
27+
with z.open(member) as src, open(target_path, "wb") as dst:
28+
dst.write(src.read())
29+
else:
30+
z.extractall(out_dir)
31+
32+
os.remove(zip_path)
2133
print(f"✅ Extracted dataset to: {os.path.abspath(out_dir)}")
22-
23-
# delete the zip file to save space
24-
os.remove(zip_path)

opensr_srgan/model/SRGAN.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
from torch.optim.lr_scheduler import ReduceLROnPlateau
1313

1414
# local imports
15-
from ..utils.logging_helpers import plot_tensors
16-
from ..utils.model_descriptions import print_model_summary
17-
from ..utils.radiometrics import histogram as histogram_match
18-
from ..utils.radiometrics import normalise_10k
19-
from .model_blocks import ExponentialMovingAverage
15+
from opensr_srgan.utils.logging_helpers import plot_tensors
16+
from opensr_srgan.utils.model_descriptions import print_model_summary
17+
from opensr_srgan.utils.radiometrics import histogram as histogram_match
18+
from opensr_srgan.utils.radiometrics import normalise_10k
19+
from opensr_srgan.model.model_blocks import ExponentialMovingAverage
2020

2121

2222
#############################################################################################################
@@ -44,15 +44,17 @@ def __init__(self, config="config.yaml", mode="train"):
4444
# SECTION: Load Configuration
4545
# Purpose: Load and parse model/training hyperparameters from YAML file.
4646
# ======================================================================
47-
if isinstance(config, Path) or isinstance(config, str):
48-
self.config = OmegaConf.load(config) # load config file with OmegaConf
47+
if isinstance(config, str) or isinstance(config, Path):
48+
config = OmegaConf.load(config)
4949
elif isinstance(config, dict):
50-
self.config = OmegaConf.create(config) # create config from dict
50+
config = OmegaConf.create(config)
5151
elif OmegaConf.is_config(config):
52-
self.config = config # already an OmegaConf object
52+
pass
5353
else:
54-
print("Invalid config type; must be file path or OmegaConf/dict.")
54+
raise TypeError("Config must be a filepath (str or Path), dict, or OmegaConf object.")
5555
assert mode in {"train", "eval"}, "Mode must be 'train' or 'eval'" # validate mode
56+
57+
self.config = config
5658
self.mode = mode # store mode (train/eval)
5759

5860
# --- Training settings ---
@@ -91,7 +93,7 @@ def __init__(self, config="config.yaml", mode="train"):
9193
# Purpose: Configure generator content loss and discriminator adversarial loss.
9294
# ======================================================================
9395
if self.mode == "train":
94-
from .loss import GeneratorContentLoss
96+
from opensr_srgan.model.loss import GeneratorContentLoss
9597
self.content_loss_criterion = GeneratorContentLoss(self.config) # perceptual loss (VGG + pixel)
9698
self.adversarial_loss_criterion = torch.nn.BCEWithLogitsLoss() # binary cross-entropy for D/G
9799

@@ -109,7 +111,7 @@ def get_models(self, mode):
109111

110112
if generator_type == 'SRResNet':
111113
# Standard SRResNet generator
112-
from .generators.srresnet import Generator
114+
from opensr_srgan.model.generators.srresnet import Generator
113115
self.generator = Generator(
114116
in_channels=self.config.Model.in_bands, # number of input channels
115117
large_kernel_size=self.config.Generator.large_kernel_size,
@@ -120,7 +122,7 @@ def get_models(self, mode):
120122
)
121123
elif generator_type in ['res', 'rcab', 'rrdb', 'lka']:
122124
# Advanced generator variants (ResNet, RCAB, RRDB, etc.)
123-
from .generators.flexible_generator import FlexibleGenerator
125+
from opensr_srgan.model.generators.flexible_generator import FlexibleGenerator
124126
self.generator = FlexibleGenerator(
125127
in_channels=self.config.Model.in_bands,
126128
n_channels=self.config.Generator.n_channels,
@@ -131,7 +133,7 @@ def get_models(self, mode):
131133
block_type=self.config.Generator.model_type
132134
)
133135
elif generator_type.lower() in ['conditional_cgan', 'cgan']:
134-
from .generators import ConditionalGANGenerator
136+
from opensr_srgan.model.generators import ConditionalGANGenerator
135137

136138
self.generator = ConditionalGANGenerator(
137139
in_channels=self.config.Model.in_bands,
@@ -156,7 +158,7 @@ def get_models(self, mode):
156158
n_blocks = getattr(self.config.Discriminator, 'n_blocks', None)
157159

158160
if discriminator_type == 'standard':
159-
from .discriminators.srgan_discriminator import Discriminator
161+
from opensr_srgan.model.discriminators.srgan_discriminator import Discriminator
160162

161163
discriminator_kwargs = {
162164
"in_channels": self.config.Model.in_bands,
@@ -166,7 +168,7 @@ def get_models(self, mode):
166168

167169
self.discriminator = Discriminator(**discriminator_kwargs)
168170
elif discriminator_type == 'patchgan':
169-
from .discriminators.patchgan import PatchGANDiscriminator
171+
from opensr_srgan.model.discriminators.patchgan import PatchGANDiscriminator
170172

171173
patchgan_layers = n_blocks if n_blocks is not None else 3
172174
self.discriminator = PatchGANDiscriminator(
@@ -198,9 +200,9 @@ def predict_step(self, lr_imgs):
198200
lr_min, lr_max = lr_imgs.min().item(), lr_imgs.max().item() # get value range
199201
if lr_max > 1.5: # Sentinel-2 style raw reflectance → normalize
200202
lr_imgs = normalise_10k(lr_imgs, stage="norm") # normalize to 0–1 range
201-
normalized = True
203+
needs_normalization = True
202204
else:
203-
normalized = False # already normalized
205+
needs_normalization = False # already normalized
204206

205207
# --- Perform super-resolution (optionally using EMA weights) ---
206208
context = self.ema.average_parameters(self.generator) if self.ema is not None else nullcontext()
@@ -211,7 +213,7 @@ def predict_step(self, lr_imgs):
211213
sr_imgs = histogram_match(lr_imgs, sr_imgs) # match distributions
212214

213215
# --- Denormalize only if normalization was applied ---
214-
if normalized:
216+
if needs_normalization:
215217
sr_imgs = normalise_10k(sr_imgs, stage="denorm") # convert back to original scale
216218

217219
# --- Move to CPU and return ---
@@ -300,6 +302,7 @@ def training_step(self,batch,batch_idx,optimizer_idx):
300302
# run discriminator and get loss between pred labels and true labels
301303
sr_discriminated = self.discriminator(sr_imgs) # D(SR): logits for generator outputs
302304
adversarial_loss = self.adversarial_loss_criterion(sr_discriminated, torch.ones_like(sr_discriminated)) # keep taargets 1.0 for G loss
305+
self.log("generator/adversarial_loss",adversarial_loss,sync_dist=True) # log unweighted adversarial loss
303306

304307
""" 3. Weight the losses"""
305308
adv_weight = self._adv_loss_weight() # get adversarial weight based on current step
@@ -317,8 +320,9 @@ def optimizer_step(
317320
optimizer,
318321
optimizer_idx,
319322
optimizer_closure,
320-
on_tpu=False,
323+
on_tpu=False, # these arguments are needed in case we're running on PL>2.0
321324
using_lbfgs=False,
325+
322326
):
323327
optimizer.step(closure=optimizer_closure)
324328
optimizer.zero_grad()
@@ -485,13 +489,13 @@ def configure_optimizers(self):
485489
optimizer_g, mode='min',
486490
factor=self.config.Schedulers.factor_g,
487491
patience=self.config.Schedulers.patience_g,
488-
verbose=self.config.Schedulers.verbose
492+
#verbose=self.config.Schedulers.verbose
489493
)
490494
scheduler_d = ReduceLROnPlateau(
491495
optimizer_d, mode='min',
492496
factor=self.config.Schedulers.factor_d,
493497
patience=self.config.Schedulers.patience_d,
494-
verbose=self.config.Schedulers.verbose
498+
#verbose=self.config.Schedulers.verbose
495499
)
496500

497501
# optional generator warmup scheduler (step-based)
@@ -556,7 +560,7 @@ def on_fit_start(self): # called once at the start of training
556560
# SECTION: Print Model Summary
557561
# Purpose: Output model architecture and parameter counts (only once).
558562
# ======================================================================
559-
from ..utils.gpu_rank import _is_global_zero
563+
from opensr_srgan.utils.gpu_rank import _is_global_zero
560564
if _is_global_zero():
561565
print_model_summary(self) # print model summary to console
562566

@@ -756,8 +760,7 @@ def load_from_checkpoint(self,ckpt_path):
756760

757761

758762
if __name__=="__main__":
759-
config_path = Path(__file__).resolve().parents[1] / "configs" / "config_20m.yaml"
760-
model = SRGAN_model(config_file_path=str(config_path))
761-
model.forward(torch.randn(1,6,32,32))
762-
763-
model.load_from_checkpoint("logs/SRGAN_6bands/2025-10-11_23-53-20/last.ckpt")
763+
config_path = "opensr_srgan/configs/config_10m.yaml"
764+
model = SRGAN_model(config=str(config_path))
765+
model.forward(torch.randn(1,4,32,32))
766+

0 commit comments

Comments
 (0)