@@ -46,7 +46,9 @@ def get_syllable_names(project_dir, model_name, syllable_ixs):
4646
4747 for ix in syllable_ixs :
4848 if len (syll_info_df [syll_info_df .syllable == ix ].label .values [0 ]) > 0 :
49- labels [ix ] = f"{ ix } ({ syll_info_df [syll_info_df .syllable == ix ].label .values [0 ]} )"
49+ labels [ix ] = (
50+ f"{ ix } ({ syll_info_df [syll_info_df .syllable == ix ].label .values [0 ]} )"
51+ )
5052 names = [labels [ix ] for ix in syllable_ixs ]
5153 return names
5254
@@ -214,14 +216,17 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
214216 np .concatenate (
215217 (
216218 [0 ],
217- np .sqrt (np .square (np .diff (v ["centroid" ], axis = 0 )).sum (axis = 1 )) * fps ,
219+ np .sqrt (np .square (np .diff (v ["centroid" ], axis = 0 )).sum (axis = 1 ))
220+ * fps ,
218221 )
219222 )
220223 )
221224
222225 if index_data is not None :
223226 # find the group for each recording from index data
224- s_group .append ([index_data [index_data ["name" ] == k ]["group" ].values [0 ]] * n_frame )
227+ s_group .append (
228+ [index_data [index_data ["name" ] == k ]["group" ].values [0 ]] * n_frame
229+ )
225230 else :
226231 # no index data
227232 s_group .append (["default" ] * n_frame )
@@ -236,8 +241,12 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
236241 heading .append (recording_heading )
237242
238243 # compute angular velocity (radian per second)
239- gaussian_smoothed_heading = filter_angle (recording_heading , size = 3 , method = "gaussian" )
240- angular_velocity .append (np .concatenate (([0 ], np .diff (gaussian_smoothed_heading ) * fps )))
244+ gaussian_smoothed_heading = filter_angle (
245+ recording_heading , size = 3 , method = "gaussian"
246+ )
247+ angular_velocity .append (
248+ np .concatenate (([0 ], np .diff (gaussian_smoothed_heading ) * fps ))
249+ )
241250
242251 # add syllable data
243252 syllables .append (v ["syllable" ])
@@ -367,7 +376,9 @@ def compute_stats_df(
367376def generate_syll_info (project_dir , model_name , syll_info_path ):
368377 # parse model results
369378 model_results = load_results (project_dir , model_name )
370- unique_sylls = np .unique (np .concatenate ([file ["syllable" ] for file in model_results .values ()]))
379+ unique_sylls = np .unique (
380+ np .concatenate ([file ["syllable" ] for file in model_results .values ()])
381+ )
371382 # construct the syllable dictionary
372383 # in the non interactive version there won't be any group info
373384 syll_info_df = pd .DataFrame (
@@ -428,8 +439,12 @@ def label_syllables(project_dir, model_name, moseq_df):
428439 # load syll_info
429440 syll_info_df = pd .read_csv (syll_info_path , index_col = False ).fillna ("" )
430441 # split into with movie and without movie
431- syll_info_df_with_movie = syll_info_df [syll_info_df .movie_path .str .contains (".mp4" )].copy ()
432- syll_info_df_without_movie = syll_info_df [~ syll_info_df .movie_path .str .contains (".mp4" )].copy ()
442+ syll_info_df_with_movie = syll_info_df [
443+ syll_info_df .movie_path .str .contains (".mp4" )
444+ ].copy ()
445+ syll_info_df_without_movie = syll_info_df [
446+ ~ syll_info_df .movie_path .str .contains (".mp4" )
447+ ].copy ()
433448
434449 # create select widget only include the ones with a movie
435450 select = pn .widgets .Select (
@@ -520,7 +535,9 @@ def b(event, save=True):
520535 button .on_click (b )
521536
522537 # bind everything together
523- return pn .Row (pn .Column (select , ivideo ), pn .Column (summary_table , pn .Column (button )))
538+ return pn .Row (
539+ pn .Column (select , ivideo ), pn .Column (summary_table , pn .Column (button ))
540+ )
524541
525542
526543def get_tie_correction (x , N_m ):
@@ -605,7 +622,10 @@ def run_manual_KW_test(
605622 # get square of sums for each group
606623 ssbn = np .zeros ((n_perm , N_s ))
607624 for i in range (num_groups ):
608- ssbn += perm_ranks [:, cum_group_idx [i ] : cum_group_idx [i + 1 ]].sum (1 ) ** 2 / n_per_group [i ]
625+ ssbn += (
626+ perm_ranks [:, cum_group_idx [i ] : cum_group_idx [i + 1 ]].sum (1 ) ** 2
627+ / n_per_group [i ]
628+ )
609629
610630 # h-statistic
611631 h_all = 12.0 / (N_m * (N_m + 1 )) * ssbn - 3 * (N_m + 1 )
@@ -616,7 +636,9 @@ def run_manual_KW_test(
616636 p_i = np .random .randint (n_perm )
617637 s_i = np .random .randint (N_s )
618638 kr = stats .kruskal (
619- * np .array_split (merged_usages_all [perm [p_i , :], s_i ], np .cumsum (n_per_group [:- 1 ]))
639+ * np .array_split (
640+ merged_usages_all [perm [p_i , :], s_i ], np .cumsum (n_per_group [:- 1 ])
641+ )
620642 )
621643 assert (kr .statistic == h_all [p_i , s_i ]) & (
622644 kr .pvalue == p_vals [p_i , s_i ]
@@ -671,7 +693,8 @@ def dunns_z_test_permute_within_group_pairs(
671693
672694 ranks_perm = real_ranks [(is_i | is_j )][rnd .rand (n_perm , n_mice ).argsort (- 1 )]
673695 diff = np .abs (
674- ranks_perm [:, : is_i .sum (), :].mean (1 ) - ranks_perm [:, is_i .sum () :, :].mean (1 )
696+ ranks_perm [:, : is_i .sum (), :].mean (1 )
697+ - ranks_perm [:, is_i .sum () :, :].mean (1 )
675698 )
676699 B = 1.0 / vc .loc [i_n ] + 1.0 / vc .loc [j_n ]
677700
@@ -732,7 +755,9 @@ def compute_pvalues_for_group_pairs(
732755
733756 p_vals_allperm = {}
734757 for pair in combinations (group_names , 2 ):
735- p_vals_allperm [pair ] = ((null_zs [pair ] > real_zs_within_group [pair ]).sum (0 ) + 1 ) / n_perm
758+ p_vals_allperm [pair ] = (
759+ (null_zs [pair ] > real_zs_within_group [pair ]).sum (0 ) + 1
760+ ) / n_perm
736761
737762 # summarize into df
738763 df_pval = pd .DataFrame (p_vals_allperm )
@@ -782,7 +807,9 @@ def run_kruskal(
782807 rnd = np .random .RandomState (seed = seed )
783808 # get grouped mean data
784809 grouped_data = (
785- stats_df .pivot_table (index = ["group" , "name" ], columns = "syllable" , values = statistic )
810+ stats_df .pivot_table (
811+ index = ["group" , "name" ], columns = "syllable" , values = statistic
812+ )
786813 .replace (np .nan , 0 )
787814 .reset_index ()
788815 )
@@ -813,7 +840,9 @@ def run_kruskal(
813840 # find the real k_real
814841 df_k_real = pd .DataFrame (
815842 [
816- stats .kruskal (* np .array_split (syllable_data [:, s_i ], np .cumsum (n_per_group [:- 1 ])))
843+ stats .kruskal (
844+ * np .array_split (syllable_data [:, s_i ], np .cumsum (n_per_group [:- 1 ]))
845+ )
817846 for s_i in range (N_s )
818847 ]
819848 )
@@ -851,7 +880,9 @@ def run_kruskal(
851880 df_z = pd .DataFrame (real_zs_within_group )
852881 df_z .index = df_z .index .set_names ("syllable" )
853882 dunn_results_df = df_z .reset_index ().melt (id_vars = [("syllable" , "" )])
854- dunn_results_df .rename (columns = {"variable_0" : "group1" , "variable_1" : "group2" }, inplace = True )
883+ dunn_results_df .rename (
884+ columns = {"variable_0" : "group1" , "variable_1" : "group2" }, inplace = True
885+ )
855886
856887 # Get intersecting significant syllables between
857888 intersect_sig_syllables = {}
@@ -864,7 +895,9 @@ def run_kruskal(
864895
865896
866897# frequency plot stuff
867- def sort_syllables_by_stat_difference (stats_df , ctrl_group , exp_group , stat = "frequency" ):
898+ def sort_syllables_by_stat_difference (
899+ stats_df , ctrl_group , exp_group , stat = "frequency"
900+ ):
868901 """Sort syllables by the difference in the stat between the control and
869902 experimental group.
870903
@@ -997,7 +1030,9 @@ def _validate_and_order_syll_stats_params(
9971030 raise ValueError (
9981031 f"Attempting to sort by { stat } differences, but { ctrl_group } or { exp_group } not in { groups } ."
9991032 )
1000- ordering = sort_syllables_by_stat_difference (complete_df , ctrl_group , exp_group , stat = stat )
1033+ ordering = sort_syllables_by_stat_difference (
1034+ complete_df , ctrl_group , exp_group , stat = stat
1035+ )
10011036 if colors is None :
10021037 colors = []
10031038 if len (colors ) == 0 or len (colors ) != len (groups ):
@@ -1146,7 +1181,9 @@ def plot_syll_stats_with_sem(
11461181 markings .append (np .where (ordering == s )[0 ])
11471182 if len (markings ) > 0 :
11481183 markings = np .concatenate (markings )
1149- plt .scatter (markings , [init_y ] * len (markings ), color = "r" , marker = "*" )
1184+ plt .scatter (
1185+ markings , [init_y ] * len (markings ), color = "r" , marker = "*"
1186+ )
11501187 plt .text (
11511188 plt .xlim ()[1 ],
11521189 init_y ,
@@ -1309,7 +1346,9 @@ def get_transition_matrix(
13091346 # Get syllable transitions
13101347 transitions = get_transitions (v )[0 ]
13111348
1312- trans_mat = n_gram_transition_matrix (transitions , n = 2 , max_label = max_syllable )
1349+ trans_mat = n_gram_transition_matrix (
1350+ transitions , n = 2 , max_label = max_syllable
1351+ )
13131352 init_matrix .append (trans_mat )
13141353
13151354 init_matrix = np .sum (init_matrix , axis = 0 ) + smoothing
@@ -1322,7 +1361,8 @@ def get_transition_matrix(
13221361 transitions = get_transitions (v )[0 ]
13231362
13241363 trans_mat = (
1325- n_gram_transition_matrix (transitions , n = 2 , max_label = max_syllable ) + smoothing
1364+ n_gram_transition_matrix (transitions , n = 2 , max_label = max_syllable )
1365+ + smoothing
13261366 )
13271367
13281368 # Normalize matrix
@@ -1368,9 +1408,9 @@ def get_group_trans_mats(labels, label_group, group, syll_include, normalize="bi
13681408 # Get recordings to include in trans_mat
13691409 # subset only syllable included
13701410 trans_mats .append (
1371- get_transition_matrix (use_labels , normalize = normalize , combine = True )[syll_include , :][
1372- :, syll_include
1373- ]
1411+ get_transition_matrix (use_labels , normalize = normalize , combine = True )[
1412+ syll_include , :
1413+ ][:, syll_include ]
13741414 )
13751415
13761416 # Getting frequency information for node scaling
@@ -1447,7 +1487,9 @@ def visualize_transition_bigram(
14471487 save_analysis_figure (fig , plot_name , project_dir , model_name , save_dir )
14481488
14491489
1450- def generate_transition_matrices (project_dir , model_name , normalize = "bigram" , min_frequency = 0.005 ):
1490+ def generate_transition_matrices (
1491+ project_dir , model_name , normalize = "bigram" , min_frequency = 0.005
1492+ ):
14511493 """Generate the transition matrices for each recording.
14521494
14531495 Parameters
@@ -1543,7 +1585,9 @@ def plot_transition_graph_group(
15431585 nodelist = G .nodes ()
15441586 # normalize the usage values
15451587 sum_usages = sum (usages [i ])
1546- normalized_usages = np .array ([u / sum_usages for u in usages [i ]]) * node_scaling + 1000
1588+ normalized_usages = (
1589+ np .array ([u / sum_usages for u in usages [i ]]) * node_scaling + 1000
1590+ )
15471591 nx .draw_networkx_nodes (
15481592 G ,
15491593 pos ,
@@ -1631,7 +1675,9 @@ def plot_transition_graph_difference(
16311675 # left tm minus right tm
16321676 tm_diff = trans_mats [left_ind ] - trans_mats [right_ind ]
16331677 # left usage minus right usage
1634- usages_diff = np .array (list (usages [left_ind ])) - np .array (list (usages [right_ind ]))
1678+ usages_diff = np .array (list (usages [left_ind ])) - np .array (
1679+ list (usages [right_ind ])
1680+ )
16351681 normlized_usg_abs_diff = (
16361682 np .abs (usages_diff ) / np .abs (usages_diff ).sum ()
16371683 ) * node_scaling + 500
0 commit comments