Skip to content

Commit 55c970e

Browse files
Merge pull request #307 from astro-informatics/20D_gaussian
Add batching of evidence estimation inputs and update examples
2 parents 9726ce2 + 16b72d5 commit 55c970e

File tree

5 files changed

+393
-177
lines changed

5 files changed

+393
-177
lines changed

examples/gaussian_nondiagcov.py

+63-79
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
import emcee
8+
import logging
89

910

1011
def ln_analytic_evidence(ndim, cov):
@@ -99,29 +100,33 @@ def run_example(
99100
inv_cov = jnp.linalg.inv(cov)
100101
training_proportion = 0.5
101102
if flow_type == "RealNVP":
102-
epochs_num = 5
103+
epochs_num = 10 #5
103104
elif flow_type == "RQSpline":
104-
epochs_num = 3
105+
#epochs_num = 5
106+
epochs_num = 110
105107

106108
# Beginning of path where plots will be saved
107109
save_name_start = "examples/plots/" + flow_type
108110

109-
temperature = 0.8
111+
temperature = 0.9
110112
standardize = True
111113
verbose = True
112-
114+
113115
# Spline params
114-
n_layers = 5
115-
n_bins = 5
116+
n_layers = 3
117+
n_bins = 128
116118
hidden_size = [32, 32]
117119
spline_range = (-10.0, 10.0)
118120

121+
if flow_type == "RQSpline":
122+
save_name_start += "_" + str(n_layers) + "l_" + str(n_bins) + "b_" + str(epochs_num) + "e_" + str(int(training_proportion * 100)) + "perc_" + str(temperature) + "T" + "_emcee"
123+
119124
# Start timer.
120125
clock = time.process_time()
121126

122127
# Run multiple realisations.
123-
n_realisations = 1
124-
evidence_inv_summary = np.zeros((n_realisations, 3))
128+
n_realisations = 100
129+
ln_evidence_inv_summary = np.zeros((n_realisations, 5))
125130
for i_realisation in range(n_realisations):
126131
if n_realisations > 0:
127132
hm.logs.info_log(
@@ -130,7 +135,7 @@ def run_example(
130135
# Define the number of dimensions and the mean of the Gaussian
131136
num_samples = nchains * samples_per_chain
132137
# Initialize a PRNG key (you can use any valid key)
133-
key = jax.random.PRNGKey(0)
138+
key = jax.random.PRNGKey(i_realisation)
134139
mean = jnp.zeros(ndim)
135140

136141
# Generate random samples from the 2D Gaussian distribution
@@ -139,7 +144,7 @@ def run_example(
139144
samples = jnp.reshape(samples, (nchains, -1, ndim))
140145
lnprob = jnp.reshape(lnprob, (nchains, -1))
141146

142-
MCMC = False
147+
MCMC = True
143148
if MCMC:
144149
nburn = 500
145150
# Set up and run sampler.
@@ -151,7 +156,7 @@ def run_example(
151156
rstate = np.random.get_state() # Set random state to repeatable
152157
# across calls.
153158
(pos, prob, state) = sampler.run_mcmc(
154-
pos, samples_per_chain, rstate0=rstate
159+
pos, samples_per_chain, rstate0=rstate, progress=True
155160
)
156161
samples = np.ascontiguousarray(sampler.chain[:, nburn:, :])
157162
lnprob = np.ascontiguousarray(sampler.lnprobability[:, nburn:])
@@ -191,92 +196,68 @@ def run_example(
191196
ev = hm.Evidence(chains_test.nchains, model)
192197
# ev.set_mean_shift(0.0)
193198
ev.add_chains(chains_test)
194-
ln_evidence, ln_evidence_std = ev.compute_ln_evidence()
199+
err_ln_inv_evidence = ev.compute_ln_inv_evidence_errors()
195200

196201
# Compute analytic evidence.
197202
if i_realisation == 0:
198203
ln_evidence_analytic = ln_analytic_evidence(ndim, cov)
199204

200-
# ======================================================================
201-
# Display evidence computation results.
202-
# ======================================================================
203205
hm.logs.info_log("---------------------------------")
206+
hm.logs.info_log("The inverse evidence in log space is:")
204207
hm.logs.info_log(
205-
"Evidence: analytic = {}, estimated = {}".format(
206-
np.exp(ln_evidence_analytic), np.exp(ln_evidence)
208+
"ln_inv_evidence = {} +/- {}".format(
209+
ev.ln_evidence_inv, err_ln_inv_evidence
207210
)
208211
)
209212
hm.logs.info_log(
210-
"Evidence: std = {}, std / estimate = {}".format(
211-
np.exp(ln_evidence_std), np.exp(ln_evidence_std - ln_evidence)
213+
"ln evidence = {} +/- {} {}".format(
214+
-ev.ln_evidence_inv, -err_ln_inv_evidence[1], -err_ln_inv_evidence[0]
212215
)
213216
)
214-
diff = np.log(np.abs(np.exp(ln_evidence_analytic) - np.exp(ln_evidence)))
217+
hm.logs.info_log("Analytic ln evidence is {}".format(ln_evidence_analytic))
218+
delta = -ln_evidence_analytic - ev.ln_evidence_inv
215219
hm.logs.info_log(
216-
"Evidence: |analytic - estimate| / estimate = {}".format(
217-
np.exp(diff - ln_evidence)
218-
)
219-
)
220-
# ======================================================================
221-
# Display inverse evidence computation results.
222-
# ======================================================================
223-
hm.logs.debug_log("---------------------------------")
224-
hm.logs.debug_log(
225-
"Inv Evidence: analytic = {}, estimate = {}".format(
226-
np.exp(-ln_evidence_analytic), ev.evidence_inv
227-
)
228-
)
229-
hm.logs.debug_log(
230-
"Inv Evidence: std = {}, std / estimate = {}".format(
231-
np.sqrt(ev.evidence_inv_var),
232-
np.sqrt(ev.evidence_inv_var) / ev.evidence_inv,
233-
)
234-
)
235-
hm.logs.debug_log(
236-
"Inv Evidence: kurtosis = {}, sqrt( 2 / ( n_eff - 1 ) ) = {}".format(
237-
ev.kurtosis, np.sqrt(2.0 / (ev.n_eff - 1))
238-
)
239-
)
240-
hm.logs.debug_log(
241-
"Inv Evidence: sqrt( var(var) ) / var = {}".format(
242-
np.sqrt(ev.evidence_inv_var_var) / ev.evidence_inv_var
220+
"Difference between analytic and harmonic is {} +- {} {}".format(
221+
delta, err_ln_inv_evidence[0], err_ln_inv_evidence[1]
243222
)
244223
)
224+
225+
hm.logs.info_log("kurtosis = {}".format(ev.kurtosis))
226+
hm.logs.info_log(" Aim for ~3.")
227+
check = np.exp(0.5 * ev.ln_evidence_inv_var_var - ev.ln_evidence_inv_var)
228+
hm.logs.info_log("sqrt( var(var) ) / var = {}".format(check))
245229
hm.logs.info_log(
246-
"Inv Evidence: |analytic - estimate| / estimate = {}".format(
247-
np.abs(np.exp(-ln_evidence_analytic) - ev.evidence_inv)
248-
/ ev.evidence_inv
249-
)
230+
" Aim for sqrt( 2/(n_eff-1) ) = {}".format(np.sqrt(2.0 / (ev.n_eff - 1)))
250231
)
251232

252233
# ===========================================================================
253234
# Display more technical details
254235
# ===========================================================================
255-
hm.logs.debug_log("---------------------------------")
256-
hm.logs.debug_log("Technical Details")
257-
hm.logs.debug_log("---------------------------------")
258-
hm.logs.debug_log(
236+
hm.logs.info_log("---------------------------------")
237+
hm.logs.info_log("Technical Details")
238+
hm.logs.info_log("---------------------------------")
239+
hm.logs.info_log(
259240
"lnargmax = {}, lnargmin = {}".format(ev.lnargmax, ev.lnargmin)
260241
)
261-
hm.logs.debug_log(
242+
hm.logs.info_log(
262243
"lnprobmax = {}, lnprobmin = {}".format(ev.lnprobmax, ev.lnprobmin)
263244
)
264-
hm.logs.debug_log(
245+
hm.logs.info_log(
265246
"lnpredictmax = {}, lnpredictmin = {}".format(
266247
ev.lnpredictmax, ev.lnpredictmin
267248
)
268249
)
269-
hm.logs.debug_log("---------------------------------")
270-
hm.logs.debug_log(
250+
hm.logs.info_log("---------------------------------")
251+
hm.logs.info_log(
271252
"shift = {}, shift setting = {}".format(ev.shift_value, ev.shift)
272253
)
273-
hm.logs.debug_log("running sum total = {}".format(sum(ev.running_sum)))
274-
hm.logs.debug_log("running sum = \n{}".format(ev.running_sum))
275-
hm.logs.debug_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
276-
hm.logs.debug_log(
254+
hm.logs.info_log("running sum total = {}".format(sum(ev.running_sum)))
255+
hm.logs.info_log("running sum = \n{}".format(ev.running_sum))
256+
hm.logs.info_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
257+
hm.logs.info_log(
277258
"nsamples eff per chain = \n{}".format(ev.nsamples_eff_per_chain)
278259
)
279-
hm.logs.debug_log("===============================")
260+
hm.logs.info_log("===============================")
280261

281262
# ======================================================================
282263
# Create corner/triangle plot.
@@ -314,28 +295,31 @@ def run_example(
314295

315296
plt.show()
316297

317-
evidence_inv_summary[i_realisation, 0] = ev.evidence_inv
318-
evidence_inv_summary[i_realisation, 1] = ev.evidence_inv_var
319-
evidence_inv_summary[i_realisation, 2] = ev.evidence_inv_var_var
298+
# Save out realisations for violin plot.
299+
ln_evidence_inv_summary[i_realisation, 0] = ev.ln_evidence_inv
300+
ln_evidence_inv_summary[i_realisation, 1] = err_ln_inv_evidence[0]
301+
ln_evidence_inv_summary[i_realisation, 2] = err_ln_inv_evidence[1]
302+
ln_evidence_inv_summary[i_realisation, 3] = ev.ln_evidence_inv_var
303+
ln_evidence_inv_summary[i_realisation, 4] = ev.ln_evidence_inv_var_var
320304

321305
clock = time.process_time() - clock
322306
hm.logs.info_log("Execution_time = {}s".format(clock))
323307

324308
if n_realisations > 1:
325309
save_name = (
326310
save_name_start
327-
+ "_gaussian_nondiagcov_evidence_inv"
311+
+ "_gaussian_nondiagcov_ln_evidence_inv"
328312
+ "_realisations_{}D.dat".format(ndim)
329313
)
330-
np.savetxt(save_name, evidence_inv_summary)
331-
evidence_inv_analytic_summary = np.zeros(1)
332-
evidence_inv_analytic_summary[0] = np.exp(-ln_evidence_analytic)
314+
np.savetxt(save_name, ln_evidence_inv_summary)
315+
ln_evidence_inv_analytic_summary = np.zeros(1)
316+
ln_evidence_inv_analytic_summary[0] = -ln_evidence_analytic
333317
save_name = (
334318
save_name_start
335-
+ "_gaussian_nondiagcov_evidence_inv"
319+
+ "_gaussian_nondiagcov_ln_evidence_inv"
336320
+ "_analytic_{}D.dat".format(ndim)
337321
)
338-
np.savetxt(save_name, evidence_inv_analytic_summary)
322+
np.savetxt(save_name, ln_evidence_inv_analytic_summary)
339323

340324
created_plots = True
341325
if created_plots:
@@ -344,14 +328,14 @@ def run_example(
344328

345329
if __name__ == "__main__":
346330
# Setup logging config.
347-
hm.logs.setup_logging()
331+
hm.logs.setup_logging(default_level=logging.DEBUG)
348332

349333
# Define parameters.
350-
ndim = 5
351-
nchains = 100
334+
ndim = 21
335+
nchains = 80
352336
samples_per_chain = 5000
353-
flow_str = "RealNVP"
354-
# flow_str = "RQSpline"
337+
#flow_str = "RealNVP"
338+
flow_str = "RQSpline"
355339
np.random.seed(10) # used for initializing covariance matrix
356340

357341
hm.logs.info_log("Non-diagonal Covariance Gaussian example")
@@ -365,4 +349,4 @@ def run_example(
365349
hm.logs.debug_log("-------------------------")
366350

367351
# Run example.
368-
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=False)
352+
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=True)

0 commit comments

Comments
 (0)