Deepfake detection using patch-level teacher-student distillation with quantization auditing.
This project implements a scalable deepfake detection system based on:
- Teacher-Student Architecture: A large pretrained LaDeDa teacher model distills knowledge to a lightweight TinyLaDeDa student
- Patch-based Analysis: Processes images as patches to identify localized deepfake artifacts
- Top-K Pooling: Aggregates patch-level predictions for final classification
- Quantization Auditing: Rigorous testing of quantized models for edge deployment
deepfake-patch-audit/
├── config/ # Configuration files
│ ├── base.yaml # Architecture contract
│ ├── dataset.yaml # Dataset paths and splits
│ ├── train.yaml # Training hyperparameters
│ └── quant.yaml # Quantization settings
│
├── data/ # Data directory
│ ├── splits/ # CSV files for train/val/test
│ └── samples/ # Debug sample images
│
├── datasets/ # Data loading
│ ├── base_dataset.py # Base dataset class
│ ├── frame_dataset.py # Frame-based loader
│ └── transforms.py # Preprocessing (normalize, JPEG-shift)
│
├── models/ # Model implementations
│ ├── teacher/
│ │ ├── ladeda_wrapper.py # Pretrained LaDeDa
│ │ └── patch_adapter.py # Patch grid enforcement
│ │
│ ├── student/
│ │ ├── tiny_ladeda.py # Lightweight student
│ │ └── blocks.py # Conv/Bottleneck blocks
│ │
│ └── pooling.py # Top-K aggregation
│
├── losses/ # Training losses
│ └── distillation.py # KD loss (MSE + BCE)
│
├── inference/ # Prediction pipeline
│ ├── pipeline.py # Main inference
│ └── heatmap.py # Patch visualization
│
├── training/ # Training utilities
│ ├── train_student.py # Student training loop
│ └── eval_loop.py # Validation/testing
│
├── evaluation/ # Metrics and analysis
│ ├── metrics.py # AUC, Accuracy@τ
│ ├── threshold.py # Threshold tuning
│ └── distribution_shift.py # JPEG robustness tests
│
├── quantization/ # Quantization tools
│ ├── dynamic_range.py # Preferred quantization
│ ├── full_int8.py # Optional (harder)
│ └── audit.py # Float vs quant comparison
│
├── federated/ # Federated learning (boxed extension)
│ ├── client.py # Client trainer
│ ├── server.py # FedProx aggregator
│ └── simulation.py # Federated simulation
│
├── scripts/ # Utility scripts
│ ├── audit_teacher.py # Verify patch alignment
│ ├── audit_student.py # Check student structure
│ ├── export_tflite.py # Export to TFLite
│ └── run_eval.py # One-command evaluation
│
├── tests/ # Unit tests
│ ├── test_patch_alignment.py
│ ├── test_topk_pooling.py
│ ├── test_quant_consistency.py
│ └── test_inference_contract.py
│
├── outputs/ # Results directory
│ ├── checkpoints/ # Model weights
│ ├── logs/ # Training logs
│ └── heatmaps/ # Visualizations
│
├── requirements.txt # Python dependencies
├── setup.py # Package setup
└── README.md # This file
# Install dependencies
pip install -r requirements.txt
# Or install as package
pip install -e .Check if the pretrained teacher is informative on your dataset:
python scripts/diagnose_teacher.py
# Output: Teacher AUC on validation setInterpretation:
- AUC > 0.65: Teacher is good → skip fine-tuning, proceed to student training
- AUC 0.55-0.65: Teacher is marginal → consider fine-tuning
- AUC < 0.55: Teacher is bad → MUST fine-tune or skip distillation
If teacher AUC < 0.55, fine-tune it on your training data:
python scripts/train_teacher.py \
--epochs 50 \
--batch-size 32 \
--lr 0.0001 \
--patience 5 \
--device cuda
# Output: weights/teacher/teacher_finetuned_best.pthFeatures:
- Unfreezes only last layers (preserves pretrained knowledge)
- BCE loss for binary classification
- Early stopping on validation AUC
- Saves best checkpoint
Evaluate the fine-tuned teacher:
python scripts/evaluate_teacher.py \
--checkpoint weights/teacher/teacher_finetuned_best.pth
# Output: Teacher AUC and distillation readiness assessmentTrain student model using fine-tuned or pretrained teacher:
# Two-stage training (recommended)
python scripts/train_student_two_stage.py \
--dataset-root dataset \
--teacher-weights finetuned \ # or 'wildrf', 'forensyth'
--batch-size 32 \
--epochs-s1 5 \
--epochs-s2 20 \
--lr-s1 0.001 \
--lr-s2 0.0001 \
--device cuda
# OR single-stage training
python scripts/train_student.py \
--dataset-root dataset \
--teacher-weights finetuned \
--epochs 50 \
--batch-size 32 \
--lr 0.001 \
--device cudaKey Points:
- Default alpha values start small (0.05 distill / 0.95 task) for stable learning
- Patch MSE is averaged per patch cell (not summed) for proper loss scaling
- Two-stage training: Stage 1 trains classifier, Stage 2 fine-tunes backbone
Evaluate the trained student on test set:
# Single-stage
python scripts/evaluate_student.py \
--checkpoint outputs/checkpoints/student_final.pt
# Two-stage
python scripts/evaluate_student.py \
--checkpoint outputs/checkpoints_two_stage/student_final.pt \
--two-stageSee if distillation actually improved the student:
python scripts/evaluate_both.py \
--teacher-checkpoint weights/teacher/teacher_finetuned_best.pth \
--student-checkpoint outputs/checkpoints_two_stage/student_final.pt
# Output: Side-by-side comparison showing improvementFor a quick start without fine-tuning:
# 1. Prepare Data
mkdir -p dataset/{train,val}/{real,fake}
# Copy your images into these directories
# 2. Train Student with Pretrained Teacher
python scripts/train_student_two_stage.py \
--dataset-root dataset \
--teacher-weights wildrf \
--device cuda
# 3. Evaluate
python scripts/evaluate_student.py \
--checkpoint outputs/checkpoints_two_stage/student_final.pt \
--two-stagefrom inference.pipeline import InferencePipeline
# Load model
pipeline = InferencePipeline(student, device="cuda", threshold=0.5)
# Predict single image
result = pipeline.predict("image.jpg")
print(f"Fake probability: {result['fake_probability']:.4f}")
print(f"Is fake: {result['is_fake']}")
# Predict batch
results = pipeline.predict_batch(["img1.jpg", "img2.jpg"])- Images processed as overlapping patches (224×224)
- Patch grid enforced via
PatchAdapter - Individual patch predictions aggregated with Top-K pooling
- Large teacher (LaDeDa) guides lightweight student (TinyLaDeDa)
- Configurable temperature and loss weighting
- Supports both KL divergence and MSE distillation
- Dynamic range quantization (preferred)
- Full INT8 quantization (optional)
- Float vs quantized model comparison
- Tolerance-based validation
- JPEG compression robustness testing
- Configurable quality levels
- Automatic threshold analysis across qualities
Defines frozen model architecture:
model:
teacher:
pretrained: true
freeze_backbone: false
student:
depth_multiplier: 0.5
width_multiplier: 0.75
patches:
patch_size: 224
stride: 16
enable_padding: trueDataset paths and preprocessing:
dataset:
root: "../dataset"
train_ratio: 0.7
val_ratio: 0.15
test_ratio: 0.15
resize_size: 224Training hyperparameters:
training:
epochs: 50
batch_size: 32
learning_rate: 0.001
scheduler: "cosine"
distillation:
temperature: 4.0
alpha: 0.5 # weight of KD lossQuantization and threshold settings:
quantization:
strategy: "dynamic_range"
bits: 8
calibration_samples: 100
threshold:
search_method: "grid"
search_range: [0.3, 0.5, 0.7]
search_metric: "f1"- Accuracy: Overall classification accuracy
- Precision/Recall: Per-class performance
- F1 Score: Harmonic mean of precision and recall
- AUC: Area under ROC curve
- Accuracy@τ: Accuracy at specific threshold τ
- Calibration: Collect activation statistics (100 samples)
- Quantization: Convert weights to INT8 (dynamic range)
- Validation: Test quantized model on validation set
- Audit: Compare float vs quantized predictions
- Threshold Tuning: Find optimal threshold for quantized model
Run unit tests:
python -m pytest tests/Key tests:
test_patch_alignment.py: Verify patch grid correctnesstest_topk_pooling.py: Test Top-K aggregationtest_quant_consistency.py: Check quantization fidelitytest_inference_contract.py: Validate inference pipeline
Run all metrics at once:
python scripts/evaluate_comprehensive.py \
--checkpoint weights/student/student_best.pth \
--split test \
--threshold 0.5Outputs: ROC-AUC, PR-AUC, TPR@FPR=1%, Brier Score, ECE, Latency
Evaluate on external benchmark:
# First, prepare benchmark (one-time)
python -m datasets.celebdf_dataset --root-dir /path/to/CelebDF --fps 1.0
# Then evaluate
python scripts/evaluate_benchmark.py \
--checkpoint weights/student/student_best.pth \
--benchmark-root /path/to/CelebDF \
--threshold 0.5Generate visual evidence (heatmaps + sanity checks):
python scripts/create_forensic_pack.py \
--checkpoint weights/student/student_best.pth \
--data-root data/ \
--num-samples 25 \
--output-dir outputs/forensic_packOutputs:
- Patch probability heatmaps
- Top-k suspicious patch overlays
- Deletion sanity check (verifies model behavior)
Analyze training logs to diagnose issues:
python scripts/diagnose_training.py \
--history outputs/checkpoints/training_history.json \
--output outputs/diagnosticsDetects:
- Distillation/task loss imbalance
- Training stagnation
- Overfitting
- Stage 1 vs Stage 2 comparison
Verify TFLite/ONNX exports match PyTorch:
python scripts/validate_export_parity.py \
--pytorch-checkpoint weights/student/student_best.pth \
--tflite-path exports/student.tflite \
--data-root data/Checks: AUC delta, decision agreement, correlation
The training pipeline includes automatic guardrails:
| Guardrail | Purpose |
|---|---|
| Shape Contract | Validates teacher/student patch dimensions |
| Scale Logging | Tracks logit statistics every batch |
| Determinism Test | Monitors outputs on fixed audit set |
Guardrails run automatically during train_student_two_stage.py and save logs to outputs/guardrails/.
Cause: Distillation loss dominated, model ignored labels.
Fix:
python scripts/train_student_two_stage.py \
--alpha-distill 0.1 \
--alpha-task 0.9
# OR try --alpha-distill 0.0 (no distillation)Diagnosis: Run scripts/diagnose_training.py on training history.
Fix: Reduce alpha_distill to 0.1 or lower.
Fix: Fine-tune teacher longer or unfreeze more layers:
python scripts/train_teacher.py --epochs 100 --unfreeze-blocks 2Cause: Quantization introduced too much error.
Fix: Use dynamic quantization instead of full INT8.
Export student model for edge deployment:
from scripts.export_tflite import export_to_tflite
export_to_tflite(student, "outputs/student.tflite")- PyTorch >= 2.0.0
- NumPy >= 1.24.0
- scikit-learn >= 1.3.0
- Pillow >= 9.0.0
- OpenCV >= 4.8.0
- LaDeDa: Teacher model architecture
- Knowledge Distillation: Hinton et al. (2015)
- Patch-based Detection: Local artifact analysis
- Quantization: Dynamic range quantization for edge deployment
Team Converge Research Project
For questions or contributions, please reach out to the Team Converge research group.