Skip to content

PayThePizzo/SR-WGAN-GP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

17 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Super-Resolution of Medical MNIST Images Using Wasserstein Generative Adversarial Networks and Gradient Penalty (SR-WGAN-GP)

The goal of this project is to develop an interesting, although rudimental, generative model that is capable of upscaling to 64x64, images of medical scans (Breast MRIs, Chest X-Rays, ...) for their low-resoution version of 28x28. The challenge, a part from achieving stability in training, is to obtain a generator model capable of capturing small details (like ribs or knuckles) without losing the human-perspective of images.


πŸ“˜ Table of Contents


🌟 About the Project

The project follows the following structure:

  • Goal: Build and train a Wasserstein GAN (WGAN) to enhance the resolution of Medical MNIST images. Specifically, the project will focus on generating higher-resolution (e.g., 64x64) versions of the original downsampled Medical MNIST images which are 28x28 pixels. This super-resolution technique should be able to show how WGANs can be applied to improve image quality, while leveraging the benefits of the Wasserstein distance to stabilize GAN training (along with GP).
  • Expected Implementation Steps: The implementation should focus on developing and implementing a WGAN architecture, including both the Generator and Critic (the WGAN version of the Discriminator), the Wasserstein loss function with gradient penalty. This can also include pooling layers and/or optimizers like Adam (although they should not be the focus of this implementation). The technological stack I would like to use is the usual Python, Numpy, PyTorch (or tensorflow depending on which is easier to use with my GPU), Tensorboard (which allows tracking and visualizing metrics such as loss and accuracy), OpenCV (for image manipulation).
  • Model Evaluation: The quality of the generated images could be evaluated using qualitative (visual inspection) and quantitative (metrics like Peak Signal-To-Noise Ratio, Structural Similarity Index Measure, etc...) methods.

πŸ‘Ύ Tech Stack

Core Languages and Libraries

For more details please refer to the pyproject.toml file

🎯 Features and Roadmap

The following features have been implemented for the current project

  • Configuration
    • Configuration parser module /config/config_parser.py that reads from /config/config.yaml
    • Load config from CLI
  • Logging
    • Automatic logging through Tensorboard of losses, images and metrics during training
    • Recording logs in separate .csv file
  • Input
    • Type
      • 1 Channel images
      • 3 Channel images (standard RGB)
    • Loading with ad-hoc module
  • Preprocessing
    • Normalization
    • Grayscale
    • Gaussian Noise
    • X/Y Flip
    • Rotation
  • Generator Models
    • Classic GAN Generator
    • WGAN generator
    • Advanced SR WGAN Generator (Residual blocks, progessive etc...)
  • Critic Models
    • Classica GAN Critic
    • WGAN critic with weight clip or gradient penalty
    • Advanced SR WGAN Critic
  • Losses
    • Generator
      • Adversary loss
      • Pixel-wise L1 loss
      • Loss based on VGG
    • Critic
      • Wasserstein distance
      • Gradient penalty
  • Training loop
    • Load custom train and val sets
    • Batch mode
    • Validation phase
    • Save model .pth
    • Save a copy of model config
  • Test
    • Load custom test set
    • Save images (both real and fake)
    • Compute metrics
    • Compute FID of all test images
    • Save .csv file of images perfomances

🧰 Getting Started

πŸ”§ Configuration, Data Positioning and Environment Variables

Before doing anything make sure your project looks like this, or just create the missing folders yourself.

/
|- config
|- data/    # create if missing, this is where the MedicalMNIST data will be copied
|- logs/    # create if missing, this is where the logs for tensorboard are located
|- models/
|- runs /   # create if missing, this is where the results and model state will be saved
|- src/
|- .flake8
|- .gitignore
|- main.py
|- poetry.toml
|- Project Presentation Slides.pdf
|- Project Presentation.mp4
|- pyproject.toml
|- README.md
|- requirements.txt

Then to ensure we have the data:

  • You must retrieve the dataset from here and unpack it
  • Enter the main folder where the other subcategories are present (BreastMRI, CXR, ...)
  • In any folder (we want to use) divide the images and put 90% of them into a new folder named Train and the rest into a folder named Test. Repeat this for the data of interest.
  • Name the main folder with MedicalMNIST and copy it into data

This should look like this in the end:

/
|- config
|- data/    # create if missing, this is where the MedicalMNIST data will be copied
    |- MedicalMNIST/
        |- BreastMRI/
            |- Train/
                |- 000000.jpeg
                |- 000001.jpeg
                ....
            |- Test/
                |- 001000.jpeg
                |- 001001.jpeg
                ....
        |- CXR/ 
            |- Train/
                |- 000000.jpeg
                |- 000001.jpeg
                ....
            |- Test/
                |- 001000.jpeg
                |- 001001.jpeg
                ....
        ...
