diff --git a/sts_jax/structural_time_series/learning.py b/sts_jax/structural_time_series/learning.py index 45e18a9..adeafed 100644 --- a/sts_jax/structural_time_series/learning.py +++ b/sts_jax/structural_time_series/learning.py @@ -162,9 +162,10 @@ def unnorm_log_pos(_unc_params): return lp # Initialize the HMC sampler using window_adaptations - warmup = blackjax.window_adaptation(blackjax.nuts, unnorm_log_pos, num_steps=warmup_steps, progress_bar=verbose) + warmup = blackjax.window_adaptation(blackjax.nuts, unnorm_log_pos, progress_bar=verbose) init_key, key = jr.split(key) - hmc_initial_state, hmc_kernel, _ = warmup.run(init_key, initial_unc_params) + (hmc_initial_state, parameters), _ = warmup.run(init_key, initial_unc_params, num_steps=warmup_steps) + hmc_kernel = blackjax.nuts(unnorm_log_pos, **parameters).step @jit def hmc_step(hmc_state, step_key): @@ -180,7 +181,7 @@ def hmc_step(hmc_state, step_key): for _ in pbar: step_key, key = jr.split(key) hmc_state, params = hmc_step(hmc_state, step_key) - log_probs.append(-hmc_state.potential_energy) + log_probs.append(-hmc_state.logdensity) samples.append(params) # Combine the samples into a single pytree