|
| 1 | +# ============================================================================ # |
| 2 | +# ⚙️ GENERAL CONFIGURATION FILE |
| 3 | +# ---------------------------------------------------------------------------- # |
| 4 | +# This file defines all key components for training and evaluating the SR model. |
| 5 | +# Sections: Data, Model, Training, Architecture, Optimization, Scheduling, Logging |
| 6 | +# ============================================================================ # |
| 7 | + |
| 8 | + |
| 9 | +# ============================================================================ # |
| 10 | +# 🗂️ DATA SETTINGS |
| 11 | +# ---------------------------------------------------------------------------- # |
| 12 | +Data: |
| 13 | + # Loader parameters |
| 14 | + train_batch_size: 2 # Batch size for training |
| 15 | + val_batch_size: 2 # Batch size for validation |
| 16 | + num_workers: 1 # Number of parallel workers for dataloader |
| 17 | + prefetch_factor: 2 # Samples prefetched per worker (2 is stable default) |
| 18 | + |
| 19 | + # Dataset configuration |
| 20 | + dataset_type: 'SISR_WW' # Choose dataset type: ['cv', 'SPOT6', 'S2_6b', 'SISR_WW'] |
| 21 | + normalization: 'normalise_10k' # Normalization strategy for data processing |
| 22 | + |
| 23 | + |
| 24 | +# ============================================================================ # |
| 25 | +# 🧠 MODEL AND CHECKPOINT SETTINGS |
| 26 | +# ---------------------------------------------------------------------------- # |
| 27 | +Model: |
| 28 | + in_bands: 4 # Number of input channels (e.g. RGB-NIR-SWIR etc.) |
| 29 | + continue_training: False # Resume full training (PL checkpoint path or False) |
| 30 | + load_checkpoint: False # Load weights only (path or False) |
| 31 | + |
| 32 | + |
| 33 | +# ============================================================================ # |
| 34 | +# 🏋️ TRAINING CONFIGURATION |
| 35 | +# ---------------------------------------------------------------------------- # |
| 36 | +Training: |
| 37 | + # --- Hardware Setup |
| 38 | + device: "cuda" # Runtime device backend: ['cuda', 'cpu'] |
| 39 | + gpus: [0] # Number of GPUs to use, individually in list form, e.g. [0] or [0,2] |
| 40 | + # --- General Training Setup |
| 41 | + max_epochs: 9999 # Maximum number of training epochs |
| 42 | + val_check_interval: 0.25 # Validate at x percent of epoch (float) or every N steps (int) |
| 43 | + limit_val_batches: 250 # Limit number of validation batches |
| 44 | + |
| 45 | + # --- Pretraining and adversarial setup --- |
| 46 | + pretrain_g_only: True # Train generator only for initial phase |
| 47 | + g_pretrain_steps: 1000 # Number of generator-only warmup steps |
| 48 | + adv_loss_ramp_steps: 500 # Gradual adversarial weight ramp steps |
| 49 | + label_smoothing: True # Discriminator target smoothing (1.0 → 0.9) |
| 50 | + |
| 51 | + EMA: |
| 52 | + enabled: False # Maintain exponential moving average of generator weights |
| 53 | + decay: 0.999 # EMA decay factor (closer to 1.0 → smoother updates) |
| 54 | + update_after_step: 0 # Delay EMA updates until this global step (0 = immediate) |
| 55 | + use_num_updates: True # Use adaptive decay warmup based on number of updates |
| 56 | + |
| 57 | + Losses: |
| 58 | + # --- GAN term --- |
| 59 | + adv_loss_beta: 0.001 # Final adversarial loss weight after ramp-up - original 0.001 |
| 60 | + adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine'] |
| 61 | + |
| 62 | + # --- Content loss components (GeneratorContentLoss) --- |
| 63 | + l1_weight: 1.0 # L1 loss over all bands |
| 64 | + sam_weight: 0.05 # Spectral Angle Mapper loss |
| 65 | + perceptual_weight: 0.2 # Perceptual similarity term weight |
| 66 | + perceptual_metric: 'vgg' # ['vgg', 'lpips'] - LPIPS requires pip install lpips |
| 67 | + tv_weight: 0.0 # Total Variation regularization (optional) |
| 68 | + |
| 69 | + # --- Metric evaluation settings --- |
| 70 | + max_val: 1.0 # Peak value assumed by PSNR/SSIM after metric preprocessing |
| 71 | + ssim_win: 11 # SSIM window size (must be odd integer) |
| 72 | + |
| 73 | + |
| 74 | +# ============================================================================ # |
| 75 | +# 🧩 ARCHITECTURAL PARAMETERS |
| 76 | +# ---------------------------------------------------------------------------- # |
| 77 | +# See Docs for archtecture details and suggestions |
| 78 | +Generator: |
| 79 | + model_type: 'SRResNet' # Generator family: ['SRResNet', 'stochastic_gan', 'esrgan'] |
| 80 | + block_type: 'rrdb' # SRResNet block variant: ['standard', 'res', 'rcab', 'rrdb', 'lka'] |
| 81 | + large_kernel_size: 9 # Kernel for head and tail conv layers (SRResNet/stochastic) |
| 82 | + small_kernel_size: 3 # Kernel for intermediate blocks (SRResNet/stochastic) |
| 83 | + n_channels: 32 # Feature width (RRDB/ESRGAN uses this as trunk width) |
| 84 | + n_blocks: 4 # Residual/attention blocks (ESRGAN: number of RRDB blocks) |
| 85 | + scaling_factor: 4 # Upscaling factor (e.g., 2×, 4×, 8×) |
| 86 | + growth_channels: 32 # ESRGAN-specific RRDB growth channels (ignored otherwise) |
| 87 | + res_scale: 0.2 # Residual scaling used by stochastic/ESRGAN variants |
| 88 | + |
| 89 | +Discriminator: |
| 90 | + model_type: 'standard' # Discriminator architecture selector ['standard', 'patchgan', 'esrgan'] |
| 91 | + n_blocks: 2 # Convolutional depth for SRGAN/PatchGAN (ignored by ESRGAN) |
| 92 | + base_channels: 32 # ESRGAN discriminator base feature width (ignored otherwise) |
| 93 | + linear_size: 1024 # Hidden dim of ESRGAN discriminator head (ignored otherwise) |
| 94 | + |
| 95 | +# ============================================================================ # |
| 96 | +# 🧮 OPTIMIZATION SETTINGS |
| 97 | +# ---------------------------------------------------------------------------- # |
| 98 | +Optimizers: |
| 99 | + optim_g_lr: 1e-4 # Learning rate for Generator |
| 100 | + optim_d_lr: 1e-6 # Learning rate for Discriminator |
| 101 | + gradient_clip_val: 1.0 # Gradient clipping value (0 disables clipping) |
| 102 | + betas: [0.0, 0.99] # optional |
| 103 | + eps: 1.0e-7 # optional |
| 104 | + weight_decay_g: 0.0 # optional |
| 105 | + weight_decay_d: 0.0 # optional |
| 106 | + |
| 107 | +# ============================================================================ # |
| 108 | +# 📉 SCHEDULERS AND EARLY STOPPING |
| 109 | +# ---------------------------------------------------------------------------- # |
| 110 | +Schedulers: |
| 111 | + g_warmup_steps: 10 # Generator warmup LR curve duration in steps (0 disables warmup) |
| 112 | + g_warmup_type: 'cosine' # Generator warmup curve: ['cosine', 'linear'] |
| 113 | + metric_g: 'val_metrics/l1' # Metric monitored for Generator LR scheduler |
| 114 | + metric_d: 'discriminator/adversarial_loss' # Metric monitored for Discriminator LR scheduler |
| 115 | + patience_g: 10 # Patience (epochs) for Generator LR scheduler |
| 116 | + patience_d: 10 # Patience (epochs) for Discriminator LR scheduler |
| 117 | + factor_g: 0.5 # LR reduction factor for Generator |
| 118 | + factor_d: 0.5 # LR reduction factor for Discriminator |
| 119 | + verbose: True # Enable scheduler logging output |
| 120 | + |
| 121 | + |
| 122 | +# ============================================================================ # |
| 123 | +# 🧾 LOGGING SETTINGS |
| 124 | +# ---------------------------------------------------------------------------- # |
| 125 | +Logging: |
| 126 | + num_val_images: 5 # Number of validation images logged per epoch |
| 127 | + wandb: |
| 128 | + enabled: False # Toggle Weights & Biases logging on/off |
| 129 | + entity: "opensr" # W&B entity or team name |
| 130 | + project: "SRGAN_10m" # W&B project name |
0 commit comments