Skip to content

Commit e9a4366

Browse files
refactor: apply refactor-cleaner review findings
- Remove panic="abort" from [profile.release] to preserve PyO3's panic catch mechanism (prevents Python process crash on Rust panic) - Eliminate xtx pre.clone() by building posterior_precision directly from xtx_ref + prior_precision (avoids k*k Vec clone per iteration) - Change xtx_precomputed type from Option<&Vec<Vec<f64>>> to Option<&[Vec<f64>]> (idiomatic Rust, use .as_deref() at call sites) - Skip xtx_static computation when spike-and-slab is active (coordinate- wise sampling does not use XtX) - Remove extra blank lines left from scale_matrix deletion
1 parent ed00fd0 commit e9a4366

2 files changed

Lines changed: 20 additions & 18 deletions

File tree

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ rand_distr = "0.4"
1616
[profile.release]
1717
lto = "thin"
1818
codegen-units = 1
19-
panic = "abort"

src/sampler.rs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ fn run_single_chain_dynamic(
319319
sigma2_obs,
320320
&prior.beta_mean,
321321
&prior.beta_precision,
322-
xtx_seasonal.as_ref(),
322+
xtx_seasonal.as_deref(),
323323
);
324324
}
325325

@@ -422,9 +422,9 @@ fn run_single_chain_static(
422422
};
423423
let g = pre_end as f64; // Zellner's g-prior parameter
424424

425-
// Pre-compute X^TX for static and seasonal regression.
426-
// X is constant across all Gibbs iterations, so this is computed once.
427-
let xtx_static = if k > 0 {
425+
// Pre-compute X^TX for static regression (not needed when spike-and-slab
426+
// is active because it uses coordinate-wise sampling without XtX).
427+
let xtx_static = if k > 0 && !use_spike_slab {
428428
Some(cross_product_matrix(model.covariates(), pre_end))
429429
} else {
430430
None
@@ -513,7 +513,7 @@ fn run_single_chain_static(
513513
sigma2_obs,
514514
&prior.beta_mean,
515515
&prior.beta_precision,
516-
xtx_seasonal.as_ref(),
516+
xtx_seasonal.as_deref(),
517517
);
518518
}
519519
} else if k > 0 || k_seasonal > 0 {
@@ -534,7 +534,7 @@ fn run_single_chain_static(
534534
sigma2_obs,
535535
&prior.beta_mean,
536536
&prior.beta_precision,
537-
xtx_static.as_ref(),
537+
xtx_static.as_deref(),
538538
);
539539
for gj in gamma.iter_mut() {
540540
*gj = true;
@@ -553,7 +553,7 @@ fn run_single_chain_static(
553553
sigma2_obs,
554554
&prior.beta_mean,
555555
&prior.beta_precision,
556-
xtx_seasonal.as_ref(),
556+
xtx_seasonal.as_deref(),
557557
);
558558
}
559559
let sigma_guess = static_regression_prior
@@ -821,13 +821,17 @@ fn sample_beta_with_normal_prior<R: rand::Rng>(
821821
sigma2_obs: f64,
822822
prior_mean: &[f64],
823823
prior_precision: &[Vec<f64>],
824-
xtx_precomputed: Option<&Vec<Vec<f64>>>,
824+
xtx_precomputed: Option<&[Vec<f64>]>,
825825
) -> Vec<f64> {
826826
let k = x.len();
827827
// Use pre-computed X^TX if available (avoids O(k^2 T) per iteration)
828-
let xtx = match xtx_precomputed {
829-
Some(pre) => pre.clone(),
830-
None => cross_product_matrix(x, y_pre.len()),
828+
let xtx_owned;
829+
let xtx_ref: &[Vec<f64>] = match xtx_precomputed {
830+
Some(pre) => pre,
831+
None => {
832+
xtx_owned = cross_product_matrix(x, y_pre.len());
833+
&xtx_owned
834+
}
831835
};
832836

833837
let residuals: Vec<f64> = y_pre
@@ -846,10 +850,11 @@ fn sample_beta_with_normal_prior<R: rand::Rng>(
846850
.sum();
847851
}
848852

849-
let mut posterior_precision = xtx;
850-
for (i, row) in posterior_precision.iter_mut().enumerate() {
851-
for (j, value) in row.iter_mut().enumerate() {
852-
*value += prior_precision[i][j];
853+
// Build posterior_precision = XtX + prior_precision without cloning XtX
854+
let mut posterior_precision = vec![vec![0.0; k]; k];
855+
for i in 0..k {
856+
for j in 0..k {
857+
posterior_precision[i][j] = xtx_ref[i][j] + prior_precision[i][j];
853858
}
854859
}
855860

@@ -1139,8 +1144,6 @@ fn matrix_vector_product(matrix: &[Vec<f64>], vector: &[f64]) -> Vec<f64> {
11391144
.collect()
11401145
}
11411146

1142-
1143-
11441147
fn sample_post_period_states<R: rand::Rng>(
11451148
rng: &mut R,
11461149
last_pre_state: f64,

0 commit comments

Comments
 (0)