Skip to content

Commit aa073f6

Browse files
authored
Some additional options for supervision-related methods (#1115)
* add transform attribute for MixedCut * add mix_first option in normalize_loudness * handle the case when mix is called on MixedCut with existing transforms * add test for mixing with transformed MixedCut * enhancements and bug fixes * small changes in some cutset methods * small fix in error message * return word alignments from ami recipe * add word alignments for ICSI * remove unwanted whitespace * fix IHM preparation * remove words with zero or negative duration * ensure word alignments respect segment boundary * add save-to-wav option for icsi * add test for mixing cut with recording * style fix * add data prep for voxpopuli * small changes in recipes * changes for max segment duration * remove extra code * made suggested changes * apply change to multi custom merge func * remove old code * fix failing tests * add tests for trim to alignments with max segment duration * add tests for merge supervisions
1 parent d97b853 commit aa073f6

12 files changed

Lines changed: 267 additions & 81 deletions

File tree

lhotse/bin/modes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
from .manipulation import *
77
from .recipes import *
88
from .shar import *
9+
from .supervision import *
910
from .validate import *
1011
from .workflows import *

lhotse/bin/modes/supervision.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import click
2+
from tqdm import tqdm
3+
4+
from lhotse.bin.modes.cli_base import cli
5+
from lhotse.serialization import load_manifest_lazy_or_eager
6+
from lhotse.supervision import SupervisionSet
7+
from lhotse.utils import Pathlike
8+
9+
10+
@cli.group()
11+
def supervision():
12+
"""Commands related to manipulating supervision manifests."""
13+
pass
14+
15+
16+
@supervision.command()
17+
@click.argument("in_supervision_manifest", type=click.Path(allow_dash=True))
18+
@click.argument("out_supervision_manifest", type=click.Path(allow_dash=True))
19+
@click.option(
20+
"--ctm-file",
21+
type=click.Path(exists=True, dir_okay=False),
22+
help="CTM file containing alignments to add.",
23+
)
24+
@click.option(
25+
"--alignment-type",
26+
type=str,
27+
default="word",
28+
help="Type of alignment to add (default = `word`).",
29+
)
30+
@click.option(
31+
"--match-channel/--no-match-channel",
32+
default=False,
33+
help="Whether to match channel between CTM and SupervisionSegment (default = False).",
34+
)
35+
@click.option(
36+
"--verbose",
37+
"-v",
38+
is_flag=True,
39+
default=False,
40+
help="Whether to print verbose output.",
41+
)
42+
def with_alignment_from_ctm(
43+
in_supervision_manifest: Pathlike,
44+
out_supervision_manifest: Pathlike,
45+
ctm_file: Pathlike,
46+
alignment_type: str,
47+
match_channel: bool,
48+
verbose: bool,
49+
):
50+
"""
51+
Add alignments from CTM file to the supervision set.
52+
53+
:param in_supervision_manifest: Path to input supervision manifest.
54+
:param out_supervision_manifest: Path to output supervision manifest.
55+
:param ctm_file: Path to CTM file.
56+
:param alignment_type: Alignment type (optional, default = `word`).
57+
:param match_channel: if True, also match channel between CTM and SupervisionSegment
58+
:param verbose: Whether to print verbose output.
59+
:return: A new SupervisionSet with AlignmentItem objects added to the segments.
60+
"""
61+
supervisions = load_manifest_lazy_or_eager(in_supervision_manifest, SupervisionSet)
62+
supervisions = supervisions.with_alignment_from_ctm(
63+
ctm_file=ctm_file,
64+
type=alignment_type,
65+
match_channel=match_channel,
66+
verbose=verbose,
67+
)
68+
with SupervisionSet.open_writer(out_supervision_manifest, overwrite=True) as writer:
69+
supervisions = (
70+
tqdm(supervisions, desc="Writing supervisions") if verbose else supervisions
71+
)
72+
for s in supervisions:
73+
writer.write(s)

lhotse/cut/base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,16 @@ def trim_to_alignments(
510510
self,
511511
type: str,
512512
max_pause: Optional[Seconds] = None,
513+
max_segment_duration: Optional[Seconds] = None,
513514
delimiter: str = " ",
514515
keep_all_channels: bool = False,
515516
) -> "CutSet": # noqa: F821
516517
"""
517518
Splits the current :class:`.Cut` into its constituent alignment items (:class:`.AlignmentItem`).
518519
These cuts have identical start times and durations as the alignment item. Additionally,
519520
the `max_pause` option can be used to merge alignment items that are separated by a pause
520-
shorter than `max_pause`.
521+
shorter than `max_pause`. If `max_segment_duration` is specified, we will keep merging
522+
consecutive segments until the duration of the merged segment exceeds `max_segment_duration`.
521523
522524
For the case of a multi-channel cut with multiple alignments, we can either trim
523525
while respecting the supervision channels (in which case output cut has the same channels
@@ -531,6 +533,21 @@ def trim_to_alignments(
531533
.. hint:: If a MultiCut is trimmed and the resulting trimmed cut contains a single channel,
532534
we convert it to a MonoCut.
533535
536+
.. hint:: If you have a Cut with multiple supervision segments and you want to trim it to
537+
the word-level alignment, you can use the :meth:`.Cut.merge_supervisions` method
538+
first to merge the supervisions into a single one, followed by the
539+
:meth:`.Cut.trim_to_alignments` method. For example::
540+
541+
>>> cut = cut.merge_supervisions(type='word', delimiter=' ')
542+
>>> cut = cut.trim_to_alignments(type='word', max_pause=1.0)
543+
544+
.. hint:: The above technique can also be used to segment long cuts into roughly equal
545+
duration segments, while respecting alignment boundaries. For example, to split a
546+
Cut into 10s segments, you can do::
547+
548+
>>> cut = cut.merge_supervisions(type='word', delimiter=' ')
549+
>>> cut = cut.trim_to_alignments(type='word', max_pause=10.0, max_segment_duration=10.0)
550+
534551
:param type: The type of the alignment to trim to (e.g. "word").
535552
:param max_pause: The maximum pause allowed between the alignments to merge them. If ``None``,
536553
no merging will be performed. [default: None]
@@ -546,6 +563,10 @@ def trim_to_alignments(
546563
# Set to a negative value so that no merging is performed.
547564
max_pause = -1.0
548565

566+
if max_segment_duration is None:
567+
# Set to the cut duration so that resulting segments are always smaller.
568+
max_segment_duration = self.duration
569+
549570
# For the implementation, we first create new supervisions for the cut, and then
550571
# use the `trim_to_supervisions` method to do the actual trimming.
551572
new_supervisions = []
@@ -561,14 +582,20 @@ def trim_to_alignments(
561582
# Merge the alignments if needed. We also keep track of the indices of the
562583
# merged alignments in the original list. This is needed to create the
563584
# `alignment` field in the new supervisions.
585+
# NOTE: We use the `AlignmentItem` class here for convenience --- the merged
586+
# alignments are not actual alignment items, but rather just a way to keep
587+
# track of merged segments.
564588
merged_alignments = [(alignments[0], [0])]
565589
for i, item in enumerate(alignments[1:]):
566590
# If alignment item is blank, skip it. Sometimes, blank alignment items
567591
# are used to denote pauses in the utterance.
568592
if item.symbol.strip() == "":
569593
continue
570594
prev_item, prev_indices = merged_alignments[-1]
571-
if item.start - prev_item.end <= max_pause:
595+
if (
596+
item.start - prev_item.end <= max_pause
597+
and item.end - prev_item.start <= max_segment_duration
598+
):
572599
new_item = AlignmentItem(
573600
symbol=delimiter.join([prev_item.symbol, item.symbol]),
574601
start=prev_item.start,

lhotse/cut/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ def filter_supervisions(
10501050
@abstractmethod
10511051
def merge_supervisions(
10521052
self,
1053+
merge_policy: str = "delimiter",
10531054
custom_merge_fn: Optional[Callable[[str, Iterable[Any]], Any]] = None,
10541055
**kwargs,
10551056
) -> "DataCut":

lhotse/cut/mixed.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,37 +1341,46 @@ def map_supervisions(
13411341
return new_mixed_cut
13421342

13431343
def merge_supervisions(
1344-
self, custom_merge_fn: Optional[Callable[[str, Iterable[Any]], Any]] = None
1344+
self,
1345+
merge_policy: str = "delimiter",
1346+
custom_merge_fn: Optional[Callable[[str, Iterable[Any]], Any]] = None,
13451347
) -> "MixedCut":
13461348
"""
13471349
Return a copy of the cut that has all of its supervisions merged into
13481350
a single segment.
13491351
13501352
The new start is the start of the earliest superivion, and the new duration
1351-
is a minimum spanning duration for all the supervisions.
1352-
1353-
The text fields are concatenated with a whitespace, and all other string fields
1354-
(including IDs) are prefixed with "cat#" and concatenated with a hash symbol "#".
1355-
This is also applied to ``custom`` fields. Fields with a ``None`` value are omitted.
1353+
is a minimum spanning duration for all the supervisions. The text fields are
1354+
concatenated with a whitespace.
13561355
13571356
.. note:: If you're using individual tracks of a mixed cut, note that this transform
13581357
drops all the supervisions in individual tracks and assigns the merged supervision
13591358
in the first :class:`.DataCut` found in ``self.tracks``.
13601359
1360+
:param merge_policy: one of "keep_first" or "delimiter". If "keep_first", we
1361+
keep only the first segment's field value, otherwise all string fields
1362+
(including IDs) are prefixed with "cat#" and concatenated with a hash symbol "#".
1363+
This is also applied to ``custom`` fields. Fields with a ``None`` value are omitted.
13611364
:param custom_merge_fn: a function that will be called to merge custom fields values.
13621365
We expect ``custom_merge_fn`` to handle all possible custom keys.
13631366
When not provided, we will treat all custom values as strings.
13641367
It will be called roughly like:
13651368
``custom_merge_fn(custom_key, [s.custom[custom_key] for s in sups])``
13661369
"""
1370+
merge_func_ = partial(
1371+
merge_items_with_delimiter,
1372+
delimiter="#",
1373+
return_first=(merge_policy == "keep_first"),
1374+
)
1375+
13671376
# "m" stands for merged in variable names below
13681377

13691378
if custom_merge_fn is not None:
13701379
# Merge custom fields with the user-provided function.
13711380
merge_custom = custom_merge_fn
13721381
else:
13731382
# Merge the string representations of custom fields.
1374-
merge_custom = lambda k, vs: merge_items_with_delimiter(map(str, vs))
1383+
merge_custom = lambda k, vs: merge_func_(map(str, vs))
13751384

13761385
sups = sorted(self.supervisions, key=lambda s: s.start)
13771386

@@ -1400,18 +1409,18 @@ def merge_supervisions(
14001409
)
14011410

14021411
msup = SupervisionSegment(
1403-
id=merge_items_with_delimiter(s.id for s in sups),
1412+
id=merge_func_(s.id for s in sups),
14041413
# Make merged recording_id is a mix of recording_ids.
1405-
recording_id=merge_items_with_delimiter(s.recording_id for s in sups),
1414+
recording_id=merge_func_(s.recording_id for s in sups),
14061415
start=mstart,
14071416
duration=mduration,
14081417
# Hardcode -1 to indicate no specific channel, as the supervisions might have
14091418
# come from different channels in their original recordings.
14101419
channel=-1,
14111420
text=" ".join(s.text for s in sups if s.text),
1412-
speaker=merge_items_with_delimiter(s.speaker for s in sups if s.speaker),
1413-
language=merge_items_with_delimiter(s.language for s in sups if s.language),
1414-
gender=merge_items_with_delimiter(s.gender for s in sups if s.gender),
1421+
speaker=merge_func_(s.speaker for s in sups if s.speaker),
1422+
language=merge_func_(s.language for s in sups if s.language),
1423+
gender=merge_func_(s.gender for s in sups if s.gender),
14151424
custom={
14161425
k: merge_custom(
14171426
k,

lhotse/cut/mono.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import warnings
44
from dataclasses import dataclass
5-
from functools import reduce
5+
from functools import partial, reduce
66
from operator import add
77
from typing import Any, Callable, Iterable, List, Optional
88

@@ -194,33 +194,42 @@ def reverb_rir(
194194
)
195195

196196
def merge_supervisions(
197-
self, custom_merge_fn: Optional[Callable[[str, Iterable[Any]], Any]] = None
197+
self,
198+
merge_policy: str = "delimiter",
199+
custom_merge_fn: Optional[Callable[[str, Iterable[Any]], Any]] = None,
198200
) -> "MonoCut":
199201
"""
200202
Return a copy of the cut that has all of its supervisions merged into
201203
a single segment.
202204
203205
The new start is the start of the earliest superivion, and the new duration
204-
is a minimum spanning duration for all the supervisions.
205-
206-
The text fields are concatenated with a whitespace, and all other string fields
207-
(including IDs) are prefixed with "cat#" and concatenated with a hash symbol "#".
208-
This is also applied to ``custom`` fields. Fields with a ``None`` value are omitted.
206+
is a minimum spanning duration for all the supervisions. The text fields of
207+
all segments are concatenated with a whitespace.
209208
209+
:param merge_policy: one of "keep_first" or "delimiter". If "keep_first", we
210+
keep only the first segment's field value, otherwise all string fields
211+
(including IDs) are prefixed with "cat#" and concatenated with a hash symbol "#".
212+
This is also applied to ``custom`` fields. Fields with a ``None`` value are omitted.
210213
:param custom_merge_fn: a function that will be called to merge custom fields values.
211214
We expect ``custom_merge_fn`` to handle all possible custom keys.
212215
When not provided, we will treat all custom values as strings.
213216
It will be called roughly like:
214217
``custom_merge_fn(custom_key, [s.custom[custom_key] for s in sups])``
215218
"""
219+
merge_func_ = partial(
220+
merge_items_with_delimiter,
221+
delimiter="#",
222+
return_first=(merge_policy == "keep_first"),
223+
)
224+
216225
# "m" stands for merged in variable names below
217226

218227
if custom_merge_fn is not None:
219228
# Merge custom fields with the user-provided function.
220229
merge_custom = custom_merge_fn
221230
else:
222231
# Merge the string representations of custom fields.
223-
merge_custom = lambda k, vs: merge_items_with_delimiter(map(str, vs))
232+
merge_custom = lambda k, vs: merge_func_(map(str, vs))
224233

225234
sups = sorted(self.supervisions, key=lambda s: s.start)
226235

@@ -249,15 +258,15 @@ def merge_supervisions(
249258
)
250259

251260
msup = SupervisionSegment(
252-
id=merge_items_with_delimiter(s.id for s in sups),
261+
id=merge_func_(s.id for s in sups),
253262
recording_id=sups[0].recording_id,
254263
start=mstart,
255264
duration=mduration,
256265
channel=sups[0].channel,
257266
text=" ".join(s.text for s in sups if s.text),
258-
speaker=merge_items_with_delimiter(s.speaker for s in sups if s.speaker),
259-
language=merge_items_with_delimiter(s.language for s in sups if s.language),
260-
gender=merge_items_with_delimiter(s.gender for s in sups if s.gender),
267+
speaker=merge_func_(s.speaker for s in sups if s.speaker),
268+
language=merge_func_(s.language for s in sups if s.language),
269+
gender=merge_func_(s.gender for s in sups if s.gender),
261270
custom={
262271
k: merge_custom(
263272
k,

0 commit comments

Comments
 (0)