Skip to content

Commit 809c404

Browse files
committed
changed generate_syllable_mapping to count frequency by instances rather than frames, which makes it consistent with plot_syllable_frequencies and reindex_syllables_in_checkpoint
1 parent 873bc73 commit 809c404

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

keypoint_moseq/util.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,7 @@ def check_video_paths(video_paths, keys):
13421342
raise ValueError("\n\n".join(error_messages))
13431343

13441344

1345-
def generate_syllable_mapping(results: dict, syllable_grouping: list[list[int]]) -> dict[int, int]:
1345+
def generate_syllable_mapping(results: dict, syllable_grouping: list[list[int]], runlength: bool = True) -> dict[int, int]:
13461346
"""
13471347
Create a mapping from old syllable indexes to new syllable indexes such that each group of
13481348
syllables in `syllable_grouping` is mapped to a single index. All syllables not included in
@@ -1358,6 +1358,11 @@ def generate_syllable_mapping(results: dict, syllable_grouping: list[list[int]])
13581358
syllable_grouping: list[list[int]]
13591359
List of lists representing groups of syllables that should be mapped to a single index.
13601360
1361+
runlength: bool, default=True
1362+
If True, frequencies are quantified using the number of non-consecutive
1363+
occurrences of each syllable. If False, frequency is quantified by
1364+
total number of frames.
1365+
13611366
Returns
13621367
-------
13631368
mapping: dict[int, int]
@@ -1371,30 +1376,27 @@ def generate_syllable_mapping(results: dict, syllable_grouping: list[list[int]])
13711376
>>> print(mapping)
13721377
>>> # {0: 0, 1: 0, 2: 1, 3: 2, 4: 3, 5: 1, 6: 1}
13731378
"""
1374-
# Count the number of times each syllable is used
1375-
syllable_counts = np.zeros(max(max(v["syllable"]) for v in results.values()) + 1, dtype=int)
1376-
for v in results.values():
1377-
unique, counts = np.unique(v["syllable"], return_counts=True)
1378-
syllable_counts[unique] += counts
1379+
syllables = {k: res["syllable"] for k, res in results.items()}
1380+
syllable_frequencies = get_frequencies(syllables, runlength=runlength)
13791381

13801382
# Get a list of all syllables that are in a group
13811383
syllables_to_group = [s for group in syllable_grouping for s in group]
13821384

1383-
# Count the total number of times a group of syllables is used
1384-
all_counts = []
1385+
# Calculate the total frequency for each group of syllables
1386+
all_frequencies = []
13851387
for group in syllable_grouping:
1386-
group_count = sum(syllable_counts[s] for s in group)
1387-
all_counts.append((group_count, group))
1388+
group_frequency = sum(syllable_frequencies[s] for s in group)
1389+
all_frequencies.append((group_frequency, group))
13881390

1389-
# Count the number of times a single syllable is used
1390-
for syllable in range(len(syllable_counts)):
1391+
# Add individual syllables not in any group
1392+
for syllable in range(len(syllable_frequencies)):
13911393
if syllable not in syllables_to_group:
1392-
all_counts.append((syllable_counts[syllable], [syllable]))
1394+
all_frequencies.append((syllable_frequencies[syllable], [syllable]))
13931395

1394-
all_counts.sort(reverse=True)
1396+
all_frequencies.sort(reverse=True)
13951397

13961398
mapping = {}
1397-
for i, (_, syllables) in enumerate(all_counts):
1399+
for i, (_, syllables) in enumerate(all_frequencies):
13981400
for syllable in syllables:
13991401
mapping[syllable] = i
14001402

0 commit comments

Comments
 (0)