A deep learning solution for plant pathology classification to identify plant diseases from leaf images.
This project implements a deep learning solution for plant pathology classification, designed to identify four different plant conditions:
- Healthy - Healthy plant leaves
- Multiple Diseases - Leaves with multiple disease symptoms
- Rust - Leaves affected by rust disease
- Scab - Leaves affected by scab disease
The solution uses a DeiT3 (Data-efficient image Transformers) model architecture with a two-phase fine-tuning strategy for optimal performance.
- π State-of-the-art Architecture: DeiT3 Vision Transformer with 83% ImageNet-1k accuracy
- π Two-Phase Fine-tuning: Efficient training strategy for better convergence
- π¨ Advanced Augmentation: Comprehensive image augmentation pipeline using Albumentations
- π Smart Data Loading: Automatically handles datasets with or without labels
- π― Advanced Training: Early stopping, learning rate scheduling, and AUROC tracking
- π§ Flexible Design: Easy to customize and extend for different use cases
- Python 3.8+
- PyTorch 1.9+
- CUDA-compatible GPU (recommended)
# Clone the repository
git clone https://github.com/yourusername/plant-pathology.git
cd plant-pathology
# Create virtual environment
python -m venv plant_venv
source plant_venv/bin/activate # On Windows: plant_venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtpip install torch torchvision
pip install timm albumentations
pip install torchmetrics tqdm
pip install opencv-python matplotlib pandas
pip install jupyter scikit-learnimport pandas as pd
from plant_pathology_data import plant_data_loader
# Load training data
train_df = pd.read_csv('train.csv')
train_df['image_path'] = 'images/' + train_df['image_id'] + '.jpg'
# Split into train/validation
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(
train_df,
test_size=0.2,
stratify=train_df['target'],
random_state=42
)# Create data loaders
train_loader, val_loader = plant_data_loader(
df=train_df,
val_df=val_df,
batch_size=16,
train_transform=train_transform,
val_transform=val_transform
)
# For test data (no labels)
test_loader = plant_data_loader(
df=test_df,
batch_size=16,
train_transform=test_transform,
is_test=True
)from trainer_V2 import Trainer
from timm import create_model
# Create model
model = create_model(
'deit3_small_patch16_224.fb_in22k_ft_in1k',
pretrained=True,
num_classes=4
)
# Phase 1: Train only classification head
model = freeze_feature_extractor(model)
trainer = Trainer(model, train_loader, val_loader, optimizer, scheduler)
history = trainer.fit(epochs=30)
# Phase 2: Fine-tune all layers
for param in model.parameters():
param.requires_grad = True
trainer = Trainer(model, train_loader, val_loader, optimizer, scheduler)
history = trainer.fit(epochs=30)from predictor_V2 import Predictor
# Load trained model
model.load_state_dict(torch.load('state_dict'))
# Make predictions
predictor = Predictor(model)
probabilities = predictor.predict_proba(test_images)
predictions = predictor.predict(test_images)plant_pathology/
βββ π Data & Models
β βββ images/ # Training and test images
β βββ train.csv # Training data with labels
β βββ test.csv # Test data for predictions
β βββ sample_submission.csv # Submission format example
β βββ state_dict # Trained model weights
β
βββ π Core Modules
β βββ plant_pathology_data.py # Data loading and preprocessing
β βββ trainer_V2.py # Training loop and validation
β βββ predictor_V2.py # Prediction and inference
β
βββ π Jupyter Notebooks
β βββ 1_check_images.ipynb # Data exploration and visualization
β βββ 2_training_model.ipynb # Model training and evaluation
β
βββ π Additional Files
βββ train_df.pkl # Preprocessed training data
βββ test_df.pkl # Preprocessed test data
βββ .gitignore # Git ignore patterns
- Base Model: DeiT3 Small Patch16 224
- Input Size: 224Γ224Γ3 RGB images
- Output: 4-class classification with probability scores
- Pretrained: ImageNet-22k β ImageNet-1k fine-tuned
- Performance: 83% accuracy on ImageNet-1k
-
Phase 1: Head Training
- Freeze feature extractor backbone
- Train only classification head
- Quick convergence with frozen features
-
Phase 2: Full Fine-tuning
- Unfreeze all layers
- Fine-tune entire model
- Further performance improvement
- Optimizer: Adam with configurable learning rate
- Scheduler: ReduceLROnPlateau with patience
- Loss Function: CrossEntropyLoss
- Metrics: AUROC (Area Under ROC Curve)
- Early Stopping: Configurable patience
- Data Augmentation: Comprehensive pipeline
import albumentations as A
from albumentations.pytorch import ToTensorV2
train_transform = A.Compose([
A.VerticalFlip(p=0.4),
A.HorizontalFlip(p=0.4),
A.Rotate(limit=15, p=0.4),
A.OneOf([
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
A.RandomBrightnessContrast(brightness_limit=8, contrast_limit=8, p=0.3),
A.RandomGamma(gamma_limit=(80, 120), p=0.3),
], p=0.3),
A.Blur(blur_limit=3, p=0.3),
A.Resize(height=224, width=224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])The model achieves competitive performance through the two-phase fine-tuning approach:
- Phase 1 (Head-only): Quick convergence with frozen backbone
- Phase 2 (Full fine-tuning): Further improvement by fine-tuning all layers
- Final Performance: High AUROC scores on validation set
Training curves show consistent improvement in both loss and AUROC metrics across epochs.

# Customize training parameters
trainer = Trainer(
model, train_loader, val_loader,
optimizer=optimizer,
scheduler=scheduler,
early_patience=15, # Early stopping patience
early_stop=True, # Enable/disable early stopping
cutmix=True, # Enable CutMix augmentation
cutmix_prob=0.5 # CutMix probability
)- DeiT3 Paper for the Vision Transformer architecture
- Albumentations for image augmentation
- timm for model implementations