Skip to content

Commit 943aed8

Browse files
committed
black formatting
1 parent c474d27 commit 943aed8

File tree

7 files changed

+357
-143
lines changed

7 files changed

+357
-143
lines changed

keypoint_moseq/analysis.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_syllable_names(project_dir, model_name, syllable_ixs):
4646

4747
for ix in syllable_ixs:
4848
if len(syll_info_df[syll_info_df.syllable == ix].label.values[0]) > 0:
49-
labels[ix] = f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
49+
labels[ix] = (
50+
f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
51+
)
5052
names = [labels[ix] for ix in syllable_ixs]
5153
return names
5254

@@ -214,14 +216,17 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
214216
np.concatenate(
215217
(
216218
[0],
217-
np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1)) * fps,
219+
np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1))
220+
* fps,
218221
)
219222
)
220223
)
221224

222225
if index_data is not None:
223226
# find the group for each recording from index data
224-
s_group.append([index_data[index_data["name"] == k]["group"].values[0]] * n_frame)
227+
s_group.append(
228+
[index_data[index_data["name"] == k]["group"].values[0]] * n_frame
229+
)
225230
else:
226231
# no index data
227232
s_group.append(["default"] * n_frame)
@@ -236,8 +241,12 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
236241
heading.append(recording_heading)
237242

238243
# compute angular velocity (radian per second)
239-
gaussian_smoothed_heading = filter_angle(recording_heading, size=3, method="gaussian")
240-
angular_velocity.append(np.concatenate(([0], np.diff(gaussian_smoothed_heading) * fps)))
244+
gaussian_smoothed_heading = filter_angle(
245+
recording_heading, size=3, method="gaussian"
246+
)
247+
angular_velocity.append(
248+
np.concatenate(([0], np.diff(gaussian_smoothed_heading) * fps))
249+
)
241250

