Skip to content

Commit d195077

Browse files
Merge pull request #10 from YuminosukeSato/feat/rust-speedup
perf: replace Gauss-Jordan with Cholesky precision sampler and pre-compute loop invariants
2 parents 9bbc2ac + 537ecc1 commit d195077

4 files changed

Lines changed: 601 additions & 143 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ pyo3 = { version = "0.28", features = ["extension-module"] }
1212
rand = "0.8"
1313
rayon = "1.10"
1414
rand_distr = "0.4"
15+
16+
[profile.release]
17+
lto = "thin"
18+
codegen-units = 1

src/distributions.rs

Lines changed: 230 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,88 @@ pub fn sample_normal<R: Rng>(rng: &mut R, mean: f64, variance: f64) -> f64 {
1515
mean + std * z
1616
}
1717

18-
pub fn sample_mvnormal<R: Rng>(rng: &mut R, mean: &[f64], cov: &[Vec<f64>]) -> Vec<f64> {
19-
let lower_cholesky = cholesky(cov);
20-
let z: Vec<f64> = mean.iter().map(|_| rng.sample(StandardNormal)).collect();
21-
22-
let mut result = mean.to_vec();
23-
for (i, value) in result.iter_mut().enumerate() {
24-
*value += lower_cholesky[i]
25-
.iter()
26-
.zip(z.iter())
27-
.take(i + 1)
28-
.map(|(lhs, rhs)| lhs * rhs)
29-
.sum::<f64>();
30-
}
31-
result
32-
}
33-
34-
fn cholesky(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
18+
/// Cholesky decomposition: A = L L^T, returns L (lower triangular).
19+
/// A must be symmetric positive definite.
20+
/// Near-zero or negative diagonals are clamped to 1e-12 for numerical stability.
21+
fn cholesky_lower(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
3522
let k = a.len();
36-
let mut lower_cholesky = vec![vec![0.0; k]; k];
23+
let mut lower = vec![vec![0.0; k]; k];
3724
for i in 0..k {
3825
for j in 0..=i {
39-
let sum = lower_cholesky[i]
26+
let sum = lower[i]
4027
.iter()
41-
.zip(lower_cholesky[j].iter())
28+
.zip(lower[j].iter())
4229
.take(j)
4330
.map(|(lhs, rhs)| lhs * rhs)
4431
.sum::<f64>();
4532
if i == j {
4633
let diagonal = a[i][i] - sum;
47-
lower_cholesky[i][j] = if diagonal > 0.0 {
34+
lower[i][j] = if diagonal > 0.0 {
4835
diagonal.sqrt()
4936
} else {
5037
1e-12
5138
};
5239
} else {
53-
lower_cholesky[i][j] = (a[i][j] - sum) / lower_cholesky[j][j];
40+
lower[i][j] = (a[i][j] - sum) / lower[j][j];
5441
}
5542
}
5643
}
57-
lower_cholesky
44+
lower
45+
}
46+
47+
/// Solve L x = b via forward substitution (L is lower triangular).
48+
fn forward_solve(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
49+
let k = b.len();
50+
let mut x = vec![0.0; k];
51+
for i in 0..k {
52+
let sum: f64 = l[i].iter().zip(x.iter()).take(i).map(|(a, b)| a * b).sum();
53+
x[i] = (b[i] - sum) / l[i][i];
54+
}
55+
x
56+
}
57+
58+
/// Solve L^T x = b via backward substitution (L is lower triangular).
59+
fn backward_solve_lt(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
60+
let k = b.len();
61+
let mut x = vec![0.0; k];
62+
for i in (0..k).rev() {
63+
let sum: f64 = ((i + 1)..k).map(|j| l[j][i] * x[j]).sum();
64+
x[i] = (b[i] - sum) / l[i][i];
65+
}
66+
x
67+
}
68+
69+
/// Solve L L^T x = b via forward + backward substitution.
70+
fn chol_solve_lower(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
71+
let y = forward_solve(l, b);
72+
backward_solve_lt(l, &y)
73+
}
74+
75+
/// Sample beta ~ N(A^{-1}b, sigma2 * A^{-1}) using Cholesky of precision A.
76+
///
77+
/// Algorithm (matches R bsts):
78+
/// 1. L L^T = A (Cholesky of precision matrix)
79+
/// 2. y = L^{-1} b (forward solve)
80+
/// 3. mean = L^{-T} y (backward solve)
81+
/// 4. z ~ N(0, I_k)
82+
/// 5. eps = sqrt(sigma2) * L^{-T} z (backward solve)
83+
/// 6. return mean + eps
84+
pub fn sample_from_precision<R: Rng>(
85+
rng: &mut R,
86+
precision: &[Vec<f64>],
87+
rhs: &[f64],
88+
sigma2_obs: f64,
89+
) -> Vec<f64> {
90+
let k = rhs.len();
91+
let l = cholesky_lower(precision);
92+
let mean = chol_solve_lower(&l, rhs);
93+
let z: Vec<f64> = (0..k).map(|_| rng.sample(StandardNormal)).collect();
94+
let scale = sigma2_obs.sqrt();
95+
let eps = backward_solve_lt(&l, &z);
96+
mean.iter()
97+
.zip(eps.iter())
98+
.map(|(m, e)| m + scale * e)
99+
.collect()
58100
}
59101

60102
#[cfg(test)]
@@ -86,21 +128,176 @@ mod tests {
86128
);
87129
}
88130

89-
#[test]
90-
fn test_mvnormal_dimension() {
91-
let mut rng = StdRng::seed_from_u64(42);
92-
let mean = vec![1.0, 2.0];
93-
let cov = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
94-
let sample = sample_mvnormal(&mut rng, &mean, &cov);
95-
assert_eq!(sample.len(), 2);
96-
}
97-
98131
#[test]
99132
fn test_cholesky_identity() {
100133
let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
101-
let l = cholesky(&a);
134+
let l = cholesky_lower(&a);
102135
assert!((l[0][0] - 1.0).abs() < 1e-12);
103136
assert!((l[1][1] - 1.0).abs() < 1e-12);
104137
assert!((l[1][0]).abs() < 1e-12);
105138
}
139+
140+
#[test]
141+
fn test_cholesky_2x2() {
142+
// A = [[4, 2], [2, 3]] => L = [[2, 0], [1, sqrt(2)]]
143+
let a = vec![vec![4.0, 2.0], vec![2.0, 3.0]];
144+
let l = cholesky_lower(&a);
145+
assert!((l[0][0] - 2.0).abs() < 1e-12);
146+
assert!((l[1][0] - 1.0).abs() < 1e-12);
147+
assert!((l[1][1] - 2.0_f64.sqrt()).abs() < 1e-12);
148+
}
149+
150+
#[test]
151+
fn test_cholesky_near_singular() {
152+
// Near-singular: diagonal element becomes near-zero after subtraction
153+
let a = vec![vec![1.0, 1.0 - 1e-14], vec![1.0 - 1e-14, 1.0]];
154+
let l = cholesky_lower(&a);
155+
// Should not panic, result should be finite
156+
for row in &l {
157+
for val in row {
158+
assert!(val.is_finite(), "Cholesky result must be finite");
159+
}
160+
}
161+
}
162+
163+
#[test]
164+
fn test_chol_solve_identity() {
165+
// L = I => solve I I^T x = b => x = b
166+
let l = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
167+
let b = vec![3.0, 7.0];
168+
let x = chol_solve_lower(&l, &b);
169+
assert!((x[0] - 3.0).abs() < 1e-12);
170+
assert!((x[1] - 7.0).abs() < 1e-12);
171+
}
172+
173+
#[test]
174+
fn test_chol_solve_2x2() {
175+
// A = [[4, 2], [2, 3]], b = [10, 8]
176+
// A^{-1} = [[3/8, -1/4], [-1/4, 1/2]]
177+
// x = A^{-1}b = [3/8*10 + (-1/4)*8, (-1/4)*10 + 1/2*8] = [1.75, 1.5]
178+
let a = vec![vec![4.0, 2.0], vec![2.0, 3.0]];
179+
let l = cholesky_lower(&a);
180+
let b = vec![10.0, 8.0];
181+
let x = chol_solve_lower(&l, &b);
182+
assert!((x[0] - 1.75).abs() < 1e-10);
183+
assert!((x[1] - 1.5).abs() < 1e-10);
184+
}
185+
186+
#[test]
187+
fn test_chol_solve_1x1() {
188+
// k=1: scalar case. A = [[5]], b = [15] => x = 3
189+
let l = cholesky_lower(&vec![vec![5.0]]);
190+
let x = chol_solve_lower(&l, &[15.0]);
191+
assert!((x[0] - 3.0).abs() < 1e-12);
192+
}
193+
194+
#[test]
195+
fn test_sample_from_precision_1x1() {
196+
// k=1: precision=2, rhs=6, sigma2=0.5
197+
// mean = rhs/precision = 3.0, variance = sigma2/precision = 0.25
198+
let mut rng = StdRng::seed_from_u64(42);
199+
let n = 10_000;
200+
let precision = vec![vec![2.0]];
201+
let rhs = vec![6.0];
202+
let sigma2 = 0.5;
203+
let mut sum = 0.0;
204+
let mut sum_sq = 0.0;
205+
for _ in 0..n {
206+
let s = sample_from_precision(&mut rng, &precision, &rhs, sigma2);
207+
sum += s[0];
208+
sum_sq += s[0] * s[0];
209+
}
210+
let sample_mean = sum / n as f64;
211+
let sample_var = sum_sq / n as f64 - sample_mean * sample_mean;
212+
assert!(
213+
(sample_mean - 3.0).abs() < 0.1,
214+
"Mean {sample_mean} should be near 3.0"
215+
);
216+
assert!(
217+
(sample_var - 0.25).abs() < 0.1,
218+
"Variance {sample_var} should be near 0.25"
219+
);
220+
}
221+
222+
#[test]
223+
fn test_sample_from_precision_diagonal() {
224+
// Diagonal precision: each component independent
225+
// precision = diag(4, 9), rhs = [12, 27], sigma2 = 1.0
226+
// mean = [3, 3], variance = [1/4, 1/9]
227+
let mut rng = StdRng::seed_from_u64(123);
228+
let n = 20_000;
229+
let precision = vec![vec![4.0, 0.0], vec![0.0, 9.0]];
230+
let rhs = vec![12.0, 27.0];
231+
let sigma2 = 1.0;
232+
let mut sum = vec![0.0; 2];
233+
let mut sum_sq = vec![0.0; 2];
234+
for _ in 0..n {
235+
let s = sample_from_precision(&mut rng, &precision, &rhs, sigma2);
236+
for j in 0..2 {
237+
sum[j] += s[j];
238+
sum_sq[j] += s[j] * s[j];
239+
}
240+
}
241+
for j in 0..2 {
242+
let mean = sum[j] / n as f64;
243+
let var = sum_sq[j] / n as f64 - mean * mean;
244+
assert!(
245+
(mean - 3.0).abs() < 0.1,
246+
"Component {j}: mean {mean} should be near 3.0"
247+
);
248+
let expected_var = sigma2 / precision[j][j];
249+
assert!(
250+
(var - expected_var).abs() < 0.1,
251+
"Component {j}: var {var} should be near {expected_var}"
252+
);
253+
}
254+
}
255+
256+
#[test]
257+
fn test_sample_from_precision_identity() {
258+
// precision = I, rhs = [5, -3], sigma2 = 2.0
259+
// mean = rhs = [5, -3], cov = 2*I
260+
let mut rng = StdRng::seed_from_u64(99);
261+
let n = 10_000;
262+
let precision = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
263+
let rhs = vec![5.0, -3.0];
264+
let sigma2 = 2.0;
265+
let mut sum = vec![0.0; 2];
266+
for _ in 0..n {
267+
let s = sample_from_precision(&mut rng, &precision, &rhs, sigma2);
268+
for j in 0..2 {
269+
sum[j] += s[j];
270+
}
271+
}
272+
let mean0 = sum[0] / n as f64;
273+
let mean1 = sum[1] / n as f64;
274+
assert!((mean0 - 5.0).abs() < 0.2, "Mean[0] {mean0} should be near 5");
275+
assert!(
276+
(mean1 - (-3.0)).abs() < 0.2,
277+
"Mean[1] {mean1} should be near -3"
278+
);
279+
}
280+
281+
#[test]
282+
fn test_sample_from_precision_finite_k20() {
283+
// k=20: verify all samples are finite
284+
let mut rng = StdRng::seed_from_u64(42);
285+
let k = 20;
286+
let mut precision = vec![vec![0.0; k]; k];
287+
for i in 0..k {
288+
precision[i][i] = 10.0;
289+
if i > 0 {
290+
precision[i][i - 1] = 0.1;
291+
precision[i - 1][i] = 0.1;
292+
}
293+
}
294+
let rhs: Vec<f64> = (0..k).map(|i| i as f64).collect();
295+
let sigma2 = 1.0;
296+
for _ in 0..100 {
297+
let s = sample_from_precision(&mut rng, &precision, &rhs, sigma2);
298+
for (j, val) in s.iter().enumerate() {
299+
assert!(val.is_finite(), "k=20 sample[{j}] is not finite: {val}");
300+
}
301+
}
302+
}
106303
}

0 commit comments

Comments
 (0)