-
Notifications
You must be signed in to change notification settings - Fork 286
Expand file tree
/
Copy pathtest_mcmc.py
More file actions
1305 lines (1093 loc) · 46.1 KB
/
test_mcmc.py
File metadata and controls
1305 lines (1093 loc) · 46.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import os
import sys
import numpy as np
from numpy.testing import assert_allclose
import pytest
import jax
from jax import device_get, jit, lax, pmap, random, vmap
import jax.numpy as jnp
from jax.scipy.special import logit
import numpyro
import numpyro.distributions as dist
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer import AIES, ESS, HMC, MCMC, NUTS, SA, BarkerMH, init_to_value
from numpyro.infer.hmc import hmc
from numpyro.infer.reparam import TransformReparam
from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect, is_prng_key
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS])
@pytest.mark.parametrize("dense_mass", [False, True])
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
true_mean, true_std = 1.0, 0.5
num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
if kernel_cls in [AIES, ESS]:
num_chains = 10
kernel = kernel_cls(potential_fn=potential_fn)
init_params = random.normal(random.key(1), (num_chains,))
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
progress_bar=False,
num_chains=num_chains,
chain_method="vectorized",
)
elif kernel_cls in [SA, BarkerMH]:
kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass)
init_params = jnp.array(0.0)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
else:
kernel = kernel_cls(
potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass
)
init_params = jnp.array(0.0)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc.run(random.key(0), init_params=init_params)
mcmc.print_summary()
hmc_states = mcmc.get_samples()
assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
if "JAX_ENABLE_X64" in os.environ:
assert hmc_states.dtype == jnp.float64
@pytest.mark.parametrize("regularize", [True, False])
def test_correlated_mvn(regularize):
# This requires dense mass matrix estimation.
D = 5
num_warmup, num_samples = 5000, 8000
true_mean = 0.0
a = jnp.tril(
0.5 * jnp.fliplr(jnp.eye(D))
+ 0.1 * jnp.exp(random.normal(random.key(0), shape=(D, D)))
)
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(
potential_fn=potential_fn, dense_mass=True, regularize_mass_matrix=regularize
)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.key(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS])
def test_logistic_regression_x64(kernel_cls):
if kernel_cls in [AIES, ESS] and "CI" in os.environ:
pytest.skip("reduce time for CI.")
N, dim = 3000, 3
key1, key2, key3 = random.split(random.key(0), 3)
data = random.normal(key1, (N, dim))
true_coefs = jnp.arange(1.0, dim + 1.0)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(key2)
def model(labels):
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
if kernel_cls in [AIES, ESS]:
if kernel_cls is AIES:
num_chains = 16
else:
num_chains = 10
samples_each_chain = 8000
num_warmup, num_samples = (10_000, samples_each_chain * num_chains)
kernel = kernel_cls(model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=samples_each_chain,
progress_bar=False,
num_chains=num_chains,
chain_method="vectorized",
)
elif kernel_cls is SA:
num_warmup, num_samples = (100000, 100000)
kernel = SA(model=model, adapt_state_size=9)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
elif kernel_cls is BarkerMH:
num_warmup, num_samples = (2000, 12000)
kernel = BarkerMH(model=model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
else:
num_warmup, num_samples = (1000, 8000)
kernel = kernel_cls(
model=model, trajectory_length=8, find_heuristic_step_size=True
)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc.run(key3, labels)
mcmc.print_summary()
samples = mcmc.get_samples()
assert samples["logits"].shape == (num_samples, N)
assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.4)
if "JAX_ENABLE_X64" in os.environ:
assert samples["coefs"].dtype == jnp.float64
@pytest.mark.parametrize("forward_mode_differentiation", [True, False])
def test_uniform_normal(forward_mode_differentiation):
true_coef = 0.9
num_warmup, num_samples = 1000, 1000
def model(data):
alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
loc = numpyro.sample(
"loc",
dist.TransformedDistribution(
dist.Uniform(0, 1), AffineTransform(0, alpha)
),
)
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.key(0), (1000,))
kernel = NUTS(
model=model, forward_mode_differentiation=forward_mode_differentiation
)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.key(2), data, collect_warmup=True)
assert mcmc.post_warmup_state is not None
warmup_samples = mcmc.get_samples()
mcmc.run(random.key(3), data)
samples = mcmc.get_samples()
assert len(warmup_samples["loc"]) == num_warmup
assert len(samples["loc"]) == num_samples
assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(random.key(3), data)
samples = mcmc.get_samples()
assert len(samples["loc"]) == num_samples
assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
@pytest.mark.parametrize("max_tree_depth", [10, (5, 10)])
def test_improper_normal(max_tree_depth):
true_coef = 0.9
def model(data):
alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
loc = numpyro.sample(
"loc",
dist.TransformedDistribution(
dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)
),
)
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.key(0), (1000,))
kernel = NUTS(model=model, max_tree_depth=max_tree_depth)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.key(0), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.007)
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS])
def test_beta_bernoulli_x64(kernel_cls):
if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
if kernel_cls is ESS and "CI" in os.environ:
pytest.skip("reduce time for CI.")
num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta))
numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.key(1), (1000,))
if kernel_cls in [AIES, ESS]:
num_chains = 10
kernel = kernel_cls(model=model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
progress_bar=False,
num_chains=num_chains,
chain_method="vectorized",
)
elif kernel_cls is SA:
kernel = SA(model=model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
elif kernel_cls is BarkerMH:
kernel = BarkerMH(model=model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
else:
kernel = kernel_cls(model=model, trajectory_length=0.1)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc.run(random.key(2), data)
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
if "JAX_ENABLE_X64" in os.environ:
assert samples["p_latent"].dtype == jnp.float64
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, BarkerMH])
@pytest.mark.parametrize("dense_mass", [False, True])
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
num_warmup, num_samples = 100, 20000
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.key(1), (2000,))
if kernel_cls is BarkerMH:
kernel = BarkerMH(model=model, dense_mass=dense_mass)
else:
kernel = kernel_cls(model, trajectory_length=1.0, dense_mass=dense_mass)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc.run(random.key(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02)
if "JAX_ENABLE_X64" in os.environ:
assert samples["p_latent"].dtype == jnp.float64
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, BarkerMH])
@pytest.mark.parametrize("rho", [-0.7, 0.8])
def test_dense_mass(kernel_cls, rho):
num_warmup, num_samples = 20000, 10000
true_cov = jnp.array([[10.0, rho], [rho, 0.1]])
def model():
numpyro.sample(
"x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
)
if kernel_cls is HMC or kernel_cls is NUTS:
kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
elif kernel_cls is BarkerMH:
kernel = BarkerMH(model, dense_mass=True)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc.run(random.key(0))
mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
if kernel_cls is HMC or kernel_cls is NUTS:
mass_matrix_sqrt = mass_matrix_sqrt[("x",)]
mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
estimated_cov = jnp.linalg.inv(mass_matrix)
assert_allclose(estimated_cov, true_cov, rtol=0.10)
samples = mcmc.get_samples()["x"]
assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20)
assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
def test_change_point_x64():
# Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
if sys.version_info.minor == 9:
pytest.skip("Skip test on Python 3.9")
num_warmup, num_samples = 1000, 3000
def model(data):
alpha = 1 / jnp.mean(data.astype(np.float32))
lambda1 = numpyro.sample("lambda1", dist.Exponential(alpha))
lambda2 = numpyro.sample("lambda2", dist.Exponential(alpha))
tau = numpyro.sample("tau", dist.Uniform(0, 1))
lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
numpyro.sample("obs", dist.Poisson(lambda12), obs=data)
# fmt: off
count_data = jnp.array([
13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22,
12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2,
15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 1, 20, 12, 35, 17, 23, 17, 4, 2,
31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22])
# fmt: on
kernel = NUTS(
model=model, init_strategy=init_to_value(values={"lambda1": 1, "lambda2": 72})
)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.key(4), count_data)
samples = mcmc.get_samples()
tau_posterior = (samples["tau"] * len(count_data)).astype(jnp.int32)
tau_values, counts = np.unique(tau_posterior, return_counts=True)
mode_ind = jnp.argmax(counts)
mode = tau_values[mode_ind]
assert mode == 44
if "JAX_ENABLE_X64" in os.environ:
assert samples["lambda1"].dtype == jnp.float64
assert samples["lambda2"].dtype == jnp.float64
assert samples["tau"].dtype == jnp.float64
@pytest.mark.parametrize("with_logits", ["True", "False"])
def test_binomial_stable_x64(with_logits):
if "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
# Ref: https://github.com/pyro-ppl/pyro/issues/1706
num_warmup, num_samples = 200, 200
def model(data):
p = numpyro.sample("p", dist.Beta(1.0, 1.0))
if with_logits:
logits = logit(p)
numpyro.sample(
"obs", dist.Binomial(data["n"], logits=logits), obs=data["x"]
)
else:
numpyro.sample("obs", dist.Binomial(data["n"], probs=p), obs=data["x"])
data = {"n": 5000000, "x": 3849}
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.key(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["p"], 0), data["x"] / data["n"], rtol=0.05)
if "JAX_ENABLE_X64" in os.environ:
assert samples["p"].dtype == jnp.float64
def test_improper_prior():
true_mean, true_std = 1.0, 2.0
num_warmup, num_samples = 1000, 8000
def model(data):
mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
std = numpyro.sample(
"std", dist.ImproperUniform(dist.constraints.positive, (), ())
)
return numpyro.sample("obs", dist.Normal(mean, std), obs=data)
data = dist.Normal(true_mean, true_std).sample(random.key(1), (2000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.key(2), data)
mcmc.run(random.key(2), data)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["mean"]), true_mean, rtol=0.05)
assert_allclose(jnp.mean(samples["std"]), true_std, rtol=0.05)
def test_mcmc_progbar():
true_mean, true_std = 1.0, 2.0
num_warmup, num_samples = 10, 10
def model(data):
mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False))
return numpyro.sample("obs", dist.Normal(mean, std), obs=data)
data = dist.Normal(true_mean, true_std).sample(random.key(1), (2000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.key(2), data)
mcmc.run(random.key(3), data)
mcmc1 = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
)
mcmc1.run(random.key(2), data)
with pytest.raises(AssertionError):
jax.tree.all(
jax.tree.map(
partial(assert_allclose, atol=1e-4, rtol=1e-4),
mcmc1.get_samples(),
mcmc.get_samples(),
)
)
mcmc1.warmup(random.key(2), data)
mcmc1.run(random.key(3), data)
jax.tree.all(
jax.tree.map(
partial(assert_allclose, atol=1e-4, rtol=1e-4),
mcmc1.get_samples(),
mcmc.get_samples(),
)
)
jax.tree.all(
jax.tree.map(
partial(assert_allclose, atol=1e-4, rtol=1e-4),
jax.tree.map(
lambda x: random.key_data(x) if is_prng_key(x) else x,
mcmc1.post_warmup_state,
),
jax.tree.map(
lambda x: random.key_data(x) if is_prng_key(x) else x,
mcmc.post_warmup_state,
),
)
)
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS])
@pytest.mark.parametrize("adapt_step_size", [True, False])
def test_diverging(kernel_cls, adapt_step_size):
data = random.normal(random.key(0), (1000,))
def model(data):
loc = numpyro.sample("loc", dist.Normal(0.0, 1.0))
numpyro.sample("obs", dist.Normal(loc, 1), obs=data)
kernel = kernel_cls(
model, step_size=10.0, adapt_step_size=adapt_step_size, adapt_mass_matrix=False
)
num_warmup = num_samples = 1000
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.key(1), data, extra_fields=["diverging"], collect_warmup=True)
warmup_divergences = mcmc.get_extra_fields()["diverging"].sum()
mcmc.run(random.key(2), data, extra_fields=["diverging"])
num_divergences = warmup_divergences + mcmc.get_extra_fields()["diverging"].sum()
if adapt_step_size:
assert num_divergences <= num_warmup
else:
assert_allclose(num_divergences, num_warmup + num_samples)
def test_prior_with_sample_shape():
data = {
"J": 8,
"y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
"sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
}
def schools_model():
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
theta = numpyro.sample("theta", dist.Normal(mu, tau), sample_shape=(data["J"],))
numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])
num_samples = 500
mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
mcmc.run(random.key(0))
assert mcmc.get_samples()["theta"].shape == (num_samples, data["J"])
@pytest.mark.parametrize("num_chains", [1, 2])
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
@pytest.mark.parametrize("progress_bar", [True, False])
@pytest.mark.filterwarnings("ignore:There are not enough devices:UserWarning")
def test_empty_model(num_chains, chain_method, progress_bar):
def model():
pass
mcmc = MCMC(
NUTS(model),
num_warmup=10,
num_samples=10,
num_chains=num_chains,
chain_method=chain_method,
progress_bar=progress_bar,
)
mcmc.run(random.key(0))
assert mcmc.get_samples() == {}
@pytest.mark.parametrize("use_init_params", [False, True])
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
@pytest.mark.skipif(
"XLA_FLAGS" not in os.environ,
reason="without this mark, we have duplicated tests in Travis",
)
def test_chain(use_init_params, chain_method):
N, dim = 3000, 3
num_chains = 2
num_warmup, num_samples = 5000, 5000
data = random.normal(random.key(0), (N, dim))
true_coefs = jnp.arange(1.0, dim + 1.0)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.key(1))
def model(labels):
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
numpyro.deterministic("logits", logits)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
kernel = NUTS(model=model)
mcmc = MCMC(
kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains
)
mcmc.chain_method = chain_method
init_params = (
None
if not use_init_params
else {"coefs": jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
)
mcmc.run(random.key(2), labels, init_params=init_params)
samples_flat = mcmc.get_samples()
assert samples_flat["coefs"].shape[0] == num_chains * num_samples
samples = mcmc.get_samples(group_by_chain=True)
assert samples["coefs"].shape[:2] == (num_chains, num_samples)
assert_allclose(jnp.mean(samples_flat["coefs"], 0), true_coefs, atol=0.21)
# test if reshape works
device_get(samples_flat["coefs"].reshape(-1))
@pytest.mark.parametrize("kernel_cls", [HMC, NUTS])
@pytest.mark.parametrize(
"chain_method",
[
pytest.param(
"parallel",
marks=pytest.mark.xfail(reason="jit+pmap does not work in CPU yet"),
),
"sequential",
"vectorized",
],
)
@pytest.mark.skipif(
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
)
def test_chain_inside_jit(kernel_cls, chain_method):
# NB: this feature is useful for consensus MC.
# Caution: compiling time will be slow (~ 90s)
if chain_method == "parallel" and jax.device_count() == 1:
pytest.skip("parallel method requires device_count greater than 1.")
num_warmup, num_samples = 100, 2000
# Here are settings which is currently supported.
rng_key = random.key(2)
step_size = 1.0
target_accept_prob = 0.8
trajectory_length = 1.0
# Not supported yet:
# + adapt_step_size
# + adapt_mass_matrix
# + max_tree_depth
# + num_warmup
# + num_samples
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
return p_latent
@jit
def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
kernel = kernel_cls(
model,
step_size=step_size,
trajectory_length=trajectory_length,
target_accept_prob=target_accept_prob,
)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=2,
chain_method=chain_method,
progress_bar=False,
)
mcmc.run(rng_key, data)
return mcmc.get_samples()
true_probs = jnp.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.key(1), (2000,))
samples = get_samples(
rng_key, data, step_size, trajectory_length, target_accept_prob
)
assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02)
@pytest.mark.parametrize("chain_method", ["sequential", "parallel", "vectorized"])
@pytest.mark.parametrize("compile_args", [False, True])
@pytest.mark.skipif(
"CI" in os.environ, reason="Compiling time the whole sampling process is slow."
)
def test_chain_jit_args_smoke(chain_method, compile_args):
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
return p_latent
data1 = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.key(1), (50,))
data2 = dist.Categorical(jnp.array([0.2, 0.4, 0.4])).sample(random.key(1), (50,))
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=2,
num_samples=5,
num_chains=2,
chain_method=chain_method,
jit_model_args=compile_args,
)
mcmc.warmup(random.key(0), data1)
mcmc.run(random.key(1), data1)
# this should be fast if jit_model_args=True
mcmc.run(random.key(2), data2)
def test_extra_fields():
def model():
numpyro.sample("x", dist.Normal(0, 1), sample_shape=(5,))
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.key(0), extra_fields=("num_steps", "adapt_state.step_size"))
samples = mcmc.get_samples(group_by_chain=True)
assert samples["x"].shape == (1, 1000, 5)
stats = mcmc.get_extra_fields(group_by_chain=True)
assert "num_steps" in stats
assert stats["num_steps"].shape == (1, 1000)
assert "adapt_state.step_size" in stats
assert stats["adapt_state.step_size"].shape == (1, 1000)
@pytest.mark.parametrize("algo", ["HMC", "NUTS"])
def test_functional_beta_bernoulli_x64(algo):
num_warmup, num_samples = 410, 100
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta))
numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.key(1), (1000, 2))
init_params, potential_fn, constrain_fn, _ = initialize_model(
random.key(2), model, model_args=(data,)
)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params, trajectory_length=1.0, num_warmup=num_warmup)
samples = fori_collect(
0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z)
)
assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)
if "JAX_ENABLE_X64" in os.environ:
assert samples["p_latent"].dtype == jnp.float64
@pytest.mark.parametrize("algo", ["HMC", "NUTS"])
@pytest.mark.parametrize("map_fn", [vmap, pmap])
@pytest.mark.skipif(
"XLA_FLAGS" not in os.environ,
reason="without this mark, we have duplicated tests in Travis",
)
def test_functional_map(algo, map_fn):
if map_fn is pmap and jax.device_count() == 1:
pytest.skip("pmap test requires device_count greater than 1.")
true_mean, true_std = 1.0, 2.0
num_warmup, num_samples = 1000, 8000
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_params = jnp.array([0.0, -1.0])
rng_keys = random.split(random.key(0), 2)
init_kernel_map = map_fn(
lambda init_param, rng_key: init_kernel(
init_param, trajectory_length=9, num_warmup=num_warmup, rng_key=rng_key
)
)
init_states = init_kernel_map(init_params, rng_keys)
fori_collect_map = map_fn(
lambda hmc_state: fori_collect(
0,
num_samples,
sample_kernel,
hmc_state,
transform=lambda x: x.z,
progbar=False,
)
)
chain_samples = fori_collect_map(init_states)
assert_allclose(
jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06
)
assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
@pytest.mark.parametrize("jit_args", [False, True])
@pytest.mark.parametrize("shape", [50, 100])
def test_reuse_mcmc_run(jit_args, shape):
y1 = np.random.normal(3, 0.1, (100,))
y2 = np.random.normal(-3, 0.1, (shape,))
def model(y_obs):
mu = numpyro.sample("mu", dist.Normal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.HalfCauchy(3.0))
numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)
# Run MCMC on zero observations.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=300, num_samples=500, jit_model_args=jit_args)
mcmc.run(random.key(32), y1)
# Re-run on new data - should be much faster.
mcmc.run(random.key(32), y2)
assert_allclose(mcmc.get_samples()["mu"].mean(), -3.0, atol=0.1)
@pytest.mark.parametrize("jit_args", [False, True])
def test_model_with_multiple_exec_paths(jit_args):
def model(a=None, b=None, z=None):
int_term = numpyro.sample("a", dist.Normal(0.0, 0.2))
x_term, y_term = 0.0, 0.0
if a is not None:
x = numpyro.sample("x", dist.HalfNormal(0.5))
x_term = a * x
if b is not None:
y = numpyro.sample("y", dist.HalfNormal(0.5))
y_term = b * y
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mu = int_term + x_term + y_term
numpyro.sample("obs", dist.Normal(mu, sigma), obs=z)
a = jnp.exp(np.random.randn(10))
b = jnp.exp(np.random.randn(10))
z = np.random.randn(10)
# Run MCMC on zero observations.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=20, num_samples=10, jit_model_args=jit_args)
mcmc.run(random.key(1), a, b=None, z=z)
assert set(mcmc.get_samples()) == {"a", "x", "sigma"}
mcmc.run(random.key(2), a=None, b=b, z=z)
assert set(mcmc.get_samples()) == {"a", "y", "sigma"}
mcmc.run(random.key(3), a=a, b=b, z=z)
assert set(mcmc.get_samples()) == {"a", "x", "y", "sigma"}
def test_mcmc_inside_jit_no_tracer_leak():
"""Regression test for https://github.com/pyro-ppl/numpyro/issues/2000"""
from numpyro.infer.mcmc import _collect_and_postprocess
from numpyro.util import fori_collect
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
@jit
def get_samples(rng_key, data):
kernel = HMC(
model, step_size=1.0, trajectory_length=1.0, target_accept_prob=0.8
)
mcmc = MCMC(
kernel,
num_warmup=5,
num_samples=10,
num_chains=1,
chain_method="sequential",
progress_bar=False,
)
mcmc.run(rng_key, data)
return mcmc.get_samples()
data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.key(1), (100,))
samples = get_samples(random.key(2), data)
assert "p_latent" in samples
# Verify no traced values leaked into module-level caches
for cached_fn in [_collect_and_postprocess, fori_collect]:
cache = getattr(cached_fn, "_cache", {})
leaves = jax.tree.leaves(list(cache.keys()) + list(cache.values()))
for leaf in leaves:
assert not isinstance(leaf, jax.core.Tracer), (
f"Tracer leaked into {cached_fn.__name__}._cache"
)
def test_reuse_mcmc_run_stable_partial_identity():
"""Regression test: repeated run() calls must reuse the same partial object.
When pmap is implemented via jit(shard_map) (JAX >= 0.8.0), jit caches
by function identity. Creating a new functools.partial each run() call
causes a fresh trace + XLA compilation whose artifacts are never freed,
leading to unbounded memory growth in long-running services.
"""
def model():
numpyro.sample("x", dist.Normal(0, 1))
mcmc = MCMC(
NUTS(model),
num_warmup=5,
num_samples=5,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.key(0))
first_partial = mcmc._partial_map_fn
assert first_partial is not None
mcmc.run(random.key(1))
assert mcmc._partial_map_fn is first_partial, (
"_partial_map_fn must be the same object across run() calls "
"to avoid pmap/jit recompilation leaks"
)
@pytest.mark.parametrize("num_chains", [1, 2])
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
@pytest.mark.parametrize("progress_bar", [True, False])
def test_compile_warmup_run(num_chains, chain_method, progress_bar):
def model():
numpyro.sample("x", dist.Normal(0, 1))
if num_chains == 1 and chain_method in ["sequential", "vectorized"]:
pytest.skip("duplicated test")
if num_chains > 1 and chain_method == "parallel":
pytest.skip("duplicated test")
rng_key = random.key(0)
num_samples = 10
mcmc = MCMC(
NUTS(model),
num_warmup=10,
num_samples=num_samples,
num_chains=num_chains,
chain_method=chain_method,
progress_bar=progress_bar,
)
mcmc.run(rng_key)
expected_samples = mcmc.get_samples()["x"]
mcmc._compile(rng_key)
# no delay after compiling
mcmc.warmup(rng_key)
mcmc.run(mcmc.last_state.rng_key)
actual_samples = mcmc.get_samples()["x"]
assert_allclose(actual_samples, expected_samples)
# test for reproducible
if num_chains > 1:
mcmc = MCMC(
NUTS(model),
num_warmup=10,
num_samples=num_samples,
num_chains=1,
progress_bar=progress_bar,
)
rng_key = random.split(rng_key)[0]
mcmc.run(rng_key)
first_chain_samples = mcmc.get_samples()["x"]
assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
@pytest.mark.parametrize("dense_mass", [True, False])
def test_get_proposal_loc_and_scale(dense_mass):
N = 10
dim = 3
samples = random.normal(random.key(0), (N, dim))
loc = jnp.mean(samples[:-1], 0)
if dense_mass:
scale = jnp.linalg.cholesky(jnp.cov(samples[:-1], rowvar=False, bias=True))
else:
scale = jnp.std(samples[:-1], 0)
actual_loc, actual_scale = _get_proposal_loc_and_scale(
samples[:-1], loc, scale, samples[-1]
)
expected_loc, expected_scale = [], []
for i in range(N - 1):
samples_i = np.delete(samples, i, axis=0)
expected_loc.append(jnp.mean(samples_i, 0))
if dense_mass:
expected_scale.append(
jnp.linalg.cholesky(jnp.cov(samples_i, rowvar=False, bias=True))
)
else:
expected_scale.append(jnp.std(samples_i, 0))
expected_loc = jnp.stack(expected_loc)
expected_scale = jnp.stack(expected_scale)
assert_allclose(actual_loc, expected_loc, rtol=1e-4)
assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.3)
@pytest.mark.parametrize("shape", [(4,), (3, 2)])
@pytest.mark.parametrize("idx", [0, 1, 2])
def test_numpy_delete(shape, idx):
x = random.normal(random.key(0), shape)
expected = np.delete(x, idx, axis=0)
actual = _numpy_delete(x, idx)
assert_allclose(actual, expected)
@pytest.mark.parametrize("batch_shape", [(), (4,)])
def test_trivial_dirichlet(batch_shape):
def model():
x = numpyro.sample("x", dist.Dirichlet(jnp.ones(1)).expand(batch_shape))
return numpyro.sample("y", dist.Normal(x, 1), obs=2)
num_samples = 10
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.key(0))
# because event_shape of x is (1,), x should only take value 1
assert_allclose(
mcmc.get_samples()["x"], jnp.ones((num_samples,) + batch_shape + (1,))
)
def test_forward_mode_differentiation():
def model():
x = numpyro.sample("x", dist.Normal(0, 1))