Skip to content

Latest commit

 

History

History
141 lines (101 loc) · 4.67 KB

File metadata and controls

141 lines (101 loc) · 4.67 KB

Fundus Disease Detection — Model Development Report

This document provides a detailed overview of the dataset selection, preprocessing, model architecture, training methodology, evaluation metrics, and future directions used to develop the AI model for retinal fundus disease detection.


Dataset Selection

Primary Dataset: AMDNet23

  • A high-quality, publicly available dataset specifically curated for Age-related Macular Degeneration detection.
  • Access: AMDNet23 Dataset
  • Images are already cropped, preprocessed, and CLAHE-enhanced.
  • Includes 4 classes: AMD, Cataract, Diabetes, Normal.

Image Preprocessing Pipeline

1. Gaussian Blur Enhancement

  • Applied Gaussian blur to all AMDNet23 images to improve edge contrast.
  • Used OpenCV for efficient blurring, as torchvision.transforms.GaussianBlur was computationally expensive.

2. Training-Time Augmentations

train_transforms = transforms.Compose([
    transforms.Resize(pretrained_size),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.CenterCrop(pretrained_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=pretrained_means, std=pretrained_stds)
])
  • Encourages generalization and model robustness.
  • Pretrained values from ImageNet were used for normalization.

Model Architecture

  • Base Model: EfficientNet-B3
  • Why EfficientNet-B3?
    • Outperformed ResNet50 and EfficientNet-B0 on this task.
    • Better feature representation, particularly for medical images with subtle class differences.
  • Final classification head modified for 4-class output.

EfficientNetB3 Architecture Diagram EfficientNetB3 Architecture Diagram

Training Setup

Component Choice
Loss Function CrossEntropyLoss
Optimizer Adam
LR Scheduler ReduceLROnPlateau
Early Stopping Enabled
Batch Size 32
Epochs 50
Framework PyTorch

Evaluation on Test Set (from AMDNet23)

Class Precision Recall F1-Score Support
amd 0.98 1.00 0.99 100
cataract 1.00 0.99 0.99 100
diabetes 0.99 0.95 0.97 100
normal 0.96 0.99 0.98 100

Overall Metrics:

  • Accuracy: 98%
  • Macro Avg F1: 0.98
  • Weighted Avg F1: 0.98

alt text


Explainability: Grad-CAM Integration

  • Grad-CAM maps are generated from the last convolutional block of EfficientNet-B3 (model.features[-1]).
  • Helps visualize attention areas used by the model for classification.
  • Displayed in the UI alongside each prediction.

GRAD-CAM Image

Future Work

1. Improve Generalization:

  • Train on larger, more diverse datasets like EyePACS, Messidor, and more real-world fundus datasets.

2. Enhance Attention with CBAM:

  • Integrate CBAM (Convolutional Block Attention Module) into EfficientNet-B3.
  • Expected to boost class-specific attention — especially helpful for fine-grained features in diseases like AMD and DR.

3. On-Device Inference:

  • Explore deployment on mobile or edge devices.
  • Targeting efficient models for real-time diagnosis on low-resource hardware.

Trained Model Output

  • Model exported as: model.pth
  • Load with:
model = torch.load("model.pth")
model.eval()

Summary

Component Choice / Outcome
Model EfficientNet-B3
Dataset AMDNet23 (preprocessed fundus dataset)
Augmentation Gaussian blur + heavy transforms
Accuracy 98% on test set
Explainability Grad-CAM integrated
Future Focus CBAM attention + mobile inference

Reproducibility

The training pipeline used to build this model is available in:

training/fundus_notebook.ipynb

It includes all preprocessing, augmentations, architecture setup, training loop, and evaluation logic.