Skip to content

Commit c80fc07

Browse files
Speech to text translation utilizing 3-way data (#1099)
1 parent aa073f6 commit c80fc07

6 files changed

Lines changed: 626 additions & 0 deletions

File tree

docs/corpus.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ a CLI tool that create the manifests given a corpus directory.
123123
- :func:`lhotse.recipes.prepare_himia`
124124
* - ICSI
125125
- :func:`lhotse.recipes.prepare_icsi`
126+
* - IWSLT22_Ta
127+
- :func:`lhotse.recipes.prepare_iwslt22_ta`
126128
* - KeSpeech
127129
- :func:`lhotse.recipes.prepare_kespeech`
128130
* - L2 Arctic

lhotse/bin/modes/recipes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .hifitts import *
4141
from .himia import *
4242
from .icsi import *
43+
from .iwslt22_ta import *
4344
from .kespeech import *
4445
from .l2_arctic import *
4546
from .libricss import *
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import click
4+
5+
from lhotse.bin.modes import prepare
6+
from lhotse.recipes.iwslt22_ta import prepare_iwslt22_ta
7+
from lhotse.utils import Pathlike
8+
9+
10+
@prepare.command(context_settings=dict(show_default=True))
11+
@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True))
12+
@click.argument("splits", type=click.Path(exists=True, dir_okay=True))
13+
@click.argument("output_dir", type=click.Path())
14+
@click.option(
15+
"-j",
16+
"--num-jobs",
17+
type=int,
18+
default=1,
19+
help="How many threads to use (can give good speed-ups with slow disks).",
20+
)
21+
@click.option(
22+
"--normalize-text",
23+
default=False,
24+
help="Whether to perform additional text cleaning and normalization from https://aclanthology.org/2022.iwslt-1.29.pdf.",
25+
)
26+
@click.option(
27+
"--langs",
28+
default="",
29+
help="Comma-separated list of language abbreviations for source and target languages",
30+
)
31+
def iwslt22_ta(
32+
corpus_dir: Pathlike,
33+
splits: Pathlike,
34+
output_dir: Pathlike,
35+
normalize_text: bool,
36+
langs: str,
37+
num_jobs: int,
38+
):
39+
"""
40+
IWSLT_2022 data preparation.
41+
\b
42+
This is conversational telephone speech collected as 8kHz-sampled data.
43+
The catalog number LDC2022E01 corresponds to the train, dev, and test1
44+
splits of the iwslt2022 shared task.
45+
To obtaining this data your institution needs to have an LDC subscription.
46+
You also should download the predined splits with
47+
git clone https://github.com/kevinduh/iwslt22-dialect.git
48+
"""
49+
langs_list = langs.split(",")
50+
prepare_iwslt22_ta(
51+
corpus_dir,
52+
splits,
53+
output_dir=output_dir,
54+
num_jobs=num_jobs,
55+
clean=normalize_text,
56+
langs=langs_list,
57+
)
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)
2+
3+
from typing import Callable, Dict, List, Union
4+
5+
import torch
6+
from torch.utils.data.dataloader import default_collate
7+
8+
from lhotse.cut import CutSet
9+
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
10+
from lhotse.dataset.speech_recognition import validate_for_asr
11+
from lhotse.utils import compute_num_frames, ifnone
12+
from lhotse.workarounds import Hdf5MemoryIssueFix
13+
14+
15+
class K2Speech2TextTranslationDataset(torch.utils.data.Dataset):
16+
"""
17+
The PyTorch Dataset for the speech translation task using k2 library.
18+
19+
This dataset expects to be queried with lists of cut IDs,
20+
for which it loads features and automatically collates/batches them.
21+
22+
To use it with a PyTorch DataLoader, set ``batch_size=None``
23+
and provide a :class:`SimpleCutSampler` sampler.
24+
25+
Each item in this dataset is a dict of:
26+
27+
.. code-block::
28+
29+
{
30+
'inputs': float tensor with shape determined by :attr:`input_strategy`:
31+
- single-channel:
32+
- features: (B, T, F)
33+
- audio: (B, T)
34+
- multi-channel: currently not supported
35+
'supervisions': [
36+
{
37+
'sequence_idx': Tensor[int] of shape (S,)
38+
'src_text': List[str] of len S
39+
'tgt_text': List[str] of len S
40+
41+
# For feature input strategies
42+
'start_frame': Tensor[int] of shape (S,)
43+
'num_frames': Tensor[int] of shape (S,)
44+
45+
# For audio input strategies
46+
'start_sample': Tensor[int] of shape (S,)
47+
'num_samples': Tensor[int] of shape (S,)
48+
49+
# Optionally, when return_cuts=True
50+
'cut': List[AnyCut] of len S
51+
}
52+
]
53+
}
54+
55+
Dimension symbols legend:
56+
* ``B`` - batch size (number of Cuts)
57+
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
58+
* ``T`` - number of frames of the longest Cut
59+
* ``F`` - number of features
60+
61+
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
62+
"""
63+
64+
def __init__(
65+
self,
66+
return_cuts: bool = False,
67+
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
68+
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
69+
input_strategy: BatchIO = PrecomputedFeatures(),
70+
):
71+
"""
72+
K2 Speech2TextTranslation IterableDataset constructor.
73+
74+
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
75+
objects used to create that batch.
76+
:param cut_transforms: A list of transforms to be applied on each sampled batch,
77+
before converting cuts to an input representation (audio/features).
78+
Examples: cut concatenation, noise cuts mixing, etc.
79+
:param input_transforms: A list of transforms to be applied on each sampled batch,
80+
after the cuts are converted to audio/features.
81+
Examples: normalization, SpecAugment, etc.
82+
:param input_strategy: Converts cuts into a collated batch of audio/features.
83+
By default, reads pre-computed features from disk.
84+
"""
85+
super().__init__()
86+
# Initialize the fields
87+
self.return_cuts = return_cuts
88+
self.cut_transforms = ifnone(cut_transforms, [])
89+
self.input_transforms = ifnone(input_transforms, [])
90+
self.input_strategy = input_strategy
91+
92+
# This attribute is a workaround to constantly growing HDF5 memory
93+
# throughout the epoch. It regularly closes open file handles to
94+
# reset the internal HDF5 caches.
95+
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
96+
97+
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
98+
"""
99+
Return a new batch, with the batch size automatically determined using the constraints
100+
of max_frames and max_cuts.
101+
"""
102+
validate_for_asr(cuts)
103+
self.hdf5_fix.update()
104+
105+
# Sort the cuts by duration so that the first one determines the batch time dimensions.
106+
cuts = cuts.sort_by_duration(ascending=False)
107+
108+
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
109+
# the supervision boundaries.
110+
for tnfm in self.cut_transforms:
111+
cuts = tnfm(cuts)
112+
113+
# Sort the cuts again after transforms
114+
cuts = cuts.sort_by_duration(ascending=False)
115+
116+
# Get a tensor with batched feature matrices, shape (B, T, F)
117+
# Collation performs auto-padding, if necessary.
118+
input_tpl = self.input_strategy(cuts)
119+
if len(input_tpl) == 3:
120+
# An input strategy with fault tolerant audio reading mode.
121+
# "cuts" may be a subset of the original "cuts" variable,
122+
# that only has cuts for which we succesfully read the audio.
123+
inputs, _, cuts = input_tpl
124+
else:
125+
inputs, _ = input_tpl
126+
127+
# Get a dict of tensors that encode the positional information about supervisions
128+
# in the batch of feature matrices. The tensors are named "sequence_idx",
129+
# "start_frame/sample" and "num_frames/samples".
130+
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
131+
132+
# Apply all available transforms on the inputs, i.e. either audio or features.
133+
# This could be feature extraction, global MVN, SpecAugment, etc.
134+
segments = torch.stack(list(supervision_intervals.values()), dim=1)
135+
for tnfm in self.input_transforms:
136+
inputs = tnfm(inputs, supervision_segments=segments)
137+
batch = {
138+
"inputs": inputs,
139+
"supervisions": default_collate(
140+
[
141+
{
142+
"text": supervision.text,
143+
"tgt_text": supervision.custom["translated_text"],
144+
}
145+
for sequence_idx, cut in enumerate(cuts)
146+
for supervision in cut.supervisions
147+
]
148+
),
149+
}
150+
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
151+
batch["supervisions"].update(supervision_intervals)
152+
if self.return_cuts:
153+
batch["supervisions"]["cut"] = [
154+
cut for cut in cuts for sup in cut.supervisions
155+
]
156+
157+
has_word_alignments = all(
158+
s.alignment is not None and "word" in s.alignment
159+
for c in cuts
160+
for s in c.supervisions
161+
)
162+
if has_word_alignments:
163+
# TODO: might need to refactor BatchIO API to move the following conditional logic
164+
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
165+
# that returns either num_frames or num_samples depending on the strategy).
166+
words, starts, ends = [], [], []
167+
frame_shift = cuts[0].frame_shift
168+
sampling_rate = cuts[0].sampling_rate
169+
if frame_shift is None:
170+
try:
171+
frame_shift = self.input_strategy.extractor.frame_shift
172+
except AttributeError:
173+
raise ValueError(
174+
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
175+
)
176+
for c in cuts:
177+
for s in c.supervisions:
178+
words.append([aliword.symbol for aliword in s.alignment["word"]])
179+
starts.append(
180+
[
181+
compute_num_frames(
182+
aliword.start,
183+
frame_shift=frame_shift,
184+
sampling_rate=sampling_rate,
185+
)
186+
for aliword in s.alignment["word"]
187+
]
188+
)
189+
ends.append(
190+
[
191+
compute_num_frames(
192+
aliword.end,
193+
frame_shift=frame_shift,
194+
sampling_rate=sampling_rate,
195+
)
196+
for aliword in s.alignment["word"]
197+
]
198+
)
199+
batch["supervisions"]["word"] = words
200+
batch["supervisions"]["word_start"] = starts
201+
batch["supervisions"]["word_end"] = ends
202+
203+
return batch

lhotse/recipes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .hifitts import download_hifitts, prepare_hifitts
3838
from .himia import download_himia, prepare_himia
3939
from .icsi import download_icsi, prepare_icsi
40+
from .iwslt22_ta import prepare_iwslt22_ta
4041
from .kespeech import prepare_kespeech
4142
from .l2_arctic import prepare_l2_arctic
4243
from .libricss import download_libricss, prepare_libricss

0 commit comments

Comments
 (0)