A PyTorch implementation for fine-tuning the Segment Anything Model (SAM) on medical images, specifically designed for gastrointestinal endoscopy image segmentation. This project unfreezes both the image encoder and mask decoder while keeping the prompt encoder frozen.
This repository provides a complete pipeline for:
- Fine-tuning SAM on custom medical datasets
- Training with both single-GPU and multi-GPU distributed training
- Comprehensive data preprocessing and augmentation
- Model evaluation and visualization tools
model.py- Main GISAM model class that combines SAM components with custom training logictrain_one_GPU.py- Single GPU training script with visualization supporttrain_multi_GPU.py- Multi-GPU distributed training script using PyTorch DDPtest.py- Model evaluation script with comprehensive metrics and visualizationcatalog_dataset.py- Custom PyTorch dataset class for loading images, masks, and bounding boxes
images/ # Training images organized by categories
├── data1/ # Category 1 images
├── data2/ # Category 2 images
├── data3/ # Category 3 images
└── data4/ # Category 4 images
masks/ # Corresponding segmentation masks
├── data1/ # Category 1 masks
├── data2/ # Category 2 masks
├── data3/ # Category 3 masks
└── data4/ # Category 4 masks
catalog_image.xlsx- Excel catalog of all training images with metadatacatalog_mask.xlsx- Excel catalog of all masks with connected component informationprepare_image_dataset_excel.py- Script to generate image catalog from directory structureprepare_mask_dataset_excel.py- Script to analyze masks and generate mask catalog with bounding boxesmake_splits.py- Script to create train/validation/test splits from catalogs/splits/- Generated dataset splitsmeta_with_split.xlsx- Complete dataset with split assignmentstrain.txt- Training set file paths and labelsval.txt- Validation set file paths and labelstest.txt- Test set file paths and labels
Core SAM model components with modifications for fine-tuning:
__init__.py- Package initializationbuild_sam.py- SAM model factory for different variants (ViT-B, ViT-L, ViT-H)predictor.py- SAM predictor class for inferenceautomatic_mask_generator.py- Automatic mask generation utilities
Neural network architecture implementations:
sam.py- Main SAM architecture combining all componentsimage_encoder.py- Vision Transformer (ViT) image encoder with relative position encodingmask_decoder.py- Lightweight mask decoder with transformer-based cross-attentionprompt_encoder.py- Prompt encoder for points, boxes, and mask promptstransformer.py- Two-way transformer for mask-image attentioncommon.py- Shared components (MLP blocks, LayerNorm2d)
transforms.py- Image preprocessing and augmentation utilitiesamg.py- Automatic mask generation helpersonnx.py- ONNX export utilities
data_alian.py- Data alignment and preprocessing utilitiesmerge_neopolyp_to_binary.py- Script to merge multi-class masks to binaryreorginaze_picture.py- Image organization and renaming utilitiessplit.py- Additional data splitting utilitiestif_to_png.py- Image format conversion (TIFF to PNG)
pip install torch torchvision
pip install opencv-python scikit-image
pip install pandas openpyxl
pip install matplotlib tqdm# Organize your images and masks in the required structure
# Generate image catalog
python meta/prepare_image_dataset_excel.py
# Generate mask catalog with bounding boxes
python meta/prepare_mask_dataset_excel.py
# Create train/val/test splits
python meta/make_splits.pypython train_one_GPU.py# For 2 GPUs
python -m torch.distributed.launch --nproc_per_node=2 train_multi_GPU.py- Model Architecture: Supports ViT-B, ViT-L, ViT-H variants
- Frozen Components: Prompt encoder (frozen), Image encoder + Mask decoder (trainable)
- Loss Function: Combined Dice loss + Binary Cross Entropy
- Optimization: AdamW with weight decay
- Mixed Precision: Optional AMP support for faster training
python test.py --model_path checkpoints/best_model.pth --test_txt meta/splits/test.txtThe fine-tuned model consists of three main components:
-
Image Encoder (Trainable)
- Vision Transformer backbone
- Processes input images to feature embeddings
-
Prompt Encoder (Frozen)
- Encodes bounding box prompts
- Maintains pre-trained prompt understanding
-
Mask Decoder (Trainable)
- Generates segmentation masks
- Uses cross-attention between image and prompt features
- Selective Fine-tuning: Only image encoder and mask decoder are updated
- Prompt Guidance: Uses bounding box annotations as prompts
- Multi-scale Training: Supports various input resolutions
- Data Augmentation: Includes geometric and photometric augmentations
| Column | Description |
|---|---|
| folder_name | Image category/class |
| filename | Image filename |
| width | Image width |
| height | Image height |
| channels | Number of channels |
| Column | Description |
|---|---|
| folder_name | Mask category |
| filename | Mask filename |
| num_components | Number of connected components |
| bboxes | Bounding boxes for each component |
| areas | Area of each component |
- Distributed Training: Multi-GPU support with PyTorch DDP
- Mixed Precision: Automatic Mixed Precision (AMP) for faster training
- Comprehensive Logging: Training metrics, loss curves, and model checkpoints
- Flexible Data Loading: Custom collate functions for variable-sized annotations
- Visualization Tools: Built-in visualization for training progress and results
- Modular Design: Easy to extend for different medical imaging tasks
The training scripts provide:
- Real-time loss monitoring
- Learning rate scheduling
- Model checkpoint saving
- Training/validation curves
- Sample prediction visualization
This implementation is particularly suitable for:
- Endoscopy image segmentation
- Polyp detection and segmentation
- Gastrointestinal lesion analysis
- Multi-class medical image segmentation
If you use this code in your research, please cite the original SAM paper:
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you encounter any issues or have questions, please open an issue in this repository.