242251
# add syllable data
243252
syllables.append(v["syllable"])
@@ -367,7 +376,9 @@ def compute_stats_df(
367376
def generate_syll_info(project_dir, model_name, syll_info_path):
368377
# parse model results
369378
model_results = load_results(project_dir, model_name)
370-
unique_sylls = np.unique(np.concatenate([file["syllable"] for file in model_results.values()]))
379+
unique_sylls = np.unique(
380+
np.concatenate([file["syllable"] for file in model_results.values()])
381+
)
371382
# construct the syllable dictionary
372383
# in the non interactive version there won't be any group info
373384
syll_info_df = pd.DataFrame(
@@ -428,8 +439,12 @@ def label_syllables(project_dir, model_name, moseq_df):
428439
# load syll_info
429440
syll_info_df = pd.read_csv(syll_info_path, index_col=False).fillna("")
430441
# split into with movie and without movie
431-
syll_info_df_with_movie = syll_info_df[syll_info_df.movie_path.str.contains(".mp4")].copy()
432-
syll_info_df_without_movie = syll_info_df[~syll_info_df.movie_path.str.contains(".mp4")].copy()
442+
syll_info_df_with_movie = syll_info_df[
443+
syll_info_df.movie_path.str.contains(".mp4")
444+
].copy()
445+
syll_info_df_without_movie = syll_info_df[
446+
~syll_info_df.movie_path.str.contains(".mp4")
447+
].copy()
433448

434449
# create select widget only include the ones with a movie
435450
select = pn.widgets.Select(
@@ -520,7 +535,9 @@ def b(event, save=True):
520535
button.on_click(b)
521536

522537
# bind everything together
523-
return pn.Row(pn.Column(select, ivideo), pn.Column(summary_table, pn.Column(button)))
538+
return pn.Row(
539+
pn.Column(select, ivideo), pn.Column(summary_table, pn.Column(button))
540+
)
524541

525542

526543
def get_tie_correction(x, N_m):
@@ -605,7 +622,10 @@ def run_manual_KW_test(
605622
# get square of sums for each group
606623
ssbn = np.zeros((n_perm, N_s))
607624
for i in range(num_groups):
608-
ssbn += perm_ranks[:, cum_group_idx[i] : cum_group_idx[i + 1]].sum(1) ** 2 / n_per_group[i]
625+
ssbn += (
626+
perm_ranks[:, cum_group_idx[i] : cum_group_idx[i + 1]].sum(1) ** 2
627+
/ n_per_group[i]
628+
)
609629

610630
# h-statistic
611631
h_all = 12.0 / (N_m * (N_m + 1)) * ssbn - 3 * (N_m + 1)
@@ -616,7 +636,9 @@ def run_manual_KW_test(
616636
p_i = np.random.randint(n_perm)
617637
s_i = np.random.randint(N_s)
618638
kr = stats.kruskal(
619-
*np.array_split(merged_usages_all[perm[p_i, :], s_i], np.cumsum(n_per_group[:-1]))
639+
*np.array_split(
640+
merged_usages_all[perm[p_i, :], s_i], np.cumsum(n_per_group[:-1])
641+
)
620642
)
621643
assert (kr.statistic == h_all[p_i, s_i]) & (
622644
kr.pvalue == p_vals[p_i, s_i]
@@ -671,7 +693,8 @@ def dunns_z_test_permute_within_group_pairs(
671693

672694
ranks_perm = real_ranks[(is_i | is_j)][rnd.rand(n_perm, n_mice).argsort(-1)]
673695
diff = np.abs(
674-
ranks_perm[:, : is_i.sum(), :].mean(1) - ranks_perm[:, is_i.sum() :, :].mean(1)
696+
ranks_perm[:, : is_i.sum(), :].mean(1)
697+
- ranks_perm[:, is_i.sum() :, :].mean(1)
675698
)
676699
B = 1.0 / vc.loc[i_n] + 1.0 / vc.loc[j_n]
677700

@@ -732,7 +755,9 @@ def compute_pvalues_for_group_pairs(
732755

733756
p_vals_allperm = {}
734757
for pair in combinations(group_names, 2):
735-
p_vals_allperm[pair] = ((null_zs[pair] > real_zs_within_group[pair]).sum(0) + 1) / n_perm
758+
p_vals_allperm[pair] = (
759+
(null_zs[pair] > real_zs_within_group[pair]).sum(0) + 1
760+
) / n_perm
736761

737762
# summarize into df
738763
df_pval = pd.DataFrame(p_vals_allperm)
@@ -782,7 +807,9 @@ def run_kruskal(
782807
rnd = np.random.RandomState(seed=seed)
783808
# get grouped mean data
784809
grouped_data = (
785-
stats_df.pivot_table(index=["group", "name"], columns="syllable", values=statistic)
810+
stats_df.pivot_table(
811+
index=["group", "name"], columns="syllable", values=statistic
812+
)
786813
.replace(np.nan, 0)
787814
.reset_index()
788815
)
@@ -813,7 +840,9 @@ def run_kruskal(
813840
# find the real k_real
814841
df_k_real = pd.DataFrame(
815842
[
816-
stats.kruskal(*np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1])))
843+
stats.kruskal(
844+
*np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1]))
845+
)
817846
for s_i in range(N_s)
818847
]
819848
)
@@ -851,7 +880,9 @@ def run_kruskal(
851880
df_z = pd.DataFrame(real_zs_within_group)
852881
df_z.index = df_z.index.set_names("syllable")
853882
dunn_results_df = df_z.reset_index().melt(id_vars=[("syllable", "")])
854-
dunn_results_df.rename(columns={"variable_0": "group1", "variable_1": "group2"}, inplace=True)
883+
dunn_results_df.rename(
884+
columns={"variable_0": "group1", "variable_1": "group2"}, inplace=True
885+
)
855886

856887
# Get intersecting significant syllables between
857888
intersect_sig_syllables = {}
@@ -864,7 +895,9 @@ def run_kruskal(
864895

865896

866897
# frequency plot stuff
867-
def sort_syllables_by_stat_difference(stats_df, ctrl_group, exp_group, stat="frequency"):
898+
def sort_syllables_by_stat_difference(
899+
stats_df, ctrl_group, exp_group, stat="frequency"
900+
):
868901
"""Sort syllables by the difference in the stat between the control and
869902
experimental group.
870903
@@ -997,7 +1030,9 @@ def _validate_and_order_syll_stats_params(
9971030
raise ValueError(
9981031
f"Attempting to sort by {stat} differences, but {ctrl_group} or {exp_group} not in {groups}."
9991032
)
1000-
ordering = sort_syllables_by_stat_difference(complete_df, ctrl_group, exp_group, stat=stat)
1033+
ordering = sort_syllables_by_stat_difference(
1034+
complete_df, ctrl_group, exp_group, stat=stat
1035+
)
10011036
if colors is None:
10021037
colors = []
10031038
if len(colors) == 0 or len(colors) != len(groups):
@@ -1146,7 +1181,9 @@ def plot_syll_stats_with_sem(
11461181
markings.append(np.where(ordering == s)[0])
11471182
if len(markings) > 0:
11481183
markings = np.concatenate(markings)
1149-
plt.scatter(markings, [init_y] * len(markings), color="r", marker="*")
1184+
plt.scatter(
1185+
markings, [init_y] * len(markings), color="r", marker="*"
1186+
)
11501187
plt.text(
11511188
plt.xlim()[1],
11521189
init_y,
@@ -1309,7 +1346,9 @@ def get_transition_matrix(
13091346
# Get syllable transitions
13101347
transitions = get_transitions(v)[0]
13111348

1312-
trans_mat = n_gram_transition_matrix(transitions, n=2, max_label=max_syllable)
1349+
trans_mat = n_gram_transition_matrix(
1350+
transitions, n=2, max_label=max_syllable
1351+
)
13131352
init_matrix.append(trans_mat)
13141353

13151354
init_matrix = np.sum(init_matrix, axis=0) + smoothing
@@ -1322,7 +1361,8 @@ def get_transition_matrix(
13221361
transitions = get_transitions(v)[0]
13231362

13241363
trans_mat = (
1325-
n_gram_transition_matrix(transitions, n=2, max_label=max_syllable) + smoothing
1364+
n_gram_transition_matrix(transitions, n=2, max_label=max_syllable)
1365+
+ smoothing
13261366
)
13271367

13281368
# Normalize matrix
@@ -1368,9 +1408,9 @@ def get_group_trans_mats(labels, label_group, group, syll_include, normalize="bi
13681408
# Get recordings to include in trans_mat
13691409
# subset only syllable included
13701410
trans_mats.append(
1371-
get_transition_matrix(use_labels, normalize=normalize, combine=True)[syll_include, :][
1372-
:, syll_include
1373-
]
1411+
get_transition_matrix(use_labels, normalize=normalize, combine=True)[
1412+
syll_include, :
1413+
][:, syll_include]
13741414
)
13751415

13761416
# Getting frequency information for node scaling
@@ -1447,7 +1487,9 @@ def visualize_transition_bigram(
14471487
save_analysis_figure(fig, plot_name, project_dir, model_name, save_dir)
14481488

14491489

1450-
def generate_transition_matrices(project_dir, model_name, normalize="bigram", min_frequency=0.005):
1490+
def generate_transition_matrices(
1491+
project_dir, model_name, normalize="bigram", min_frequency=0.005
1492+
):
14511493
"""Generate the transition matrices for each recording.
14521494
14531495
Parameters
@@ -1543,7 +1585,9 @@ def plot_transition_graph_group(
15431585
nodelist = G.nodes()
15441586
# normalize the usage values
15451587
sum_usages = sum(usages[i])
1546-
normalized_usages = np.array([u / sum_usages for u in usages[i]]) * node_scaling + 1000
1588+
normalized_usages = (
1589+
np.array([u / sum_usages for u in usages[i]]) * node_scaling + 1000
1590+
)
15471591
nx.draw_networkx_nodes(
15481592
G,
15491593
pos,
@@ -1631,7 +1675,9 @@ def plot_transition_graph_difference(
16311675
# left tm minus right tm
16321676
tm_diff = trans_mats[left_ind] - trans_mats[right_ind]
16331677
# left usage minus right usage
1634-
usages_diff = np.array(list(usages[left_ind])) - np.array(list(usages[right_ind]))
1678+
usages_diff = np.array(list(usages[left_ind])) - np.array(
1679+
list(usages[right_ind])
1680+
)
16351681
normlized_usg_abs_diff = (
16361682
np.abs(usages_diff) / np.abs(usages_diff).sum()
16371683
) * node_scaling + 500

0 commit comments

Comments
 (0)