|
56 | 56 | "import inspect\n", |
57 | 57 | "import numpy as np\n", |
58 | 58 | "import matplotlib.pyplot as plt\n", |
59 | | - "from cuqi.sampler import MH, CWMH, ULA, MALA, NUTS\n", |
| 59 | + "from cuqi.experimental.mcmc import MH, CWMH, ULA, MALA, NUTS\n", |
60 | 60 | "import time\n", |
61 | 61 | "import scipy.stats as sps\n", |
62 | | - "from scipy.stats import gaussian_kde" |
| 62 | + "from scipy.stats import gaussian_kde\n", |
| 63 | + "cuqi.config.DISABLE_PROGRESS_BAR = True" |
63 | 64 | ] |
64 | 65 | }, |
65 | 66 | { |
|
256 | 257 | "metadata": {}, |
257 | 258 | "outputs": [], |
258 | 259 | "source": [ |
259 | | - "MH_sampler = MH(target_donut, scale=scale, x0=np.array([0,0]))" |
| 260 | + "MH_sampler = MH(target_donut, scale=scale, initial_point=np.array([0,0]))" |
260 | 261 | ] |
261 | 262 | }, |
262 | 263 | { |
|
278 | 279 | }, |
279 | 280 | "outputs": [], |
280 | 281 | "source": [ |
281 | | - "MH_fixed_samples = MH_sampler.sample(Ns, Nb)" |
| 282 | + "MH_sampler.sample(Ns+Nb)\n", |
| 283 | + "MH_fixed_samples = MH_sampler.get_samples().burnthin(Nb)" |
282 | 284 | ] |
283 | 285 | }, |
284 | 286 | { |
|
297 | 299 | "outputs": [], |
298 | 300 | "source": [ |
299 | 301 | "plot_pdf_2D(target_donut, -4, 4, -4, 4)\n", |
| 302 | + "\n", |
300 | 303 | "MH_fixed_samples.plot_pair(ax=plt.gca())" |
301 | 304 | ] |
302 | 305 | }, |
|
370 | 373 | "source": [ |
371 | 374 | "Ns = 8500\n", |
372 | 375 | "Nb = 1500\n", |
373 | | - "MH_adapted_samples = MH_sampler.sample_adapt(Ns, Nb)" |
| 376 | + "MH_sampler.warmup(Nb)\n", |
| 377 | + "MH_sampler.sample(Ns)\n", |
| 378 | + "MH_adapted_samples = MH_sampler.get_samples().burnthin(Nb)" |
374 | 379 | ] |
375 | 380 | }, |
376 | 381 | { |
|
523 | 528 | }, |
524 | 529 | "outputs": [], |
525 | 530 | "source": [ |
526 | | - "MH_sampler = MH(target_poisson, scale = scale, x0=np.ones(target_poisson.dim))\n", |
527 | | - "MH_samples = MH_sampler.sample_adapt(Ns, Nb)" |
| 531 | + "MH_sampler = MH(target_poisson, scale = scale, initial_point=np.ones(target_poisson.dim))\n", |
| 532 | + "MH_sampler.warmup(Nb)\n", |
| 533 | + "MH_sampler.sample(Ns)\n", |
| 534 | + "MH_samples = MH_sampler.get_samples().burnthin(Nb)" |
528 | 535 | ] |
529 | 536 | }, |
530 | 537 | { |
|
546 | 553 | }, |
547 | 554 | "outputs": [], |
548 | 555 | "source": [ |
549 | | - "CWMH_sampler = CWMH(target_poisson, scale = scale, x0=np.ones(target_poisson.dim))\n", |
550 | | - "CWMH_samples = CWMH_sampler.sample_adapt(Ns, Nb)" |
| 556 | + "CWMH_sampler = CWMH(target_poisson, scale = scale, initial_point=np.ones(target_poisson.dim))\n", |
| 557 | + "CWMH_sampler.warmup(Nb)\n", |
| 558 | + "CWMH_sampler.sample(Ns)\n", |
| 559 | + "CWMH_samples = CWMH_sampler.get_samples().burnthin(Nb)\n" |
551 | 560 | ] |
552 | 561 | }, |
553 | 562 | { |
|
667 | 676 | "metadata": {}, |
668 | 677 | "outputs": [], |
669 | 678 | "source": [ |
670 | | - "ULA_sampler = ULA(target=target_donut, scale=0.065, x0=np.array([0,0]))" |
| 679 | + "ULA_sampler = ULA(target=target_donut, scale=0.065, initial_point=np.array([0,0]))" |
671 | 680 | ] |
672 | 681 | }, |
673 | 682 | { |
|
690 | 699 | "outputs": [], |
691 | 700 | "source": [ |
692 | 701 | "Ns = 1000\n", |
693 | | - "ULA_samples = ULA_sampler.sample(Ns)" |
| 702 | + "ULA_sampler.sample(Ns)\n", |
| 703 | + "ULA_samples = ULA_sampler.get_samples()" |
694 | 704 | ] |
695 | 705 | }, |
696 | 706 | { |
|
828 | 838 | }, |
829 | 839 | "outputs": [], |
830 | 840 | "source": [ |
831 | | - "MALA_uni = MALA(x_uni, scale=1, x0=0)\n", |
| 841 | + "MALA_uni = MALA(x_uni, scale=1, initial_point=0)\n", |
832 | 842 | "\n", |
833 | 843 | "Ns = 40000\n", |
834 | | - "ULA_samples_uni = MALA_uni.sample(Ns)" |
| 844 | + "MALA_uni.sample(Ns)\n", |
| 845 | + "ULA_samples_uni = MALA_uni.get_samples()" |
835 | 846 | ] |
836 | 847 | }, |
837 | 848 | { |
|
978 | 989 | "metadata": {}, |
979 | 990 | "outputs": [], |
980 | 991 | "source": [ |
981 | | - "NUTS_donut = NUTS(target=target_donut, x0=np.array([0,0]))" |
| 992 | + "NUTS_donut = NUTS(target=target_donut, initial_point=np.array([0,0]))" |
982 | 993 | ] |
983 | 994 | }, |
984 | 995 | { |
|
1002 | 1013 | "source": [ |
1003 | 1014 | "Ns = 100\n", |
1004 | 1015 | "Nb = 10\n", |
1005 | | - "NUTS_donuts_samples = NUTS_donut.sample(Ns, Nb)" |
| 1016 | + "NUTS_donut.warmup(Nb)\n", |
| 1017 | + "NUTS_donut.sample(Ns)\n", |
| 1018 | + "NUTS_donuts_samples = NUTS_donut.get_samples().burnthin(Nb)" |
1006 | 1019 | ] |
1007 | 1020 | }, |
1008 | 1021 | { |
|
1098 | 1111 | "notebook_metadata_filter": "-all" |
1099 | 1112 | }, |
1100 | 1113 | "kernelspec": { |
1101 | | - "display_name": "Python 3", |
| 1114 | + "display_name": "fenicsproject", |
1102 | 1115 | "language": "python", |
1103 | 1116 | "name": "python3" |
1104 | 1117 | }, |
|
0 commit comments