This repository hosts the implementation of the numerical studies reported in “On Flow Matching KL Divergence,” Su et al., 2025. The code reproduces the empirical evidence supporting the paper’s KL error bounds.
For reference, the paper builds on the KL evolution identity:
where:
-
$p_t$ evolves under velocity field$u(x,t) = a(t) , x$ -
$q_t$ evolves under learned velocity field$v_{\theta}(x,t)$ - Both start as standard Gaussians:
$p_0 = q_0 = \mathcal{N}(0, I)$
flow_kl/
├── README.md
├── requirements.txt
├── core/ # Shared schedules, densities, utilities
│ ├── __init__.py
│ ├── true_path.py # Schedules, sampling, densities, scores
│ └── utils.py # Seeding, device helpers, plotting I/O
├── part1/ # Part 1: learned velocity identity experiments
│ ├── __init__.py
│ ├── experiment.py # CLI: python -m part1.experiment
│ ├── eval.py # LHS / RHS evaluation routines
│ ├── model.py # Velocity MLP
│ └── train.py # Training loop and plotting helpers
├── part2/
│ ├── __init__.py
│ ├── synthetic/ # Part 2A: synthetic perturbation studies
│ │ ├── __init__.py
│ │ ├── experiment.py # CLI: python -m part2.synthetic.experiment
│ │ ├── eval.py # Part 2 evaluation helpers
│ │ ├── synthetic_velocity.py
│ │ └── run_all_experiments.py
│ └── learned/ # Part 2B: learned perturbation studies
│ ├── __init__.py
│ ├── experiment.py # CLI: python -m part2.learned.experiment
│ ├── eval.py
│ ├── model.py
│ └── train.py
├── plotting/ # Plot regeneration & epsilon-curve utilities
│ ├── __init__.py
│ ├── plot_eps_curves.py
│ ├── regenerate_plots.py
│ └── regenerate_plots_from_csv.py
├── scripts/ # Automation & shell entry points
│ ├── run_all_experiments.py / .ps1
│ ├── run_all_cross_eval.sh / .ps1
│ ├── run_all_nolearning.sh / .ps1
│ ├── run_all_pt2_experiments.ps1
│ ├── run_all_pt2_learn_experiments.ps1
│ └── nolearning_test.py
├── tests/ # Unit / integration tests
│ ├── __init__.py
│ ├── test_golden_path.py
│ ├── test_rhs.py
│ ├── test_pt2.py
│ ├── test_learn_pt2.py
│ └── test_eps_curves.py
└── data/ # Generated checkpoints, plots, metrics
-
Clone the repository (or navigate to the project directory)
-
Create a conda environment:
conda create -n flow-kl python=3.10
conda activate flow-kl- Install dependencies:
pip install -r requirements.txtThis verifies the identity using analytic formulas (no neural networks):
conda activate flow-kl
python scripts/nolearning_test.py --schedule_p a1 --schedule_q a2 --skip_odeFor all 6 schedule permutations:
bash scripts/run_all_nolearning.sh # or: pwsh scripts/run_all_nolearning.ps1Train a model to learn velocity field
python -m part1.experiment --schedule a1 --target_mse 0.05This will reproduce the Section 5.2 checkpoints:
- Train a neural network to match the true velocity
- Write checkpoints, metrics, and plots into the configured output directory
- Evaluate the KL identity
- Generate plots showing LHS vs RHS
python -m part1.experiment --schedule a1 --load_model path/to/vtheta_schedule_a1_mse_0-05_TIMESTAMP.pthValidate the bound
python -m part2.synthetic.experiment --schedule a1 --delta_beta 0.0 0.05 0.1 0.2Run all Part 2 experiments:
python -m part2.synthetic.run_all_experimentsTrain a velocity MLP and verify the bound across training checkpoints:
python -m part2.learned.experiment --schedule a1 --epochs 400 --eval_checkpoints "all"This will:
- Train a neural network for up to 400 epochs
- Save multiple checkpoints (best, final, and on improvement)
- Evaluate the bound for all saved checkpoints
- Generate scatter plots showing bound tightening with training
torch>=2.0.0: Neural networks and autogradtorchdiffeq>=0.2.3: ODE solvingnumpy>=1.24.0: Numerical computationmatplotlib>=3.7.0: Plottingscipy>=1.10.0: Scientific computingtqdm>=4.65.0: Progress barsseaborn>=0.12.0: Statistical plots
This project is distributed under the MIT License.