Skip to content

bebemdjd/SAM-Fine-tuning-for-Medical-Image-Segmentation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAM Fine-tuning for Medical Image Segmentation

A PyTorch implementation for fine-tuning the Segment Anything Model (SAM) on medical images, specifically designed for gastrointestinal endoscopy image segmentation. This project unfreezes both the image encoder and mask decoder while keeping the prompt encoder frozen.

🎯 Project Overview

This repository provides a complete pipeline for:

  • Fine-tuning SAM on custom medical datasets
  • Training with both single-GPU and multi-GPU distributed training
  • Comprehensive data preprocessing and augmentation
  • Model evaluation and visualization tools

📁 Project Structure

Core Training Files

  • model.py - Main GISAM model class that combines SAM components with custom training logic
  • train_one_GPU.py - Single GPU training script with visualization support
  • train_multi_GPU.py - Multi-GPU distributed training script using PyTorch DDP
  • test.py - Model evaluation script with comprehensive metrics and visualization
  • catalog_dataset.py - Custom PyTorch dataset class for loading images, masks, and bounding boxes

Data Organization

/images/ and /masks/

images/                     # Training images organized by categories
├── data1/                 # Category 1 images
├── data2/                 # Category 2 images  
├── data3/                 # Category 3 images
└── data4/                 # Category 4 images

masks/                      # Corresponding segmentation masks
├── data1/                 # Category 1 masks
├── data2/                 # Category 2 masks
├── data3/                 # Category 3 masks
└── data4/                 # Category 4 masks

/meta/

  • catalog_image.xlsx - Excel catalog of all training images with metadata
  • catalog_mask.xlsx - Excel catalog of all masks with connected component information
  • prepare_image_dataset_excel.py - Script to generate image catalog from directory structure
  • prepare_mask_dataset_excel.py - Script to analyze masks and generate mask catalog with bounding boxes
  • make_splits.py - Script to create train/validation/test splits from catalogs
  • /splits/ - Generated dataset splits
    • meta_with_split.xlsx - Complete dataset with split assignments
    • train.txt - Training set file paths and labels
    • val.txt - Validation set file paths and labels
    • test.txt - Test set file paths and labels

SAM Model Implementation

/SAM_model/

Core SAM model components with modifications for fine-tuning:

  • __init__.py - Package initialization
  • build_sam.py - SAM model factory for different variants (ViT-B, ViT-L, ViT-H)
  • predictor.py - SAM predictor class for inference
  • automatic_mask_generator.py - Automatic mask generation utilities

/SAM_model/modeling/

Neural network architecture implementations:

  • sam.py - Main SAM architecture combining all components
  • image_encoder.py - Vision Transformer (ViT) image encoder with relative position encoding
  • mask_decoder.py - Lightweight mask decoder with transformer-based cross-attention
  • prompt_encoder.py - Prompt encoder for points, boxes, and mask prompts
  • transformer.py - Two-way transformer for mask-image attention
  • common.py - Shared components (MLP blocks, LayerNorm2d)

/SAM_model/utils/

  • transforms.py - Image preprocessing and augmentation utilities
  • amg.py - Automatic mask generation helpers
  • onnx.py - ONNX export utilities

Utility Scripts

/utils/

  • data_alian.py - Data alignment and preprocessing utilities
  • merge_neopolyp_to_binary.py - Script to merge multi-class masks to binary
  • reorginaze_picture.py - Image organization and renaming utilities
  • split.py - Additional data splitting utilities
  • tif_to_png.py - Image format conversion (TIFF to PNG)

🚀 Quick Start

1. Environment Setup

pip install torch torchvision
pip install opencv-python scikit-image
pip install pandas openpyxl
pip install matplotlib tqdm

2. Data Preparation

# Organize your images and masks in the required structure
# Generate image catalog
python meta/prepare_image_dataset_excel.py

# Generate mask catalog with bounding boxes
python meta/prepare_mask_dataset_excel.py

# Create train/val/test splits
python meta/make_splits.py

3. Training

Single GPU Training

python train_one_GPU.py

Multi-GPU Training

# For 2 GPUs
python -m torch.distributed.launch --nproc_per_node=2 train_multi_GPU.py

Key Training Parameters

  • Model Architecture: Supports ViT-B, ViT-L, ViT-H variants
  • Frozen Components: Prompt encoder (frozen), Image encoder + Mask decoder (trainable)
  • Loss Function: Combined Dice loss + Binary Cross Entropy
  • Optimization: AdamW with weight decay
  • Mixed Precision: Optional AMP support for faster training

4. Evaluation

python test.py --model_path checkpoints/best_model.pth --test_txt meta/splits/test.txt

🔧 Model Architecture

GISAM (Gastrointestinal SAM)

The fine-tuned model consists of three main components:

  1. Image Encoder (Trainable)

    • Vision Transformer backbone
    • Processes input images to feature embeddings
  2. Prompt Encoder (Frozen)

    • Encodes bounding box prompts
    • Maintains pre-trained prompt understanding
  3. Mask Decoder (Trainable)

    • Generates segmentation masks
    • Uses cross-attention between image and prompt features

Training Strategy

  • Selective Fine-tuning: Only image encoder and mask decoder are updated
  • Prompt Guidance: Uses bounding box annotations as prompts
  • Multi-scale Training: Supports various input resolutions
  • Data Augmentation: Includes geometric and photometric augmentations

📊 Dataset Format

Image Catalog (catalog_image.xlsx)

Column Description
folder_name Image category/class
filename Image filename
width Image width
height Image height
channels Number of channels

Mask Catalog (catalog_mask.xlsx)

Column Description
folder_name Mask category
filename Mask filename
num_components Number of connected components
bboxes Bounding boxes for each component
areas Area of each component

🔍 Key Features

  • Distributed Training: Multi-GPU support with PyTorch DDP
  • Mixed Precision: Automatic Mixed Precision (AMP) for faster training
  • Comprehensive Logging: Training metrics, loss curves, and model checkpoints
  • Flexible Data Loading: Custom collate functions for variable-sized annotations
  • Visualization Tools: Built-in visualization for training progress and results
  • Modular Design: Easy to extend for different medical imaging tasks

📈 Monitoring and Visualization

The training scripts provide:

  • Real-time loss monitoring
  • Learning rate scheduling
  • Model checkpoint saving
  • Training/validation curves
  • Sample prediction visualization

🏥 Medical Imaging Applications

This implementation is particularly suitable for:

  • Endoscopy image segmentation
  • Polyp detection and segmentation
  • Gastrointestinal lesion analysis
  • Multi-class medical image segmentation

📝 Citation

If you use this code in your research, please cite the original SAM paper:

@article{kirillov2023segany,
  title={Segment Anything},
  author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
  journal={arXiv:2304.02643},
  year={2023}
}

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙋‍♂️ Support

If you encounter any issues or have questions, please open an issue in this repository.

About

A PyTorch implementation for fine-tuning the Segment Anything Model (SAM) on medical images, specifically designed for gastrointestinal endoscopy image segmentation. This project unfreezes both the image encoder and mask decoder while keeping the prompt encoder frozen.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages