Skip to content

Commit 37d37ab

Browse files
author
Henry Wallace
committed
examples: add class-conditional 2D flow matching
Three-class conditional generation: separate LinearCondField per class, Euler integration for sampling, per-class accuracy and mean distance evaluation. 100% accuracy on well-separated Gaussian clusters.
1 parent 9e48212 commit 37d37ab

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

examples/rfm_conditional_2d.rs

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//! Class-conditional flow matching on 2D synthetic data.
2+
//!
3+
//! Trains one linear conditional field per class, then generates samples
4+
//! conditioned on each class label and checks they land near the correct
5+
//! cluster center.
6+
//!
7+
//! - Class 0: cluster at (-3, 0)
8+
//! - Class 1: cluster at ( 3, 0)
9+
//! - Class 2: cluster at ( 0, 3)
10+
//!
11+
//! Source distribution: standard normal for all classes.
12+
//!
13+
//! Run:
14+
//! ```bash
15+
//! cargo run -p flowmatch --example rfm_conditional_2d
16+
//! ```
17+
18+
use flowmatch::linear::LinearCondField;
19+
use flowmatch::ode::{integrate_fixed, OdeMethod};
20+
use ndarray::{Array1, Array2};
21+
use rand::SeedableRng;
22+
use rand_chacha::ChaCha8Rng;
23+
use rand_distr::{Distribution, Normal, StandardNormal};
24+
25+
const CENTERS: [[f32; 2]; 3] = [[-3.0, 0.0], [3.0, 0.0], [0.0, 3.0]];
26+
const N_PER_CLASS: usize = 200;
27+
const NOISE: f32 = 0.3;
28+
const TRAIN_STEPS: usize = 6_000;
29+
const LR: f32 = 3e-3;
30+
const SAMPLE_STEPS: usize = 30;
31+
const N_EVAL: usize = 200;
32+
33+
/// Generate target samples for one class: Gaussian cluster around `center`.
34+
fn sample_cluster(center: [f32; 2], n: usize, rng: &mut ChaCha8Rng) -> Array2<f32> {
35+
let noise_dist = Normal::new(0.0f32, NOISE).unwrap();
36+
let mut out = Array2::<f32>::zeros((n, 2));
37+
for i in 0..n {
38+
out[[i, 0]] = center[0] + noise_dist.sample(rng);
39+
out[[i, 1]] = center[1] + noise_dist.sample(rng);
40+
}
41+
out
42+
}
43+
44+
/// Train a `LinearCondField` on (source, target) pairs for one class.
45+
///
46+
/// Uses the standard conditional FM objective: sample t ~ U(0,1),
47+
/// form x_t = (1-t)*x0 + t*y, regress v_theta(x_t, t; y) toward u = y - x0.
48+
fn train_class_field(target: &Array2<f32>, seed: u64) -> LinearCondField {
49+
let n = target.nrows();
50+
let mut field = LinearCondField::new_zeros(2);
51+
let mut rng = ChaCha8Rng::seed_from_u64(seed);
52+
53+
for _ in 0..TRAIN_STEPS {
54+
// Sample a random target point.
55+
let idx = rand::Rng::random_range(&mut rng, 0..n);
56+
let y = target.row(idx);
57+
58+
// Sample x0 ~ N(0, I).
59+
let x0_0: f32 = StandardNormal.sample(&mut rng);
60+
let x0_1: f32 = StandardNormal.sample(&mut rng);
61+
let x0 = Array1::from_vec(vec![x0_0, x0_1]);
62+
63+
// Sample t ~ U(0,1), clamped away from boundaries.
64+
let t: f32 = rand::Rng::random::<f32>(&mut rng).clamp(1e-5, 1.0 - 1e-5);
65+
66+
// Interpolant: x_t = (1-t)*x0 + t*y.
67+
let xt = Array1::from_vec(vec![
68+
(1.0 - t) * x0[0] + t * y[0],
69+
(1.0 - t) * x0[1] + t * y[1],
70+
]);
71+
72+
// Target velocity: u = y - x0.
73+
let u = Array1::from_vec(vec![y[0] - x0[0], y[1] - x0[1]]);
74+
75+
field.sgd_step(&xt.view(), t, &y, &u.view(), LR);
76+
}
77+
field
78+
}
79+
80+
/// Generate samples from a trained field by integrating from N(0,I).
81+
///
82+
/// Uses the target cluster mean as the conditioning signal y (since at
83+
/// inference we condition on the class, not on individual target points).
84+
fn generate_samples(
85+
field: &LinearCondField,
86+
center: [f32; 2],
87+
n: usize,
88+
seed: u64,
89+
) -> Array2<f32> {
90+
let mut rng = ChaCha8Rng::seed_from_u64(seed);
91+
let dt = 1.0f32 / (SAMPLE_STEPS as f32);
92+
let y = Array1::from_vec(vec![center[0], center[1]]);
93+
94+
let mut out = Array2::<f32>::zeros((n, 2));
95+
for i in 0..n {
96+
let x0 = Array1::from_vec(vec![
97+
StandardNormal.sample(&mut rng),
98+
StandardNormal.sample(&mut rng),
99+
]);
100+
let x1 = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, SAMPLE_STEPS, |xt, t| {
101+
field.eval(xt, t, &y.view())
102+
});
103+
out[[i, 0]] = x1[0];
104+
out[[i, 1]] = x1[1];
105+
}
106+
out
107+
}
108+
109+
fn dist(a: &[f32], b: &[f32; 2]) -> f32 {
110+
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
111+
}
112+
113+
fn main() {
114+
let mut rng = ChaCha8Rng::seed_from_u64(42);
115+
116+
// Generate target data per class.
117+
let targets: Vec<Array2<f32>> = CENTERS
118+
.iter()
119+
.map(|c| sample_cluster(*c, N_PER_CLASS, &mut rng))
120+
.collect();
121+
122+
// Train one field per class.
123+
let fields: Vec<LinearCondField> = targets
124+
.iter()
125+
.enumerate()
126+
.map(|(c, t)| {
127+
let f = train_class_field(t, 100 + c as u64);
128+
println!("Trained class {c} field.");
129+
f
130+
})
131+
.collect();
132+
133+
// Generate and evaluate.
134+
let mut total_correct = 0usize;
135+
let mut total = 0usize;
136+
137+
println!("\nPer-class evaluation ({N_EVAL} samples each):");
138+
for (c, field) in fields.iter().enumerate() {
139+
let samples = generate_samples(field, CENTERS[c], N_EVAL, 500 + c as u64);
140+
141+
// Mean distance to true center.
142+
let mean_dist: f32 = (0..N_EVAL)
143+
.map(|i| dist(&[samples[[i, 0]], samples[[i, 1]]], &CENTERS[c]))
144+
.sum::<f32>()
145+
/ N_EVAL as f32;
146+
147+
// Accuracy: fraction closer to correct center than any other.
148+
let correct = (0..N_EVAL)
149+
.filter(|&i| {
150+
let s = [samples[[i, 0]], samples[[i, 1]]];
151+
let d_own = dist(&s, &CENTERS[c]);
152+
CENTERS.iter().enumerate().all(|(k, ck)| k == c || dist(&s, ck) > d_own)
153+
})
154+
.count();
155+
156+
total_correct += correct;
157+
total += N_EVAL;
158+
159+
println!(
160+
" class {c} center={:?}: mean_dist={mean_dist:.3}, accuracy={}/{N_EVAL} ({:.1}%)",
161+
CENTERS[c],
162+
correct,
163+
100.0 * correct as f32 / N_EVAL as f32,
164+
);
165+
}
166+
167+
let overall = 100.0 * total_correct as f32 / total as f32;
168+
println!("\nOverall accuracy: {total_correct}/{total} ({overall:.1}%)");
169+
}

0 commit comments

Comments
 (0)