This repo builds a classifier to distinguish real vs. AI-generated images with two complementary branches:
- Spatial branch (src/spatial_model.py): CNN with residual-style
BasicBlocks, global pooling, dropout, and a linear head. Trained via src/train_spatial.py, which includes strong data augmentation, mixup, AMP, and AdamW. - Frequency branch (src/freq_model.py): Operates on log-magnitude FFT views; trained via src/train_freq.py.
- Fusion branch (src/fusion_model.py): Pools spatial + frequency features and fuses them with a small MLP head; trained via src/train_fusion.py and evaluated with src/eval_fusion.py.
- Download & merge (src/data_downloader.py): Pulls Kaggle/Hugging Face datasets and merges class folders into
rawdata/*. - Process dual views (src/data_processing.py): Creates spatial (224×224 RGB) and frequency (320×320 log-FFT) tensors and saves them under
dataset/<source>/spatial|freq/{real,fake}. - Train/val split (src/split_train_val.py): Splits processed data into
dataset/dataset_split/train|val/spatial|freq/{real,fake}with a fixed seed.
- Spatial:
python src/train_spatial.py(auto-detectsdataset/dataset_split; falls back to in-memory split if absent). - Frequency:
python src/train_freq.py.
- Use src/eval_spatial.py (and corresponding freq eval) on saved checkpoints.
- Download:
python src/data_downloader.py - Process:
python src/data_processing.py - Split:
python src/split_train_val.py - Train spatial:
python src/train_spatial.py - Train frequency:
python src/train_freq.py - Train fusion:
python src/train_fusion.py - Eval fusion:
python src/eval_fusion.py
# Create venv
python3 -m venv venv
# Activate (Mac/Linux)
source venv/bin/activate
# Activate (Windows)
# venv\Scripts\activatepip install --upgrade pip
pip install -r requirements.txt# Download your Kaggle API token from https://www.kaggle.com/settings
# Place kaggle.json in ~/.kaggle/
mkdir -p ~/.kaggle
cp /path/to/kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json# Download raw datasets
python src/data_downloader.py
# Process to dual views
python src/data_processing.py
# Create train/val splits
python src/split_train_val.py
# Train models
python src/train_spatial.py
python src/train_freq.py