|
| 1 | +//! Class-conditional flow matching on 2D synthetic data. |
| 2 | +//! |
| 3 | +//! Trains one linear conditional field per class, then generates samples |
| 4 | +//! conditioned on each class label and checks they land near the correct |
| 5 | +//! cluster center. |
| 6 | +//! |
| 7 | +//! - Class 0: cluster at (-3, 0) |
| 8 | +//! - Class 1: cluster at ( 3, 0) |
| 9 | +//! - Class 2: cluster at ( 0, 3) |
| 10 | +//! |
| 11 | +//! Source distribution: standard normal for all classes. |
| 12 | +//! |
| 13 | +//! Run: |
| 14 | +//! ```bash |
| 15 | +//! cargo run -p flowmatch --example rfm_conditional_2d |
| 16 | +//! ``` |
| 17 | +
|
| 18 | +use flowmatch::linear::LinearCondField; |
| 19 | +use flowmatch::ode::{integrate_fixed, OdeMethod}; |
| 20 | +use ndarray::{Array1, Array2}; |
| 21 | +use rand::SeedableRng; |
| 22 | +use rand_chacha::ChaCha8Rng; |
| 23 | +use rand_distr::{Distribution, Normal, StandardNormal}; |
| 24 | + |
| 25 | +const CENTERS: [[f32; 2]; 3] = [[-3.0, 0.0], [3.0, 0.0], [0.0, 3.0]]; |
| 26 | +const N_PER_CLASS: usize = 200; |
| 27 | +const NOISE: f32 = 0.3; |
| 28 | +const TRAIN_STEPS: usize = 6_000; |
| 29 | +const LR: f32 = 3e-3; |
| 30 | +const SAMPLE_STEPS: usize = 30; |
| 31 | +const N_EVAL: usize = 200; |
| 32 | + |
| 33 | +/// Generate target samples for one class: Gaussian cluster around `center`. |
| 34 | +fn sample_cluster(center: [f32; 2], n: usize, rng: &mut ChaCha8Rng) -> Array2<f32> { |
| 35 | + let noise_dist = Normal::new(0.0f32, NOISE).unwrap(); |
| 36 | + let mut out = Array2::<f32>::zeros((n, 2)); |
| 37 | + for i in 0..n { |
| 38 | + out[[i, 0]] = center[0] + noise_dist.sample(rng); |
| 39 | + out[[i, 1]] = center[1] + noise_dist.sample(rng); |
| 40 | + } |
| 41 | + out |
| 42 | +} |
| 43 | + |
| 44 | +/// Train a `LinearCondField` on (source, target) pairs for one class. |
| 45 | +/// |
| 46 | +/// Uses the standard conditional FM objective: sample t ~ U(0,1), |
| 47 | +/// form x_t = (1-t)*x0 + t*y, regress v_theta(x_t, t; y) toward u = y - x0. |
| 48 | +fn train_class_field(target: &Array2<f32>, seed: u64) -> LinearCondField { |
| 49 | + let n = target.nrows(); |
| 50 | + let mut field = LinearCondField::new_zeros(2); |
| 51 | + let mut rng = ChaCha8Rng::seed_from_u64(seed); |
| 52 | + |
| 53 | + for _ in 0..TRAIN_STEPS { |
| 54 | + // Sample a random target point. |
| 55 | + let idx = rand::Rng::random_range(&mut rng, 0..n); |
| 56 | + let y = target.row(idx); |
| 57 | + |
| 58 | + // Sample x0 ~ N(0, I). |
| 59 | + let x0_0: f32 = StandardNormal.sample(&mut rng); |
| 60 | + let x0_1: f32 = StandardNormal.sample(&mut rng); |
| 61 | + let x0 = Array1::from_vec(vec![x0_0, x0_1]); |
| 62 | + |
| 63 | + // Sample t ~ U(0,1), clamped away from boundaries. |
| 64 | + let t: f32 = rand::Rng::random::<f32>(&mut rng).clamp(1e-5, 1.0 - 1e-5); |
| 65 | + |
| 66 | + // Interpolant: x_t = (1-t)*x0 + t*y. |
| 67 | + let xt = Array1::from_vec(vec![ |
| 68 | + (1.0 - t) * x0[0] + t * y[0], |
| 69 | + (1.0 - t) * x0[1] + t * y[1], |
| 70 | + ]); |
| 71 | + |
| 72 | + // Target velocity: u = y - x0. |
| 73 | + let u = Array1::from_vec(vec![y[0] - x0[0], y[1] - x0[1]]); |
| 74 | + |
| 75 | + field.sgd_step(&xt.view(), t, &y, &u.view(), LR); |
| 76 | + } |
| 77 | + field |
| 78 | +} |
| 79 | + |
| 80 | +/// Generate samples from a trained field by integrating from N(0,I). |
| 81 | +/// |
| 82 | +/// Uses the target cluster mean as the conditioning signal y (since at |
| 83 | +/// inference we condition on the class, not on individual target points). |
| 84 | +fn generate_samples( |
| 85 | + field: &LinearCondField, |
| 86 | + center: [f32; 2], |
| 87 | + n: usize, |
| 88 | + seed: u64, |
| 89 | +) -> Array2<f32> { |
| 90 | + let mut rng = ChaCha8Rng::seed_from_u64(seed); |
| 91 | + let dt = 1.0f32 / (SAMPLE_STEPS as f32); |
| 92 | + let y = Array1::from_vec(vec![center[0], center[1]]); |
| 93 | + |
| 94 | + let mut out = Array2::<f32>::zeros((n, 2)); |
| 95 | + for i in 0..n { |
| 96 | + let x0 = Array1::from_vec(vec![ |
| 97 | + StandardNormal.sample(&mut rng), |
| 98 | + StandardNormal.sample(&mut rng), |
| 99 | + ]); |
| 100 | + let x1 = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, SAMPLE_STEPS, |xt, t| { |
| 101 | + field.eval(xt, t, &y.view()) |
| 102 | + }); |
| 103 | + out[[i, 0]] = x1[0]; |
| 104 | + out[[i, 1]] = x1[1]; |
| 105 | + } |
| 106 | + out |
| 107 | +} |
| 108 | + |
| 109 | +fn dist(a: &[f32], b: &[f32; 2]) -> f32 { |
| 110 | + ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt() |
| 111 | +} |
| 112 | + |
| 113 | +fn main() { |
| 114 | + let mut rng = ChaCha8Rng::seed_from_u64(42); |
| 115 | + |
| 116 | + // Generate target data per class. |
| 117 | + let targets: Vec<Array2<f32>> = CENTERS |
| 118 | + .iter() |
| 119 | + .map(|c| sample_cluster(*c, N_PER_CLASS, &mut rng)) |
| 120 | + .collect(); |
| 121 | + |
| 122 | + // Train one field per class. |
| 123 | + let fields: Vec<LinearCondField> = targets |
| 124 | + .iter() |
| 125 | + .enumerate() |
| 126 | + .map(|(c, t)| { |
| 127 | + let f = train_class_field(t, 100 + c as u64); |
| 128 | + println!("Trained class {c} field."); |
| 129 | + f |
| 130 | + }) |
| 131 | + .collect(); |
| 132 | + |
| 133 | + // Generate and evaluate. |
| 134 | + let mut total_correct = 0usize; |
| 135 | + let mut total = 0usize; |
| 136 | + |
| 137 | + println!("\nPer-class evaluation ({N_EVAL} samples each):"); |
| 138 | + for (c, field) in fields.iter().enumerate() { |
| 139 | + let samples = generate_samples(field, CENTERS[c], N_EVAL, 500 + c as u64); |
| 140 | + |
| 141 | + // Mean distance to true center. |
| 142 | + let mean_dist: f32 = (0..N_EVAL) |
| 143 | + .map(|i| dist(&[samples[[i, 0]], samples[[i, 1]]], &CENTERS[c])) |
| 144 | + .sum::<f32>() |
| 145 | + / N_EVAL as f32; |
| 146 | + |
| 147 | + // Accuracy: fraction closer to correct center than any other. |
| 148 | + let correct = (0..N_EVAL) |
| 149 | + .filter(|&i| { |
| 150 | + let s = [samples[[i, 0]], samples[[i, 1]]]; |
| 151 | + let d_own = dist(&s, &CENTERS[c]); |
| 152 | + CENTERS.iter().enumerate().all(|(k, ck)| k == c || dist(&s, ck) > d_own) |
| 153 | + }) |
| 154 | + .count(); |
| 155 | + |
| 156 | + total_correct += correct; |
| 157 | + total += N_EVAL; |
| 158 | + |
| 159 | + println!( |
| 160 | + " class {c} center={:?}: mean_dist={mean_dist:.3}, accuracy={}/{N_EVAL} ({:.1}%)", |
| 161 | + CENTERS[c], |
| 162 | + correct, |
| 163 | + 100.0 * correct as f32 / N_EVAL as f32, |
| 164 | + ); |
| 165 | + } |
| 166 | + |
| 167 | + let overall = 100.0 * total_correct as f32 / total as f32; |
| 168 | + println!("\nOverall accuracy: {total_correct}/{total} ({overall:.1}%)"); |
| 169 | +} |
0 commit comments