|
| 1 | +# Copyright 2023 Johns Hopkins (authors: Amir Hussein) |
| 2 | + |
| 3 | +from typing import Callable, Dict, List, Union |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.utils.data.dataloader import default_collate |
| 7 | + |
| 8 | +from lhotse.cut import CutSet |
| 9 | +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures |
| 10 | +from lhotse.dataset.speech_recognition import validate_for_asr |
| 11 | +from lhotse.utils import compute_num_frames, ifnone |
| 12 | +from lhotse.workarounds import Hdf5MemoryIssueFix |
| 13 | + |
| 14 | + |
| 15 | +class K2Speech2TextTranslationDataset(torch.utils.data.Dataset): |
| 16 | + """ |
| 17 | + The PyTorch Dataset for the speech translation task using k2 library. |
| 18 | +
|
| 19 | + This dataset expects to be queried with lists of cut IDs, |
| 20 | + for which it loads features and automatically collates/batches them. |
| 21 | +
|
| 22 | + To use it with a PyTorch DataLoader, set ``batch_size=None`` |
| 23 | + and provide a :class:`SimpleCutSampler` sampler. |
| 24 | +
|
| 25 | + Each item in this dataset is a dict of: |
| 26 | +
|
| 27 | + .. code-block:: |
| 28 | +
|
| 29 | + { |
| 30 | + 'inputs': float tensor with shape determined by :attr:`input_strategy`: |
| 31 | + - single-channel: |
| 32 | + - features: (B, T, F) |
| 33 | + - audio: (B, T) |
| 34 | + - multi-channel: currently not supported |
| 35 | + 'supervisions': [ |
| 36 | + { |
| 37 | + 'sequence_idx': Tensor[int] of shape (S,) |
| 38 | + 'src_text': List[str] of len S |
| 39 | + 'tgt_text': List[str] of len S |
| 40 | +
|
| 41 | + # For feature input strategies |
| 42 | + 'start_frame': Tensor[int] of shape (S,) |
| 43 | + 'num_frames': Tensor[int] of shape (S,) |
| 44 | +
|
| 45 | + # For audio input strategies |
| 46 | + 'start_sample': Tensor[int] of shape (S,) |
| 47 | + 'num_samples': Tensor[int] of shape (S,) |
| 48 | +
|
| 49 | + # Optionally, when return_cuts=True |
| 50 | + 'cut': List[AnyCut] of len S |
| 51 | + } |
| 52 | + ] |
| 53 | + } |
| 54 | +
|
| 55 | + Dimension symbols legend: |
| 56 | + * ``B`` - batch size (number of Cuts) |
| 57 | + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) |
| 58 | + * ``T`` - number of frames of the longest Cut |
| 59 | + * ``F`` - number of features |
| 60 | +
|
| 61 | + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + return_cuts: bool = False, |
| 67 | + cut_transforms: List[Callable[[CutSet], CutSet]] = None, |
| 68 | + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, |
| 69 | + input_strategy: BatchIO = PrecomputedFeatures(), |
| 70 | + ): |
| 71 | + """ |
| 72 | + K2 Speech2TextTranslation IterableDataset constructor. |
| 73 | +
|
| 74 | + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut |
| 75 | + objects used to create that batch. |
| 76 | + :param cut_transforms: A list of transforms to be applied on each sampled batch, |
| 77 | + before converting cuts to an input representation (audio/features). |
| 78 | + Examples: cut concatenation, noise cuts mixing, etc. |
| 79 | + :param input_transforms: A list of transforms to be applied on each sampled batch, |
| 80 | + after the cuts are converted to audio/features. |
| 81 | + Examples: normalization, SpecAugment, etc. |
| 82 | + :param input_strategy: Converts cuts into a collated batch of audio/features. |
| 83 | + By default, reads pre-computed features from disk. |
| 84 | + """ |
| 85 | + super().__init__() |
| 86 | + # Initialize the fields |
| 87 | + self.return_cuts = return_cuts |
| 88 | + self.cut_transforms = ifnone(cut_transforms, []) |
| 89 | + self.input_transforms = ifnone(input_transforms, []) |
| 90 | + self.input_strategy = input_strategy |
| 91 | + |
| 92 | + # This attribute is a workaround to constantly growing HDF5 memory |
| 93 | + # throughout the epoch. It regularly closes open file handles to |
| 94 | + # reset the internal HDF5 caches. |
| 95 | + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) |
| 96 | + |
| 97 | + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: |
| 98 | + """ |
| 99 | + Return a new batch, with the batch size automatically determined using the constraints |
| 100 | + of max_frames and max_cuts. |
| 101 | + """ |
| 102 | + validate_for_asr(cuts) |
| 103 | + self.hdf5_fix.update() |
| 104 | + |
| 105 | + # Sort the cuts by duration so that the first one determines the batch time dimensions. |
| 106 | + cuts = cuts.sort_by_duration(ascending=False) |
| 107 | + |
| 108 | + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts |
| 109 | + # the supervision boundaries. |
| 110 | + for tnfm in self.cut_transforms: |
| 111 | + cuts = tnfm(cuts) |
| 112 | + |
| 113 | + # Sort the cuts again after transforms |
| 114 | + cuts = cuts.sort_by_duration(ascending=False) |
| 115 | + |
| 116 | + # Get a tensor with batched feature matrices, shape (B, T, F) |
| 117 | + # Collation performs auto-padding, if necessary. |
| 118 | + input_tpl = self.input_strategy(cuts) |
| 119 | + if len(input_tpl) == 3: |
| 120 | + # An input strategy with fault tolerant audio reading mode. |
| 121 | + # "cuts" may be a subset of the original "cuts" variable, |
| 122 | + # that only has cuts for which we succesfully read the audio. |
| 123 | + inputs, _, cuts = input_tpl |
| 124 | + else: |
| 125 | + inputs, _ = input_tpl |
| 126 | + |
| 127 | + # Get a dict of tensors that encode the positional information about supervisions |
| 128 | + # in the batch of feature matrices. The tensors are named "sequence_idx", |
| 129 | + # "start_frame/sample" and "num_frames/samples". |
| 130 | + supervision_intervals = self.input_strategy.supervision_intervals(cuts) |
| 131 | + |
| 132 | + # Apply all available transforms on the inputs, i.e. either audio or features. |
| 133 | + # This could be feature extraction, global MVN, SpecAugment, etc. |
| 134 | + segments = torch.stack(list(supervision_intervals.values()), dim=1) |
| 135 | + for tnfm in self.input_transforms: |
| 136 | + inputs = tnfm(inputs, supervision_segments=segments) |
| 137 | + batch = { |
| 138 | + "inputs": inputs, |
| 139 | + "supervisions": default_collate( |
| 140 | + [ |
| 141 | + { |
| 142 | + "text": supervision.text, |
| 143 | + "tgt_text": supervision.custom["translated_text"], |
| 144 | + } |
| 145 | + for sequence_idx, cut in enumerate(cuts) |
| 146 | + for supervision in cut.supervisions |
| 147 | + ] |
| 148 | + ), |
| 149 | + } |
| 150 | + # Update the 'supervisions' field with sequence_idx and start/num frames/samples |
| 151 | + batch["supervisions"].update(supervision_intervals) |
| 152 | + if self.return_cuts: |
| 153 | + batch["supervisions"]["cut"] = [ |
| 154 | + cut for cut in cuts for sup in cut.supervisions |
| 155 | + ] |
| 156 | + |
| 157 | + has_word_alignments = all( |
| 158 | + s.alignment is not None and "word" in s.alignment |
| 159 | + for c in cuts |
| 160 | + for s in c.supervisions |
| 161 | + ) |
| 162 | + if has_word_alignments: |
| 163 | + # TODO: might need to refactor BatchIO API to move the following conditional logic |
| 164 | + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), |
| 165 | + # that returns either num_frames or num_samples depending on the strategy). |
| 166 | + words, starts, ends = [], [], [] |
| 167 | + frame_shift = cuts[0].frame_shift |
| 168 | + sampling_rate = cuts[0].sampling_rate |
| 169 | + if frame_shift is None: |
| 170 | + try: |
| 171 | + frame_shift = self.input_strategy.extractor.frame_shift |
| 172 | + except AttributeError: |
| 173 | + raise ValueError( |
| 174 | + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " |
| 175 | + ) |
| 176 | + for c in cuts: |
| 177 | + for s in c.supervisions: |
| 178 | + words.append([aliword.symbol for aliword in s.alignment["word"]]) |
| 179 | + starts.append( |
| 180 | + [ |
| 181 | + compute_num_frames( |
| 182 | + aliword.start, |
| 183 | + frame_shift=frame_shift, |
| 184 | + sampling_rate=sampling_rate, |
| 185 | + ) |
| 186 | + for aliword in s.alignment["word"] |
| 187 | + ] |
| 188 | + ) |
| 189 | + ends.append( |
| 190 | + [ |
| 191 | + compute_num_frames( |
| 192 | + aliword.end, |
| 193 | + frame_shift=frame_shift, |
| 194 | + sampling_rate=sampling_rate, |
| 195 | + ) |
| 196 | + for aliword in s.alignment["word"] |
| 197 | + ] |
| 198 | + ) |
| 199 | + batch["supervisions"]["word"] = words |
| 200 | + batch["supervisions"]["word_start"] = starts |
| 201 | + batch["supervisions"]["word_end"] = ends |
| 202 | + |
| 203 | + return batch |
0 commit comments