Skip to content

Commit f26dd86

Browse files
committed
Add script to plot ensemble size power plots
1 parent 66b1aca commit f26dd86

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

detclim/ens_size_plots.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import xarray as xr
8+
from statsmodels.stats import multitest as smm
9+
10+
import detclim
11+
12+
13+
def correct_pvals(pvals, method: str = "fdr_bh", alpha: float = 0.05):
14+
_pval_cr = []
15+
for idx in range(pvals.shape[0]):
16+
for jdx in range(pvals.shape[1]):
17+
for kdx in range(pvals.shape[2]):
18+
for ldx in range(pvals.shape[-1]):
19+
_pval_cr.append(
20+
smm.multipletests(
21+
pvals=pvals[idx, jdx, kdx, :, ldx],
22+
alpha=alpha,
23+
method=method,
24+
is_sorted=False,
25+
)[1]
26+
)
27+
return np.array(_pval_cr).reshape(pvals.shape)
28+
29+
30+
def load_data(param, pcts, esizes):
31+
bst_dir = Path("bootstrap_data")
32+
data_out = []
33+
for esize in esizes:
34+
bstp_files = []
35+
for pct in pcts:
36+
if pct == 0:
37+
_param = "ctl"
38+
else:
39+
_param = f"{param}-{pct}p0pct"
40+
41+
_bst_file = Path(
42+
bst_dir, f"bootstrap_output.1year_12avg_ts{esize}.ctl_{_param}_n1000.nc"
43+
)
44+
bstp_files.append(_bst_file)
45+
data_out.append(
46+
xr.open_mfdataset(bstp_files, combine="nested", concat_dim="pct")
47+
)
48+
data_out[-1]["pct"] = pcts
49+
50+
data_out = xr.concat(data_out, dim="esize")
51+
data_out["esize"] = esizes
52+
return data_out
53+
54+
55+
def main(param: str = "clubb_c1", ext: str = "png"):
56+
pcts = {"clubb_c1": [0, 1, 3, 5, 10], "effgw_oro": [0, 1, 5, 10, 20, 30, 40, 50]}
57+
pct_single = {"clubb_c1": 5, "effgw_oro": 30}
58+
esizes = [15, 20, 25, 30, 35, 40, 45, 50, 55, 60]
59+
ALPHA = 0.05
60+
61+
colors = {"ks": "C0", "mw": "C1", "cvm": "C2", "wsr": "C3"}
62+
stests = ["ks", "cvm", "mw"] # , "wsr"]
63+
data_out = load_data(param, pcts[param], esizes)
64+
data_out_cr = data_out.copy()
65+
66+
for stest in ["ks", "mw", "cvm"]:
67+
data_out_cr[stest] = (
68+
data_out[stest].dims,
69+
correct_pvals(data_out[stest].values),
70+
)
71+
72+
ctl_thr = (data_out.sel(pct=0) < ALPHA).sum(dim="vars").quantile(q=0.95, dim="iter")
73+
74+
failed_tests = (
75+
((data_out.sel(pct=pcts[param][1:]) < ALPHA).sum(dim="vars") > ctl_thr)
76+
.sum(dim="iter")
77+
.isel(time=2)
78+
)
79+
80+
failed_tests_cr = (
81+
((data_out_cr.sel(pct=pcts[param][1:]) < ALPHA).sum(dim="vars") > 0)
82+
.sum(dim="iter")
83+
.isel(time=2)
84+
)
85+
86+
failed_tests /= data_out.iter.shape[0]
87+
failed_tests_cr /= data_out.iter.shape[0]
88+
89+
fig, axis = plt.subplots(1, 4, figsize=(16, 7))
90+
for idx, stest in enumerate(stests):
91+
failed_tests[stest].plot.line(x="pct", ax=axis[idx])
92+
failed_tests_cr[stest].plot.line(
93+
x="pct",
94+
ax=axis[idx],
95+
ls="--",
96+
)
97+
axis[idx].set_title(stest)
98+
axis[idx].set_yscale("log")
99+
axis[idx].grid(visible=True, ls="--", color="grey")
100+
101+
fig.tight_layout()
102+
plt.savefig(f"plt_enssize_power_{param}.{ext}")
103+
104+
if len(pcts[param][1:]) == 4:
105+
fig, axis = plt.subplots(2, 2, figsize=(10, 6))
106+
else:
107+
fig, axis = plt.subplots(2, 4, figsize=(10, 6))
108+
109+
axis = axis.flatten()
110+
for idx, pct in enumerate(pcts[param][1:]):
111+
for stest in stests:
112+
failed_tests[stest].sel(pct=pct).plot(
113+
x="esize",
114+
ax=axis[idx],
115+
label=detclim.STESTS[stest],
116+
color=colors[stest],
117+
)
118+
failed_tests_cr[stest].sel(pct=pct).plot(
119+
x="esize",
120+
ax=axis[idx],
121+
label=f"{detclim.STESTS[stest]} BH-FDR",
122+
ls="--",
123+
color=colors[stest],
124+
)
125+
axis[idx].set_title(f"{pct}% change")
126+
axis[idx].grid(visible=True, ls="--", color="grey")
127+
128+
fig.tight_layout()
129+
plt.legend()
130+
plt.savefig(f"plt_enssize_power_{param}_bypct.{ext}")
131+
132+
fig, axis = plt.subplots(1, 1, figsize=(12.5 / 2.54, 10 / 2.54), dpi=120)
133+
134+
pct = pct_single[param]
135+
for stest in stests:
136+
failed_tests[stest].sel(pct=pct).plot(
137+
x="esize", ax=axis, label=detclim.STESTS[stest], color=colors[stest]
138+
)
139+
failed_tests_cr[stest].sel(pct=pct).plot(
140+
x="esize", ax=axis, ls="--", color=colors[stest]
141+
)
142+
axis.set_xlabel("Sub-ensemble size")
143+
axis.set_ylabel("Fraction of rejected tests")
144+
axis.set_title(f"{pct}% change in {param}")
145+
axis.grid(visible=True, ls="--", color="grey")
146+
147+
fig.tight_layout()
148+
plt.legend()
149+
plt.savefig(f"plt_enssize_power_{param}_single.{ext}")
150+
151+
152+
if __name__ == "__main__":
153+
for _param in ["clubb_c1", "effgw_oro"]:
154+
main(param=_param, ext="pdf")

0 commit comments

Comments
 (0)