Skip to content

Commit e16467c

Browse files
committed
Merge branch 'main' into dev
2 parents f2e094d + 0802bbf commit e16467c

File tree

4 files changed

+235
-744
lines changed

4 files changed

+235
-744
lines changed
File renamed without changes.

examples/sampling/test_samp.ipynb

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "2c795b88-378b-489d-85bb-3e1786930b4a",
6+
"metadata": {},
7+
"source": [
8+
"# Test sampling algorithms"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"id": "eafc9aac-719c-434f-8816-7aae59f171db",
15+
"metadata": {
16+
"execution": {
17+
"iopub.execute_input": "2025-05-08T16:21:49.363904Z",
18+
"iopub.status.busy": "2025-05-08T16:21:49.363456Z",
19+
"iopub.status.idle": "2025-05-08T16:21:50.671986Z",
20+
"shell.execute_reply": "2025-05-08T16:21:50.671739Z"
21+
}
22+
},
23+
"outputs": [],
24+
"source": [
25+
"import matplotlib.pyplot as plt\n",
26+
"import numpy as np\n",
27+
"\n",
28+
"import ment"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"id": "b2fda8d3",
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"plt.style.use(\"../style.mplstyle\")"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"id": "e5922304-a9c9-47a7-b09f-7ae4f4845d2d",
44+
"metadata": {},
45+
"source": [
46+
"## Create distribution"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"id": "3b8a7941-17c2-446d-aa36-6a3cb7308de2",
53+
"metadata": {
54+
"execution": {
55+
"iopub.execute_input": "2025-05-08T16:21:50.676745Z",
56+
"iopub.status.busy": "2025-05-08T16:21:50.676655Z",
57+
"iopub.status.idle": "2025-05-08T16:21:52.023604Z",
58+
"shell.execute_reply": "2025-05-08T16:21:52.023261Z"
59+
}
60+
},
61+
"outputs": [],
62+
"source": [
63+
"class RingDistribution:\n",
64+
" def __init__(self) -> None:\n",
65+
" self.ndim = 2\n",
66+
"\n",
67+
" def prob(self, x: np.ndarray) -> np.ndarray:\n",
68+
" x1 = x[:, 0]\n",
69+
" x2 = x[:, 1]\n",
70+
" log_prob = np.sin(np.pi * x1) - 2.0 * (x1**2 + x2**2 - 2.0) ** 2\n",
71+
" return np.exp(log_prob)\n",
72+
"\n",
73+
" def prob_grid(\n",
74+
" self, shape: tuple[int], limits: list[tuple[float, float]]\n",
75+
" ) -> tuple[np.ndarray, list[np.ndarray]]:\n",
76+
" edges = [\n",
77+
" np.linspace(limits[i][0], limits[i][1], shape[i] + 1)\n",
78+
" for i in range(self.ndim)\n",
79+
" ]\n",
80+
" coords = [0.5 * (e[:-1] + e[1:]) for e in edges]\n",
81+
" points = np.stack(\n",
82+
" [c.ravel() for c in np.meshgrid(*coords, indexing=\"ij\")], axis=-1\n",
83+
" )\n",
84+
" values = self.prob(points)\n",
85+
" values = values.reshape(shape)\n",
86+
" return values, coords"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "fe88acb4-1a21-4565-81f2-e41de80586bc",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"ndim = 2\n",
97+
"xmax = 3.0\n",
98+
"dist = RingDistribution()"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"id": "e083ac94",
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"grid_limits = 2 * [(-xmax, xmax)]\n",
109+
"grid_shape = (128, 128)\n",
110+
"grid_values, grid_coords = dist.prob_grid(grid_shape, grid_limits)\n",
111+
"\n",
112+
"fig, ax = plt.subplots(figsize=(2.5, 2.5))\n",
113+
"ax.pcolormesh(grid_coords[0], grid_coords[1], grid_values.T)\n",
114+
"plt.show()"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "4a72fdfd",
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"def plot_samples(x_pred: np.ndarray) -> tuple:\n",
125+
" fig, axs = plt.subplots(ncols=2, figsize=(5.0, 2.75), sharex=True, sharey=True)\n",
126+
" hist, edges = np.histogramdd(x_pred, bins=80, range=grid_limits)\n",
127+
" axs[0].pcolormesh(edges[0], edges[1], hist.T)\n",
128+
" axs[1].pcolormesh(grid_coords[0], grid_coords[1], grid_values.T)\n",
129+
" axs[0].set_title(\"PRED\", fontsize=\"medium\")\n",
130+
" axs[1].set_title(\"TRUE\", fontsize=\"medium\")\n",
131+
" return fig, axs"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"id": "f5ff2f9d",
138+
"metadata": {},
139+
"outputs": [],
140+
"source": [
141+
"def evaluate_sampler(sampler, size: int = 100_000):\n",
142+
" x_pred = sampler(dist.prob, size=size)\n",
143+
" return plot_samples(x_pred)"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"id": "18c50faa",
149+
"metadata": {},
150+
"source": [
151+
"## Grid sampler"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"id": "09414cdc",
158+
"metadata": {},
159+
"outputs": [],
160+
"source": [
161+
"sampler = ment.GridSampler(grid_limits=grid_limits, grid_shape=grid_shape)\n",
162+
"\n",
163+
"evaluate_sampler(sampler);"
164+
]
165+
},
166+
{
167+
"cell_type": "markdown",
168+
"id": "1cfb7494",
169+
"metadata": {},
170+
"source": [
171+
"## Metropolis-Hastings sampler"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"id": "37e61a8a",
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"chains = 4\n",
182+
"proposal_cov = np.identity(ndim) * 0.25\n",
183+
"start_loc = np.zeros(ndim)\n",
184+
"start_cov = np.identity(ndim) * 0.25\n",
185+
"start_point = np.random.multivariate_normal(start_loc, start_cov, size=chains)\n",
186+
"\n",
187+
"sampler = ment.samp.MetropolisHastingsSampler(\n",
188+
" ndim=ndim,\n",
189+
" proposal_cov=proposal_cov,\n",
190+
" start=start_point,\n",
191+
" chains=chains,\n",
192+
" burnin=0,\n",
193+
")\n",
194+
"evaluate_sampler(sampler);\n"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"id": "00ab115a",
201+
"metadata": {},
202+
"outputs": [],
203+
"source": []
204+
}
205+
],
206+
"metadata": {
207+
"kernelspec": {
208+
"display_name": "ment",
209+
"language": "python",
210+
"name": "python3"
211+
},
212+
"language_info": {
213+
"codemirror_mode": {
214+
"name": "ipython",
215+
"version": 3
216+
},
217+
"file_extension": ".py",
218+
"mimetype": "text/x-python",
219+
"name": "python",
220+
"nbconvert_exporter": "python",
221+
"pygments_lexer": "ipython3",
222+
"version": "3.13.9"
223+
}
224+
},
225+
"nbformat": 4,
226+
"nbformat_minor": 5
227+
}

examples/style.mplstyle

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
axes.linewidth: 1.25
2+
axes.titlesize: "medium"
3+
figure.constrained_layout.use: True
4+
xtick.minor.visible: True
5+
ytick.minor.visible: True
6+
7+
savefig.dpi: 300
8+
savefig.format: "png"

examples/tests/test_samp.ipynb

Lines changed: 0 additions & 744 deletions
This file was deleted.

0 commit comments

Comments
 (0)