|- logs/    # create if missing, this is where the logs for tensorboard are located
|- models/
|- runs /   # create if missing, this is where the results and model state will be saved
|- src/
|- .flake8
|- .gitignore
|- main.py
|- poetry.toml
|- Project Presentation Slides.pdf
|- Project Presentation.mp4
|- pyproject.toml
|- README.md
|- requirements.txt

This is needed for the data to be loaded. The very next thing to do is to check the config.yaml file:

logging:
  interval: 5
  image_log_count: 16

data:
  dataset: "CXR"            # Folder to load the train set
  img_channels: 3
  train_percentage: 0.9
  batch_size: 16

preprocessing:              # For now they are ignored, but will be used in the future
  normalize: true
  normalize_mean: 0.5
  normalize_std: 0.5
  rotation: false         
  rotation_deg: 5
  rotation_p: 0.1   
  flip_horizontal: false  
  flip_h_p: 0.1
  flip_vertical: false    
  flip_v_p: 0.1
  gaussian_noise: false 
  noise_percentage: 0.1
  gaussian_p: 0.1

training:
  epochs: 50

validating:
  interval: 5

generator:
  mode: "classic"
  lr: 0.00001
  beta_1: 0.0
  beta_2: 0.9
  lambda_adv_loss: 0.5
  lambda_l1_loss: 0.2

critic:
  mode: "gp"                      # classic (weight clip) or gp
  lr: 0.00001
  beta_1: 0.0
  beta_2: 0.9
  lambda_wasserstrein: 1.0
  lambda_gp: 10
  weight_clip: 0.01           
  c_iter: 5

Now that the project is correctly configured we can move on. In the future we will add a CLI and a way to validate the configuration.

❗ Prerequisites

First of all install Python 3.12 (preferrably 3.12.8) and create a virtual environment. Then, activate the environment either use requirements.txt or Poetry 1.8.5 to install the dependencies.

# Poetry
(.venv) pip install poetry==1.8.5
# ...
(.venv) poetry install

# Or use pip
(.venv) pip install -r requirements.txt

Once the virtual environment is ready, if you are using a GPU, make sure it matches the version of the packages we use for PyTorch and related packages (or just use the cpu).


πŸ‘€ Usage

We have two things we can do with our project: Train and Test

πŸš€ Training

After having determined the configuration, for which we can modify the config.yaml file, we can run our project with

#
python models/train.py

This will start the execution and create the following folders:

/
...
|- logs/
    |- [MODEL NAME]/
        |- fake/
        |- metrics/
        |- real/
...
|- runs /  
    |- [MODEL NAME]/
        |- config.yaml
...

Where [MODEL NAME] is the name generated for the current model.

To see how the training is going we can just use tensorboard where everything is plotted for us:

tensorboard --logdir logs/[MODEL NAME]/

# Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
# TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Once the training is finished we find the model state in that folder

/
...
|- runs /  
    |- [MODEL NAME]/
        |- config.yaml
        |- model.pth    # model state
...

πŸ”¬ Testing

Now it is time to run the tests. We just need to find the folder in which out model state has been saved and run the following

#
python models/test.py runs/[MODEL NAME]/

This generates a benchmark.csv inside the folder, along with an output/ folder where the generated and test images are copied to. Thus, this results in the following thing:

/
...
|- runs /  
    |- [MODEL NAME]/
        |- output/
            |- generated/
                |- 00000.jpeg
                |- 00001.jpeg
                ....
            |- original/
                |- 00000.jpeg
                |- 00001.jpeg
                ....
        |- benchmark.csv
        |- config.yaml
        |- model.pth    # model state
...

And also prints a summary over the whole test set:

======== Test Set Evaluation ========
MSE: 0.0016
RMSE: 0.0398
PSNR: 28.1763
SSIM: 0.9245
LPIPS: 0.1074
Fid: 34.15489524220507
Saved metrics to: runs/wgan_gp_CXR_2025-06-21_18-38-03/benchmark.csv

πŸ’Ž Acknowledgements and References

The following references have been the core foundations for us to develop the project:

Please feel free to contact me if anything is wrong, incorrect or not cited/used properly. I will immediately proceed to remove or modify anything that goes against copyright or any other guideline I might be ignoring.

About

Super Resolution WGAN with Gradient Penalty trained on Medical MNIST Images

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages