@@ -15,46 +15,88 @@ pub fn sample_normal<R: Rng>(rng: &mut R, mean: f64, variance: f64) -> f64 {
1515 mean + std * z
1616}
1717
18- pub fn sample_mvnormal < R : Rng > ( rng : & mut R , mean : & [ f64 ] , cov : & [ Vec < f64 > ] ) -> Vec < f64 > {
19- let lower_cholesky = cholesky ( cov) ;
20- let z: Vec < f64 > = mean. iter ( ) . map ( |_| rng. sample ( StandardNormal ) ) . collect ( ) ;
21-
22- let mut result = mean. to_vec ( ) ;
23- for ( i, value) in result. iter_mut ( ) . enumerate ( ) {
24- * value += lower_cholesky[ i]
25- . iter ( )
26- . zip ( z. iter ( ) )
27- . take ( i + 1 )
28- . map ( |( lhs, rhs) | lhs * rhs)
29- . sum :: < f64 > ( ) ;
30- }
31- result
32- }
33-
34- fn cholesky ( a : & [ Vec < f64 > ] ) -> Vec < Vec < f64 > > {
18+ /// Cholesky decomposition: A = L L^T, returns L (lower triangular).
19+ /// A must be symmetric positive definite.
20+ /// Near-zero or negative diagonals are clamped to 1e-12 for numerical stability.
21+ fn cholesky_lower ( a : & [ Vec < f64 > ] ) -> Vec < Vec < f64 > > {
3522 let k = a. len ( ) ;
36- let mut lower_cholesky = vec ! [ vec![ 0.0 ; k] ; k] ;
23+ let mut lower = vec ! [ vec![ 0.0 ; k] ; k] ;
3724 for i in 0 ..k {
3825 for j in 0 ..=i {
39- let sum = lower_cholesky [ i]
26+ let sum = lower [ i]
4027 . iter ( )
41- . zip ( lower_cholesky [ j] . iter ( ) )
28+ . zip ( lower [ j] . iter ( ) )
4229 . take ( j)
4330 . map ( |( lhs, rhs) | lhs * rhs)
4431 . sum :: < f64 > ( ) ;
4532 if i == j {
4633 let diagonal = a[ i] [ i] - sum;
47- lower_cholesky [ i] [ j] = if diagonal > 0.0 {
34+ lower [ i] [ j] = if diagonal > 0.0 {
4835 diagonal. sqrt ( )
4936 } else {
5037 1e-12
5138 } ;
5239 } else {
53- lower_cholesky [ i] [ j] = ( a[ i] [ j] - sum) / lower_cholesky [ j] [ j] ;
40+ lower [ i] [ j] = ( a[ i] [ j] - sum) / lower [ j] [ j] ;
5441 }
5542 }
5643 }
57- lower_cholesky
44+ lower
45+ }
46+
47+ /// Solve L x = b via forward substitution (L is lower triangular).
48+ fn forward_solve ( l : & [ Vec < f64 > ] , b : & [ f64 ] ) -> Vec < f64 > {
49+ let k = b. len ( ) ;
50+ let mut x = vec ! [ 0.0 ; k] ;
51+ for i in 0 ..k {
52+ let sum: f64 = l[ i] . iter ( ) . zip ( x. iter ( ) ) . take ( i) . map ( |( a, b) | a * b) . sum ( ) ;
53+ x[ i] = ( b[ i] - sum) / l[ i] [ i] ;
54+ }
55+ x
56+ }
57+
58+ /// Solve L^T x = b via backward substitution (L is lower triangular).
59+ fn backward_solve_lt ( l : & [ Vec < f64 > ] , b : & [ f64 ] ) -> Vec < f64 > {
60+ let k = b. len ( ) ;
61+ let mut x = vec ! [ 0.0 ; k] ;
62+ for i in ( 0 ..k) . rev ( ) {
63+ let sum: f64 = ( ( i + 1 ) ..k) . map ( |j| l[ j] [ i] * x[ j] ) . sum ( ) ;
64+ x[ i] = ( b[ i] - sum) / l[ i] [ i] ;
65+ }
66+ x
67+ }
68+
69+ /// Solve L L^T x = b via forward + backward substitution.
70+ fn chol_solve_lower ( l : & [ Vec < f64 > ] , b : & [ f64 ] ) -> Vec < f64 > {
71+ let y = forward_solve ( l, b) ;
72+ backward_solve_lt ( l, & y)
73+ }
74+
75+ /// Sample beta ~ N(A^{-1}b, sigma2 * A^{-1}) using Cholesky of precision A.
76+ ///
77+ /// Algorithm (matches R bsts):
78+ /// 1. L L^T = A (Cholesky of precision matrix)
79+ /// 2. y = L^{-1} b (forward solve)
80+ /// 3. mean = L^{-T} y (backward solve)
81+ /// 4. z ~ N(0, I_k)
82+ /// 5. eps = sqrt(sigma2) * L^{-T} z (backward solve)
83+ /// 6. return mean + eps
84+ pub fn sample_from_precision < R : Rng > (
85+ rng : & mut R ,
86+ precision : & [ Vec < f64 > ] ,
87+ rhs : & [ f64 ] ,
88+ sigma2_obs : f64 ,
89+ ) -> Vec < f64 > {
90+ let k = rhs. len ( ) ;
91+ let l = cholesky_lower ( precision) ;
92+ let mean = chol_solve_lower ( & l, rhs) ;
93+ let z: Vec < f64 > = ( 0 ..k) . map ( |_| rng. sample ( StandardNormal ) ) . collect ( ) ;
94+ let scale = sigma2_obs. sqrt ( ) ;
95+ let eps = backward_solve_lt ( & l, & z) ;
96+ mean. iter ( )
97+ . zip ( eps. iter ( ) )
98+ . map ( |( m, e) | m + scale * e)
99+ . collect ( )
58100}
59101
60102#[ cfg( test) ]
@@ -86,21 +128,176 @@ mod tests {
86128 ) ;
87129 }
88130
89- #[ test]
90- fn test_mvnormal_dimension ( ) {
91- let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
92- let mean = vec ! [ 1.0 , 2.0 ] ;
93- let cov = vec ! [ vec![ 1.0 , 0.0 ] , vec![ 0.0 , 1.0 ] ] ;
94- let sample = sample_mvnormal ( & mut rng, & mean, & cov) ;
95- assert_eq ! ( sample. len( ) , 2 ) ;
96- }
97-
98131 #[ test]
99132 fn test_cholesky_identity ( ) {
100133 let a = vec ! [ vec![ 1.0 , 0.0 ] , vec![ 0.0 , 1.0 ] ] ;
101- let l = cholesky ( & a) ;
134+ let l = cholesky_lower ( & a) ;
102135 assert ! ( ( l[ 0 ] [ 0 ] - 1.0 ) . abs( ) < 1e-12 ) ;
103136 assert ! ( ( l[ 1 ] [ 1 ] - 1.0 ) . abs( ) < 1e-12 ) ;
104137 assert ! ( ( l[ 1 ] [ 0 ] ) . abs( ) < 1e-12 ) ;
105138 }
139+
140+ #[ test]
141+ fn test_cholesky_2x2 ( ) {
142+ // A = [[4, 2], [2, 3]] => L = [[2, 0], [1, sqrt(2)]]
143+ let a = vec ! [ vec![ 4.0 , 2.0 ] , vec![ 2.0 , 3.0 ] ] ;
144+ let l = cholesky_lower ( & a) ;
145+ assert ! ( ( l[ 0 ] [ 0 ] - 2.0 ) . abs( ) < 1e-12 ) ;
146+ assert ! ( ( l[ 1 ] [ 0 ] - 1.0 ) . abs( ) < 1e-12 ) ;
147+ assert ! ( ( l[ 1 ] [ 1 ] - 2.0_f64 . sqrt( ) ) . abs( ) < 1e-12 ) ;
148+ }
149+
150+ #[ test]
151+ fn test_cholesky_near_singular ( ) {
152+ // Near-singular: diagonal element becomes near-zero after subtraction
153+ let a = vec ! [ vec![ 1.0 , 1.0 - 1e-14 ] , vec![ 1.0 - 1e-14 , 1.0 ] ] ;
154+ let l = cholesky_lower ( & a) ;
155+ // Should not panic, result should be finite
156+ for row in & l {
157+ for val in row {
158+ assert ! ( val. is_finite( ) , "Cholesky result must be finite" ) ;
159+ }
160+ }
161+ }
162+
163+ #[ test]
164+ fn test_chol_solve_identity ( ) {
165+ // L = I => solve I I^T x = b => x = b
166+ let l = vec ! [ vec![ 1.0 , 0.0 ] , vec![ 0.0 , 1.0 ] ] ;
167+ let b = vec ! [ 3.0 , 7.0 ] ;
168+ let x = chol_solve_lower ( & l, & b) ;
169+ assert ! ( ( x[ 0 ] - 3.0 ) . abs( ) < 1e-12 ) ;
170+ assert ! ( ( x[ 1 ] - 7.0 ) . abs( ) < 1e-12 ) ;
171+ }
172+
173+ #[ test]
174+ fn test_chol_solve_2x2 ( ) {
175+ // A = [[4, 2], [2, 3]], b = [10, 8]
176+ // A^{-1} = [[3/8, -1/4], [-1/4, 1/2]]
177+ // x = A^{-1}b = [3/8*10 + (-1/4)*8, (-1/4)*10 + 1/2*8] = [1.75, 1.5]
178+ let a = vec ! [ vec![ 4.0 , 2.0 ] , vec![ 2.0 , 3.0 ] ] ;
179+ let l = cholesky_lower ( & a) ;
180+ let b = vec ! [ 10.0 , 8.0 ] ;
181+ let x = chol_solve_lower ( & l, & b) ;
182+ assert ! ( ( x[ 0 ] - 1.75 ) . abs( ) < 1e-10 ) ;
183+ assert ! ( ( x[ 1 ] - 1.5 ) . abs( ) < 1e-10 ) ;
184+ }
185+
186+ #[ test]
187+ fn test_chol_solve_1x1 ( ) {
188+ // k=1: scalar case. A = [[5]], b = [15] => x = 3
189+ let l = cholesky_lower ( & vec ! [ vec![ 5.0 ] ] ) ;
190+ let x = chol_solve_lower ( & l, & [ 15.0 ] ) ;
191+ assert ! ( ( x[ 0 ] - 3.0 ) . abs( ) < 1e-12 ) ;
192+ }
193+
194+ #[ test]
195+ fn test_sample_from_precision_1x1 ( ) {
196+ // k=1: precision=2, rhs=6, sigma2=0.5
197+ // mean = rhs/precision = 3.0, variance = sigma2/precision = 0.25
198+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
199+ let n = 10_000 ;
200+ let precision = vec ! [ vec![ 2.0 ] ] ;
201+ let rhs = vec ! [ 6.0 ] ;
202+ let sigma2 = 0.5 ;
203+ let mut sum = 0.0 ;
204+ let mut sum_sq = 0.0 ;
205+ for _ in 0 ..n {
206+ let s = sample_from_precision ( & mut rng, & precision, & rhs, sigma2) ;
207+ sum += s[ 0 ] ;
208+ sum_sq += s[ 0 ] * s[ 0 ] ;
209+ }
210+ let sample_mean = sum / n as f64 ;
211+ let sample_var = sum_sq / n as f64 - sample_mean * sample_mean;
212+ assert ! (
213+ ( sample_mean - 3.0 ) . abs( ) < 0.1 ,
214+ "Mean {sample_mean} should be near 3.0"
215+ ) ;
216+ assert ! (
217+ ( sample_var - 0.25 ) . abs( ) < 0.1 ,
218+ "Variance {sample_var} should be near 0.25"
219+ ) ;
220+ }
221+
222+ #[ test]
223+ fn test_sample_from_precision_diagonal ( ) {
224+ // Diagonal precision: each component independent
225+ // precision = diag(4, 9), rhs = [12, 27], sigma2 = 1.0
226+ // mean = [3, 3], variance = [1/4, 1/9]
227+ let mut rng = StdRng :: seed_from_u64 ( 123 ) ;
228+ let n = 20_000 ;
229+ let precision = vec ! [ vec![ 4.0 , 0.0 ] , vec![ 0.0 , 9.0 ] ] ;
230+ let rhs = vec ! [ 12.0 , 27.0 ] ;
231+ let sigma2 = 1.0 ;
232+ let mut sum = vec ! [ 0.0 ; 2 ] ;
233+ let mut sum_sq = vec ! [ 0.0 ; 2 ] ;
234+ for _ in 0 ..n {
235+ let s = sample_from_precision ( & mut rng, & precision, & rhs, sigma2) ;
236+ for j in 0 ..2 {
237+ sum[ j] += s[ j] ;
238+ sum_sq[ j] += s[ j] * s[ j] ;
239+ }
240+ }
241+ for j in 0 ..2 {
242+ let mean = sum[ j] / n as f64 ;
243+ let var = sum_sq[ j] / n as f64 - mean * mean;
244+ assert ! (
245+ ( mean - 3.0 ) . abs( ) < 0.1 ,
246+ "Component {j}: mean {mean} should be near 3.0"
247+ ) ;
248+ let expected_var = sigma2 / precision[ j] [ j] ;
249+ assert ! (
250+ ( var - expected_var) . abs( ) < 0.1 ,
251+ "Component {j}: var {var} should be near {expected_var}"
252+ ) ;
253+ }
254+ }
255+
256+ #[ test]
257+ fn test_sample_from_precision_identity ( ) {
258+ // precision = I, rhs = [5, -3], sigma2 = 2.0
259+ // mean = rhs = [5, -3], cov = 2*I
260+ let mut rng = StdRng :: seed_from_u64 ( 99 ) ;
261+ let n = 10_000 ;
262+ let precision = vec ! [ vec![ 1.0 , 0.0 ] , vec![ 0.0 , 1.0 ] ] ;
263+ let rhs = vec ! [ 5.0 , -3.0 ] ;
264+ let sigma2 = 2.0 ;
265+ let mut sum = vec ! [ 0.0 ; 2 ] ;
266+ for _ in 0 ..n {
267+ let s = sample_from_precision ( & mut rng, & precision, & rhs, sigma2) ;
268+ for j in 0 ..2 {
269+ sum[ j] += s[ j] ;
270+ }
271+ }
272+ let mean0 = sum[ 0 ] / n as f64 ;
273+ let mean1 = sum[ 1 ] / n as f64 ;
274+ assert ! ( ( mean0 - 5.0 ) . abs( ) < 0.2 , "Mean[0] {mean0} should be near 5" ) ;
275+ assert ! (
276+ ( mean1 - ( -3.0 ) ) . abs( ) < 0.2 ,
277+ "Mean[1] {mean1} should be near -3"
278+ ) ;
279+ }
280+
281+ #[ test]
282+ fn test_sample_from_precision_finite_k20 ( ) {
283+ // k=20: verify all samples are finite
284+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
285+ let k = 20 ;
286+ let mut precision = vec ! [ vec![ 0.0 ; k] ; k] ;
287+ for i in 0 ..k {
288+ precision[ i] [ i] = 10.0 ;
289+ if i > 0 {
290+ precision[ i] [ i - 1 ] = 0.1 ;
291+ precision[ i - 1 ] [ i] = 0.1 ;
292+ }
293+ }
294+ let rhs: Vec < f64 > = ( 0 ..k) . map ( |i| i as f64 ) . collect ( ) ;
295+ let sigma2 = 1.0 ;
296+ for _ in 0 ..100 {
297+ let s = sample_from_precision ( & mut rng, & precision, & rhs, sigma2) ;
298+ for ( j, val) in s. iter ( ) . enumerate ( ) {
299+ assert ! ( val. is_finite( ) , "k=20 sample[{j}] is not finite: {val}" ) ;
300+ }
301+ }
302+ }
106303}
0 commit comments