5
5
import jax
6
6
import jax .numpy as jnp
7
7
import emcee
8
+ import logging
8
9
9
10
10
11
def ln_analytic_evidence (ndim , cov ):
@@ -99,29 +100,33 @@ def run_example(
99
100
inv_cov = jnp .linalg .inv (cov )
100
101
training_proportion = 0.5
101
102
if flow_type == "RealNVP" :
102
- epochs_num = 5
103
+ epochs_num = 10 # 5
103
104
elif flow_type == "RQSpline" :
104
- epochs_num = 3
105
+ #epochs_num = 5
106
+ epochs_num = 110
105
107
106
108
# Beginning of path where plots will be saved
107
109
save_name_start = "examples/plots/" + flow_type
108
110
109
- temperature = 0.8
111
+ temperature = 0.9
110
112
standardize = True
111
113
verbose = True
112
-
114
+
113
115
# Spline params
114
- n_layers = 5
115
- n_bins = 5
116
+ n_layers = 3
117
+ n_bins = 128
116
118
hidden_size = [32 , 32 ]
117
119
spline_range = (- 10.0 , 10.0 )
118
120
121
+ if flow_type == "RQSpline" :
122
+ save_name_start += "_" + str (n_layers ) + "l_" + str (n_bins ) + "b_" + str (epochs_num ) + "e_" + str (int (training_proportion * 100 )) + "perc_" + str (temperature ) + "T" + "_emcee"
123
+
119
124
# Start timer.
120
125
clock = time .process_time ()
121
126
122
127
# Run multiple realisations.
123
- n_realisations = 1
124
- evidence_inv_summary = np .zeros ((n_realisations , 3 ))
128
+ n_realisations = 100
129
+ ln_evidence_inv_summary = np .zeros ((n_realisations , 5 ))
125
130
for i_realisation in range (n_realisations ):
126
131
if n_realisations > 0 :
127
132
hm .logs .info_log (
@@ -130,7 +135,7 @@ def run_example(
130
135
# Define the number of dimensions and the mean of the Gaussian
131
136
num_samples = nchains * samples_per_chain
132
137
# Initialize a PRNG key (you can use any valid key)
133
- key = jax .random .PRNGKey (0 )
138
+ key = jax .random .PRNGKey (i_realisation )
134
139
mean = jnp .zeros (ndim )
135
140
136
141
# Generate random samples from the 2D Gaussian distribution
@@ -139,7 +144,7 @@ def run_example(
139
144
samples = jnp .reshape (samples , (nchains , - 1 , ndim ))
140
145
lnprob = jnp .reshape (lnprob , (nchains , - 1 ))
141
146
142
- MCMC = False
147
+ MCMC = True
143
148
if MCMC :
144
149
nburn = 500
145
150
# Set up and run sampler.
@@ -151,7 +156,7 @@ def run_example(
151
156
rstate = np .random .get_state () # Set random state to repeatable
152
157
# across calls.
153
158
(pos , prob , state ) = sampler .run_mcmc (
154
- pos , samples_per_chain , rstate0 = rstate
159
+ pos , samples_per_chain , rstate0 = rstate , progress = True
155
160
)
156
161
samples = np .ascontiguousarray (sampler .chain [:, nburn :, :])
157
162
lnprob = np .ascontiguousarray (sampler .lnprobability [:, nburn :])
@@ -191,92 +196,68 @@ def run_example(
191
196
ev = hm .Evidence (chains_test .nchains , model )
192
197
# ev.set_mean_shift(0.0)
193
198
ev .add_chains (chains_test )
194
- ln_evidence , ln_evidence_std = ev .compute_ln_evidence ()
199
+ err_ln_inv_evidence = ev .compute_ln_inv_evidence_errors ()
195
200
196
201
# Compute analytic evidence.
197
202
if i_realisation == 0 :
198
203
ln_evidence_analytic = ln_analytic_evidence (ndim , cov )
199
204
200
- # ======================================================================
201
- # Display evidence computation results.
202
- # ======================================================================
203
205
hm .logs .info_log ("---------------------------------" )
206
+ hm .logs .info_log ("The inverse evidence in log space is:" )
204
207
hm .logs .info_log (
205
- "Evidence: analytic = {}, estimated = {}" .format (
206
- np . exp ( ln_evidence_analytic ), np . exp ( ln_evidence )
208
+ "ln_inv_evidence = {} +/- {}" .format (
209
+ ev . ln_evidence_inv , err_ln_inv_evidence
207
210
)
208
211
)
209
212
hm .logs .info_log (
210
- "Evidence: std = {}, std / estimate = {}" .format (
211
- np . exp ( ln_evidence_std ), np . exp ( ln_evidence_std - ln_evidence )
213
+ "ln evidence = {} +/- {} {}" .format (
214
+ - ev . ln_evidence_inv , - err_ln_inv_evidence [ 1 ], - err_ln_inv_evidence [ 0 ]
212
215
)
213
216
)
214
- diff = np .log (np .abs (np .exp (ln_evidence_analytic ) - np .exp (ln_evidence )))
217
+ hm .logs .info_log ("Analytic ln evidence is {}" .format (ln_evidence_analytic ))
218
+ delta = - ln_evidence_analytic - ev .ln_evidence_inv
215
219
hm .logs .info_log (
216
- "Evidence: |analytic - estimate| / estimate = {}" .format (
217
- np .exp (diff - ln_evidence )
218
- )
219
- )
220
- # ======================================================================
221
- # Display inverse evidence computation results.
222
- # ======================================================================
223
- hm .logs .debug_log ("---------------------------------" )
224
- hm .logs .debug_log (
225
- "Inv Evidence: analytic = {}, estimate = {}" .format (
226
- np .exp (- ln_evidence_analytic ), ev .evidence_inv
227
- )
228
- )
229
- hm .logs .debug_log (
230
- "Inv Evidence: std = {}, std / estimate = {}" .format (
231
- np .sqrt (ev .evidence_inv_var ),
232
- np .sqrt (ev .evidence_inv_var ) / ev .evidence_inv ,
233
- )
234
- )
235
- hm .logs .debug_log (
236
- "Inv Evidence: kurtosis = {}, sqrt( 2 / ( n_eff - 1 ) ) = {}" .format (
237
- ev .kurtosis , np .sqrt (2.0 / (ev .n_eff - 1 ))
238
- )
239
- )
240
- hm .logs .debug_log (
241
- "Inv Evidence: sqrt( var(var) ) / var = {}" .format (
242
- np .sqrt (ev .evidence_inv_var_var ) / ev .evidence_inv_var
220
+ "Difference between analytic and harmonic is {} +- {} {}" .format (
221
+ delta , err_ln_inv_evidence [0 ], err_ln_inv_evidence [1 ]
243
222
)
244
223
)
224
+
225
+ hm .logs .info_log ("kurtosis = {}" .format (ev .kurtosis ))
226
+ hm .logs .info_log (" Aim for ~3." )
227
+ check = np .exp (0.5 * ev .ln_evidence_inv_var_var - ev .ln_evidence_inv_var )
228
+ hm .logs .info_log ("sqrt( var(var) ) / var = {}" .format (check ))
245
229
hm .logs .info_log (
246
- "Inv Evidence: |analytic - estimate| / estimate = {}" .format (
247
- np .abs (np .exp (- ln_evidence_analytic ) - ev .evidence_inv )
248
- / ev .evidence_inv
249
- )
230
+ " Aim for sqrt( 2/(n_eff-1) ) = {}" .format (np .sqrt (2.0 / (ev .n_eff - 1 )))
250
231
)
251
232
252
233
# ===========================================================================
253
234
# Display more technical details
254
235
# ===========================================================================
255
- hm .logs .debug_log ("---------------------------------" )
256
- hm .logs .debug_log ("Technical Details" )
257
- hm .logs .debug_log ("---------------------------------" )
258
- hm .logs .debug_log (
236
+ hm .logs .info_log ("---------------------------------" )
237
+ hm .logs .info_log ("Technical Details" )
238
+ hm .logs .info_log ("---------------------------------" )
239
+ hm .logs .info_log (
259
240
"lnargmax = {}, lnargmin = {}" .format (ev .lnargmax , ev .lnargmin )
260
241
)
261
- hm .logs .debug_log (
242
+ hm .logs .info_log (
262
243
"lnprobmax = {}, lnprobmin = {}" .format (ev .lnprobmax , ev .lnprobmin )
263
244
)
264
- hm .logs .debug_log (
245
+ hm .logs .info_log (
265
246
"lnpredictmax = {}, lnpredictmin = {}" .format (
266
247
ev .lnpredictmax , ev .lnpredictmin
267
248
)
268
249
)
269
- hm .logs .debug_log ("---------------------------------" )
270
- hm .logs .debug_log (
250
+ hm .logs .info_log ("---------------------------------" )
251
+ hm .logs .info_log (
271
252
"shift = {}, shift setting = {}" .format (ev .shift_value , ev .shift )
272
253
)
273
- hm .logs .debug_log ("running sum total = {}" .format (sum (ev .running_sum )))
274
- hm .logs .debug_log ("running sum = \n {}" .format (ev .running_sum ))
275
- hm .logs .debug_log ("nsamples per chain = \n {}" .format (ev .nsamples_per_chain ))
276
- hm .logs .debug_log (
254
+ hm .logs .info_log ("running sum total = {}" .format (sum (ev .running_sum )))
255
+ hm .logs .info_log ("running sum = \n {}" .format (ev .running_sum ))
256
+ hm .logs .info_log ("nsamples per chain = \n {}" .format (ev .nsamples_per_chain ))
257
+ hm .logs .info_log (
277
258
"nsamples eff per chain = \n {}" .format (ev .nsamples_eff_per_chain )
278
259
)
279
- hm .logs .debug_log ("===============================" )
260
+ hm .logs .info_log ("===============================" )
280
261
281
262
# ======================================================================
282
263
# Create corner/triangle plot.
@@ -314,28 +295,31 @@ def run_example(
314
295
315
296
plt .show ()
316
297
317
- evidence_inv_summary [i_realisation , 0 ] = ev .evidence_inv
318
- evidence_inv_summary [i_realisation , 1 ] = ev .evidence_inv_var
319
- evidence_inv_summary [i_realisation , 2 ] = ev .evidence_inv_var_var
298
+ # Save out realisations for violin plot.
299
+ ln_evidence_inv_summary [i_realisation , 0 ] = ev .ln_evidence_inv
300
+ ln_evidence_inv_summary [i_realisation , 1 ] = err_ln_inv_evidence [0 ]
301
+ ln_evidence_inv_summary [i_realisation , 2 ] = err_ln_inv_evidence [1 ]
302
+ ln_evidence_inv_summary [i_realisation , 3 ] = ev .ln_evidence_inv_var
303
+ ln_evidence_inv_summary [i_realisation , 4 ] = ev .ln_evidence_inv_var_var
320
304
321
305
clock = time .process_time () - clock
322
306
hm .logs .info_log ("Execution_time = {}s" .format (clock ))
323
307
324
308
if n_realisations > 1 :
325
309
save_name = (
326
310
save_name_start
327
- + "_gaussian_nondiagcov_evidence_inv "
311
+ + "_gaussian_nondiagcov_ln_evidence_inv "
328
312
+ "_realisations_{}D.dat" .format (ndim )
329
313
)
330
- np .savetxt (save_name , evidence_inv_summary )
331
- evidence_inv_analytic_summary = np .zeros (1 )
332
- evidence_inv_analytic_summary [0 ] = np . exp ( - ln_evidence_analytic )
314
+ np .savetxt (save_name , ln_evidence_inv_summary )
315
+ ln_evidence_inv_analytic_summary = np .zeros (1 )
316
+ ln_evidence_inv_analytic_summary [0 ] = - ln_evidence_analytic
333
317
save_name = (
334
318
save_name_start
335
- + "_gaussian_nondiagcov_evidence_inv "
319
+ + "_gaussian_nondiagcov_ln_evidence_inv "
336
320
+ "_analytic_{}D.dat" .format (ndim )
337
321
)
338
- np .savetxt (save_name , evidence_inv_analytic_summary )
322
+ np .savetxt (save_name , ln_evidence_inv_analytic_summary )
339
323
340
324
created_plots = True
341
325
if created_plots :
@@ -344,14 +328,14 @@ def run_example(
344
328
345
329
if __name__ == "__main__" :
346
330
# Setup logging config.
347
- hm .logs .setup_logging ()
331
+ hm .logs .setup_logging (default_level = logging . DEBUG )
348
332
349
333
# Define parameters.
350
- ndim = 5
351
- nchains = 100
334
+ ndim = 21
335
+ nchains = 80
352
336
samples_per_chain = 5000
353
- flow_str = "RealNVP"
354
- # flow_str = "RQSpline"
337
+ # flow_str = "RealNVP"
338
+ flow_str = "RQSpline"
355
339
np .random .seed (10 ) # used for initializing covariance matrix
356
340
357
341
hm .logs .info_log ("Non-diagonal Covariance Gaussian example" )
@@ -365,4 +349,4 @@ def run_example(
365
349
hm .logs .debug_log ("-------------------------" )
366
350
367
351
# Run example.
368
- run_example (flow_str , ndim , nchains , samples_per_chain , plot_corner = False )
352
+ run_example (flow_str , ndim , nchains , samples_per_chain , plot_corner = True )
0 commit comments