diff --git a/lhotse/cut/mono.py b/lhotse/cut/mono.py index 75bbe23f7..0fa3f5abd 100644 --- a/lhotse/cut/mono.py +++ b/lhotse/cut/mono.py @@ -389,3 +389,13 @@ def from_dict(data: dict) -> "MonoCut": recording=recording, supervisions=[SupervisionSegment.from_dict(s) for s in supervision_infos], ) + + def __lt__(self, other): + token_length = len(self.supervisions[0].custom['tokens']['text']) + other_token_length = len(other.supervisions[0].custom['tokens']['text']) + if token_length - other_token_length >= 3: + return True + elif token_length - other_token_length <= -3: + return False + else: + return self.duration > other.duration diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index ca43b36fc..6e33e5f5f 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -437,6 +437,7 @@ class TimeConstraint(SamplingConstraint): current: Union[int, Seconds] = 0 num_cuts: int = 0 longest_seen: Union[int, float] = 0 + longest_seen_text: Union[int, float] = 0 quadratic_duration: Optional[Seconds] = None def __post_init__(self) -> None: @@ -454,18 +455,23 @@ def add(self, example: Cut) -> None: selecting the right property from the input ``cut`` object. """ if self.max_duration is not None: - duration = self._maybe_apply_quadratic_correction(example.duration) - self.current += duration + duration = self._maybe_apply_quadratic_correction(example.duration, quadratic_expand_value=1.5) + try: + text_duration = self._maybe_apply_quadratic_correction(len(example.supervisions[0].custom['tokens']['text']) * 0.046) + except: + text_duration = 0 + self.current += duration + text_duration self.longest_seen = max(self.longest_seen, duration) + self.longest_seen_text = max(self.longest_seen_text, text_duration) self.num_cuts += 1 - def _maybe_apply_quadratic_correction(self, duration: Seconds) -> Seconds: + def _maybe_apply_quadratic_correction(self, duration: Seconds, quadratic_expand_value: Seconds = 1.0) -> Seconds: if self.quadratic_duration is None: return duration # For the quadratic complexity case, we add a term that accounts for # extra memory occupied by the model. The 1/quadratic_duration term causes # the effective duration to be doubled when it's equal to quadratic_duration. - return duration + (duration**2) / self.quadratic_duration + return duration + ((duration * quadratic_expand_value)**2) / self.quadratic_duration def exceeded(self) -> bool: """Is the constraint exceeded or not.""" @@ -473,7 +479,7 @@ def exceeded(self) -> bool: return True if self.max_duration is None: return False - effective_duration = self.num_cuts * self.longest_seen + effective_duration = self.num_cuts * (self.longest_seen + self.longest_seen_text) return effective_duration > self.max_duration def close_to_exceeding(self) -> bool: @@ -487,7 +493,7 @@ def close_to_exceeding(self) -> bool: return True if self.max_duration is not None: - effective_duration = (self.num_cuts + 1) * self.longest_seen + effective_duration = (self.num_cuts + 1) * (self.longest_seen + self.longest_seen_text) return effective_duration > self.max_duration return False @@ -499,6 +505,7 @@ def reset(self) -> None: self.current = 0 self.num_cuts = 0 self.longest_seen = 0 + self.longest_seen_text = 0 def measure_length(self, example: Cut) -> float: return example.duration @@ -512,6 +519,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.current = state_dict.pop("current") self.num_cuts = state_dict.pop("num_cuts") self.longest_seen = state_dict.pop("longest_seen", 0) + self.longest_seen_text = state_dict.pop("longest_seen_text", 0) self.quadratic_duration = state_dict.pop("quadratic_duration", None) # backward compatibility state_dict.pop("strict", None) @@ -537,6 +545,7 @@ def __add__(self, other: "TimeConstraint") -> "TimeConstraint": current=self.current + other.current, num_cuts=self.num_cuts + other.num_cuts, longest_seen=max(self.longest_seen, other.longest_seen), + longest_seen_text=max(self.longest_seen_text, other.longest_seen_text), quadratic_duration=self.quadratic_duration, ) diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index cf4da23f2..69f354529 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -18,6 +18,7 @@ Tuple, Union, ) +import heapq import numpy as np import torch @@ -458,8 +459,8 @@ def __init__( ) # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). - self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [ - deque() for _ in range(len(duration_bins) + 1) + self.buckets: List[List[Union[Cut, Tuple[Cut, ...], Tuple[int, Cut]]]] = [ + [] for _ in range(len(duration_bins) + 1) ] def __iter__(self) -> Generator[CutSet, None, None]: @@ -485,6 +486,10 @@ def __iter__(self) -> Generator[CutSet, None, None]: maybe_shuffled = pick_at_random( maybe_shuffled, rng=self.rng, out_indexes_used=indexes_used ) + else: + maybe_shuffled = pick_at_min_heap( + maybe_shuffled + ) # Sample one batch from that bucket and yield it to the caller. batcher = DurationBatcher( maybe_shuffled, @@ -504,9 +509,8 @@ def __iter__(self) -> Generator[CutSet, None, None]: for idx in indexes_used: del sampling_bucket[idx] else: - # No shuffling, remove first N - for _ in range(batch_size): - sampling_bucket.popleft() + # already pop by heappop + pass # Fetch new cuts and add them to appropriate buckets. self._collect_cuts_in_buckets(batch_size) except StopIteration: @@ -609,7 +613,7 @@ def _collect_cuts_in_buckets(self, n_cuts: int) -> None: cuts[0] if isinstance(cuts, tuple) else cuts ) bucket_idx = bisect_right(self.duration_bins, duration) - self.buckets[bucket_idx].append(cuts) + heapq.heappush(self.buckets[bucket_idx], cuts) except StopIteration: pass @@ -629,6 +633,12 @@ def pick_at_random( out_indexes_used.append(idx) yield bucket[idx] +def pick_at_min_heap( + bucket: Sequence[Union[Cut, Tuple[Cut, ...]]], +) -> Generator[Union[Cut, Tuple[Cut, ...]], None, None]: + while bucket: + yield heapq.heappop(bucket) + class BucketsDontHaveEnoughData(Exception): pass diff --git a/lhotse/shar/readers/lazy.py b/lhotse/shar/readers/lazy.py index c1b0b74a3..3558866d0 100644 --- a/lhotse/shar/readers/lazy.py +++ b/lhotse/shar/readers/lazy.py @@ -239,7 +239,7 @@ def __iter__(self): map_fns = self._maybe_shuffle_shards(map_fns) map_fns = self._maybe_split_for_dataloading(map_fns) - for shard, cut_map_fn in zip(shards, map_fns): + for shard_idx, (shard, cut_map_fn) in enumerate(zip(shards, map_fns)): # Iterate over cuts for the current shard cuts = LazyManifestIterator(shard["cuts"]) @@ -249,33 +249,59 @@ def __iter__(self): } # Open every tarfile/jsonl so it's ready for streaming - field_iters = { - field: TarIterator(path) - if extension_contains(".tar", path) - else _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field) - for field, path in field_paths.items() - } + field_iters = {} + for field, path in field_paths.items(): + if path is None: + pass + elif extension_contains(".tar", path): + field_iters[field] = TarIterator(path) + else: + for item, pseudo_path in _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field): + field_iters[field] = TarIterator(pseudo_path) + # field_iters = { + # field: TarIterator(path) + # if extension_contains(".tar", path) + # else _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field) + # for field, path in field_paths.items() + # } # *field_data contains all fields for a single cut (recording, features, array, etc.) - for cut, *field_data in zip(cuts, *field_iters.values()): - for (field, (maybe_manifest, data_path)) in zip( + try: + for cut in self.generate_shard_cuts(cuts, field_iters): + cut.shard_origin = shard["cuts"] + cut.shar_epoch = self.epoch + if cut_map_fn is not None: + cut = cut_map_fn(cut) + yield cut + except OSError as e: + print("WARNING(LazySharIterator): OSError: {}".format(e)) + if shard_idx % 100 == 0: + print("INFO(LazySharIterator): EPOCH({}) Shard {}/{} done".format(self.epoch, shard_idx+1, len(shards))) + + for cut in self.generate_customized_cuts(): + yield cut + + self.epoch += 1 + + + def generate_customized_cuts(self): + for cut in []: + yield cut + + + def generate_shard_cuts(self, cuts, field_iters): + for cut, *field_data in zip(cuts, *field_iters.values()): + for (field, (maybe_manifest, data_path)) in zip( field_iters.keys(), field_data, - ): - if maybe_manifest is None: - continue # No value available for the current field for this cut. - assert ( + ): + if maybe_manifest is None: + continue # No value available for the current field for this cut. + assert ( str(data_path.parent / data_path.stem) == cut.id - ), f"Mismatched IDs: cut ID is '{cut.id}' but found data with name '{data_path}' fsor field {field}" - setattr(cut, field, maybe_manifest) - - cut.shard_origin = shard["cuts"] - cut.shar_epoch = self.epoch - if cut_map_fn is not None: - cut = cut_map_fn(cut) - yield cut - - self.epoch += 1 + ), f"Mismatched IDs: cut ID is '{cut.id}' but found data with name '{data_path}' fsor field {field}" + setattr(cut, field, maybe_manifest) + yield cut def __len__(self) -> int: if self._len is None: diff --git a/lhotse/shar/readers/tar.py b/lhotse/shar/readers/tar.py index a6c2ecf86..49ae7060c 100644 --- a/lhotse/shar/readers/tar.py +++ b/lhotse/shar/readers/tar.py @@ -56,7 +56,11 @@ def iterate_tarfile_pairwise( if len(result) == 2: yield tuple(result) result = [] - result.append(parse_tarinfo(tarinfo, tar_file)) + try: + result.append(parse_tarinfo(tarinfo, tar_file)) + except: + print("Failed to parse tar info of tar file {}".format(tar_file)) + return if len(result) == 2: yield tuple(result)