Skip to content

Commit 6a9222c

Browse files
committed
add all trials accuracy to summary table; bugfix
1 parent d0bcaa4 commit 6a9222c

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/dynamic_routing_analysis/decoding_utils.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,9 +1123,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
11231123
decoder_results[session_id]['results'][aa]['shift'][nunits]={}
11241124
decoder_results[session_id]['results'][aa]['no_shift'][nunits]={}
11251125
for rr in range(n_repeats):
1126-
if n_repeats>1:
1127-
decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={}
1128-
decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={}
1126+
decoder_results[session_id]['results'][aa]['shift'][nunits][rr]={}
1127+
decoder_results[session_id]['results'][aa]['no_shift'][nunits][rr]={}
11291128

11301129
if input_data_type=='spikes':
11311130
if nunits=='all':
@@ -1248,6 +1247,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12481247
n_repeats=25
12491248

12501249
all_bal_acc={}
1250+
all_trials_bal_acc={}
12511251

12521252
linear_shift_dict={
12531253
'session_id':[],
@@ -1276,10 +1276,11 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12761276
linear_shift_dict['null_accuracy_median_'+str(nu)]=[]
12771277
linear_shift_dict['null_accuracy_std_'+str(nu)]=[]
12781278
linear_shift_dict['p_value_'+str(nu)]=[]
1279+
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)]=[]
12791280

12801281
#loop through sessions
12811282
for file in files:
1282-
try:
1283+
# try:
12831284
decoder_results=pickle.loads(upath.UPath(file).read_bytes())
12841285
session_id=str(list(decoder_results.keys())[0])
12851286
session_info=npc_lims.get_session_info(session_id)
@@ -1298,6 +1299,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12981299
continue
12991300

13001301
all_bal_acc[session_id]={}
1302+
all_trials_bal_acc[session_id]={}
13011303

13021304
nunits=decoder_results[session_id]['n_units']
13031305
if nunits!=nunits_global:
@@ -1330,28 +1332,33 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
13301332
for aa in areas:
13311333
if aa in decoder_results[session_id]['results']:
13321334
all_bal_acc[session_id][aa]={}
1335+
all_trials_bal_acc[session_id][aa]={}
13331336
### ADD LOOP FOR NUNITS ###
13341337
for nu in nunits:
13351338
if nu not in decoder_results[session_id]['results'][aa]['shift'].keys():
13361339
continue
13371340
all_bal_acc[session_id][aa][nu]=[]
1341+
all_trials_bal_acc[session_id][aa][nu]=[]
13381342
for rr in range(n_repeats):
13391343
if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys():
13401344
temp_bal_acc=[]
1341-
temp_bal_acc_all_trials=[]
13421345
# else:
13431346
# print('n repeats invalid: '+str(rr))
13441347
# continue
13451348
for sh in half_shift_inds:
13461349
if sh in list(decoder_results[session_id]['results'][aa]['shift'][nu][rr].keys()):
13471350
temp_bal_acc.append(decoder_results[session_id]['results'][aa]['shift'][nu][rr][sh]['balanced_accuracy_test'])
1348-
if sh==0:
1349-
temp_bal_acc_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test'])
1351+
13501352
if len(temp_bal_acc)>0:
13511353
all_bal_acc[session_id][aa][nu].append(np.array(temp_bal_acc))
1354+
1355+
all_trials_bal_acc[session_id][aa][nu].append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test'])
1356+
13521357
all_bal_acc[session_id][aa][nu]=np.vstack(all_bal_acc[session_id][aa][nu])
13531358
all_bal_acc[session_id][aa][nu]=np.nanmean(all_bal_acc[session_id][aa][nu],axis=0)
13541359

1360+
all_trials_bal_acc[session_id][aa][nu]=np.nanmean(all_trials_bal_acc[session_id][aa][nu])
1361+
13551362
if type(aa)==str:
13561363
if '_probe' in aa:
13571364
area_name=aa.split('_probe')[0]
@@ -1391,6 +1398,12 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
13911398
linear_shift_dict['null_accuracy_std_'+str(nu)].append(np.nan)
13921399
linear_shift_dict['p_value_'+str(nu)].append(np.nan)
13931400

1401+
if nu in all_trials_bal_acc[session_id][aa].keys():
1402+
true_accuracy=all_trials_bal_acc[session_id][aa][nu]
1403+
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(true_accuracy)
1404+
else:
1405+
linear_shift_dict['true_accuracy_all_trials_no_shift_'+str(nu)].append(np.nan)
1406+
13941407
#make big dict/dataframe for this:
13951408
#save true decoding, mean/median null decoding, and p value for each area/probe
13961409
linear_shift_dict['session_id'].append(session_id)
@@ -1414,10 +1427,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
14141427
linear_shift_dict['probe'].append(np.nan)
14151428

14161429
print(aa+' done')
1417-
except Exception as e:
1418-
print(e)
1419-
print('error with session: '+session_id)
1420-
continue
1430+
# except Exception as e:
1431+
# print(e)
1432+
# print('error with session: '+session_id)
1433+
# continue
14211434

14221435

14231436
linear_shift_df=pd.DataFrame(linear_shift_dict)

0 commit comments

Comments
 (0)