Skip to content

Commit 40632b4

Browse files
committed
add roi stats functions
1 parent 0dc800b commit 40632b4

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from matplotlib_venn import venn3
4+
5+
from phoneme_segmentation.config import *
6+
7+
def EV_calc():
8+
valid_all = {subj: np.load(f"{BOLD_VALID_DIR}/{subj}_valid.npz")["wheretheressmoke"] for subj in SUBJECTS_ALL}
9+
for subj in SUBJECTS_ALL:
10+
print(subj)
11+
print(valid_all[subj].shape)
12+
13+
14+
EV_all = {k: explainable_variance(v) for k, v in valid_all.items()}
15+
print([v.shape for _, v in EV_all.items()])
16+
17+
cci.dict2cloud(f"{S3_EV_DIR}", EV)
18+
return EV_all
19+
20+
def sig_vox_threshold(perf_all:dict,
21+
perm_feature:str,
22+
models:list):
23+
24+
'''
25+
perf_all: output of reorg_perf_raw
26+
perm_feature: toSemAll(for all other models, particularly thirdOrder vs semantic comparison) or thirdOrder (for VP analysis)
27+
models: list of models: for toSemAll threshold (MODELS_ALL + ["powspec", "numPhns"]): baseline, powspec, numPhns, firstOrder, secondOrder, thirdOrder, toSemAll, semantic
28+
for thirdOrder threshold (MODELS_VP): single, diphone, triphone, singleDiphn, singleTri, DiphoneTri, SingDiTri
29+
'''
30+
31+
perf_sig_all = {}
32+
perf_sig_EVcorrected_all = {}
33+
EV_all = {}
34+
perm_res_all = {}
35+
36+
for subj in SUBJECTS_ALL:
37+
38+
EV_all[subj] = cci.download_raw_array(f"{S3_EV_DIR}/{subj}")
39+
40+
## load in perm res
41+
perm_res_all[f"{subj}_{perm_feature}"] = np.load(f"{LOCAL_PERM_RES}/{subj}_{perm_feature}_perm.npz")["perm"]
42+
perf_perm_tmp = cci.download_raw_array(f"{S3_MODEL_RAW_SUMMARY_ROOT_DIR}/{subj}_{perm_feature}")
43+
44+
45+
for m in models:
46+
print(f"{subj}_{m}")
47+
perf_sig_all[f"{subj}_{m}"], perf_sig_EVcorrected_all[f"{subj}_{m}"] = calc_stats(perf_all[f"{subj}_{m}"], perf_perm_tmp, perm_res_all[f"{subj}_{perm_feature}"], EV_all[subj])
48+
print(perf_sig_all[f"{subj}_{m}"].shape)
49+
print(perf_sig_EVcorrected_all[f"{subj}_{m}"].shape)
50+
51+
cci.upload_raw_array(f"{S3_MODEL_SIG_SUMMARY_ROOT_DIR}/{subj}_{m}", perf_sig_all[f"{subj}_{m}"])
52+
cci.upload_raw_array(f"{S3_MODEL_SIG_EVcorrected_SUMMARY_ROOT_DIR}/{subj}_{m}", perf_sig_EVcorrected_all[f"{subj}_{m}"])
53+
54+
return perf_sig_all, perf_sig_EVcorrected_all
55+
56+
57+
def plot_venn_diagram(perf:dict,
58+
plot_idx=[0, 1, 3, 2, 4, 5, 6]):
59+
'''
60+
perf: dict: key: subj; value: [n_features, n_vox]
61+
'''
62+
## quickly viz venn diagram
63+
true_area_all = {}
64+
true_area_arr = np.zeros((len(SUBJECTS_ALL), len(MODELS_VP)))
65+
for subj_i, subj in enumerate(SUBJECTS_ALL):
66+
true_area = np.nan_to_num(perf[subj]).mean(1)
67+
true_area /= true_area.sum()
68+
true_area_all[subj] = true_area
69+
true_area_arr[subj_i, :] = true_area
70+
print(MODELS_VP)
71+
print(true_area)
72+
true_area_plot = [true_area[i] for i in plot_idx]
73+
print(true_area_plot)
74+
75+
# Make the diagram
76+
plt.title(subj)
77+
venn3(subsets = [round(i, 2) for i in true_area_plot])
78+
plt.show()
79+
80+
def extract_roi_hemi_perf(pycortex_info
81+
subj:str,
82+
roi:str,
83+
hemi:str,
84+
perf):
85+
if hemi == "left":
86+
hemi_code = 0
87+
else:
88+
hemi_code = 1
89+
90+
roi_mask = cortex.utils.get_roi_masks(**pycortex_info[subj], roi_list=[roi], gm_sampler='cortical-conservative', return_dict=True)[roi]
91+
wholeBrain_mask = cortex.get_cortical_mask(**pycortex_info[subj], type = "thick")
92+
hemi_mask = cortex.utils.get_hemi_masks(**pycortex_info[subj], type='nearest')[hemi_code]
93+
94+
roi_hem_mask = np.zeros(np.sum(wholeBrain_mask)).astype(bool)
95+
h = 0
96+
for n_i, n in enumerate(wholeBrain_mask.flatten()):
97+
if n == True:
98+
if roi_mask.flatten()[n_i]!=0 and hemi_mask.flatten()[n_i] == True:
99+
roi_hem_mask[h] = True
100+
else:
101+
roi_hem_mask[h] = False
102+
h = h + 1
103+
104+
perf_roi_hemi = perf[roi_hem_mask]
105+
106+
roi_hem_mask_plot = np.full(roi_mask.shape, False)
107+
for x in range(roi_mask.shape[0]):
108+
for y in range(roi_mask.shape[1]):
109+
for z in range(roi_mask.shape[2]):
110+
if wholeBrain_mask[x,y,z]== True and hemi_mask[x,y,z]== True and roi_mask[x,y,z]!=0:
111+
roi_hem_mask_plot[x,y,z] = True
112+
113+
return perf_roi_hemi, roi_hem_mask_plot, roi_hem_mask
114+
115+
def extract_LTCunique(subj, perf, h):
116+
117+
_,STS_maskPlot, STS_maskPerf = extract_roi_hemi_perf(subj, "STS", h, perf)
118+
_,STG_maskPlot, STG_maskPerf = extract_roi_hemi_perf(subj, "STG", h, perf)
119+
_,AC_maskPlot, AC_maskPerf = extract_roi_hemi_perf(subj, "AC", h, perf)
120+
_,LTC_maskPlot, LTC_maskPerf = extract_roi_hemi_perf(subj, "LTC", h, perf)
121+
122+
print ("obtain new mask")
123+
mask_union_forPlot_tmp = np.array([a or b or c for a, b, c in zip(STS_maskPlot.flatten(), STG_maskPlot.flatten(), AC_maskPlot.flatten())]).reshape((STS_maskPlot.shape))
124+
mask_union_forPerf_tmp = np.array([a or b or c for a, b, c in zip(STS_maskPerf, STG_maskPerf, AC_maskPerf)])
125+
126+
mask_forPlot = []
127+
for a_i, a in enumerate(mask_union_forPlot_tmp.flatten()):
128+
if a == False and LTC_maskPlot.flatten()[a_i] == True:
129+
mask_forPlot.append(True)
130+
else:
131+
mask_forPlot.append(False)
132+
mask_exclusion_new = np.array(mask_forPlot)
133+
134+
mask_forPerf = []
135+
for a_i, a in enumerate(mask_union_forPerf_tmp):
136+
if a == False and LTC_maskPerf[a_i] == True:
137+
mask_forPerf.append(True)
138+
else:
139+
mask_forPerf.append(False)
140+
mask_forPerf = np.array(mask_forPerf)
141+
142+
perf_exclusion_new = perf[mask_forPerf]
143+
144+
return perf_exclusion_new, mask_exclusion_new
145+
146+
def extract_FCunique(subj, perf, h):
147+
_,Broca_maskPlot, Broca_maskPerf = extract_roi_hemi_perf(subj, "Broca", h, perf)
148+
_,sPMv_maskPlot, sPMv_maskPerf = extract_roi_hemi_perf(subj, "sPMv", h, perf)
149+
_,FC_maskPlot, FC_maskPerf = extract_roi_hemi_perf(subj, "FC", h, perf)
150+
151+
print ("obtain new mask")
152+
mask_union_forPlot_tmp = np.array([a or b for a, b in zip(Broca_maskPlot.flatten(), sPMv_maskPlot.flatten())])
153+
mask_union_forPerf_tmp = np.array([a or b for a, b in zip(Broca_maskPerf, sPMv_maskPerf)])
154+
155+
mask_forPlot = []
156+
for a_i, a in enumerate(mask_union_forPlot_tmp.flatten()):
157+
if a == False and FC_maskPlot.flatten()[a_i] == True:
158+
mask_forPlot.append(True)
159+
else:
160+
mask_forPlot.append(False)
161+
mask_exclusion_new = np.array(mask_forPlot)
162+
mask_forPerf = []
163+
for a_i, a in enumerate(mask_union_forPerf_tmp):
164+
if a == False and FC_maskPerf[a_i] == True:
165+
mask_forPerf.append(True)
166+
else:
167+
mask_forPerf.append(False)
168+
mask_forPerf = np.array(mask_forPerf)
169+
170+
perf_exclusion_new = perf[mask_forPerf]
171+
172+
return perf_exclusion_new, mask_exclusion_new
173+
174+
def extract_ACunique(subj, perf, h):
175+
176+
_,STS_maskPlot, STS_maskPerf = extract_roi_hemi_perf(subj, "STS", h, perf)
177+
_,STG_maskPlot, STG_maskPerf = extract_roi_hemi_perf(subj, "STG", h, perf)
178+
_,AC_maskPlot, AC_maskPerf = extract_roi_hemi_perf(subj, "AC", h, perf)
179+
180+
print ("obtain new mask")
181+
mask_union_forPlot_tmp = np.array([a or b for a, b in zip(STS_maskPlot.flatten(), STG_maskPlot.flatten())]).reshape((STS_maskPlot.shape))
182+
mask_union_forPerf_tmp = np.array([a or b for a, b in zip(STS_maskPerf, STG_maskPerf)])
183+
184+
mask_forPlot = []
185+
for a_i, a in enumerate(mask_union_forPlot_tmp.flatten()):
186+
if a == False and AC_maskPlot.flatten()[a_i] == True:
187+
mask_forPlot.append(True)
188+
else:
189+
mask_forPlot.append(False)
190+
mask_exclusion_new = np.array(mask_forPlot)
191+
192+
mask_forPerf = []
193+
for a_i, a in enumerate(mask_union_forPerf_tmp):
194+
if a == False and AC_maskPerf[a_i] == True:
195+
mask_forPerf.append(True)
196+
else:
197+
mask_forPerf.append(False)
198+
mask_forPerf = np.array(mask_forPerf)
199+
200+
perf_exclusion_new = perf[mask_forPerf]
201+
202+
return perf_exclusion_new, mask_exclusion_new
203+

0 commit comments

Comments
 (0)