@@ -1123,9 +1123,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
11231123 decoder_results [session_id ]['results' ][aa ]['shift' ][nunits ]= {}
11241124 decoder_results [session_id ]['results' ][aa ]['no_shift' ][nunits ]= {}
11251125 for rr in range (n_repeats ):
1126- if n_repeats > 1 :
1127- decoder_results [session_id ]['results' ][aa ]['shift' ][nunits ][rr ]= {}
1128- decoder_results [session_id ]['results' ][aa ]['no_shift' ][nunits ][rr ]= {}
1126+ decoder_results [session_id ]['results' ][aa ]['shift' ][nunits ][rr ]= {}
1127+ decoder_results [session_id ]['results' ][aa ]['no_shift' ][nunits ][rr ]= {}
11291128
11301129 if input_data_type == 'spikes' :
11311130 if nunits == 'all' :
@@ -1248,6 +1247,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12481247 n_repeats = 25
12491248
12501249 all_bal_acc = {}
1250+ all_trials_bal_acc = {}
12511251
12521252 linear_shift_dict = {
12531253 'session_id' :[],
@@ -1276,10 +1276,11 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12761276 linear_shift_dict ['null_accuracy_median_' + str (nu )]= []
12771277 linear_shift_dict ['null_accuracy_std_' + str (nu )]= []
12781278 linear_shift_dict ['p_value_' + str (nu )]= []
1279+ linear_shift_dict ['true_accuracy_all_trials_no_shift_' + str (nu )]= []
12791280
12801281 #loop through sessions
12811282 for file in files :
1282- try :
1283+ # try:
12831284 decoder_results = pickle .loads (upath .UPath (file ).read_bytes ())
12841285 session_id = str (list (decoder_results .keys ())[0 ])
12851286 session_info = npc_lims .get_session_info (session_id )
@@ -1298,6 +1299,7 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
12981299 continue
12991300
13001301 all_bal_acc [session_id ]= {}
1302+ all_trials_bal_acc [session_id ]= {}
13011303
13021304 nunits = decoder_results [session_id ]['n_units' ]
13031305 if nunits != nunits_global :
@@ -1330,28 +1332,33 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
13301332 for aa in areas :
13311333 if aa in decoder_results [session_id ]['results' ]:
13321334 all_bal_acc [session_id ][aa ]= {}
1335+ all_trials_bal_acc [session_id ][aa ]= {}
13331336 ### ADD LOOP FOR NUNITS ###
13341337 for nu in nunits :
13351338 if nu not in decoder_results [session_id ]['results' ][aa ]['shift' ].keys ():
13361339 continue
13371340 all_bal_acc [session_id ][aa ][nu ]= []
1341+ all_trials_bal_acc [session_id ][aa ][nu ]= []
13381342 for rr in range (n_repeats ):
13391343 if rr in decoder_results [session_id ]['results' ][aa ]['shift' ][nu ].keys ():
13401344 temp_bal_acc = []
1341- temp_bal_acc_all_trials = []
13421345 # else:
13431346 # print('n repeats invalid: '+str(rr))
13441347 # continue
13451348 for sh in half_shift_inds :
13461349 if sh in list (decoder_results [session_id ]['results' ][aa ]['shift' ][nu ][rr ].keys ()):
13471350 temp_bal_acc .append (decoder_results [session_id ]['results' ][aa ]['shift' ][nu ][rr ][sh ]['balanced_accuracy_test' ])
1348- if sh == 0 :
1349- temp_bal_acc_all_trials .append (decoder_results [session_id ]['results' ][aa ]['no_shift' ][nu ][rr ]['balanced_accuracy_test' ])
1351+
13501352 if len (temp_bal_acc )> 0 :
13511353 all_bal_acc [session_id ][aa ][nu ].append (np .array (temp_bal_acc ))
1354+
1355+ all_trials_bal_acc [session_id ][aa ][nu ].append (decoder_results [session_id ]['results' ][aa ]['no_shift' ][nu ][rr ]['balanced_accuracy_test' ])
1356+
13521357 all_bal_acc [session_id ][aa ][nu ]= np .vstack (all_bal_acc [session_id ][aa ][nu ])
13531358 all_bal_acc [session_id ][aa ][nu ]= np .nanmean (all_bal_acc [session_id ][aa ][nu ],axis = 0 )
13541359
1360+ all_trials_bal_acc [session_id ][aa ][nu ]= np .nanmean (all_trials_bal_acc [session_id ][aa ][nu ])
1361+
13551362 if type (aa )== str :
13561363 if '_probe' in aa :
13571364 area_name = aa .split ('_probe' )[0 ]
@@ -1391,6 +1398,12 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
13911398 linear_shift_dict ['null_accuracy_std_' + str (nu )].append (np .nan )
13921399 linear_shift_dict ['p_value_' + str (nu )].append (np .nan )
13931400
1401+ if nu in all_trials_bal_acc [session_id ][aa ].keys ():
1402+ true_accuracy = all_trials_bal_acc [session_id ][aa ][nu ]
1403+ linear_shift_dict ['true_accuracy_all_trials_no_shift_' + str (nu )].append (true_accuracy )
1404+ else :
1405+ linear_shift_dict ['true_accuracy_all_trials_no_shift_' + str (nu )].append (np .nan )
1406+
13941407 #make big dict/dataframe for this:
13951408 #save true decoding, mean/median null decoding, and p value for each area/probe
13961409 linear_shift_dict ['session_id' ].append (session_id )
@@ -1414,10 +1427,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
14141427 linear_shift_dict ['probe' ].append (np .nan )
14151428
14161429 print (aa + ' done' )
1417- except Exception as e :
1418- print (e )
1419- print ('error with session: ' + session_id )
1420- continue
1430+ # except Exception as e:
1431+ # print(e)
1432+ # print('error with session: '+session_id)
1433+ # continue
14211434
14221435
14231436 linear_shift_df = pd .DataFrame (linear_shift_dict )
0 commit comments