Skip to content

Commit f1e816c

Browse files
committed
train_example
1 parent f47601a commit f1e816c

2 files changed

Lines changed: 131 additions & 1 deletion

File tree

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
| **PyPI** | **Versions** | **Docs & License** | **Tests** | **Reference** |
66
|:---------:|:-------------:|:------------------:|:----------:|:--------------:|
7-
| [![PyPI](https://img.shields.io/pypi/v/opensr-srgan)](https://pypi.org/project/opensr-srgan/) | ![PythonVersion](https://img.shields.io/badge/Python-v3.10%20v3.12-blue.svg)<br>![PLVersion](https://img.shields.io/badge/PyTorchLightning-v1.9%20v2.0-blue.svg) | [![Docs](https://img.shields.io/badge/docs-mkdocs%20material-brightgreen)](https://srgan.opensr.eu)<br>![License: Apache](https://img.shields.io/badge/license-Apache%20License%202.0-blue) | [![CI](https://github.com/simon-donike/SISR-RS-SRGAN/actions/workflows/ci.yml/badge.svg)](https://github.com/simon-donike/SISR-RS-SRGAN/actions/workflows/ci.yml)<br>[![codecov](https://codecov.io/gh/simon-donike/SISR-RS-SRGAN/graph/badge.svg?token=PWZND7MHRR)](https://codecov.io/gh/simon-donike/SISR-RS-SRGAN) | [![arXiv](https://img.shields.io/badge/arXiv-2511.10461-b31b1b.svg)](https://arxiv.org/abs/2511.10461)<br>[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17590993.svg)](https://doi.org/10.5281/zenodo.17590993) |
7+
| [![PyPI](https://img.shields.io/pypi/v/opensr-srgan)](https://pypi.org/project/opensr-srgan/) | ![PythonVersion](https://img.shields.io/badge/Python-v3.10%20v3.12-blue.svg)<br>![PLVersion](https://img.shields.io/badge/PyTorchLightning-v1.9%20v2.0-blue.svg) | [![Docs](https://img.shields.io/badge/docs-mkdocs%20material-brightgreen)](https://srgan.opensr.eu)<br>![License: Apache](https://img.shields.io/badge/license-Apache%20License%202.0-blue) | [![CI](https://github.com/simon-donike/SISR-RS-SRGAN/actions/workflows/ci.yml/badge.svg)](https://github.com/simon-donike/SISR-RS-SRGAN/actions/workflows/ci.yml)<br>[![codecov](https://codecov.io/github/ESAOpenSR/SRGAN/graph/badge.svg?token=LQ69MIMLVE)](https://codecov.io/github/ESAOpenSR/SRGAN) | [![arXiv](https://img.shields.io/badge/arXiv-2511.10461-b31b1b.svg)](https://arxiv.org/abs/2511.10461)<br>[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17590993.svg)](https://doi.org/10.5281/zenodo.17590993) |
88

99
![Super-resolved Sentinel-2 example](assets/6band_banner.png)
1010

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# ============================================================================ #
2+
# ⚙️ GENERAL CONFIGURATION FILE
3+
# ---------------------------------------------------------------------------- #
4+
# This file defines all key components for training and evaluating the SR model.
5+
# Sections: Data, Model, Training, Architecture, Optimization, Scheduling, Logging
6+
# ============================================================================ #
7+
8+
9+
# ============================================================================ #
10+
# 🗂️ DATA SETTINGS
11+
# ---------------------------------------------------------------------------- #
12+
Data:
13+
# Loader parameters
14+
train_batch_size: 2 # Batch size for training
15+
val_batch_size: 2 # Batch size for validation
16+
num_workers: 1 # Number of parallel workers for dataloader
17+
prefetch_factor: 2 # Samples prefetched per worker (2 is stable default)
18+
19+
# Dataset configuration
20+
dataset_type: 'SISR_WW' # Choose dataset type: ['cv', 'SPOT6', 'S2_6b', 'SISR_WW']
21+
normalization: 'normalise_10k' # Normalization strategy for data processing
22+
23+
24+
# ============================================================================ #
25+
# 🧠 MODEL AND CHECKPOINT SETTINGS
26+
# ---------------------------------------------------------------------------- #
27+
Model:
28+
in_bands: 4 # Number of input channels (e.g. RGB-NIR-SWIR etc.)
29+
continue_training: False # Resume full training (PL checkpoint path or False)
30+
load_checkpoint: False # Load weights only (path or False)
31+
32+
33+
# ============================================================================ #
34+
# 🏋️ TRAINING CONFIGURATION
35+
# ---------------------------------------------------------------------------- #
36+
Training:
37+
# --- Hardware Setup
38+
device: "cuda" # Runtime device backend: ['cuda', 'cpu']
39+
gpus: [0] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2]
40+
# --- General Training Setup
41+
max_epochs: 9999 # Maximum number of training epochs
42+
val_check_interval: 0.25 # Validate at x percent of epoch (float) or every N steps (int)
43+
limit_val_batches: 250 # Limit number of validation batches
44+
45+
# --- Pretraining and adversarial setup ---
46+
pretrain_g_only: True # Train generator only for initial phase
47+
g_pretrain_steps: 1000 # Number of generator-only warmup steps
48+
adv_loss_ramp_steps: 500 # Gradual adversarial weight ramp steps
49+
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)
50+
51+
EMA:
52+
enabled: False # Maintain exponential moving average of generator weights
53+
decay: 0.999 # EMA decay factor (closer to 1.0 → smoother updates)
54+
update_after_step: 0 # Delay EMA updates until this global step (0 = immediate)
55+
use_num_updates: True # Use adaptive decay warmup based on number of updates
56+
57+
Losses:
58+
# --- GAN term ---
59+
adv_loss_beta: 0.001 # Final adversarial loss weight after ramp-up - original 0.001
60+
adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine']
61+
62+
# --- Content loss components (GeneratorContentLoss) ---
63+
l1_weight: 1.0 # L1 loss over all bands
64+
sam_weight: 0.05 # Spectral Angle Mapper loss
65+
perceptual_weight: 0.2 # Perceptual similarity term weight
66+
perceptual_metric: 'vgg' # ['vgg', 'lpips'] - LPIPS requires pip install lpips
67+
tv_weight: 0.0 # Total Variation regularization (optional)
68+
69+
# --- Metric evaluation settings ---
70+
max_val: 1.0 # Peak value assumed by PSNR/SSIM after metric preprocessing
71+
ssim_win: 11 # SSIM window size (must be odd integer)
72+
73+
74+
# ============================================================================ #
75+
# 🧩 ARCHITECTURAL PARAMETERS
76+
# ---------------------------------------------------------------------------- #
77+
# See Docs for archtecture details and suggestions
78+
Generator:
79+
model_type: 'SRResNet' # Generator family: ['SRResNet', 'stochastic_gan', 'esrgan']
80+
block_type: 'rrdb' # SRResNet block variant: ['standard', 'res', 'rcab', 'rrdb', 'lka']
81+
large_kernel_size: 9 # Kernel for head and tail conv layers (SRResNet/stochastic)
82+
small_kernel_size: 3 # Kernel for intermediate blocks (SRResNet/stochastic)
83+
n_channels: 32 # Feature width (RRDB/ESRGAN uses this as trunk width)
84+
n_blocks: 4 # Residual/attention blocks (ESRGAN: number of RRDB blocks)
85+
scaling_factor: 4 # Upscaling factor (e.g., 2×, 4×, 8×)
86+
growth_channels: 32 # ESRGAN-specific RRDB growth channels (ignored otherwise)
87+
res_scale: 0.2 # Residual scaling used by stochastic/ESRGAN variants
88+
89+
Discriminator:
90+
model_type: 'standard' # Discriminator architecture selector ['standard', 'patchgan', 'esrgan']
91+
n_blocks: 2 # Convolutional depth for SRGAN/PatchGAN (ignored by ESRGAN)
92+
base_channels: 32 # ESRGAN discriminator base feature width (ignored otherwise)
93+
linear_size: 1024 # Hidden dim of ESRGAN discriminator head (ignored otherwise)
94+
95+
# ============================================================================ #
96+
# 🧮 OPTIMIZATION SETTINGS
97+
# ---------------------------------------------------------------------------- #
98+
Optimizers:
99+
optim_g_lr: 1e-4 # Learning rate for Generator
100+
optim_d_lr: 1e-6 # Learning rate for Discriminator
101+
gradient_clip_val: 1.0 # Gradient clipping value (0 disables clipping)
102+
betas: [0.0, 0.99] # optional
103+
eps: 1.0e-7 # optional
104+
weight_decay_g: 0.0 # optional
105+
weight_decay_d: 0.0 # optional
106+
107+
# ============================================================================ #
108+
# 📉 SCHEDULERS AND EARLY STOPPING
109+
# ---------------------------------------------------------------------------- #
110+
Schedulers:
111+
g_warmup_steps: 10 # Generator warmup LR curve duration in steps (0 disables warmup)
112+
g_warmup_type: 'cosine' # Generator warmup curve: ['cosine', 'linear']
113+
metric_g: 'val_metrics/l1' # Metric monitored for Generator LR scheduler
114+
metric_d: 'discriminator/adversarial_loss' # Metric monitored for Discriminator LR scheduler
115+
patience_g: 10 # Patience (epochs) for Generator LR scheduler
116+
patience_d: 10 # Patience (epochs) for Discriminator LR scheduler
117+
factor_g: 0.5 # LR reduction factor for Generator
118+
factor_d: 0.5 # LR reduction factor for Discriminator
119+
verbose: True # Enable scheduler logging output
120+
121+
122+
# ============================================================================ #
123+
# 🧾 LOGGING SETTINGS
124+
# ---------------------------------------------------------------------------- #
125+
Logging:
126+
num_val_images: 5 # Number of validation images logged per epoch
127+
wandb:
128+
enabled: False # Toggle Weights & Biases logging on/off
129+
entity: "opensr" # W&B entity or team name
130+
project: "SRGAN_10m" # W&B project name

0 commit comments

Comments
 (0)