Skip to content

Commit 4dcac7f

Browse files
committed
Add test for new bootstrapping method
1 parent 1f9e96c commit 4dcac7f

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "770901da-bd6a-4e00-b935-f888f5038fdc",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import xarray as xr\n",
11+
"import matplotlib.pyplot as plt\n",
12+
"from pathlib import Path\n",
13+
"import numpy as np\n",
14+
"import scipy.stats as sts\n",
15+
"import json\n",
16+
"import random\n",
17+
"from functools import partial\n",
18+
"import multiprocessing as mp\n",
19+
"\n",
20+
"from dask.distributed import Client"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "fc2afdae-6b08-4212-b4a4-95ff22ded4e7",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"def ks_all_times(data, ens_ids):\n",
31+
" \"\"\"Perform K-S test on two arrays across all times in the array.\n",
32+
"\n",
33+
" Parameters\n",
34+
" ----------\n",
35+
" data_1, data_2 : array_like\n",
36+
" Arrays of data for testing, dimension 2 (typically [ensemble, time]),\n",
37+
" with time dimension as the rightmost dimension.\n",
38+
"\n",
39+
" Returns\n",
40+
" -------\n",
41+
" ks_test_output : `da.array`\n",
42+
" Dask array with shape [data_n.shape[1], 2] of 2 sample K-S test\n",
43+
" results (statstic, p-value)\n",
44+
"\n",
45+
" \"\"\"\n",
46+
" data_1 = data.isel(exp=0, ens=ens_ids[0])\n",
47+
" data_2 = data.isel(exp=1, ens=ens_ids[1])\n",
48+
"\n",
49+
" ks_test = np.vectorize(sts.mstats.ks_2samp, signature=\"(n),(n)->(),()\")\n",
50+
" _, ks_pval = ks_test(data_1.T, data_2.T)\n",
51+
"\n",
52+
" return xr.DataArray(\n",
53+
" data=ks_pval, dims=(\"time\",), coords={\"time\": data.time}\n",
54+
" )\n",
55+
"ks_test_vec = np.vectorize(sts.mstats.ks_2samp, signature=\"(n),(n)->(),()\")\n",
56+
"def ks_vec(data_1, data_2):\n",
57+
" return ks_test_vec(data_1, data_2)\n",
58+
"\n",
59+
"def randomise_new(ens_min, ens_max, ens_size, with_repl=False, ncases=2):\n",
60+
" ens_idx = sorted(range(ens_min, ens_max + 1))\n",
61+
" assert len(ens_idx) > ens_size, \"ENSEMBLE SIZE MUST BE SMALLER THAN ENSEMBLE RANGE\"\n",
62+
" if not with_repl:\n",
63+
" selected = [\n",
64+
" random.sample(ens_idx, ens_size)\n",
65+
" for _ in range(ncases)\n",
66+
" ]\n",
67+
" else:\n",
68+
" selected = [\n",
69+
" [random.randint(ens_min, ens_max) for _ in range(ens_size)]\n",
70+
" for _ in range(ncases)\n",
71+
" ]\n",
72+
" return selected\n",
73+
"\n",
74+
"\n",
75+
"def rolling_mean_data(data, period_len=12, time_var=\"time\"):\n",
76+
" select = {time_var: period_len}\n",
77+
" return data.rolling(**select).mean().dropna(time_var)\n",
78+
"\n",
79+
"\n",
80+
"def ks_bootstrap(idx, data):\n",
81+
" return data.apply(ks_all_times, ens_ids=idx)\n",
82+
"\n",
83+
"\n",
84+
"def cvm_2samp(data_x, data_y):\n",
85+
" \"\"\"Perform a 2 sample Cramer von Mises test, map output to a tuple.\"\"\"\n",
86+
" _res = sts.cramervonmises_2samp(data_x, data_y)\n",
87+
" return _res.pvalue\n",
88+
" \n",
89+
"\n",
90+
"cvm_test_vec = np.vectorize(cvm_2samp, signature=\"(n),(n)->()\")\n",
91+
"\n",
92+
"\n",
93+
"def cvm_all_times(data_c, ens_ids):\n",
94+
" \"\"\"Perform a 2 sample Cramer von Mises test on all times.\"\"\"\n",
95+
"\n",
96+
" data_1 = data_c.isel(exp=0, ens=ens_ids[0])\n",
97+
" data_2 = data_c.isel(exp=1, ens=ens_ids[1])\n",
98+
"\n",
99+
" cvm_pval = cvm_test_vec(data_1.T, data_2.T)\n",
100+
"\n",
101+
" return xr.DataArray(\n",
102+
" data=cvm_pval, dims=(\"time\",), coords={\"time\": data_c.time}\n",
103+
" )\n",
104+
"\n",
105+
"def cvm_bootstrap(idx, data):\n",
106+
" return data.apply(cvm_all_times, ens_ids=idx)\n",
107+
"\n",
108+
"def anderson_pval(data_1, data_2):\n",
109+
" try:\n",
110+
" _res = sts.anderson_ksamp([data_1, data_2], method=sts.PermutationMethod(n_resamples=1000))\n",
111+
" except ValueError:\n",
112+
" return 1.\n",
113+
" return _res.pvalue\n",
114+
"\n",
115+
"anderson_test_vec = np.vectorize(anderson_pval, signature=\"(n),(n)->()\")\n",
116+
"\n",
117+
"def anderson_all_times(data, ens_ids):\n",
118+
" data_1 = data.isel(exp=0, ens=ens_ids[0])\n",
119+
" data_2 = data.isel(exp=1, ens=ens_ids[1])\n",
120+
" _pval = anderson_test_vec(data_1.T, data_2.T)\n",
121+
"\n",
122+
" return xr.DataArray(\n",
123+
" data=_pval, dims=(\"time\",), coords={\"time\": data.time}\n",
124+
" )\n",
125+
"\n",
126+
" \n",
127+
"def anderson_bootstrap(idx, data):\n",
128+
" return data.apply(anderson_all_times, ens_ids=idx)"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"id": "c855932b-f28b-426b-964f-d193d26453b6",
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"%%time\n",
139+
"scratch = Path(\"/home/mikek/Code/2025-09-16.F2010.ne30pg2_r05_oECv3_aavgs\")\n",
140+
"in_dirs = sorted(scratch.glob(\"*\"))\n",
141+
"_ds_ctl = xr.open_mfdataset(\n",
142+
" sorted(in_dirs[1].glob(\"*.nc\")), combine=\"nested\", concat_dim=\"ens\"\n",
143+
")\n",
144+
"\n",
145+
"_ds_exp = xr.open_mfdataset(\n",
146+
" sorted(in_dirs[0].glob(\"*.nc\")), combine=\"nested\", concat_dim=\"ens\"\n",
147+
")\n",
148+
"\n",
149+
"_ds_all = xr.concat([_ds_ctl, _ds_exp], dim=\"exp\")\n",
150+
"dvars = json.loads(\n",
151+
" open(\"../new_vars.json\", \"r\", encoding=\"utf-8\").read()\n",
152+
")[\"default\"]\n",
153+
"\n",
154+
"_ds_all_mean = _ds_all[dvars].map(rolling_mean_data, period_len=12).load()\n",
155+
"_emin = _ds_all_mean.ens.values.min()\n",
156+
"_emax = _ds_all_mean.ens.values.max()"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"id": "60cf23e0-2b84-473c-82a0-a05fbffb2a9f",
163+
"metadata": {},
164+
"outputs": [],
165+
"source": [
166+
"ninst = 100\n",
167+
"ens_size = 20\n",
168+
"ens_sel = [randomise_new(_emin, _emax, ens_size=ens_size, ncases=2) for _ in range(ninst)]"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"id": "358e3df8-a194-4cb5-a1dd-866e97ca87e0",
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"%%time\n",
179+
"ks_bootsrap_part = partial(ks_bootstrap, data=_ds_all_mean[dvars])\n",
180+
"with mp.Pool(16) as pool:\n",
181+
" pvals_out_ks = xr.concat(pool.map(ks_bootsrap_part, ens_sel), dim=\"iter\")"
182+
]
183+
},
184+
{
185+
"cell_type": "code",
186+
"execution_count": null,
187+
"id": "6d037a0a-c29f-46ce-a46a-4e7027c6ee9b",
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"%%time\n",
192+
"anderson_bootstrap_part = partial(anderson_bootstrap, data=_ds_all_mean[dvars])\n",
193+
"with mp.Pool(16) as pool:\n",
194+
" pvals_out_anderson = xr.concat(pool.map(anderson_bootstrap_part, ens_sel), dim=\"iter\")\n"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"id": "6e154409-cbc1-4a2d-ae99-0e7440606280",
201+
"metadata": {},
202+
"outputs": [],
203+
"source": [
204+
"%%time\n",
205+
"cvm_bootstrap_part = partial(cvm_bootstrap, data=_ds_all_mean[dvars])\n",
206+
"with mp.Pool(16) as pool:\n",
207+
" pvals_out_cvm = xr.concat(pool.map(cvm_bootstrap_part, ens_sel), dim=\"iter\")"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"id": "4599831c-1bee-41e6-bb95-0279ff133033",
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"pvals_all = {\n",
218+
" \"ks\": np.array([pvals_out_ks.isel(time=2)[_var].values for _var in pvals_out_ks.data_vars]),\n",
219+
" \"cvm\": np.array([pvals_out_cvm.isel(time=2)[_var].values for _var in pvals_out_cvm.data_vars]),\n",
220+
" \"anderson\": np.array([pvals_out_anderson.isel(time=2)[_var].values for _var in pvals_out_anderson.data_vars]),\n",
221+
"}\n",
222+
"\n",
223+
"fig, axis = plt.subplots(1, 3, figsize=(12, 5), sharey=True)\n",
224+
"for idx, pvals_out in enumerate(pvals_all):\n",
225+
" pvals = pvals_all[pvals_out]\n",
226+
" pvals.sort(axis=0)\n",
227+
" _ = axis[idx].semilogy(pvals, color=\"grey\", lw=0.5)\n",
228+
" _ = axis[idx].semilogy(pvals.mean(axis=1), color=\"k\")\n",
229+
" _ = axis[idx].axhline(0.05, ls=\"--\", color=\"green\")\n",
230+
" axis[idx].set_title(pvals_out)"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": null,
236+
"id": "065536c6-edc8-4b09-b624-1257792a38b5",
237+
"metadata": {},
238+
"outputs": [],
239+
"source": [
240+
"nreject = {\n",
241+
" mode: [(pvals_all[mode][:, i] < 0.05).sum() for i in range(pvals_all[mode].shape[1])]\n",
242+
" for mode in pvals_all\n",
243+
"}"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": null,
249+
"id": "c8ab0556-e334-45f9-81c4-421622d0ba91",
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"plt.figure(figsize=(12, 5))\n",
254+
"for idx, mode in enumerate(nreject):\n",
255+
" plt.subplot(1, 3, idx + 1)\n",
256+
" plt.hist(nreject[mode], bins=15, edgecolor=\"k\")\n",
257+
" plt.title(mode)"
258+
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": null,
263+
"id": "0c6d4bfc-5699-429e-9a84-313399b6d19b",
264+
"metadata": {},
265+
"outputs": [],
266+
"source": []
267+
}
268+
],
269+
"metadata": {
270+
"kernelspec": {
271+
"display_name": "Python 3 (ipykernel)",
272+
"language": "python",
273+
"name": "python3"
274+
},
275+
"language_info": {
276+
"codemirror_mode": {
277+
"name": "ipython",
278+
"version": 3
279+
},
280+
"file_extension": ".py",
281+
"mimetype": "text/x-python",
282+
"name": "python",
283+
"nbconvert_exporter": "python",
284+
"pygments_lexer": "ipython3",
285+
"version": "3.13.7"
286+
}
287+
},
288+
"nbformat": 4,
289+
"nbformat_minor": 5
290+
}

0 commit comments

Comments
 (0)